├── .github └── workflows │ ├── build.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── doc ├── Makefile └── source │ ├── _templates │ └── classtemplate.rst │ ├── api.rst │ ├── conf.py │ ├── example_notebooks.rst │ ├── examples │ ├── draw_bspline.ipynb │ ├── draw_legendre.ipynb │ ├── factorization_machine.ipynb │ ├── kan_bspline_rat.ipynb │ ├── kan_legendre_rat.ipynb │ └── transformer_mixed_curves.ipynb │ ├── index.rst │ ├── torchcurves.functional.rst │ └── torchcurves.rst ├── logo.png ├── logo_small.png ├── pyproject.toml ├── src └── torchcurves │ ├── __init__.py │ ├── functional │ ├── __init__.py │ ├── _bspline.py │ ├── _legendre.py │ └── _normalization.py │ ├── modules │ ├── __init__.py │ ├── _bspline.py │ ├── _kan_tools.py │ ├── _legendre.py │ └── _normalization.py │ └── types.py ├── tests ├── __init__.py ├── conftest.py ├── test_bspline.py └── test_legendre.py └── uv.lock /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Install uv 19 | uses: astral-sh/setup-uv@v3 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | run: uv python install ${{ matrix.python-version }} 23 | 24 | - name: Install build dependencies 25 | run: | 26 | uv venv 27 | uv sync --all-groups 28 | 29 | - name: Build package 30 | run: | 31 | uv build 32 | 33 | - name: Archive package 34 | uses: actions/upload-artifact@v4 35 | with: 36 | name: wheel_and_source_${{ matrix.python-version }} 37 | path: | 38 | dist 39 | 40 | - name: Build documentation 41 | run: | 42 | sudo apt-get -qq update 43 | sudo apt-get install -y pandoc 44 | cd doc 45 | make html 46 | 47 | - name: Archive documentation 48 | uses: actions/upload-artifact@v4 49 | with: 50 | name: docs_${{ matrix.python-version }} 51 | path: | 52 | doc/build/html 53 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v3 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | run: uv python install ${{ matrix.python-version }} 24 | 25 | - name: Install dependencies 26 | run: | 27 | uv venv 28 | uv sync --all-groups 29 | 30 | - name: Lint with ruff 31 | run: | 32 | uv run ruff check src/ tests/ 33 | 34 | - name: Format with black 35 | run: | 36 | uv run black --check --line-length 120 src/ tests/ 37 | 38 | - name: Type check with mypy 39 | run: | 40 | uv run mypy src/ 41 | 42 | - name: Test with pytest 43 | run: | 44 | uv run pytest tests/ -v --cov=torch_bspline --cov-report=term-missing 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | 24 | # Virtual environments 25 | venv/ 26 | ENV/ 27 | env/ 28 | .venv 29 | 30 | # IDEs 31 | .vscode/ 32 | .idea/ 33 | *.swp 34 | *.swo 35 | *~ 36 | 37 | # Testing 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | .pytest_cache/ 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Jupyter 49 | .ipynb_checkpoints/ 50 | 51 | # Documentation 52 | docs/_build/ 53 | docs/_static/ 54 | docs/_templates/ 55 | 56 | # OS 57 | .DS_Store 58 | Thumbs.db 59 | 60 | # UV 61 | .python-version 62 | .ruff_cache 63 | 64 | # Sphinx 65 | doc/source/generated 66 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | # Ruff version. 12 | rev: v0.9.3 13 | hooks: 14 | # Run the linter. 15 | - id: ruff 16 | types_or: [ python, pyi ] 17 | args: [ --fix ] 18 | # Run the formatter. 19 | - id: ruff-format 20 | types_or: [ python, pyi ] 21 | - repo: https://github.com/astral-sh/uv-pre-commit 22 | rev: 0.5.25 23 | hooks: 24 | - id: uv-lock 25 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.13" 12 | apt_packages: 13 | - pandoc 14 | jobs: 15 | pre_create_environment: 16 | - asdf plugin add uv 17 | - asdf install uv latest 18 | - asdf global uv latest 19 | create_environment: 20 | - uv venv "${READTHEDOCS_VIRTUALENV_PATH}" 21 | install: 22 | - UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --all-groups 23 | 24 | sphinx: 25 | configuration: doc/source/conf.py 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchcurves 2 | 3 | ![torchcurves logo](https://raw.githubusercontent.com/alexshtf/torchcurves/master/logo.png) 4 | 5 | A PyTorch module for differentiable parametric curves with learnable coefficients, 6 | such as a B-Spline curve with learnable control points. 7 | 8 | This package provides fully differentiable curve implementations that integrate 9 | seamlessly with PyTorch's autograd system. It streamlines use cases such as 10 | continuous numerical embeddings for factorization machines [6] or transformers 11 | [2,3], Kolmogorov-Arnold networks [1], or path planning in robotics. 12 | 13 | ## Docs 14 | - [Documentation site](https://torchcurves.readthedocs.io/en/latest/). 15 | - [Example notebooks](https://torchcurves.readthedocs.io/en/latest/example_notebooks.html) for you to try our 16 | 17 | ## Features 18 | 19 | - **Fully Differentiable**: Custom autograd function ensures gradients flow 20 | properly through the curve evaluation. 21 | - **Batch Processing**: Vectorized operations for efficient batch evaluation. 22 | 23 | ## Installation 24 | 25 | ```bash 26 | pip install torchcurves 27 | ``` 28 | 29 | ```bash 30 | uv add torchcurves 31 | ``` 32 | 33 | ## Use cases 34 | 35 | There are examples in the `examples` directory showing how to build models using 36 | this library. Here we show some simple code snippets to appreciate the library. 37 | 38 | ## Use case 1 - continuous embeddings 39 | 40 | ```python 41 | import torchcurves as tc 42 | from torch import nn 43 | import torch 44 | 45 | 46 | def Net(nn.Module): 47 | def __init__(self, num_categorical, num_numerical, dim, num_knots=10): 48 | super().__init__() 49 | self.cat_emb = nn.Embedding(num_categorical, dim) 50 | self.num_emb = tc.BSplineEmbeddings(num_numerical, dim, knots_config=num_knots) 51 | self.my_super_duper_transformer = MySuperDuperTransformer() 52 | 53 | def forward(self, x_categorical, x_numerical): 54 | embeddings = torch.cat([self.cat_emb(x_categorical), self.num_emb(x_numerical)], axis=-2) 55 | return self.my_super_duper_transformer(embeddings) 56 | ``` 57 | 58 | ## Use case 2 - Kolmogorov-Arnold networks 59 | 60 | A KAN [1] based on the B-Spline basis, along the lines of the original paper: 61 | 62 | ```python 63 | import torchcurves as tc 64 | from torch import nn 65 | 66 | input_dim = 2 67 | intermediate_dim = 5 68 | num_control_points = 10 69 | 70 | kan = nn.Sequential( 71 | # layer 1 72 | tc.BSplineCurve(input_dim, intermediate_dim, knots_config=num_control_points), 73 | tc.Sum(dim=-2), 74 | # layer 2 75 | tc.BSplineCurve(intermediate_dim, intermediate_dim, knots_config=num_control_points), 76 | tc.Sum(dim=-2), 77 | # layer 3 78 | tc.BSplineCurve(intermediate_dim, 1, knots_config=num_control_points), 79 | tc.Sum(dim=-2), 80 | ) 81 | ``` 82 | Yes, we know the original KAN paper used a different curve parametrization, 83 | B-Spline + arcsinh, but the whole point of this repo is showing that KAN 84 | activations can be parametrized in arbitrary ways. 85 | 86 | For example, here is a KAN based on Legendre polynomials of degree 5: 87 | 88 | ```python 89 | import torchcurves as tc 90 | from torch import nn 91 | 92 | input_dim = 2 93 | intermediate_dim = 5 94 | degree = 5 95 | 96 | kan = nn.Sequential( 97 | # layer 1 98 | tc.LegendreCurve(input_dim, intermediate_dim, degree=degree), 99 | tc.Sum(dim=-2), 100 | # layer 2 101 | tc.LegendreCurve(intermediate_dim, intermediate_dim, degree=degree), 102 | tc.Sum(dim=-2), 103 | # layer 3 104 | tc.LegendreCurve(intermediate_dim, 1, degree=degree), 105 | tc.Sum(dim=-2), 106 | ) 107 | ``` 108 | 109 | Since KANs are the primary use case for the `tc.Sum()` layer, we can omit the `dim=-2` argument, but it is provided 110 | here for clarity. 111 | 112 | ## Advanced features 113 | 114 | The curves we provide here typically rely on their inputs to lie in a compact 115 | interval, typically [-1, 1]. Arbitrary inputs need to be normalized to this 116 | interval. We provide two simple out-of-the-box normalization strategies 117 | described below. 118 | 119 | ## Rational scaling 120 | 121 | This is the default strategy — this strategy computes 122 | 123 | ```math 124 | x \to \frac{x}{\sqrt{s^2 + x^2}}, 125 | ``` 126 | 127 | and is based on the paper 128 | >Wang, Z.Q. and Guo, B.Y., 2004. Modified Legendre rational spectral method for the whole line. Journal of Computational Mathematics, pp.457-474. 129 | 130 | In Python it looks like this: 131 | 132 | ```python 133 | tc.BSplineCurve(curve_dim, normalization_fn='rational', normalization_scale=s) 134 | ``` 135 | 136 | ## Clamping 137 | 138 | The inputs are simply clipped to [-1, 1] after scaling, i.e. 139 | 140 | ```math 141 | x \to \max(\min(1, x / s), -1) 142 | ``` 143 | 144 | In Python it looks like this: 145 | 146 | ```python 147 | tc.BSplineCurve(curve_dim, normalization_fn='clamp', normalization_scale=s) 148 | ``` 149 | 150 | ## Custom normalization 151 | 152 | Provide a custom function that maps its input to the designated range after 153 | scaling. Example: 154 | 155 | ```python 156 | def erf_clamp(x: Tensor, scale: float = 1, out_min: float = -1, out_max: float = 1) -> Tensor: 157 | mapped = torch.special.erf(x / scale) 158 | return ((mapped + 1) * (out_max - out_min)) / 2 + out_min 159 | 160 | tc.BSplineCurve(curve_dim, normalization_fn=erf_clamp, normalization_scale=s) 161 | ``` 162 | 163 | ## Example: B-Spline KAN with clamping 164 | 165 | A KAN based on rationally scaled B-Spline basis with the default scale of $s=1$: 166 | 167 | ```python 168 | spline_kan = nn.Sequential( 169 | # layer 1 170 | tc.BSplineCurve(input_dim, intermediate_dim, knots_config=knots, normalization_fn='clamp'), 171 | tc.Sum(), 172 | # layer 2 173 | tc.BSplineCurve(intermediate_dim, intermediate_dim, knots_config=knots, normalization_fn='clamp'), 174 | tc.Sum(), 175 | # layer 3 176 | tc.BSplineCurve(intermediate_dim, 1, knots_config=knots, normalization_fn='clamp'), 177 | tc.Sum(), 178 | ) 179 | ``` 180 | 181 | ### Legendre KAN with rational clamping 182 | 183 | ```python 184 | import torchcurves as tc 185 | from torch import nn 186 | 187 | input_dim = 2 188 | intermediate_dim = 5 189 | degree = 5 190 | 191 | config = dict(degree=degree, normalization_fn="clamp") 192 | kan = nn.Sequential( 193 | # layer 1 194 | tc.LegendreCurve(input_dim, intermediate_dim, **config), 195 | tc.Sum(), 196 | # layer 2 197 | tc.LegendreCurve(intermediate_dim, intermediate_dim, **config), 198 | tc.Sum(), 199 | # layer 3 200 | tc.LegendreCurve(intermediate_dim, 1, **config), 201 | tc.Sum(), 202 | ) 203 | ``` 204 | 205 | 206 | ## Development 207 | 208 | ## Development Installation 209 | 210 | Using [uv](https://github.com/astral-sh/uv) (recommended): 211 | 212 | ```bash 213 | # Clone the repository 214 | git clone https://github.com/alexshtf/torchcurves.git 215 | cd torchcurves 216 | 217 | # Create virtual environment and install 218 | uv venv 219 | uv sync --all-groups 220 | ``` 221 | 222 | ## Running Tests 223 | 224 | ```bash 225 | # Run all tests 226 | uv run pytest 227 | 228 | # Run with coverage 229 | uv run pytest --cov=torchcurves 230 | 231 | # Run specific test file 232 | uv run pytest tests/test_bspline.py -v 233 | ``` 234 | 235 | ## Building the docs 236 | 237 | ```bash 238 | # Prepare API docs 239 | cd docs 240 | make html 241 | ``` 242 | 243 | ## Citation 244 | 245 | If you use this package in your research, please cite: 246 | 247 | ```bibtex 248 | @software{torchcurves, 249 | author = {Shtoff, Alex}, 250 | title = {torchcurves: Differentiable Parametric Curves in PyTorch}, 251 | year = {2025}, 252 | publisher = {GitHub}, 253 | url = {https://github.com/alexshtf/torchcurves} 254 | } 255 | ``` 256 | 257 | ## References 258 | 259 | [1]: Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljacic, Thomas Y. Hou, Max Tegmark. "KAN: Kolmogorov–Arnold Networks." *ICLR* (2025). \ 260 | [2]: Juergen Schmidhuber. "Learning to control fast-weight memories: An alternative to dynamic recurrent networks." *Neural Computation*, 4(1), pp.131-139. (1992) \ 261 | [3]: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." *Advances in neural information processing systems* 30 (2017). \ 262 | [4]: Alex Shtoff, Elie Abboud, Rotem Stram, and Oren Somekh. "Function Basis Encoding of Numerical Features in Factorization Machines." *Transactions on Machine Learning Research*. \ 263 | [5]: Rügamer, David. "Scalable Higher-Order Tensor Product Spline Models." In *International Conference on Artificial Intelligence and Statistics*, pp. 1-9. PMLR, 2024. \ 264 | [6]: Steffen Rendle. "Factorization machines." In *2010 IEEE International conference on data mining*, pp. 995-1000. IEEE, 2010. 265 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= uv run sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/source/_templates/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | -------------------------------------------------------------------------------- /doc/source/api.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ================= 3 | 4 | .. toctree:: 5 | torchcurves 6 | torchcurves.functional 7 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | # Make sure we add the source code root to the path 5 | PROJECT_ROOT = Path(__file__).resolve().parents[2] # two levels up from conf.py 6 | sys.path.insert(0, str(PROJECT_ROOT / "src")) # make `import torchcurves` work 7 | 8 | # Configuration file for the Sphinx documentation builder. 9 | # 10 | # For the full list of built-in configuration values, see the documentation: 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 12 | 13 | # -- Project information ----------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 15 | 16 | project = "TorchCurves" 17 | copyright = "2025, Alex Shtoff" 18 | author = "Alex Shtoff" 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | 24 | extensions = [ 25 | "sphinx.ext.mathjax", 26 | "sphinx.ext.autodoc", 27 | "sphinx.ext.autosummary", 28 | "sphinx.ext.napoleon", 29 | "myst_parser", 30 | "sphinx_autodoc_typehints", 31 | "nbsphinx", 32 | "sphinx_copybutton", 33 | "sphinxext.opengraph", 34 | ] 35 | 36 | 37 | napoleon_google_docstring = True 38 | napoleon_numpy_docstring = False 39 | 40 | autodoc_typehints = "description" # rely on PEP-484 annotations 41 | 42 | templates_path = ["_templates"] 43 | exclude_patterns = [] 44 | pygments_style = "sphinx" 45 | # mathjax_path = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/3.0.0/es5/latest?tex-mml-chtml.js" 46 | # mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@2/MathJax.js?config=TeX-AMS_CHTML" 47 | mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" 48 | 49 | 50 | language = "en" 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 54 | 55 | html_theme = "pydata_sphinx_theme" 56 | html_static_path = ["_static"] 57 | -------------------------------------------------------------------------------- /doc/source/example_notebooks.rst: -------------------------------------------------------------------------------- 1 | Example notebooks 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | examples/draw_bspline 8 | examples/draw_legendre 9 | examples/kan_bspline_rat 10 | examples/kan_legendre_rat 11 | examples/factorization_machine 12 | examples/transformer_mixed_curves 13 | -------------------------------------------------------------------------------- /doc/source/examples/draw_legendre.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "dc2e0d2d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Legendre curve plotting demo\n", 9 | "In this notebook we show the spectral nature of Legendre curves. The parameters we learn are a kind of a frequency\n", 10 | "domain, defining the spectrum of the curves. They oscilate more close to the origin, and less farther away from\n", 11 | "the origin." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 11, 17 | "id": "88b0fbb8", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import torch\n", 22 | "import torch.nn.functional as F\n", 23 | "import torchcurves.functional as tcf\n", 24 | "import matplotlib.pyplot as plt" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "9d949c41", 30 | "metadata": {}, 31 | "source": [ 32 | "## Define parameters" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "28e80e56", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "degree = 10\n", 43 | "n_coefficients = 1 + degree" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "d72eac2f", 49 | "metadata": {}, 50 | "source": [ 51 | "## Define coefficients of various curves" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 66, 57 | "id": "d356a784", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "torch.Size([6, 3, 2])" 64 | ] 65 | }, 66 | "execution_count": 66, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "num_curves = 3\n", 73 | "dim = 2\n", 74 | "\n", 75 | "t = torch.linspace(-1, 1, n_coefficients)\n", 76 | "\n", 77 | "freq = torch.pi * n_coefficients\n", 78 | "first_coef = torch.stack([torch.sin(freq * t), torch.cos(freq * t)], dim=1)\n", 79 | "second_coef = torch.stack([torch.sin(freq * t) / F.softplus(t), torch.cos(freq * t) / F.softplus(t)], dim=1)\n", 80 | "third_coef = torch.stack([torch.exp(t) / (1 + 5 * (1 + t)), torch.sin(freq * t) / (1 + 10 * (1 + t))], dim=1)\n", 81 | "coefs = torch.stack([first_coef, second_coef, third_coef], dim=1)\n", 82 | "coefs.shape" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "id": "dc241927", 88 | "metadata": {}, 89 | "source": [ 90 | "## Sample and draw the Legendre curves with 100 sample points from -1 to 1" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 67, 96 | "id": "742d7cd8", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "sample_points = torch.torch.linspace(-1, 1, 1000)\n", 101 | "curve_args = sample_points.reshape(-1, 1).expand(-1, 3)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 68, 107 | "id": "e60014e9", 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "data": { 112 | "text/plain": [ 113 | "torch.Size([1000, 3, 2])" 114 | ] 115 | }, 116 | "execution_count": 68, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "curve_points = tcf.legendre_curves(curve_args, coefs)\n", 123 | "curve_points.shape" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 69, 129 | "id": "35cb8884", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "first_curve, second_curve, third_curve = curve_points.unbind(dim=1)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 70, 139 | "id": "4f5abaed", 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "image/png": "", 145 | "text/plain": [ 146 | "
" 147 | ] 148 | }, 149 | "metadata": {}, 150 | "output_type": "display_data" 151 | } 152 | ], 153 | "source": [ 154 | "plt.plot(*first_curve.unbind(dim=1), label='Curve')\n", 155 | "plt.legend()\n", 156 | "plt.show()" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 71, 162 | "id": "42ff728f", 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "image/png": "", 168 | "text/plain": [ 169 | "
" 170 | ] 171 | }, 172 | "metadata": {}, 173 | "output_type": "display_data" 174 | } 175 | ], 176 | "source": [ 177 | "plt.plot(*second_curve.unbind(dim=1), label='Curve')\n", 178 | "plt.legend()\n", 179 | "plt.show()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 72, 185 | "id": "beafc604", 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "image/png": "", 191 | "text/plain": [ 192 | "
" 193 | ] 194 | }, 195 | "metadata": {}, 196 | "output_type": "display_data" 197 | } 198 | ], 199 | "source": [ 200 | "plt.plot(*third_curve.unbind(dim=1), label='Curve')\n", 201 | "plt.legend()\n", 202 | "plt.show()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "6a7e0cf4", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "torchcurves", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.12.8" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 5 235 | } 236 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | TorchCurves documentation 2 | ========================= 3 | .. toctree:: 4 | :caption: API 5 | :maxdepth: 1 6 | 7 | torchcurves 8 | torchcurves.functional 9 | 10 | .. toctree:: 11 | :caption: Examples 12 | :maxdepth: 2 13 | 14 | example_notebooks 15 | -------------------------------------------------------------------------------- /doc/source/torchcurves.functional.rst: -------------------------------------------------------------------------------- 1 | torchcurves.functional 2 | ====================== 3 | 4 | .. currentmodule:: torchcurves.functional 5 | 6 | Normalization 7 | ------------- 8 | Functions for normalizing inputs to the :math:`[-1, 1]` interval, required by most parametric curves. 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | 14 | clamp 15 | rational 16 | 17 | 18 | Parametrized curves 19 | ------------------- 20 | Vectorized parametric curve evaluation functions. 21 | 22 | .. autosummary:: 23 | :toctree: generated 24 | :nosignatures: 25 | 26 | bspline_curves 27 | bspline_embeddings 28 | legendre_curves 29 | 30 | 31 | Utilities 32 | --------- 33 | 34 | .. autosummary:: 35 | :toctree: generated 36 | :nosignatures: 37 | 38 | uniform_augmented_knots 39 | -------------------------------------------------------------------------------- /doc/source/torchcurves.rst: -------------------------------------------------------------------------------- 1 | torchcurves 2 | =========== 3 | 4 | .. automodule:: torchcurves 5 | .. automodule:: torchcurves.modules 6 | 7 | 8 | .. contents:: torchcurves 9 | :depth: 2 10 | :local: 11 | :backlinks: top 12 | 13 | 14 | .. currentmodule:: torchcurves 15 | 16 | Layers 17 | ------ 18 | 19 | .. autosummary:: 20 | :toctree: generated 21 | :nosignatures: 22 | :template: classtemplate.rst 23 | 24 | BSplineEmbeddings 25 | BSplineCurve 26 | LegendreCurve 27 | Sum 28 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexshtf/torchcurves/78e2c4c193edc65aedfdf4804bd068afff3c5214/logo.png -------------------------------------------------------------------------------- /logo_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexshtf/torchcurves/78e2c4c193edc65aedfdf4804bd068afff3c5214/logo_small.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "torchcurves" 3 | version = "0.1.1" 4 | description = "PyTorch module for differentiable parametric curves with learnable coefficients" 5 | authors = [ 6 | { name = "Alex Shtoff", email = "alex.shtf@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | license = { text = "Apache 2.0" } 10 | requires-python = ">=3.9" 11 | keywords = ["pytorch", "bspline", "curves", "differentiable", "deep-learning", "geometric-deep-learning"] 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Science/Research", 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "Topic :: Scientific/Engineering :: Mathematics", 24 | ] 25 | 26 | dependencies = [ 27 | "torch>=1.10.0", 28 | ] 29 | 30 | [dependency-groups] 31 | dev = [ 32 | "numpy>=1.17.0", 33 | "pytest>=7.0.0", 34 | "pytest-cov>=4.0.0", 35 | "black>=23.0.0", 36 | "ruff>=0.1.0", 37 | "mypy>=1.0.0", 38 | "pre-commit>=3.0.0", 39 | ] 40 | doc = [ 41 | "sphinx>=7.4.7", 42 | "myst-parser>=3.0.1", 43 | "sphinx-autodoc-typehints>=2.3.0", 44 | "sphinx-copybutton>=0.5.2", 45 | "sphinx-autobuild>=2024.10.3", 46 | "sphinxext-opengraph>=0.10.0", 47 | "pydata-sphinx-theme>=0.16.1", 48 | "ipython>=8.18.1", 49 | "nbsphinx>=0.9.7", 50 | "pandoc>=2.4", 51 | ] 52 | examples = [ 53 | "ipykernel>=6.29.0", 54 | "matplotlib>=3.8.0", 55 | "ipywidgets>=8.1.7", 56 | "kagglehub[pandas-datasets]>=0.3.12", 57 | "pandas>=2.3.0", 58 | "scikit-learn>=1.6.1", 59 | ] 60 | 61 | [project.urls] 62 | Homepage = "https://github.com/alexshtf/torchcurves" 63 | Repository = "https://github.com/alexshtf/torchcurves" 64 | Issues = "https://github.com/alexshtf/torchcurves/issues" 65 | 66 | [build-system] 67 | requires = ["hatchling"] 68 | build-backend = "hatchling.build" 69 | 70 | [tool.hatch.build.targets.wheel] 71 | packages = ["src/torchcurves"] 72 | 73 | [tool.pytest.ini_options] 74 | testpaths = ["tests"] 75 | python_files = "test_*.py" 76 | python_classes = "Test*" 77 | python_functions = "test_*" 78 | addopts = "-v --cov=torch_bspline --cov-report=term-missing" 79 | 80 | [tool.ruff] 81 | line-length = 120 82 | 83 | [tool.ruff.lint] 84 | select = [ "E", "W", "F", "I", "B", "C4", "N", "D", "SIM",] 85 | extend-select = [ "RUF100",] 86 | ignore = [ "E101", "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", "B024",] 87 | fixable = [ "ALL",] 88 | -------------------------------------------------------------------------------- /src/torchcurves/__init__.py: -------------------------------------------------------------------------------- 1 | """Differentiable parametric curves in arbitrary dimensions.""" 2 | 3 | from torchcurves import types as types 4 | 5 | from .modules import * # noqa: F403 6 | -------------------------------------------------------------------------------- /src/torchcurves/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from ._bspline import bspline_curves, bspline_embeddings, uniform_augmented_knots 2 | from ._legendre import legendre_curves 3 | from ._normalization import clamp, rational 4 | 5 | __all__ = ["clamp", "rational", "uniform_augmented_knots", "bspline_curves", "bspline_embeddings", "legendre_curves"] 6 | -------------------------------------------------------------------------------- /src/torchcurves/functional/_bspline.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F # noqa: N812 5 | 6 | 7 | def uniform_augmented_knots( 8 | n_control_points: int, degree: int, dtype=torch.float32, device: Union[torch.device, str, None] = None 9 | ) -> torch.Tensor: 10 | """Generate an augmented knot vector with uniform spacing in [-1, 1] for B-spline curves. 11 | 12 | This function returns a 1D tensor containing knot values. The internal knots are computed uniformly in the interval 13 | [-1, 1] for the given number of control points and degree. The head and tail, each containing (degree + 1) identical 14 | knots, conforming to the not-a-knot boundary conditions. 15 | 16 | Args: 17 | n_control_points (int): The total number of control points for the B-spline. 18 | Must be at least (degree + 1) to have a valid knot vector. 19 | degree (int): The degree of the B-spline. 20 | dtype (torch.dtype, optional): The desired data type of the output tensor. 21 | Defaults to torch.float32. 22 | device (torch.device or str): The device on which the knot vector will reside. 23 | 24 | Returns: 25 | torch.Tensor: A 1D tensor of knots consisting of head knots, uniformly spaced 26 | internal knots, and tail knots, all in the range [-1.0, 1.0]. 27 | 28 | Raises: 29 | ValueError: If the number of control points is less than (degree + 1), indicating 30 | that there are not enough points to form a valid knot vector. 31 | 32 | """ 33 | if n_control_points < 1 + degree: 34 | raise ValueError("Not enough control points for the given degree to form internal knots.") 35 | 36 | # Generates knots in [-1, 1] 37 | k_min, k_max = -1.0, 1.0 # Target range for normalized u 38 | 39 | head_knots = torch.full((degree + 1,), k_min, dtype=dtype, device=device) 40 | tail_knots = torch.full((degree + 1,), k_max, dtype=dtype, device=device) 41 | 42 | num_internal_knots = n_control_points - degree - 1 43 | if num_internal_knots == 0: 44 | internal_knots = torch.empty(0, dtype=dtype, device=device) 45 | else: 46 | internal_knots = torch.linspace(k_min, k_max, num_internal_knots + 2, dtype=dtype, device=device)[1:-1] 47 | 48 | return torch.cat([head_knots, internal_knots, tail_knots]) 49 | 50 | 51 | class _BSplineFunction(torch.autograd.Function): 52 | ZERO_TOLERANCE = 1e-12 53 | ONE_TOLERANCE = 1.0 - ZERO_TOLERANCE # Assuming u is normalized to [0,1] for these constants 54 | 55 | """Custom autograd function for B-spline evaluation and differentiation (Vectorized for multiple curves).""" 56 | 57 | @staticmethod 58 | def find_spans(u: torch.Tensor, knots: torch.Tensor, degree: int, n_control_points: int) -> torch.Tensor: 59 | """Find the knot span index for each parameter value (vectorized). 60 | 61 | Args: 62 | u: Parameter values, shape (N, M) or (N,). N samples, M curves. 63 | If u is (N,), it's treated as (N,1). 64 | Values are expected to be in the range defined by the knots (e.g., [0,1] or [-1,1]). 65 | knots: Knot vector, shape (num_total_knots,). Expected to be a clamped knot vector. 66 | degree: B-spline degree (p). 67 | n_control_points: Number of control points per curve (c). 68 | 69 | Returns: 70 | Span indices, shape (N, M) or (N,). Each span_idx `s` means u falls in [knots[s], knots[s+1]). 71 | 72 | """ 73 | # Note: The original ZERO_TOLERANCE and ONE_TOLERANCE assumed u in [0,1] and knots clamped to [0,1]. 74 | # If knots are e.g. [-1,1], this specific boundary handling might need adjustment 75 | # or u should be pre-normalized to [0,1] if this logic is to be kept strictly. 76 | # For now, we assume u is in the range [knots[degree], knots[n_control_points]]. 77 | # The torch.searchsorted and clamp largely handle this. 78 | 79 | spans = torch.searchsorted(knots, u, side="right") - 1 80 | 81 | # Handle boundaries based on the actual knot values for robustness 82 | # This assumes knots is sorted and clamped: knots[0]..knots[degree] are same, 83 | # and knots[n_control_points]..knots[n_control_points+degree] are same. 84 | min_knot_val = knots[degree] 85 | max_knot_val = knots[n_control_points] # This is the start of the last segment of p+1 knots 86 | 87 | # For u values at or slightly below the minimum parameter value 88 | spans[u <= min_knot_val + _BSplineFunction.ZERO_TOLERANCE] = degree 89 | # For u values at or slightly above the maximum parameter value 90 | spans[u >= max_knot_val - _BSplineFunction.ZERO_TOLERANCE] = n_control_points - 1 91 | 92 | spans = torch.clamp(spans, min=degree, max=n_control_points - 1) 93 | return spans 94 | 95 | @staticmethod 96 | def cox_de_boor(u: torch.Tensor, knots: torch.Tensor, spans: torch.Tensor, degree: int) -> torch.Tensor: 97 | """Compute B-spline basis functions using Cox-de Boor recursion. 98 | 99 | Args: 100 | u: Parameter values, shape (N, M). N samples, M curves. 101 | knots: Knot vector, shape (num_total_knots,). 102 | spans: Knot span indices, shape (N, M). `spans[n,m]` is `s`. 103 | degree: B-spline degree (p). 104 | 105 | Returns: 106 | Basis function values N_batch, shape (N, M, degree+1). 107 | N_batch[n, m, j] = B_{spans[n,m]-degree+j, degree}(u[n,m]). 108 | 109 | """ 110 | num_samples_n, num_curves_m = u.shape 111 | device, dtype = u.device, u.dtype 112 | 113 | # batch_nonzero_basis[n, m, k] will store B_{spans[n,m]-degree+k, degree}(u[n,m]) 114 | batch_nonzero_basis = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype) 115 | 116 | left_dist_all_p = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype) 117 | right_dist_all_p = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype) 118 | 119 | batch_nonzero_basis[..., 0] = 1.0 120 | 121 | for p_iter in range(1, degree + 1): # p_iter is 'j' in Piegl & Tiller A2.2 122 | # knots is 1D. We gather using indices derived from spans (N,M) 123 | # Resulting shapes for left_dist_all_p, etc. will be (N,M) 124 | idx_knot_left = (spans + 1 - p_iter).clamp(min=0, max=knots.shape[0] - 1) 125 | left_dist_all_p[..., p_iter] = u - knots[idx_knot_left] 126 | 127 | idx_knot_right = (spans + p_iter).clamp(min=0, max=knots.shape[0] - 1) 128 | right_dist_all_p[..., p_iter] = knots[idx_knot_right] - u 129 | 130 | saved_val = torch.zeros(num_samples_n, num_curves_m, device=device, dtype=dtype) 131 | 132 | for r_iter in range(p_iter): 133 | denominator_batch = right_dist_all_p[..., r_iter + 1] + left_dist_all_p[..., p_iter - r_iter] 134 | 135 | ratios = batch_nonzero_basis[..., r_iter] / denominator_batch 136 | ratios = torch.where(torch.isfinite(ratios), ratios, torch.zeros_like(ratios)) 137 | 138 | batch_nonzero_basis[..., r_iter] = saved_val + right_dist_all_p[..., r_iter + 1] * ratios 139 | saved_val = left_dist_all_p[..., p_iter - r_iter] * ratios 140 | 141 | batch_nonzero_basis[..., p_iter] = saved_val 142 | return batch_nonzero_basis 143 | 144 | @staticmethod 145 | def evaluate_curve( 146 | basis: torch.Tensor, # shape (N, M, degree+1) 147 | control_points: torch.Tensor, # shape (M, C, D) C=n_control_points 148 | spans: torch.Tensor, # shape (N, M) 149 | degree: int, 150 | ) -> torch.Tensor: 151 | """Evaluate B-spline curves (vectorized for multiple curves). 152 | 153 | Args: 154 | basis: Basis function values. basis[n,m,j] = N_{spans[n,m]-degree+j, degree}(u[n,m]). 155 | control_points: Control points for M curves. 156 | spans: Knot span indices. 157 | degree: B-spline degree. 158 | 159 | Returns: 160 | Points on curves, shape (N, M, D). 161 | 162 | """ 163 | num_samples_n, num_curves_m = spans.shape 164 | # C = num_control_points_per_curve, D = dim 165 | # M_cp, C_cp, D_cp = control_points.shape 166 | # Assert M_cp == num_curves_m 167 | 168 | # control_point_indices: indices into C dimension of control_points 169 | # Shape: (N, M, degree+1) 170 | degrees_range = torch.arange(degree + 1, device=spans.device).view(1, 1, -1) 171 | control_point_indices = spans.unsqueeze(-1) - degree + degrees_range 172 | 173 | # Clamp indices to be valid for control_points' C dimension 174 | clamped_cp_indices = torch.clamp(control_point_indices, 0, control_points.shape[1] - 1) 175 | 176 | # Gather control points: gathered_control_points[n, m, i, d] = control_points[m, clamped_cp_indices[n,m,i], d] 177 | # Need to create m_indices for gathering from control_points' M dimension 178 | # m_indices_for_gather shape: (N, M, degree+1) 179 | m_indices_for_gather = torch.arange(num_curves_m, device=control_points.device).view(1, -1, 1) 180 | m_indices_for_gather = m_indices_for_gather.expand(num_samples_n, -1, degree + 1) 181 | 182 | gathered_control_points = control_points[ 183 | m_indices_for_gather, # Selects the curve from M dimension of control_points 184 | clamped_cp_indices, # Selects the control points from C dimension 185 | :, # Selects all D dimensions 186 | ] # Shape (N, M, degree+1, D) 187 | 188 | # Compute points: points[n,m,d] = sum_i basis[n,m,i] * gathered_control_points[n,m,i,d] 189 | # basis.unsqueeze(-1) gives (N, M, degree+1, 1) 190 | return (basis.unsqueeze(-1) * gathered_control_points).sum(dim=2) # Sum over degree+1 dim 191 | 192 | @staticmethod 193 | def basis_derivative_coefficients( 194 | knots: torch.Tensor, spans: torch.Tensor, degree: int 195 | ) -> Tuple[torch.Tensor, torch.Tensor]: 196 | """Compute coefficients for basis function derivatives (vectorized for multiple curves). 197 | 198 | Args: 199 | knots: Knot vector. 200 | spans: Knot span indices, shape (N, M). 201 | degree: B-spline degree (p). 202 | 203 | Returns: 204 | alpha_coeffs_batch, beta_coeffs_batch: shape (N, M, degree+1). 205 | 206 | """ 207 | num_samples_n, num_curves_m = spans.shape 208 | device, dtype = spans.device, knots.dtype # Use knot's dtype for coeffs 209 | 210 | degrees_range = torch.arange(degree + 1, device=device).view(1, 1, -1) 211 | knots_idx = spans.unsqueeze(-1) - degree + degrees_range # (N, M, degree+1) 212 | 213 | # Gather knot values - knots[knots_idx] will broadcast correctly 214 | knots_k = knots[knots_idx.clamp(min=0, max=knots.shape[0] - 1)] 215 | knots_k_plus_deg = knots[(knots_idx + degree).clamp(min=0, max=knots.shape[0] - 1)] 216 | knots_k_plus_1 = knots[(knots_idx + 1).clamp(min=0, max=knots.shape[0] - 1)] 217 | knots_k_plus_deg_plus_1 = knots[(knots_idx + degree + 1).clamp(min=0, max=knots.shape[0] - 1)] 218 | 219 | alpha_coeffs_batch = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype) 220 | beta_coeffs_batch = torch.zeros(num_samples_n, num_curves_m, degree + 1, device=device, dtype=dtype) 221 | 222 | denom_alpha = knots_k_plus_deg - knots_k 223 | mask_alpha = torch.abs(denom_alpha) > _BSplineFunction.ZERO_TOLERANCE 224 | alpha_coeffs_batch[mask_alpha] = degree / denom_alpha[mask_alpha] 225 | 226 | denom_beta = knots_k_plus_deg_plus_1 - knots_k_plus_1 227 | mask_beta = torch.abs(denom_beta) > _BSplineFunction.ZERO_TOLERANCE 228 | beta_coeffs_batch[mask_beta] = degree / denom_beta[mask_beta] 229 | 230 | return alpha_coeffs_batch, beta_coeffs_batch 231 | 232 | @staticmethod 233 | def compute_basis_derivatives( 234 | u: torch.Tensor, knots: torch.Tensor, spans: torch.Tensor, degree: int 235 | ) -> torch.Tensor: 236 | """Compute derivatives of B-spline basis functions (vectorized for multiple curves). 237 | 238 | Output basis_deriv[n,m,i] = B'_{spans[n,m]-degree+i, degree}(u[n,m]). 239 | Shape: (N, M, degree+1) 240 | """ 241 | if degree == 0: 242 | return torch.zeros(*u.shape, 1, device=u.device, dtype=u.dtype) 243 | 244 | # lower_deg_basis shape: (N, M, degree) 245 | lower_deg_basis = _BSplineFunction.cox_de_boor(u, knots, spans, degree - 1) 246 | 247 | # alpha, beta have shape (N, M, degree+1) 248 | alpha, beta = _BSplineFunction.basis_derivative_coefficients(knots, spans, degree) 249 | 250 | # Pad lower_deg_basis's last dimension to (degree+1) 251 | # Pad (0,1) means add 1 zero to the right: [N0,...,N(deg-1), 0] 252 | lower_pad_right = F.pad(lower_deg_basis, (0, 1)) 253 | # Pad (1,0) means add 1 zero to the left: [0, N0,...,N(deg-1)] 254 | lower_pad_left = F.pad(lower_deg_basis, (1, 0)) 255 | 256 | basis_deriv = alpha * lower_pad_left - beta * lower_pad_right 257 | return basis_deriv 258 | 259 | @staticmethod 260 | def forward( 261 | ctx, 262 | u: torch.Tensor, # shape (N, M) 263 | control_points: torch.Tensor, # shape (M, C, D) 264 | knots: torch.Tensor, # shape (num_total_knots,) 265 | degree: int, 266 | ) -> torch.Tensor: 267 | # M_cp = control_points.shape[0] # Number of curves from control_points 268 | # N_u, M_u = u.shape # N samples, M curves from u 269 | # Assert M_cp == M_u 270 | 271 | n_control_points_per_curve = control_points.shape[1] # C 272 | 273 | spans = _BSplineFunction.find_spans(u, knots, degree, n_control_points_per_curve) # (N,M) 274 | basis_funcs = _BSplineFunction.cox_de_boor(u, knots, spans, degree) # (N,M,degree+1) 275 | points = _BSplineFunction.evaluate_curve(basis_funcs, control_points, spans, degree) # (N,M,D) 276 | 277 | ctx.save_for_backward(u, control_points, knots, spans, basis_funcs) 278 | ctx.degree = degree 279 | ctx.n_control_points_per_curve = n_control_points_per_curve # C 280 | 281 | # For re-computing control_point_indices in backward 282 | degrees_range = torch.arange(degree + 1, device=spans.device).view(1, 1, -1) 283 | ctx.control_point_indices = spans.unsqueeze(-1) - degree + degrees_range # (N,M,degree+1) 284 | 285 | return points 286 | 287 | @staticmethod 288 | def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None, None]: # type: ignore 289 | # grad_output shape: (N, M, D) 290 | u, control_points, knots, spans, basis_funcs = ctx.saved_tensors 291 | # u: (N,M), control_points: (M,C,D), knots: (K,), spans: (N,M), basis_funcs: (N,M,deg+1) 292 | 293 | degree = ctx.degree 294 | n_control_points_per_curve = ctx.n_control_points_per_curve # C 295 | control_point_indices = ctx.control_point_indices # (N,M,deg+1) 296 | 297 | num_samples_n, num_curves_m = u.shape 298 | # _, _, dim_d = grad_output.shape 299 | 300 | # Gradient with respect to u 301 | # basis_deriv shape: (N, M, degree+1) 302 | basis_deriv = _BSplineFunction.compute_basis_derivatives(u, knots, spans, degree) 303 | 304 | clamped_cp_indices = torch.clamp(control_point_indices, 0, n_control_points_per_curve - 1) # (N,M,deg+1) 305 | 306 | # Gather control points for d_points_du calculation 307 | # m_indices_for_gather shape: (N, M, degree+1) 308 | m_indices_for_gather = torch.arange(num_curves_m, device=u.device).view(1, -1, 1) 309 | m_indices_for_gather = m_indices_for_gather.expand(num_samples_n, -1, degree + 1) 310 | 311 | # gathered_cps shape: (N, M, degree+1, D) 312 | gathered_cps = control_points[m_indices_for_gather, clamped_cp_indices, :] 313 | 314 | # d_points_du[n,m,d] = sum_i basis_deriv[n,m,i] * gathered_cps[n,m,i,d] 315 | d_points_du = torch.einsum("nmi,nmid->nmd", basis_deriv, gathered_cps) # Shape (N, M, D) 316 | 317 | # grad_u[n,m] = sum_d grad_output[n,m,d] * d_points_du[n,m,d] 318 | grad_u = (grad_output * d_points_du).sum(dim=-1) # Shape (N, M) 319 | 320 | # Gradient with respect to control_points 321 | # grad_control_points shape: (M, C, D) 322 | grad_control_points = torch.zeros_like(control_points) 323 | 324 | # update_values[n,m,i,d] = grad_output[n,m,d] * basis_funcs[n,m,i] 325 | # grad_output.unsqueeze(2): (N,M,1,D) 326 | # basis_funcs.unsqueeze(3): (N,M,deg+1,1) 327 | update_values = grad_output.unsqueeze(2) * basis_funcs.unsqueeze(3) # (N,M,deg+1,D) 328 | 329 | # Permute for scatter_add_: target grad_control_points[m_idx, c_idx, d_idx] 330 | # update_values: (N, M, deg+1, D) -> (M, N, deg+1, D) 331 | update_values_perm = update_values.permute(1, 0, 2, 3) 332 | # clamped_cp_indices: (N, M, deg+1) -> (M, N, deg+1) 333 | clamped_cp_indices_perm = clamped_cp_indices.permute(1, 0, 2) 334 | 335 | # Flatten N and deg+1 dimensions 336 | # uv_flat: (M, N*(deg+1), D) 337 | uv_flat = update_values_perm.reshape(num_curves_m, -1, grad_output.shape[-1]) 338 | # idx_flat: (M, N*(deg+1)) 339 | idx_flat = clamped_cp_indices_perm.reshape(num_curves_m, -1) 340 | 341 | # Expand idx_flat to match uv_flat for scatter_add_ 342 | # idx_expanded_for_scatter: (M, N*(deg+1), D) 343 | idx_expanded_for_scatter = idx_flat.unsqueeze(-1).expand_as(uv_flat) 344 | 345 | # Scatter add along dimension C (index 1) 346 | grad_control_points.scatter_add_(1, idx_expanded_for_scatter, uv_flat) 347 | 348 | return grad_u, grad_control_points, None, None 349 | 350 | 351 | def bspline_curves( 352 | u: torch.Tensor, control_points: torch.Tensor, knots: Optional[torch.Tensor] = None, degree: int = 3 353 | ): 354 | r"""Evaluate multiple B-Spline curves, each with its own control points, sharing the same knots and degree. 355 | 356 | This function allow back-propagating both through the control points and the argument. Useful as a layer in 357 | a neural network. 358 | 359 | Args: 360 | u: A tensor of size :math:`(B, C)` of values between ``knots.min()`` and ``knots.max()``, representing 361 | a mini-batch of :math:`B` arguments for sampling each of the :math:`C` curves. 362 | control_points: A tensor of size :math:`(M, C, D)` describing :math:`M` curves with :math:`C` control 363 | points each, embedded in :math:`\mathbb{R}^D`. 364 | knots: A 1D tensor of size :math:`M + P + 1` representing the spline function's 365 | knot vector, where :math:`P` is the degree of the piecewise polynomials defining the spline function. 366 | ``None`` means uniformly-spaced knots in :math:`[-1, 1]` with the not-a-knot boundary 367 | conditions. (default: ``None``) 368 | degree: The degree :math:`P` of the B-Spline function. (default: ``3`` meaning a cubic spline) 369 | 370 | Returns: 371 | A tensor of size :math:`(B, C, D)`, representing a mini-batch of size :math:`B`, corresponding to samples from 372 | :math:`C` curves in :math:`\mathbb{R}^D`. 373 | 374 | """ 375 | if knots is None: 376 | n_control_points = control_points.shape[1] 377 | knots = uniform_augmented_knots( 378 | n_control_points, degree, dtype=control_points.dtype, device=control_points.device 379 | ) 380 | 381 | return _BSplineFunction.apply( 382 | u, 383 | control_points, 384 | knots, 385 | degree, 386 | ) 387 | 388 | 389 | def bspline_embeddings( 390 | u: torch.Tensor, control_points: torch.Tensor, knots: Optional[torch.Tensor] = None, degree: int = 3 391 | ): 392 | r"""Evaluate multiple B-Spline curves, each with its own control points, sharing the same knots and degree. 393 | 394 | This function allow back-propagating only through the control points and the argument. Useful as the input layer 395 | in a neural network, whose arguments come from a data-set that requires no back-prop, while allowing a cheaper 396 | computation for this usecase than :func:`bspline_curves`. 397 | 398 | Args: 399 | u: A tensor of size :math:`(B, C)` of values between ``knots.min()`` and ``knots.max()``, representing 400 | a mini-batch of :math:`B` arguments for sampling each of the :math:`C` curves. 401 | control_points: A tensor of size :math:`(M, C, D)` describing :math:`M` curves with :math:`C` control 402 | points each, embedded in :math:`\mathbb{R}^D`. 403 | knots: A 1D tensor of size :math:`M + P + 1` representing the spline function's 404 | knot vector, where :math:`P` is the degree of the piecewise polynomials defining the spline function. 405 | ``None`` means uniformly-spaced knots in :math:`[-1, 1]` with the not-a-knot boundary 406 | conditions. (default: ``None``) 407 | degree: The degree :math:`P` of the B-Spline function. (default: ``3`` meaning a cubic spline) 408 | 409 | Returns: 410 | A tensor of size :math:`(B, C, D)`, representing a mini-batch of size :math:`B`, corresponding to samples from 411 | :math:`C` curves in :math:`\mathbb{R}^D`. 412 | 413 | """ 414 | n_control_points = control_points.shape[1] 415 | if knots is None: 416 | knots = uniform_augmented_knots( 417 | n_control_points, degree, dtype=control_points.dtype, device=control_points.device 418 | ) 419 | 420 | spans = _BSplineFunction.find_spans(u, knots, degree, n_control_points) # (N,M) 421 | basis_funcs = _BSplineFunction.cox_de_boor(u, knots, spans, degree) # (N,M,deg+1) 422 | return _BSplineFunction.evaluate_curve(basis_funcs, control_points, spans, degree) # (N,M,D) 423 | -------------------------------------------------------------------------------- /src/torchcurves/functional/_legendre.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def legendre_curves(x: torch.Tensor, coefficients: torch.Tensor) -> torch.Tensor: 5 | r"""Evaluate curves parametrized by Legendre polynomials. 6 | 7 | Args: 8 | coefficients: A tensor of size :math:`(N, C, D)` of curve coefficients, of a set of :math:`C` polynomial curves 9 | in :math:`\mathbb{R}^D` of degree :math:`N-1`, represented in the Legendre basis. 10 | x: Batch of size :math:`(B, C)`, where ``x[:, j]`` is the batch of inputs for the j-th curve in the batch. 11 | 12 | Returns: 13 | Evaluated points on the curves, shape :math:`(B, C, D)`. 14 | 15 | Note: 16 | Uses the Clenshaw recursive algorithm, and thus requires :math:`O(N)` time. Implementation is vectorized along 17 | the :math:`B` and :math:`D` dimensions, but the algorithm requires a loop over the polynomial degree. 18 | 19 | """ 20 | n, c, m = coefficients.shape # n - number of coefficients, c - number of curves, m - curve dimension 21 | x = x.unsqueeze(-1).expand(-1, -1, m) # (b × c × m), b = batch size 22 | b2 = torch.zeros_like(x) # (b × c × m) 23 | b1 = torch.zeros_like(x) # (b × c × m) 24 | for k in reversed(range(n)): 25 | alpha = (2 * k + 1) / (k + 1) 26 | beta = (k + 1) / (k + 2) 27 | curr_coef = coefficients[k].unsqueeze(0) # (1 x c x m) 28 | bnext = torch.add(torch.addcmul(curr_coef, x, b1, value=alpha), b2, alpha=-beta) 29 | b2, b1 = b1, bnext 30 | return b1 31 | -------------------------------------------------------------------------------- /src/torchcurves/functional/_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..types import TensorLike 4 | 5 | 6 | def rational(x: TensorLike, scale: float = 1, out_min: float = -1, out_max: float = 1) -> torch.Tensor: 7 | r"""Normalize values using the "Legendre Rational Functions" [1] normalization method. 8 | 9 | The normalization is performed with the formula 10 | 11 | .. math:: 12 | x_{\mathrm{norm}} = \frac{x}{\sqrt{\mathrm{scale}^2 + x^2}}, 13 | 14 | where `scale` is a scaling factor. 15 | 16 | Args: 17 | x: Input tensor to be normalized. 18 | scale: Scale factor for normalization. (default=1) 19 | out_min: Lower bound of the output interval (default=-1) 20 | out_max: Upper bound of the output interval (default=1) 21 | 22 | Returns: 23 | Normalized tensor. 24 | 25 | **References** 26 | 27 | [1] Wang, Z.Q. and Guo, B.Y., 2004. 28 | *Modified Legendre rational spectral method for the whole line.* 29 | Journal of Computational Mathematics, pp.457-474. 30 | 31 | """ 32 | x = torch.as_tensor(x) 33 | result = x / torch.sqrt(scale**2 + x.square()) 34 | out_scaled = ((out_max - out_min) * result + out_max + out_min) / 2 35 | return torch.clip(out_scaled, out_min, out_max) 36 | 37 | 38 | def clamp(x: TensorLike, scale: float = 1, out_min: float = -1, out_max: float = 1) -> torch.Tensor: 39 | r"""Clamp values in a tensor to a specified range. 40 | 41 | The function clamps the values of the input tensor `x` to be within the output range, after scaling by the 42 | `scale` factor, by the formula: 43 | 44 | .. math:: 45 | x_{\mathrm{norm}} = \min(1, \max(0, x / \mathrm{scale})) 46 | 47 | Args: 48 | x: Input tensor to be normalized. 49 | scale: Scale factor for normalization. (default=1) 50 | out_min: Lower bound of the output interval (default=-1) 51 | out_max: Upper bound of the output interval (default=1) 52 | 53 | Returns: 54 | Normalized tensor. 55 | 56 | """ 57 | x = torch.as_tensor(x) 58 | return torch.clip(x / scale, out_min, out_max) 59 | -------------------------------------------------------------------------------- /src/torchcurves/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from ._bspline import BSplineCurve, BSplineEmbeddings 2 | from ._kan_tools import Sum 3 | from ._legendre import LegendreCurve 4 | 5 | __all__ = ["BSplineEmbeddings", "BSplineCurve", "LegendreCurve", "Sum"] 6 | -------------------------------------------------------------------------------- /src/torchcurves/modules/_bspline.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..functional import bspline_curves, bspline_embeddings, uniform_augmented_knots 7 | from ..types import NormalizationFn 8 | from ._normalization import _normalization_catalogue 9 | 10 | 11 | class BSplineCurveBase(nn.Module): 12 | """Base PyTorch module for B-spline curves, supporting a batch of multiple curves.""" 13 | 14 | knots: torch.Tensor # explicit annotation for type-checking 15 | 16 | def __init__( 17 | self, 18 | num_curves: int, 19 | dim: int, 20 | degree: int = 3, 21 | knots_config: Union[int, torch.Tensor] = 10, # This is n_control_points_per_curve 22 | normalize_fn: Union[Literal["clamp", "rational"], NormalizationFn] = "rational", 23 | normalization_scale: float = 1.0, 24 | ): 25 | super().__init__() 26 | 27 | if not isinstance(num_curves, int) or num_curves <= 0: 28 | raise ValueError("num_curves must be a positive integer.") 29 | if not isinstance(dim, int) or dim <= 0: 30 | raise ValueError("dim must be a positive integer.") 31 | if not isinstance(degree, int) or degree < 0: 32 | raise ValueError("degree must be a non-negative integer.") 33 | 34 | self.num_curves = num_curves # m 35 | self.dim = dim # d 36 | self.degree = degree # p 37 | 38 | if isinstance(normalize_fn, str): 39 | normalize_fn_callable = _normalization_catalogue.get(normalize_fn) 40 | if normalize_fn_callable is None: 41 | raise ValueError(f"Unknown normalization {normalize_fn}") 42 | self.normalize_fn = normalize_fn_callable 43 | else: 44 | self.normalize_fn = normalize_fn 45 | 46 | self.normalization_scale = normalization_scale 47 | if self.normalization_scale <= 0: 48 | raise ValueError(f"Normalization scale must be positive, but {normalization_scale} was given.") 49 | 50 | if isinstance(knots_config, int): 51 | n_control_points_per_curve = knots_config # c 52 | elif isinstance(knots_config, torch.Tensor): 53 | if knots_config.ndim != 1: 54 | raise ValueError("Provided knots_config tensor must be 1D.") 55 | num_knots_from_tensor = knots_config.shape[0] 56 | n_control_points_per_curve = num_knots_from_tensor - self.degree - 1 57 | else: 58 | raise TypeError( 59 | "knots_config must be an int (number of control points per curve) or a torch.Tensor (knot vector)." 60 | ) 61 | 62 | if n_control_points_per_curve <= self.degree: 63 | raise ValueError( 64 | f"Number of control points per curve ({n_control_points_per_curve}) must be greater " 65 | f"than the degree ({self.degree})." 66 | ) 67 | self.n_control_points_per_curve = n_control_points_per_curve # c 68 | 69 | # Control points shape: (m, c, d) 70 | self.control_points = nn.Parameter(torch.empty(self.num_curves, self.n_control_points_per_curve, self.dim)) 71 | nn.init.xavier_uniform_(self.control_points) 72 | 73 | if isinstance(knots_config, int): 74 | # Knots are shared by all m curves 75 | knot_buffer = uniform_augmented_knots( 76 | self.n_control_points_per_curve, self.degree, dtype=self.control_points.dtype 77 | ) 78 | else: # knots_config is a torch.Tensor 79 | knot_buffer = knots_config.to(dtype=self.control_points.dtype, copy=True) 80 | 81 | self.register_buffer("knots", knot_buffer) 82 | # Determine knot range for normalization, assuming knots are sorted. 83 | # Effective parameter range for B-spline is [knots[degree], knots[n_control_points_per_curve]] 84 | self._knot_min = knot_buffer[self.degree].item() 85 | self._knot_max = knot_buffer[self.n_control_points_per_curve].item() 86 | 87 | def __repr__(self): 88 | return ( 89 | f"{self.__class__.__name__}(" 90 | f"num_curves={self.num_curves}, " 91 | f"n_control_points_per_curve={self.n_control_points_per_curve}, " 92 | f"dim={self.dim}, degree={self.degree}, " 93 | f"knots_shape={self.knots.shape if hasattr(self, 'knots') else None})" 94 | ) 95 | 96 | def _prepare_arg(self, u: torch.Tensor) -> torch.Tensor: 97 | return self.normalize_fn(u, self.normalization_scale, out_min=self._knot_min, out_max=self._knot_max) 98 | 99 | def forward(self, u: torch.Tensor): 100 | """Evaluate a batch of B-spline curves. 101 | 102 | Args: 103 | u: Parameter values of size :math:`(B, C)`, where :math:`B` is the mini-batch size, and `C` is the number 104 | of curves, and must be equal to `self.num_curves`. 105 | 106 | Returns: 107 | Points on the B-spline curves of shape :math:`(B, C, D)`. 108 | 109 | """ 110 | if u.ndim != 2 or u.shape[1] != self.num_curves: 111 | raise ValueError( 112 | f"Input u must be a 2D tensor of shape (N, num_curves={self.num_curves}). Got shape: {u.shape}" 113 | ) 114 | 115 | u_prepared = self._prepare_arg(u) 116 | return self._forward_core(u_prepared) 117 | 118 | def _forward_core(self, u_prepared: torch.Tensor) -> torch.Tensor: 119 | # u_prepared has shape (N, M) 120 | # self.control_points has shape (M, C, D) 121 | # Should return tensor of shape (N, M, D) 122 | raise NotImplementedError("This method should be implemented in derived classes") 123 | 124 | 125 | class BSplineEmbeddings(BSplineCurveBase): 126 | r"""Embeddings layer based on B-Spline curves. 127 | 128 | Useful as the first layer in a neural network, where the input comes from a data-set. 129 | 130 | The learnable parameters are the control points of :math:`M` curves in :math:`\mathbb{R}^D`. 131 | All curves share the same degree and knot configuration. 132 | 133 | The input of this layer normalized to the range :math:`[-1, 1]` (or the range of the knots if specified differently) 134 | using the specified normalization strategy. 135 | 136 | Args: 137 | num_curves: Number of B-spline curves to define in this module (:math:`M`). 138 | dim: Dimension of each curve's output points (:math:`D`). 139 | degree: Degree of the B-spline (default: 3). 140 | knots_config: 141 | If an int, it specifies the number of control points per curve (:math:`C`). 142 | A uniformly-spaced knot vector will be automatically generated in [-1, 1]. 143 | If a torch.Tensor, it explicitly specifies the knot values. The number 144 | of control points will be inferred. The tensor should be 1D. 145 | normalize_fn: Normalization method layer's input. (default: "rational") 146 | normalization_scale: Scale factor for normalization (default: 1.0). 147 | 148 | Note: 149 | Assumes the input of this layer is not learnable, and thus doesn't require computing gradients. 150 | 151 | """ 152 | 153 | def _forward_core(self, u_prepared: torch.Tensor) -> torch.Tensor: 154 | return bspline_embeddings(u_prepared, self.control_points, self.knots, self.degree) 155 | 156 | 157 | class BSplineCurve(BSplineCurveBase): 158 | r"""B-Spline curves layer that allows back-propagating through its input. 159 | 160 | The learnable parameters are the control points of :math:`M` curves in :math:`\mathbb{R}^D`. 161 | All curves share the same degree and knot configuration. 162 | 163 | The input of this layer normalized to the range :math:`[-1, 1]` (or the range of the knots if specified differently) 164 | using the specified normalization strategy. 165 | 166 | Args: 167 | num_curves: Number of B-spline curves to define in this module (:math:`M`). 168 | dim: Dimension of each curve's output points (:math:`D`). 169 | degree: Degree of the B-spline (default: 3). 170 | knots_config: 171 | If an int, it specifies the number of control points per curve (:math:`C`). 172 | A uniformly-spaced knot vector will be automatically generated in [-1, 1]. 173 | If a torch.Tensor, it explicitly specifies the knot values. The number 174 | of control points will be inferred. The tensor should be 1D. 175 | normalize_fn: Normalization method layer's input. (default: "rational") 176 | normalization_scale: Scale factor for normalization (default: 1.0). 177 | 178 | """ 179 | 180 | def _forward_core(self, u_prepared: torch.Tensor) -> torch.Tensor: 181 | return bspline_curves(u_prepared, self.control_points, self.knots, self.degree) 182 | -------------------------------------------------------------------------------- /src/torchcurves/modules/_kan_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Sum(nn.Module): 6 | """A pooling layer that sums along the given dimension. 7 | 8 | Args: 9 | dim: The dimension along which to sum. 10 | 11 | """ 12 | 13 | def __init__(self, dim: int = -2): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | def forward(self, x: torch.Tensor): 18 | return torch.sum(x, self.dim) 19 | -------------------------------------------------------------------------------- /src/torchcurves/modules/_legendre.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..functional import legendre_curves 7 | from ..types import NormalizationFn 8 | from ._normalization import _normalization_catalogue 9 | 10 | 11 | class LegendreCurve(nn.Module): 12 | r"""PyTorch module for a batch of parametrized curves using Legendre polynomial basis. 13 | 14 | The learnable parameters are the control points (coefficients) of the 15 | `Legendre series `_ for each curve. 16 | All curves share the same degree. The input of this layer is normalized to :math:`[-1, 1]`. 17 | Each curve is: 18 | 19 | .. math:: 20 | 21 | \mathbf{C}_m(u) = \sum_{k=0}^{\mathrm{degree}} \mathbf{C}_{m,k} \cdot P_k(u), 22 | 23 | where :math:`P_k` is the :math:`k`-th Legendre polynomial. 24 | 25 | Args: 26 | num_curves: Number of Legendre curves to define (:math:`M`). 27 | dim: Dimension of each curve's output points (:math:`D`). 28 | degree: Degree of the Legendre polynomial basis (shared by all curves). 29 | The number of coefficients per curve will be `degree + 1`. 30 | normalize_fn: 31 | Normalization method this layer's input. (default: "rational") 32 | normalization_scale (float): 33 | Scale factor for normalization (default: 1.0). 34 | 35 | """ 36 | 37 | def __init__( 38 | self, 39 | num_curves: int, 40 | dim: int, 41 | degree: int, 42 | normalize_fn: Union[Literal["clamp", "rational"], NormalizationFn] = "rational", 43 | normalization_scale: float = 1.0, 44 | ): 45 | super().__init__() 46 | 47 | if not isinstance(num_curves, int) or num_curves <= 0: 48 | raise ValueError("num_curves must be a positive integer.") 49 | if not isinstance(dim, int) or dim <= 0: 50 | raise ValueError("dim must be a positive integer.") 51 | if not isinstance(degree, int) or degree < 0: 52 | raise ValueError("degree must be a non-negative integer.") 53 | 54 | self.num_curves = num_curves # M 55 | self.dim = dim # D 56 | self.degree = degree 57 | self.n_coefficients = self.degree + 1 # C (coefficients per curve) 58 | 59 | if isinstance(normalize_fn, str): 60 | normalize_fn_from_catalogue = _normalization_catalogue.get(normalize_fn) 61 | if normalize_fn_from_catalogue is None: 62 | raise ValueError(f"Unknown normalization {normalize_fn}") 63 | self.normalize_fn = normalize_fn_from_catalogue 64 | else: 65 | self.normalize_fn = normalize_fn 66 | 67 | self.normalization_scale = normalization_scale 68 | if self.normalization_scale <= 0: 69 | raise ValueError(f"Normalization scale must be positive, but {normalization_scale} was given.") 70 | 71 | # Coefficients shape: (M, C, D) 72 | self.coefficients = nn.Parameter(torch.empty(self.n_coefficients, self.num_curves, self.dim)) 73 | nn.init.xavier_uniform_(self.coefficients) 74 | 75 | def forward(self, u: torch.Tensor) -> torch.Tensor: 76 | """Evaluate the batch of Legendre curves. 77 | 78 | Args: 79 | u: Parameter values of size :math:`(B, C)`, where :math:`B` is the mini-batch size, and `C` is the number 80 | of curves, and must be equal to `self.num_curves`. 81 | 82 | Returns: 83 | Points on the Legendre curves of shape :math:`(B, C, D)`. 84 | 85 | """ 86 | if u.ndim != 2 or u.shape[1] != self.num_curves: 87 | raise ValueError( 88 | f"Input u must be a 2D tensor of shape (N, num_curves={self.num_curves}). Got shape: {u.shape}" 89 | ) 90 | 91 | u_normalized = self.normalize_fn(u, self.normalization_scale, out_min=-1.0, out_max=1.0) 92 | return legendre_curves(u_normalized, self.coefficients) 93 | 94 | def __repr__(self): 95 | return ( 96 | f"{self.__class__.__name__}(" 97 | f"num_curves={self.num_curves}, " 98 | f"dim={self.dim}, degree={self.degree}, " 99 | f"n_coefficients_per_curve={self.n_coefficients})" 100 | ) 101 | -------------------------------------------------------------------------------- /src/torchcurves/modules/_normalization.py: -------------------------------------------------------------------------------- 1 | from ..functional import clamp, rational 2 | from ..types import NormalizationFn 3 | 4 | _normalization_catalogue: dict[str, NormalizationFn] = { 5 | "rational": rational, 6 | "clamp": clamp, 7 | } 8 | -------------------------------------------------------------------------------- /src/torchcurves/types.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, Sequence, Union 2 | 3 | import torch 4 | 5 | Numeric = Union[int, float] 6 | """A number""" 7 | 8 | TensorLike = Union[torch.Tensor, Sequence[Numeric]] 9 | """A PyTorch tensor or a sequence of numbers""" 10 | 11 | 12 | class NormalizationFn(Protocol): 13 | """Protocol for normalization functions. 14 | 15 | A normalization function takes a tensor and normalizes it based on the provided parameters. 16 | 17 | Args: 18 | tensor: The input tensor to normalize. 19 | min_val: The minimum value for normalization. 20 | max_val: The maximum value for normalization. 21 | scale: Scale factor for normalization. 22 | 23 | Returns: 24 | The normalized tensor. 25 | 26 | """ 27 | 28 | def __call__(self, x: TensorLike, scale: float, out_min: float, out_max: float) -> torch.Tensor: ... 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexshtf/torchcurves/78e2c4c193edc65aedfdf4804bd068afff3c5214/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexshtf/torchcurves/78e2c4c193edc65aedfdf4804bd068afff3c5214/tests/conftest.py -------------------------------------------------------------------------------- /tests/test_bspline.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchcurves import BSplineCurve 8 | from torchcurves.functional import bspline_curves 9 | 10 | 11 | class TestBSplineFunction(unittest.TestCase): 12 | def setUp(self): 13 | self.default_dtype = torch.float64 # For gradcheck 14 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | # self.device = torch.device("cpu") # Force CPU for easier debugging if needed 16 | # print(f"Using device: {self.device}") 17 | 18 | @staticmethod 19 | def generate_clamped_knot_vector( 20 | n_control_points: int, degree: int, device="cpu", dtype=torch.float32 21 | ) -> torch.Tensor: 22 | """Generate a clamped knot vector in [-1, 1].""" 23 | if n_control_points <= degree: 24 | raise ValueError("Number of control points must be greater than degree.") 25 | 26 | # Total number of knots m = n_control_points + degree + 1. 27 | # Correct clamping: first p+1 knots are k_min, last p+1 knots are k_max 28 | k_min, k_max = -1.0, 1.0 29 | 30 | head_knots = torch.full((degree + 1,), k_min, dtype=dtype, device=device) 31 | tail_knots = torch.full((degree + 1,), k_max, dtype=dtype, device=device) 32 | 33 | num_internal_knots = n_control_points - degree - 1 34 | if num_internal_knots < 0: 35 | raise ValueError("Not enough control points for the given degree to form internal knots.") 36 | 37 | if num_internal_knots == 0: 38 | internal_knots = torch.empty(0, dtype=dtype, device=device) 39 | else: 40 | internal_knots = torch.linspace(k_min, k_max, num_internal_knots + 2, dtype=dtype, device=device)[1:-1] 41 | 42 | return torch.cat([head_knots, internal_knots, tail_knots]) 43 | 44 | def test_constant_function_degree0(self): 45 | degree = 0 46 | # control_points: (M, C, D) -> (1 curve, 1 CP, 1 Dim) 47 | control_points = torch.tensor([[[2.5]]], dtype=self.default_dtype, device=self.device) 48 | n_cp_c = control_points.shape[1] 49 | knots = self.generate_clamped_knot_vector(n_cp_c, degree, device=self.device, dtype=self.default_dtype) 50 | self.assertEqual(knots.shape[0], n_cp_c + degree + 1) 51 | 52 | u_values_scalar = torch.tensor([0.0, 0.5, 0.99], dtype=self.default_dtype, device=self.device) 53 | 54 | for u_val_scalar_item in u_values_scalar: 55 | # u: (N, M) -> (1 sample, 1 curve) 56 | u = u_val_scalar_item.view(1, 1) 57 | # points: (N, M, D) -> (1, 1, 1) 58 | points = bspline_curves(u, control_points, knots, degree) 59 | self.assertAlmostEqual( 60 | points.squeeze().item(), control_points.squeeze().item(), places=5, msg=f"Failed for u={u.item()}" 61 | ) 62 | 63 | u_gc = u.clone().requires_grad_(True) 64 | cp_gc = control_points.clone() 65 | 66 | # Output is (1,1,1), gradcheck handles this. 67 | self.assertTrue( 68 | torch.autograd.gradcheck( 69 | lambda x: bspline_curves(x, cp_gc, knots, degree), # noqa: B023 70 | u_gc, 71 | eps=1e-6, 72 | atol=1e-5, 73 | rtol=1e-3, 74 | nondet_tol=1e-7, 75 | ) 76 | ) 77 | 78 | points_gc = bspline_curves(u_gc, cp_gc, knots, degree) 79 | points_gc.sum().backward() # .sum() for scalar loss 80 | self.assertAlmostEqual(u_gc.grad.squeeze().item(), 0.0, places=5, msg=f"Grad_u non-zero for u={u.item()}") 81 | 82 | def test_constant_function_all_cps_same(self): 83 | degree = 2 84 | n_cp_c = 4 85 | const_val = 5.0 86 | # control_points: (M,C,D) -> (1, 4, 1) 87 | control_points = torch.full((1, n_cp_c, 1), const_val, dtype=self.default_dtype, device=self.device) 88 | knots = self.generate_clamped_knot_vector(n_cp_c, degree, device=self.device, dtype=self.default_dtype) 89 | 90 | # u_scalar: (N,) 91 | u_scalar = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0], dtype=self.default_dtype, device=self.device) 92 | # u: (N, M) -> (N, 1) 93 | u = u_scalar.unsqueeze(1) 94 | 95 | # points: (N, M, D) -> (N, 1, 1) 96 | points = bspline_curves(u, control_points, knots, degree) 97 | expected_points = torch.full((u.shape[0], 1, 1), const_val, dtype=self.default_dtype, device=self.device) 98 | torch.testing.assert_close(points, expected_points, atol=1e-5, rtol=1e-5) 99 | 100 | u_gc = u.clone().requires_grad_(True) 101 | cp_gc = control_points.clone().requires_grad_(True) 102 | 103 | output = bspline_curves(u_gc, cp_gc, knots, degree) 104 | output.sum().backward() 105 | 106 | # u_gc.grad: (N,1) 107 | torch.testing.assert_close(u_gc.grad, torch.zeros_like(u_gc), atol=1e-5, rtol=1e-5) 108 | # cp_gc.grad: (1, C, D). Sum of basis functions is 1. 109 | self.assertAlmostEqual(cp_gc.grad.sum().item(), u.shape[0], places=5) 110 | 111 | def test_linear_function_degree1(self): 112 | degree = 1 113 | # control_points: (M,C,D) -> (1, 2, 1) 114 | control_points = torch.tensor([[[0.0], [1.0]]], dtype=self.default_dtype, device=self.device) 115 | n_cp_c = control_points.shape[1] 116 | knots = self.generate_clamped_knot_vector(n_cp_c, degree, device=self.device, dtype=self.default_dtype) 117 | 118 | u_scalar = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], dtype=self.default_dtype, device=self.device) 119 | # u: (N,M) -> (N,1) 120 | u = u_scalar.unsqueeze(1) 121 | # For knots [-1,-1,1,1] and u in [-1,1], C(u) = ( (1-u)/2 * P0 + (1+u)/2 * P1 ) if knots are normalized to 122 | # [0,1] internally 123 | # If knots are [-1,-1,1,1] and u is directly used, then for u in [-1,1], it's linear interpolation. 124 | # The current BSplineFunction expects u to be in the knot range. 125 | # With knots [-1,-1,1,1], P0=[-1], P1=[1], then C(u)=u. 126 | # Here P0=[0], P1=[1]. Knots are [-1,-1,1,1]. 127 | # N01(u) = (knots[1+1]-u)/(knots[1+1]-knots[1]) = (1-u)/(1-(-1)) = (1-u)/2 for u in [-1,1) 128 | # N11(u) = (u-knots[1])/(knots[1+1]-knots[1]) = (u-(-1))/(1-(-1)) = (u+1)/2 for u in [-1,1) 129 | # C(u) = (1-u)/2 * 0 + (u+1)/2 * 1 = (u+1)/2 130 | # To get C(u)=u, we need u_norm = (u+1)/2. If input u is already in [-1,1], then we expect (u+1)/2. 131 | # Let's adjust CPs or expected points. If CPs are [0],[1] and knots are [-1,-1,1,1] 132 | # C(-1) = P0 = 0. C(1) = P1 = 1. (u+1)/2. 133 | # If we want C(u) = u for u in [-1,1], then P0=-1, P1=1. 134 | # Let's keep P0=0, P1=1. Then expected is (u_scalar+1)/2 135 | expected_points_scalar = (u_scalar + 1.0) / 2.0 136 | expected_points = expected_points_scalar.unsqueeze(1).unsqueeze(1) # (N,1,1) 137 | 138 | points = bspline_curves(u, control_points, knots, degree) 139 | torch.testing.assert_close(points, expected_points, atol=1e-6, rtol=1e-5) 140 | 141 | u_gc = u.clone().requires_grad_(True) 142 | cp_gc = control_points.clone().requires_grad_(True) 143 | 144 | self.assertTrue( 145 | torch.autograd.gradcheck( 146 | lambda x: bspline_curves(x, cp_gc.detach(), knots, degree).sum(), 147 | u_gc.detach().requires_grad_(True), 148 | eps=1e-6, 149 | atol=1e-5, 150 | rtol=1e-3, 151 | nondet_tol=1e-7, 152 | ) 153 | ) 154 | self.assertTrue( 155 | torch.autograd.gradcheck( 156 | lambda x: bspline_curves(u_gc.detach(), x, knots, degree).sum(), 157 | cp_gc.detach().requires_grad_(True), 158 | eps=1e-6, 159 | atol=1e-5, 160 | rtol=1e-3, 161 | nondet_tol=1e-7, 162 | ) 163 | ) 164 | 165 | output_an = bspline_curves(u_gc, cp_gc.detach(), knots, degree) 166 | output_an.sum().backward() 167 | # C'(u) = 0.5 168 | expected_grad_u = torch.full_like(u_gc, 0.5) 169 | torch.testing.assert_close(u_gc.grad, expected_grad_u, atol=1e-6, rtol=1e-5) 170 | 171 | def test_parabola_degree2(self): 172 | degree = 2 173 | # Knots [-1,-1,-1, 1,1,1]. u in [-1,1]. 174 | # N02=(1-u_norm)^2, N12=2*u_norm(1-u_norm), N22=u_norm^2 where u_norm = (u+1)/2 175 | # C(u) = N02*P0 + N12*P1 + N22*P2. 176 | # To get C(u) = u_norm^2 = ((u+1)/2)^2: P0=0, P1=0, P2=1. 177 | # control_points: (M,C,D) -> (1,3,1) 178 | control_points = torch.tensor([[[0.0], [0.0], [1.0]]], dtype=self.default_dtype, device=self.device) 179 | n_cp_c = control_points.shape[1] 180 | knots = self.generate_clamped_knot_vector(n_cp_c, degree, device=self.device, dtype=self.default_dtype) 181 | 182 | u_scalar = torch.tensor([-1.0, -0.6, -0.2, 0.2, 0.6, 1.0], dtype=self.default_dtype, device=self.device) 183 | u = u_scalar.unsqueeze(1) # (N,1) 184 | 185 | u_norm_scalar = (u_scalar + 1.0) / 2.0 186 | expected_points_scalar = u_norm_scalar.pow(2) 187 | expected_points = expected_points_scalar.unsqueeze(1).unsqueeze(1) # (N,1,1) 188 | 189 | points = bspline_curves(u, control_points, knots, degree) 190 | torch.testing.assert_close(points, expected_points, atol=1e-6, rtol=1e-5) 191 | 192 | u_gc = u.clone().requires_grad_(True) 193 | cp_gc = control_points.clone().requires_grad_(True) 194 | 195 | self.assertTrue( 196 | torch.autograd.gradcheck( 197 | lambda x_u: bspline_curves(x_u, cp_gc.detach(), knots, degree).sum(), 198 | u_gc.detach().requires_grad_(True), 199 | eps=1e-6, 200 | atol=1e-4, # Increased atol for parabola 201 | rtol=1e-3, 202 | nondet_tol=1e-7, 203 | ) 204 | ) 205 | self.assertTrue( 206 | torch.autograd.gradcheck( 207 | lambda x_cp: bspline_curves(u_gc.detach(), x_cp, knots, degree).sum(), 208 | cp_gc.detach().requires_grad_(True), 209 | eps=1e-6, 210 | atol=1e-5, 211 | rtol=1e-3, 212 | nondet_tol=1e-7, 213 | ) 214 | ) 215 | 216 | output_an = bspline_curves(u_gc, cp_gc.detach(), knots, degree) 217 | output_an.sum().backward() 218 | # C'(u) = d/du [((u+1)/2)^2] = 2 * ((u+1)/2) * (1/2) = (u+1)/2 219 | expected_grad_u_scalar = (u_gc.detach().squeeze(1) + 1.0) / 2.0 220 | expected_grad_u = expected_grad_u_scalar.unsqueeze(1) 221 | torch.testing.assert_close(u_gc.grad, expected_grad_u, atol=1e-6, rtol=1e-5) 222 | 223 | def test_boundary_values(self): 224 | degree = 3 225 | n_cp_c = 5 226 | # control_points: (M,C,D) -> (1,5,2) 227 | control_points = torch.randn(1, n_cp_c, 2, dtype=self.default_dtype, device=self.device) 228 | knots = self.generate_clamped_knot_vector(n_cp_c, degree, device=self.device, dtype=self.default_dtype) 229 | 230 | # u: (N,M) -> (1,1) 231 | u_start = torch.tensor([[-1.0]], dtype=self.default_dtype, device=self.device) # Min knot value 232 | u_end = torch.tensor([[1.0]], dtype=self.default_dtype, device=self.device) # Max knot value 233 | 234 | point_start = bspline_curves(u_start, control_points, knots, degree) # (1,1,2) 235 | point_end = bspline_curves(u_end, control_points, knots, degree) # (1,1,2) 236 | 237 | # control_points[:, 0, :] is (1,2). Need (1,1,2) 238 | torch.testing.assert_close(point_start, control_points[:, 0:1, :], atol=1e-6, rtol=1e-5) 239 | torch.testing.assert_close(point_end, control_points[:, -1:, :], atol=1e-6, rtol=1e-5) 240 | 241 | def test_multiple_dimensions(self): 242 | degree = 2 243 | # control_points: (M,C,D) -> (1,3,2) 244 | control_points_data = torch.tensor( 245 | [[[0.0, 0.0], [0.5, 1.0], [1.0, 0.0]]], dtype=self.default_dtype, device=self.device 246 | ) 247 | n_cp_c = control_points_data.shape[1] 248 | knots = self.generate_clamped_knot_vector(n_cp_c, degree, device=self.device, dtype=self.default_dtype) 249 | 250 | u_scalar = torch.tensor([-1.0, 0.0, 1.0], dtype=self.default_dtype, device=self.device) 251 | u = u_scalar.unsqueeze(1) # (N,1) 252 | 253 | # Expected points: (N,1,D) 254 | # u_norm = (u_scalar+1)/2 -> [0, 0.5, 1] 255 | # C(u_norm) = (1-u_norm)^2 P0 + 2u_norm(1-u_norm)P1 + u_norm^2 P2 256 | expected_points_calc = torch.empty((u_scalar.shape[0], 1, 2), dtype=self.default_dtype, device=self.device) 257 | P0, P1, P2 = control_points_data[0, 0], control_points_data[0, 1], control_points_data[0, 2] # noqa: N806 258 | 259 | expected_points_calc[0, 0, :] = P0 # u_norm = 0 260 | expected_points_calc[1, 0, :] = 0.25 * P0 + 0.5 * P1 + 0.25 * P2 # u_norm = 0.5 261 | expected_points_calc[2, 0, :] = P2 # u_norm = 1 262 | 263 | points = bspline_curves(u, control_points_data, knots, degree) 264 | torch.testing.assert_close(points, expected_points_calc, atol=1e-6, rtol=1e-5) 265 | 266 | u_gc = u.clone().requires_grad_(True) 267 | cp_gc = control_points_data.clone().requires_grad_(True) 268 | self.assertTrue( 269 | torch.autograd.gradcheck( 270 | lambda x_u: bspline_curves(x_u, cp_gc.detach(), knots, degree).sum(), 271 | u_gc.detach().requires_grad_(True), 272 | eps=1e-6, 273 | atol=1e-5, 274 | rtol=1e-3, 275 | nondet_tol=1e-7, 276 | ) 277 | ) 278 | self.assertTrue( 279 | torch.autograd.gradcheck( 280 | lambda x_cp: bspline_curves(u_gc.detach(), x_cp, knots, degree).sum(), 281 | cp_gc.detach().requires_grad_(True), 282 | eps=1e-6, 283 | atol=1e-5, 284 | rtol=1e-3, 285 | nondet_tol=1e-7, 286 | ) 287 | ) 288 | 289 | def test_batch_processing_u_values_single_curve(self): # Renamed for clarity 290 | degree = 1 291 | # control_points: (M,C,D) -> (1,2,2) 292 | control_points = torch.tensor([[[0.0, 1.0], [2.0, 3.0]]], dtype=self.default_dtype, device=self.device) 293 | n_cp_c = control_points.shape[1] 294 | knots = self.generate_clamped_knot_vector( 295 | n_cp_c, degree, device=self.device, dtype=self.default_dtype 296 | ) # Knots [-1,-1,1,1] 297 | 298 | u_scalar_batch = torch.tensor( 299 | [-1.0, 0.0, 1.0], dtype=self.default_dtype, device=self.device 300 | ) # Batch of N u-values 301 | u_batch = u_scalar_batch.unsqueeze(1) # (N,1) for 1 curve 302 | 303 | # Expected points (N,1,D) 304 | # u_norm = (u_scalar_batch+1)/2 305 | # C(u_norm) = (1-u_norm)P0 + u_norm*P1 306 | expected_points_batch = torch.empty( 307 | (u_batch.shape[0], 1, control_points.shape[2]), dtype=self.default_dtype, device=self.device 308 | ) 309 | P0, P1 = control_points[0, 0, :], control_points[0, 1, :] # noqa: N806 310 | u_norm_vals = (u_scalar_batch + 1.0) / 2.0 311 | for i, u_n_val in enumerate(u_norm_vals): 312 | expected_points_batch[i, 0, :] = (1 - u_n_val) * P0 + u_n_val * P1 313 | 314 | points_batch = bspline_curves(u_batch, control_points, knots, degree) 315 | torch.testing.assert_close(points_batch, expected_points_batch, atol=1e-6, rtol=1e-5) 316 | self.assertEqual(points_batch.shape, (u_batch.shape[0], 1, control_points.shape[2])) 317 | 318 | u_gc_batch = u_batch.clone().requires_grad_(True) 319 | cp_gc = control_points.clone().requires_grad_(True) 320 | 321 | self.assertTrue( 322 | torch.autograd.gradcheck( 323 | lambda x_u: bspline_curves(x_u, cp_gc.detach(), knots, degree).sum(), 324 | u_gc_batch.detach().requires_grad_(True), 325 | eps=1e-6, 326 | atol=1e-5, 327 | rtol=1e-3, 328 | nondet_tol=1e-7, 329 | ) 330 | ) 331 | self.assertTrue( 332 | torch.autograd.gradcheck( 333 | lambda x_cp: bspline_curves(u_gc_batch.detach(), x_cp, knots, degree).sum(), 334 | cp_gc.detach().requires_grad_(True), 335 | eps=1e-6, 336 | atol=1e-5, 337 | rtol=1e-3, 338 | nondet_tol=1e-7, 339 | ) 340 | ) 341 | 342 | def test_multiple_curves_equivalence(self): 343 | num_curves_m = 3 344 | n_samples_n = 5 345 | dim_d = 2 346 | degree = 2 347 | n_cp_c = 4 # Number of control points per curve 348 | 349 | knots = self.generate_clamped_knot_vector( 350 | n_cp_c, degree, device=self.device, dtype=self.default_dtype 351 | ) # Knots are in [-1,1] 352 | 353 | control_points_batched = torch.randn(num_curves_m, n_cp_c, dim_d, dtype=self.default_dtype, device=self.device) 354 | control_points_batched_clone_for_grad = control_points_batched.clone().requires_grad_(True) 355 | 356 | # u values for M curves: (N, M), in knot range [-1,1] 357 | u_batched_rand = torch.rand(n_samples_n, num_curves_m, dtype=self.default_dtype, device=self.device) 358 | # Scale u to be within the effective knot range [knots[degree], knots[n_cp_c]] 359 | # For default knots: knots[degree]=-1, knots[n_cp_c]=1 360 | knot_min_effective = knots[degree] 361 | knot_max_effective = knots[n_cp_c] # This is the start of the last span's p+1 knots. 362 | # For u, it should be knots[n_cp_c] which is the end of the domain. 363 | 364 | u_batched = u_batched_rand * (knot_max_effective - knot_min_effective) + knot_min_effective 365 | u_batched_clone_for_grad = u_batched.clone().requires_grad_(True) 366 | 367 | # 1. Evaluate all curves together 368 | points_batched = bspline_curves(u_batched_clone_for_grad, control_points_batched_clone_for_grad, knots, degree) 369 | 370 | # 2. Evaluate each curve individually 371 | points_individual_list = [] 372 | for i in range(num_curves_m): 373 | cp_single = control_points_batched[i : i + 1, :, :].clone() # Shape (1, C, D) 374 | u_single = u_batched[:, i : i + 1].clone() # Shape (N, 1) 375 | 376 | # For individual evaluation, BSplineFunction expects (M_cp=1, C, D) and (N, M_u=1) 377 | points_single = bspline_curves(u_single, cp_single, knots, degree) # Output (N, 1, D) 378 | points_individual_list.append(points_single) 379 | 380 | points_stacked = torch.cat(points_individual_list, dim=1) # (N,M,D) 381 | torch.testing.assert_close(points_batched.data, points_stacked.data, atol=1e-6, rtol=1e-5) 382 | 383 | # Compare backward pass 384 | grad_output = torch.randn_like(points_batched) 385 | 386 | points_batched.backward(grad_output) 387 | grad_u_batched_actual = u_batched_clone_for_grad.grad.clone() 388 | grad_cp_batched_actual = control_points_batched_clone_for_grad.grad.clone() 389 | 390 | # Zero grads for individual calculations 391 | # We need new tensors for individual grad accumulation if we want to compare to original batched grads 392 | 393 | expected_grad_u_from_individuals = torch.zeros_like(u_batched) 394 | expected_grad_cp_from_individuals = torch.zeros_like(control_points_batched) 395 | 396 | for i in range(num_curves_m): 397 | cp_single_grad_target = control_points_batched[i : i + 1, :, :].detach().clone().requires_grad_(True) 398 | u_single_grad_target = u_batched[:, i : i + 1].detach().clone().requires_grad_(True) 399 | 400 | points_single_eval = bspline_curves(u_single_grad_target, cp_single_grad_target, knots, degree) 401 | grad_output_single = grad_output[:, i : i + 1, :] 402 | points_single_eval.backward(grad_output_single) 403 | 404 | expected_grad_u_from_individuals[:, i : i + 1] = u_single_grad_target.grad 405 | expected_grad_cp_from_individuals[i : i + 1, :, :] = cp_single_grad_target.grad 406 | 407 | torch.testing.assert_close(grad_u_batched_actual, expected_grad_u_from_individuals, atol=1e-6, rtol=1e-5) 408 | torch.testing.assert_close(grad_cp_batched_actual, expected_grad_cp_from_individuals, atol=1e-6, rtol=1e-5) 409 | 410 | 411 | class TestBSplineCurveModule(unittest.TestCase): 412 | def setUp(self): 413 | self.default_dtype = torch.float64 414 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 415 | 416 | def test_init_with_int(self): 417 | num_curves = 1 418 | dim = 2 419 | degree = 3 420 | n_cps_per_curve = 5 421 | module = ( 422 | BSplineCurve(num_curves=num_curves, dim=dim, degree=degree, knots_config=n_cps_per_curve) 423 | .to(self.device) 424 | .to(self.default_dtype) 425 | ) 426 | 427 | self.assertEqual(module.num_curves, num_curves) 428 | self.assertEqual(module.n_control_points_per_curve, n_cps_per_curve) 429 | self.assertEqual(module.dim, dim) 430 | self.assertEqual(module.degree, degree) 431 | self.assertIsInstance(module.control_points, nn.Parameter) 432 | self.assertTrue(module.control_points.requires_grad) 433 | self.assertEqual(module.control_points.shape, (num_curves, n_cps_per_curve, dim)) 434 | self.assertIsInstance(module.knots, torch.Tensor) 435 | self.assertEqual(module.knots.shape[0], n_cps_per_curve + degree + 1) 436 | # Check if knots are clamped to [-1,1] 437 | self.assertTrue(torch.all(module.knots[0 : degree + 1] == -1.0)) 438 | self.assertTrue(torch.all(module.knots[n_cps_per_curve:] == 1.0)) 439 | 440 | def test_init_with_tensor(self): 441 | num_curves = 1 442 | dim = 3 443 | degree = 2 444 | # n_cp=4, deg=2 -> knots=7. Example knots in [0,1] 445 | knots_tensor = torch.tensor([0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0], dtype=self.default_dtype) 446 | expected_n_cps_per_curve = 4 # 7 - 2 - 1 = 4 447 | 448 | module = ( 449 | BSplineCurve(num_curves=num_curves, dim=dim, degree=degree, knots_config=knots_tensor) 450 | .to(self.device) 451 | .to(self.default_dtype) 452 | ) 453 | 454 | self.assertEqual(module.n_control_points_per_curve, expected_n_cps_per_curve) 455 | self.assertEqual(module.control_points.shape, (num_curves, expected_n_cps_per_curve, dim)) 456 | torch.testing.assert_close(module.knots, knots_tensor.to(self.device).to(self.default_dtype)) 457 | 458 | def test_init_errors(self): 459 | with self.assertRaisesRegex(ValueError, "must be greater than the degree"): 460 | BSplineCurve(num_curves=1, dim=2, degree=3, knots_config=3) # n_cp <= degree 461 | 462 | knots_tensor_short = torch.tensor([0.0, 0.0, 1.0, 1.0]) 463 | with self.assertRaisesRegex(ValueError, "must be greater than the degree"): 464 | BSplineCurve(num_curves=1, dim=2, degree=3, knots_config=knots_tensor_short) 465 | 466 | with self.assertRaisesRegex(TypeError, "knots_config must be an int .*or.*Tensor.*"): 467 | BSplineCurve(num_curves=1, dim=2, degree=3, knots_config="wrong_type") # type: ignore 468 | 469 | knots_tensor_2d = torch.tensor([[0.0, 1.0]]) 470 | with self.assertRaisesRegex(ValueError, "Provided knots_config tensor must be 1D"): 471 | BSplineCurve(num_curves=1, dim=2, degree=1, knots_config=knots_tensor_2d) 472 | 473 | def test_forward_pass_shape_and_device(self): 474 | num_curves = 1 475 | dim = 3 476 | degree = 2 477 | n_cps_per_curve = 4 478 | batch_size = 10 # Number of u-samples per curve 479 | module = ( 480 | BSplineCurve(num_curves=num_curves, dim=dim, degree=degree, knots_config=n_cps_per_curve) 481 | .to(self.device) 482 | .to(self.default_dtype) 483 | ) 484 | 485 | # u: (N, M) 486 | u_scalar = torch.linspace(-1, 1, batch_size, device=self.device, dtype=self.default_dtype) 487 | u = u_scalar.unsqueeze(1) # (N,1) for M=1 curve 488 | 489 | points = module(u) # Output (N,M,D) 490 | 491 | self.assertEqual(points.shape, (batch_size, num_curves, dim)) 492 | self.assertEqual(points.device, self.device) 493 | self.assertEqual(points.dtype, self.default_dtype) 494 | 495 | def test_boundary_interpolation_with_clamp_normalization(self): 496 | num_curves = 1 497 | dim = 2 498 | degree = 3 499 | n_cps_per_curve = 5 500 | module = ( 501 | BSplineCurve( 502 | num_curves=num_curves, dim=dim, degree=degree, knots_config=n_cps_per_curve, normalize_fn="clamp" 503 | ) 504 | .to(self.device) 505 | .to(self.default_dtype) 506 | ) # Knots are [-1,1] 507 | 508 | # u: (N,M) -> (1,1) 509 | u_start = torch.tensor([[-1.0]], device=self.device, dtype=self.default_dtype) 510 | u_end = torch.tensor([[1.0]], device=self.device, dtype=self.default_dtype) 511 | 512 | point_start = module(u_start) # (1,1,D) 513 | point_end = module(u_end) # (1,1,D) 514 | 515 | # module.control_points is (1,C,D) 516 | torch.testing.assert_close(point_start, module.control_points[:, 0:1, :]) 517 | torch.testing.assert_close(point_end, module.control_points[:, -1:, :]) 518 | 519 | def test_backward_pass(self): 520 | num_curves = 1 521 | dim = 2 522 | degree = 2 523 | n_cps_per_curve = 4 524 | module = ( 525 | BSplineCurve(num_curves=num_curves, dim=dim, degree=degree, knots_config=n_cps_per_curve) 526 | .to(self.device) 527 | .to(self.default_dtype) 528 | ) 529 | 530 | # u: (N,M) 531 | u = torch.tensor([[-0.7], [0.6]], device=self.device, dtype=self.default_dtype) # N=2, M=1 532 | 533 | self.assertIsNone(module.control_points.grad) 534 | 535 | points = module(u) # (2,1,D) 536 | loss = points.sum() 537 | loss.backward() 538 | 539 | self.assertIsNotNone(module.control_points.grad) 540 | self.assertEqual(module.control_points.grad.shape, module.control_points.shape) # (1,C,D) 541 | self.assertNotEqual(torch.sum(module.control_points.grad**2).item(), 0.0) 542 | 543 | def test_gradcheck_module(self): 544 | num_curves = 1 545 | dim = 2 546 | degree = 2 547 | n_cps_per_curve = 3 548 | module = ( 549 | BSplineCurve(num_curves=num_curves, dim=dim, degree=degree, knots_config=n_cps_per_curve) 550 | .to(self.device) 551 | .to(self.default_dtype) 552 | ) 553 | 554 | # u_gc: (N,M) 555 | u_gc = torch.tensor([[-0.75], [0.25]], device=self.device, dtype=self.default_dtype).requires_grad_(True) 556 | 557 | # Check BSplineFunction.apply part 558 | cp_gc = module.control_points.clone().requires_grad_(True) # (1,C,D) 559 | knots = module.knots 560 | current_degree = module.degree # Use module's degree 561 | 562 | # Output of apply is (N,M,D), sum for gradcheck 563 | self.assertTrue( 564 | torch.autograd.gradcheck( 565 | lambda u_in, cp_in: bspline_curves(u_in, cp_in, knots, current_degree).sum(), 566 | (u_gc, cp_gc), 567 | eps=1e-6, 568 | atol=1e-4, 569 | rtol=1e-3, 570 | nondet_tol=1e-7, 571 | ) 572 | ) 573 | 574 | # Check module call w.r.t 'u' 575 | module_clone = ( 576 | BSplineCurve(num_curves=num_curves, dim=dim, degree=degree, knots_config=n_cps_per_curve) 577 | .to(self.device) 578 | .to(self.default_dtype) 579 | ) 580 | module_clone.load_state_dict(module.state_dict()) 581 | 582 | u_gc_mod = torch.tensor([[-0.6]], device=self.device, dtype=self.default_dtype).requires_grad_(True) # (1,1) 583 | 584 | # Output of module is (N,M,D), sum for gradcheck 585 | self.assertTrue( 586 | torch.autograd.gradcheck( 587 | lambda u_in: module_clone(u_in).sum(), u_gc_mod, eps=1e-6, atol=1e-4, rtol=1e-3, nondet_tol=1e-7 588 | ) 589 | ) 590 | 591 | def test_device_movement(self): 592 | if not torch.cuda.is_available(): 593 | self.skipTest("CUDA not available, skipping device movement test.") 594 | 595 | num_curves = 1 596 | dim = 2 597 | degree = 2 598 | n_cps_per_curve = 4 599 | module_cpu = BSplineCurve(num_curves, dim, degree, n_cps_per_curve) 600 | 601 | self.assertEqual(module_cpu.control_points.device.type, "cpu") 602 | self.assertEqual(module_cpu.knots.device.type, "cpu") 603 | 604 | module_cuda = module_cpu.to("cuda").to(self.default_dtype) 605 | 606 | self.assertEqual(module_cuda.control_points.device.type, "cuda") 607 | self.assertEqual(module_cuda.knots.device.type, "cuda") 608 | 609 | u_cuda = torch.tensor([[-0.7], [0.6]], device="cuda", dtype=self.default_dtype) # (2,1) 610 | points = module_cuda(u_cuda) # (2,1,D) 611 | 612 | self.assertEqual(points.device.type, "cuda") 613 | self.assertEqual(points.shape, (2, num_curves, dim)) 614 | 615 | loss = points.sum() 616 | loss.backward() 617 | self.assertIsNotNone(module_cuda.control_points.grad) 618 | self.assertEqual(module_cuda.control_points.grad.device.type, "cuda") 619 | 620 | 621 | def test_bspline_curves_default_knots_device_dtype(): 622 | dtype = torch.float64 623 | device = torch.device("cpu") 624 | 625 | u = torch.tensor([[0.0]], dtype=dtype, device=device) 626 | control_points = torch.zeros((1, 4, 1), dtype=dtype, device=device) 627 | 628 | out = bspline_curves(u, control_points) 629 | 630 | assert out.dtype == control_points.dtype 631 | 632 | 633 | def test_bspline_curves_default_knots_cuda(): 634 | if not torch.cuda.is_available(): 635 | pytest.skip("CUDA not available, skipping test.") 636 | 637 | dtype = torch.float64 638 | device = torch.device("cuda") 639 | 640 | control_points = torch.randn(1, 4, 1, dtype=dtype, device=device) 641 | u = torch.linspace(-1, 1, 5, dtype=dtype, device=device).unsqueeze(1) 642 | 643 | result = bspline_curves(u, control_points, knots=None, degree=3) 644 | 645 | assert result.device == device 646 | assert result.dtype == dtype 647 | 648 | 649 | if __name__ == "__main__": 650 | pytest.main([__file__, "-v"]) 651 | -------------------------------------------------------------------------------- /tests/test_legendre.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | import torch.nn as nn 7 | 8 | from torchcurves import LegendreCurve 9 | from torchcurves.functional import legendre_curves 10 | 11 | 12 | @pytest.mark.parametrize("num_curves", [1, 2, 5]) 13 | @pytest.mark.parametrize("dim", [1, 2, 3]) 14 | @pytest.mark.parametrize("degree", [0, 1, 2, 3]) 15 | @pytest.mark.parametrize("n_samples", [1, 10, 100]) 16 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 17 | def test_legendre_curves(num_curves, dim, degree, n_samples, dtype): 18 | torch.random.manual_seed(42) # For reproducibility 19 | coefs = torch.randn(1 + degree, num_curves, dim, dtype=dtype) 20 | x = 2 * torch.rand(n_samples, num_curves, dtype=dtype) 21 | torch_eval = legendre_curves(x, coefs) 22 | for ci in range(num_curves): 23 | for mi in range(dim): 24 | coef_np = coefs[:, ci, mi].numpy() 25 | x_np = x[:, ci].numpy() 26 | np_vals = np.polynomial.legendre.legval(x_np, coef_np) 27 | torch_vals = torch_eval[:, ci, mi].numpy() 28 | np.testing.assert_allclose( 29 | np_vals, 30 | torch_vals, 31 | rtol=1e-3 if dtype == torch.float32 else 1e-10, 32 | err_msg=f"Mismatch for curve {ci}, dimension {mi} with degree {degree} and dtype {dtype}", 33 | ) 34 | 35 | 36 | class TestLegendreCurveModule(unittest.TestCase): 37 | def setUp(self): 38 | self.default_dtype = torch.float64 39 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | def test_init(self): 42 | num_curves = 2 43 | dim = 3 44 | degree = 4 45 | module = LegendreCurve(num_curves, dim, degree).to(self.device).to(self.default_dtype) 46 | 47 | self.assertEqual(module.num_curves, num_curves) 48 | self.assertEqual(module.dim, dim) 49 | self.assertEqual(module.degree, degree) 50 | self.assertEqual(module.n_coefficients, degree + 1) 51 | self.assertIsInstance(module.coefficients, nn.Parameter) 52 | self.assertTrue(module.coefficients.requires_grad) 53 | self.assertEqual(module.coefficients.shape, (degree + 1, num_curves, dim)) 54 | 55 | def test_init_errors(self): 56 | with self.assertRaises(ValueError): 57 | LegendreCurve(num_curves=0, dim=1, degree=1) # num_curves <= 0 58 | with self.assertRaises(ValueError): 59 | LegendreCurve(num_curves=1, dim=0, degree=1) # dim <= 0 60 | with self.assertRaises(ValueError): 61 | LegendreCurve(num_curves=1, dim=1, degree=-1) # degree < 0 62 | with self.assertRaises(ValueError): # Unknown normalization 63 | LegendreCurve(num_curves=1, dim=1, degree=1, normalize_fn="unknown_norm") 64 | with self.assertRaises(ValueError): # Scale <=0 65 | LegendreCurve(num_curves=1, dim=1, degree=1, normalization_scale=0) 66 | 67 | def test_forward_pass_shape_and_device(self): 68 | num_curves = 2 69 | dim = 3 70 | degree = 2 71 | n_samples = 10 72 | 73 | module = LegendreCurve(num_curves, dim, degree).to(self.device).to(self.default_dtype) 74 | 75 | # u: (N, M) 76 | u_input = torch.rand(n_samples, num_curves, device=self.device, dtype=self.default_dtype) * 2 - 1 # in [-1,1] 77 | 78 | points = module(u_input) # Output (N, M, D) 79 | 80 | self.assertEqual(points.shape, (n_samples, num_curves, dim)) 81 | self.assertEqual(points.device, self.device) 82 | self.assertEqual(points.dtype, self.default_dtype) 83 | 84 | def test_backward_pass_module(self): 85 | num_curves = 2 86 | dim = 2 87 | degree = 3 88 | n_samples = 5 89 | module = LegendreCurve(num_curves, dim, degree).to(self.device).to(self.default_dtype) 90 | 91 | u_input = torch.rand(n_samples, num_curves, device=self.device, dtype=self.default_dtype).requires_grad_(True) 92 | 93 | self.assertIsNone(module.coefficients.grad) 94 | 95 | points = module(u_input) # (N,M,D) 96 | loss = points.sum() 97 | loss.backward() 98 | 99 | self.assertIsNotNone(module.coefficients.grad) 100 | self.assertEqual(module.coefficients.grad.shape, module.coefficients.shape) 101 | self.assertNotEqual(torch.sum(module.coefficients.grad**2).item(), 0.0) 102 | 103 | self.assertIsNotNone(u_input.grad) # Check grad w.r.t. u as well 104 | self.assertEqual(u_input.grad.shape, u_input.shape) 105 | 106 | def test_device_movement_module(self): 107 | if not torch.cuda.is_available(): 108 | self.skipTest("CUDA not available, skipping device movement test.") 109 | 110 | num_curves = 2 111 | dim = 1 112 | degree = 2 113 | module_cpu = LegendreCurve(num_curves, dim, degree) 114 | 115 | self.assertEqual(module_cpu.coefficients.device.type, "cpu") 116 | 117 | module_cuda = module_cpu.to("cuda").to(self.default_dtype) 118 | self.assertEqual(module_cuda.coefficients.device.type, "cuda") 119 | 120 | u_cuda = torch.rand(5, num_curves, device="cuda", dtype=self.default_dtype) * 2 - 1 121 | points = module_cuda(u_cuda) 122 | 123 | self.assertEqual(points.device.type, "cuda") 124 | self.assertEqual(points.shape, (5, num_curves, dim)) 125 | 126 | loss = points.sum() 127 | loss.backward() 128 | self.assertIsNotNone(module_cuda.coefficients.grad) 129 | self.assertEqual(module_cuda.coefficients.grad.device.type, "cuda") 130 | 131 | 132 | if __name__ == "__main__": 133 | pytest.main([__file__, "-v"]) 134 | --------------------------------------------------------------------------------