├── tps ├── __init__.py ├── instance.py └── functions.py ├── example.png ├── pyproject.toml ├── Readme.md ├── License ├── .gitignore └── poetry.lock /tps/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .instance import * 3 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzing/tps-deformation/HEAD/example.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tps" 3 | version = "1.0.0" 4 | description = "Thin plate spline transfomer implementation" 5 | authors = ["Tzu-Ting "] 6 | license = "BSD-3-Clause" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.8" 11 | numpy = "^1.22.4" 12 | 13 | [tool.poetry.group.dev.dependencies] 14 | 15 | [build-system] 16 | requires = ["poetry-core"] 17 | build-backend = "poetry.core.masonry.api" 18 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # TPS deformation 2 | 3 | Python implementation of [thin plate spline] function. 4 | 5 | Rewrite from [daeyun/TPS-Deformation], which was originally matlab code. 6 | 7 | 8 | [thin plate spline]: https://en.wikipedia.org/wiki/Thin_plate_spline 9 | [daeyun/TPS-Deformation]: https://github.com/daeyun/TPS-Deformation 10 | 11 | 12 | ## Usage 13 | 14 | Use `tps.find_coefficients` to get coefficients, and then you can transform other points from source surface to the deformed surface by using `tps.tps.transform`. Or one could use the shortcut `tps.TPS` (see example below). 15 | 16 | Both 2D and 3D points are supported. **Noted** the points should be in N by 2 or N by 3 matrix. 17 | 18 | 19 | ***Example*** 20 | 21 | ```py 22 | samp = np.linspace(-2, 2, 4) 23 | xx, yy = np.meshgrid(samp, samp) 24 | 25 | # make source surface, get uniformed distributed control points 26 | source_xy = np.stack([xx, yy], axis=2).reshape(-1, 2) 27 | 28 | # make deformed surface 29 | yy[:, [0, 3]] *=2 30 | deform_xy = np.stack([xx, yy], axis=2).reshape(-1, 2) 31 | 32 | # get coefficient, use class 33 | trans = tps.TPS(source_xy, deform_xy) 34 | 35 | # make other points a left-bottom to upper-right line on source surface 36 | samp2 = np.linspace(-1.8, 1.8, 10) 37 | test_xy = np.tile(samp2, [2, 1]).T 38 | 39 | # get transformed points 40 | transformed_xy = trans(test_xy) 41 | ``` 42 | 43 | ![](./example.png) 44 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | Copyright 2014, Daeyun Shin. All rights reserved. 2 | Copyright 2018, tzing 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /tps/instance.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from . import functions 4 | 5 | __all__ = ['TPS'] 6 | 7 | 8 | class TPS: 9 | """The thin plate spline deformation warpping. 10 | """ 11 | 12 | def __init__(self, 13 | control_points: numpy.ndarray, 14 | target_points: numpy.ndarray, 15 | lambda_: float = 0., 16 | solver: str = 'exact'): 17 | """Create a instance that preserve the TPS coefficients. 18 | 19 | Arguments 20 | --------- 21 | control_points : numpy.array 22 | p by d vector of control points 23 | target_points : numpy.array 24 | p by d vector of corresponding target points on the deformed 25 | surface 26 | lambda_ : float 27 | regularization parameter 28 | solver : str 29 | the solver to get the coefficients. default is 'exact' for the 30 | exact solution. Or use 'lstsq' for the least square solution. 31 | """ 32 | self.control_points = control_points 33 | self.coefficient = functions.find_coefficients( 34 | control_points, target_points, lambda_, solver) 35 | 36 | def __call__(self, source_points): 37 | """Transform the source points form the original surface to the 38 | destination (deformed) surface. 39 | 40 | Arguments 41 | --------- 42 | source_points : numpy.array 43 | n by d array of source points to be transformed 44 | """ 45 | return functions.transform(source_points, self.control_points, 46 | self.coefficient) 47 | 48 | transform = __call__ 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # editor 107 | .vscode/* 108 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Poetry and should not be changed by hand. 2 | 3 | [[package]] 4 | name = "numpy" 5 | version = "1.24.2" 6 | description = "Fundamental package for array computing in Python" 7 | category = "main" 8 | optional = false 9 | python-versions = ">=3.8" 10 | files = [ 11 | {file = "numpy-1.24.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eef70b4fc1e872ebddc38cddacc87c19a3709c0e3e5d20bf3954c147b1dd941d"}, 12 | {file = "numpy-1.24.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d2859428712785e8a8b7d2b3ef0a1d1565892367b32f915c4a4df44d0e64f5"}, 13 | {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6524630f71631be2dabe0c541e7675db82651eb998496bbe16bc4f77f0772253"}, 14 | {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a51725a815a6188c662fb66fb32077709a9ca38053f0274640293a14fdd22978"}, 15 | {file = "numpy-1.24.2-cp310-cp310-win32.whl", hash = "sha256:2620e8592136e073bd12ee4536149380695fbe9ebeae845b81237f986479ffc9"}, 16 | {file = "numpy-1.24.2-cp310-cp310-win_amd64.whl", hash = "sha256:97cf27e51fa078078c649a51d7ade3c92d9e709ba2bfb97493007103c741f1d0"}, 17 | {file = "numpy-1.24.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7de8fdde0003f4294655aa5d5f0a89c26b9f22c0a58790c38fae1ed392d44a5a"}, 18 | {file = "numpy-1.24.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4173bde9fa2a005c2c6e2ea8ac1618e2ed2c1c6ec8a7657237854d42094123a0"}, 19 | {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cecaed30dc14123020f77b03601559fff3e6cd0c048f8b5289f4eeabb0eb281"}, 20 | {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a23f8440561a633204a67fb44617ce2a299beecf3295f0d13c495518908e910"}, 21 | {file = "numpy-1.24.2-cp311-cp311-win32.whl", hash = "sha256:e428c4fbfa085f947b536706a2fc349245d7baa8334f0c5723c56a10595f9b95"}, 22 | {file = "numpy-1.24.2-cp311-cp311-win_amd64.whl", hash = "sha256:557d42778a6869c2162deb40ad82612645e21d79e11c1dc62c6e82a2220ffb04"}, 23 | {file = "numpy-1.24.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d0a2db9d20117bf523dde15858398e7c0858aadca7c0f088ac0d6edd360e9ad2"}, 24 | {file = "numpy-1.24.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c72a6b2f4af1adfe193f7beb91ddf708ff867a3f977ef2ec53c0ffb8283ab9f5"}, 25 | {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c29e6bd0ec49a44d7690ecb623a8eac5ab8a923bce0bea6293953992edf3a76a"}, 26 | {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2eabd64ddb96a1239791da78fa5f4e1693ae2dadc82a76bc76a14cbb2b966e96"}, 27 | {file = "numpy-1.24.2-cp38-cp38-win32.whl", hash = "sha256:e3ab5d32784e843fc0dd3ab6dcafc67ef806e6b6828dc6af2f689be0eb4d781d"}, 28 | {file = "numpy-1.24.2-cp38-cp38-win_amd64.whl", hash = "sha256:76807b4063f0002c8532cfeac47a3068a69561e9c8715efdad3c642eb27c0756"}, 29 | {file = "numpy-1.24.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4199e7cfc307a778f72d293372736223e39ec9ac096ff0a2e64853b866a8e18a"}, 30 | {file = "numpy-1.24.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:adbdce121896fd3a17a77ab0b0b5eedf05a9834a18699db6829a64e1dfccca7f"}, 31 | {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889b2cc88b837d86eda1b17008ebeb679d82875022200c6e8e4ce6cf549b7acb"}, 32 | {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64bb98ac59b3ea3bf74b02f13836eb2e24e48e0ab0145bbda646295769bd780"}, 33 | {file = "numpy-1.24.2-cp39-cp39-win32.whl", hash = "sha256:63e45511ee4d9d976637d11e6c9864eae50e12dc9598f531c035265991910468"}, 34 | {file = "numpy-1.24.2-cp39-cp39-win_amd64.whl", hash = "sha256:a77d3e1163a7770164404607b7ba3967fb49b24782a6ef85d9b5f54126cc39e5"}, 35 | {file = "numpy-1.24.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92011118955724465fb6853def593cf397b4a1367495e0b59a7e69d40c4eb71d"}, 36 | {file = "numpy-1.24.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9006288bcf4895917d02583cf3411f98631275bc67cce355a7f39f8c14338fa"}, 37 | {file = "numpy-1.24.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:150947adbdfeceec4e5926d956a06865c1c690f2fd902efede4ca6fe2e657c3f"}, 38 | {file = "numpy-1.24.2.tar.gz", hash = "sha256:003a9f530e880cb2cd177cba1af7220b9aa42def9c4afc2a2fc3ee6be7eb2b22"}, 39 | ] 40 | 41 | [metadata] 42 | lock-version = "2.0" 43 | python-versions = "^3.8" 44 | content-hash = "bdf8f744ec76dc93b63dc5d14153cc709e35fb2cd3b53b368a4ed3deed477c44" 45 | -------------------------------------------------------------------------------- /tps/functions.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | __all__ = ['find_coefficients', 'transform'] 4 | 5 | 6 | def cdist(K: numpy.ndarray, B: numpy.ndarray) -> numpy.ndarray: 7 | """Calculate Euclidean distance between K[i, :] and B[j, :]. 8 | 9 | Arguments 10 | --------- 11 | K : numpy.array 12 | B : numpy.array 13 | """ 14 | K = numpy.atleast_2d(K) 15 | B = numpy.atleast_2d(B) 16 | assert K.ndim == 2 17 | assert B.ndim == 2 18 | 19 | K = numpy.expand_dims(K, 1) 20 | B = numpy.expand_dims(B, 0) 21 | D = K - B 22 | return numpy.linalg.norm(D, axis=2) 23 | 24 | 25 | def pairwise_radial_basis(K: numpy.ndarray, B: numpy.ndarray) -> numpy.ndarray: 26 | """Compute the TPS radial basis function phi(r) between every row-pair of K 27 | and B where r is the Euclidean distance. 28 | 29 | Arguments 30 | --------- 31 | K : numpy.array 32 | n by d vector containing n d-dimensional points. 33 | B : numpy.array 34 | m by d vector containing m d-dimensional points. 35 | 36 | Return 37 | ------ 38 | P : numpy.array 39 | n by m matrix where. 40 | P(i, j) = phi( norm( K(i,:) - B(j,:) ) ), 41 | where phi(r) = r^2*log(r), if r >= 1 42 | r*log(r^r), if r < 1 43 | """ 44 | # r_mat(i, j) is the Euclidean distance between K(i, :) and B(j, :). 45 | r_mat = cdist(K, B) 46 | 47 | pwise_cond_ind1 = r_mat >= 1 48 | pwise_cond_ind2 = r_mat < 1 49 | r_mat_p1 = r_mat[pwise_cond_ind1] 50 | r_mat_p2 = r_mat[pwise_cond_ind2] 51 | 52 | # P correcponds to the matrix K from [1]. 53 | P = numpy.empty(r_mat.shape) 54 | P[pwise_cond_ind1] = (r_mat_p1**2) * numpy.log(r_mat_p1) 55 | P[pwise_cond_ind2] = r_mat_p2 * numpy.log(numpy.power(r_mat_p2, r_mat_p2)) 56 | 57 | return P 58 | 59 | 60 | def find_coefficients(control_points: numpy.ndarray, 61 | target_points: numpy.ndarray, 62 | lambda_: float = 0., 63 | solver: str = 'exact') -> numpy.ndarray: 64 | """Given a set of control points and their corresponding points, compute the 65 | coefficients of the TPS interpolant deforming surface. 66 | 67 | Arguments 68 | --------- 69 | control_points : numpy.array 70 | p by d vector of control points 71 | target_points : numpy.array 72 | p by d vector of corresponding target points on the deformed 73 | surface 74 | lambda_ : float 75 | regularization parameter 76 | solver : str 77 | the solver to get the coefficients. default is 'exact' for the exact 78 | solution. Or use 'lstsq' for the least square solution. 79 | 80 | Return 81 | ------ 82 | coef : numpy.ndarray 83 | the coefficients 84 | 85 | .. seealso:: 86 | 87 | http://cseweb.ucsd.edu/~sjb/pami_tps.pdf 88 | """ 89 | # ensure data type and shape 90 | control_points = numpy.atleast_2d(control_points) 91 | target_points = numpy.atleast_2d(target_points) 92 | if control_points.shape != target_points.shape: 93 | raise ValueError( 94 | 'Shape of and control points {cp} and target points {tp} are not the same.'. 95 | format(cp=control_points.shape, tp=target_points.shape)) 96 | 97 | p, d = control_points.shape 98 | 99 | # The matrix 100 | K = pairwise_radial_basis(control_points, control_points) 101 | P = numpy.hstack([numpy.ones((p, 1)), control_points]) 102 | 103 | # Relax the exact interpolation requirement by means of regularization. 104 | K = K + lambda_ * numpy.identity(p) 105 | 106 | # Target points 107 | M = numpy.vstack([ 108 | numpy.hstack([K, P]), 109 | numpy.hstack([P.T, numpy.zeros((d + 1, d + 1))]) 110 | ]) 111 | Y = numpy.vstack([target_points, numpy.zeros((d + 1, d))]) 112 | 113 | # solve for M*X = Y. 114 | # At least d+1 control points should not be in a subspace; e.g. for d=2, at 115 | # least 3 points are not on a straight line. Otherwise M will be singular. 116 | solver = solver.lower() 117 | if solver == 'exact': 118 | X = numpy.linalg.solve(M, Y) 119 | elif solver == 'lstsq': 120 | X, _, _, _ = numpy.linalg.lstsq(M, Y, None) 121 | else: 122 | raise ValueError('Unknown solver: ' + solver) 123 | 124 | return X 125 | 126 | 127 | def transform(source_points: numpy.ndarray, control_points: numpy.ndarray, 128 | coefficient: numpy.ndarray) -> numpy.ndarray: 129 | """Transform the source points form the original surface to the destination 130 | (deformed) surface. 131 | 132 | Arguments 133 | --------- 134 | source_points : numpy.array 135 | n by d array of source points to be transformed 136 | control_points : numpy.array 137 | the control points used in the function `find_coefficients` 138 | coefficient : numpy.array 139 | the computed coefficients 140 | 141 | Return 142 | ------ 143 | deformed_points : numpy.array 144 | n by d array of the transformed point on the target surface 145 | """ 146 | source_points = numpy.atleast_2d(source_points) 147 | control_points = numpy.atleast_2d(control_points) 148 | if source_points.shape[-1] != control_points.shape[-1]: 149 | raise ValueError( 150 | 'Dimension of source points ({sd}D) and control points ({cd}D) are not the same.'. 151 | format(sd=source_points.shape[-1], cd=control_points.shape[-1])) 152 | 153 | n = source_points.shape[0] 154 | 155 | A = pairwise_radial_basis(source_points, control_points) 156 | K = numpy.hstack([A, numpy.ones((n, 1)), source_points]) 157 | 158 | deformed_points = numpy.dot(K, coefficient) 159 | return deformed_points 160 | --------------------------------------------------------------------------------