├── test ├── __init__.py └── test_affine.py ├── examples ├── alice_big.jpg ├── alice_small.jpg └── basic.ipynb ├── torchreg ├── __init__.py ├── utils.py ├── metrics.py ├── affine.py └── syn.py ├── CITATION.cff ├── pyproject.toml ├── LICENSE ├── README.md └── .gitignore /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/alice_big.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codingfisch/torchreg/HEAD/examples/alice_big.jpg -------------------------------------------------------------------------------- /torchreg/__init__.py: -------------------------------------------------------------------------------- 1 | from .affine import AffineRegistration 2 | from .syn import SyNRegistration 3 | -------------------------------------------------------------------------------- /examples/alice_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codingfisch/torchreg/HEAD/examples/alice_small.jpg -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Fisch" 5 | given-names: "Lukas" 6 | title: "torchreg - Lightweight image registration library using PyTorch" 7 | version: 0.0.1 8 | #doi: 10.5281/zenodo.1234 9 | date-released: 2023-08-23 10 | url: "https://github.com/codingfisch/torchreg" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = 'torchreg' 3 | version = '0.1.3' 4 | description = 'Lightweight image registration library using PyTorch' 5 | authors = ['codingfisch '] 6 | license = 'MIT' 7 | readme = 'README.md' 8 | repository = 'https://github.com/codingfisch/torchreg' 9 | classifiers = [ 10 | 'Programming Language :: Python :: 3', 11 | 'Operating System :: OS Independent', 12 | 'Intended Audience :: Science/Research' 13 | ] 14 | 15 | 16 | [tool.poetry.dependencies] 17 | python = '^3.9' 18 | torch = '*' 19 | tqdm = '*' 20 | 21 | 22 | [tool.poetry.group.test.dependencies] 23 | pytest = '*' 24 | 25 | 26 | [build-system] 27 | requires = ['poetry-core'] 28 | build-backend = 'poetry.core.masonry.api' 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 codingfisch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchreg/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | INTERP_KWARGS = {'mode': 'trilinear', 'align_corners': True} 4 | 5 | 6 | def smooth_kernel(kernel_size, sigma): 7 | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32, device=sigma.device) for size in kernel_size]) 8 | kernel = 1 9 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 10 | mean = (size - 1) / 2 11 | kernel *= 1 / (std * (2 * torch.pi)**.5) * torch.exp(-((mgrid - mean) / std) ** 2 / 2) 12 | return kernel / kernel.sum() 13 | 14 | 15 | def jacobi_determinant(u, id_grid=None): 16 | gradient = jacobi_gradient(u, id_grid) 17 | dx, dy, dz = gradient[..., 2], gradient[..., 1], gradient[..., 0] 18 | jdet0 = dx[2] * (dy[1] * dz[0] - dy[0] * dz[1]) 19 | jdet1 = dx[1] * (dy[2] * dz[0] - dy[0] * dz[2]) 20 | jdet2 = dx[0] * (dy[2] * dz[1] - dy[1] * dz[2]) 21 | jdet = jdet0 - jdet1 + jdet2 22 | return F.pad(jdet[None, None, 2:-2, 2:-2, 2:-2], (2, 2, 2, 2, 2, 2), mode='replicate')[0, 0] 23 | 24 | 25 | def jacobi_gradient(u, id_grid=None): 26 | if id_grid is None: 27 | id_grid = create_grid(u.shape[1:4], u.device) 28 | x = 0.5 * (u + id_grid) * (torch.tensor(u.shape[1:4], device=u.device, dtype=u.dtype) - 1) 29 | window = torch.tensor([-.5, 0, .5], device=u.device) 30 | w = torch.zeros((3, 1, 3, 3, 3), device=u.device, dtype=u.dtype) 31 | w[2, 0, :, 1, 1] = window 32 | w[1, 0, 1, :, 1] = window 33 | w[0, 0, 1, 1, :] = window 34 | x = x.permute(4, 0, 1, 2, 3) 35 | x = F.conv3d(x, w) 36 | x = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate') # 'circular' for bfloat16 37 | return x.permute(0, 2, 3, 4, 1) 38 | 39 | 40 | def create_grid(shape, device): 41 | return F.affine_grid(torch.eye(4, device=device)[None, :3], [1, 3, *shape], align_corners=INTERP_KWARGS['align_corners']) 42 | -------------------------------------------------------------------------------- /torchreg/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .utils import create_grid, jacobi_gradient 5 | 6 | 7 | def dice_loss(x1, x2): 8 | return 1 - dice_score(x1, x2) 9 | 10 | 11 | def dice_score(x1, x2): 12 | dim = [2, 3, 4] if len(x2.shape) == 5 else [2, 3] 13 | inter = torch.sum(x1 * x2, dim=dim) 14 | union = torch.sum(x1 + x2, dim=dim) 15 | return (2. * inter / union).mean() 16 | 17 | 18 | class LinearElasticity(torch.nn.Module): 19 | def __init__(self, mu=2., lam=1., refresh_id_grid=False): 20 | super(LinearElasticity, self).__init__() 21 | self.mu = mu 22 | self.lam = lam 23 | self.id_grid = None 24 | self.refresh_id_grid = refresh_id_grid 25 | 26 | def forward(self, u): 27 | if self.id_grid is None or self.refresh_id_grid: 28 | self.id_grid = create_grid(u.shape[1:4], u.device) 29 | gradients = jacobi_gradient(u, self.id_grid) 30 | u_xz, u_xy, u_xx = jacobi_gradient(gradients[None, 2], self.id_grid) 31 | u_yz, u_yy, u_yx = jacobi_gradient(gradients[None, 1], self.id_grid) 32 | u_zz, u_zy, u_zx = jacobi_gradient(gradients[None, 0], self.id_grid) 33 | e_xy = .5 * (u_xy + u_yx) 34 | e_xz = .5 * (u_xz + u_zx) 35 | e_yz = .5 * (u_yz + u_zy) 36 | sigma_xx = 2 * self.mu * u_xx + self.lam * (u_xx + u_yy + u_zz) 37 | sigma_xy = 2 * self.mu * e_xy 38 | sigma_xz = 2 * self.mu * e_xz 39 | sigma_yy = 2 * self.mu * u_yy + self.lam * (u_xx + u_yy + u_zz) 40 | sigma_yz = 2 * self.mu * e_yz 41 | sigma_zz = 2 * self.mu * u_zz + self.lam * (u_xx + u_yy + u_zz) 42 | return (sigma_xx ** 2 + sigma_xy ** 2 + sigma_xz ** 2 + 43 | sigma_yy ** 2 + sigma_yz ** 2 + sigma_zz ** 2).mean() 44 | 45 | 46 | class NCC(torch.nn.Module): 47 | def __init__(self, kernel_size=7, epsilon_numerator=1e-5, epsilon_denominator=1e-5): 48 | super(NCC, self).__init__() 49 | self.kernel_size = kernel_size 50 | self.eps_nr = epsilon_numerator 51 | self.eps_dr = epsilon_denominator 52 | 53 | def forward(self, pred, targ): 54 | kernel = torch.ones([*targ.shape[:2]] + 3 * [self.kernel_size], device=targ.device) 55 | t_sum = F.conv3d(targ, kernel, padding=self.kernel_size // 2) 56 | p_sum = F.conv3d(pred, kernel, padding=self.kernel_size // 2) 57 | t2_sum = F.conv3d(targ ** 2, kernel, padding=self.kernel_size // 2) 58 | p2_sum = F.conv3d(pred ** 2, kernel, padding=self.kernel_size // 2) 59 | tp_sum = F.conv3d(targ * pred, kernel, padding=self.kernel_size // 2) 60 | cross = tp_sum - t_sum * p_sum / kernel.sum() 61 | t_var = F.relu(t2_sum - t_sum ** 2 / kernel.sum()) 62 | p_var = F.relu(p2_sum - p_sum ** 2 / kernel.sum()) 63 | cc = (cross ** 2 + self.eps_nr) / (t_var * p_var + self.eps_dr) 64 | return -torch.mean(cc) 65 | -------------------------------------------------------------------------------- /test/test_affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from unittest import TestCase 3 | from torchreg import AffineRegistration 4 | from torchreg.affine import compose_affine, affine_transform, init_parameters, _check_parameter_shapes 5 | 6 | 7 | class TestAffineRegistration(TestCase): 8 | def test_fit(self): 9 | for batch_size in [1, 2]: 10 | for n_dim in [2, 3]: 11 | reg = AffineRegistration(scales=(1,), is_3d=n_dim == 3, learning_rate=1e-1, verbose=False) 12 | moving = synthetic_image(batch_size, n_dim, shift=1) 13 | static = synthetic_image(batch_size, n_dim, shift=0) 14 | fitted_moved = reg(moving, static, return_moved=True) 15 | fitted_affine = reg.get_affine() 16 | affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]]) 17 | affine[:, -1, -1] += -1/3 18 | self.assertTrue(torch.allclose(fitted_affine, affine, atol=1e-2)) 19 | moved = affine_transform(moving, affine) 20 | self.assertTrue(torch.allclose(fitted_moved, moved, atol=1e-2)) 21 | 22 | def test_affine_transform(self): 23 | for batch_size in [1, 2]: 24 | for n_dim in [2, 3]: 25 | moving = synthetic_image(batch_size, n_dim, shift=1) 26 | static = synthetic_image(batch_size, n_dim, shift=0) 27 | affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]]) 28 | affine[:, -1, -1] += -1/3 29 | moved = affine_transform(moving, affine) 30 | self.assertTrue(torch.allclose(moved, static, atol=1e-6)) 31 | 32 | def test_init_parameters(self): 33 | for batch_size in [1, 2]: 34 | for is_3d in [False, True]: 35 | params = init_parameters(is_3d=is_3d, batch_size=batch_size) 36 | self.assertIsInstance(params, list) 37 | self.assertEqual(len(params), 4) 38 | for param in params: 39 | self.assertTrue(isinstance(param, torch.nn.Parameter)) 40 | _check_parameter_shapes(*params, is_3d=is_3d, batch_size=batch_size) 41 | 42 | def test_compose_affine(self): 43 | for batch_size in [1, 2]: 44 | for n_dim in [2, 3]: 45 | translation = torch.zeros(batch_size, n_dim) 46 | rotation = torch.stack(batch_size * [torch.eye(n_dim)]) 47 | zoom = torch.ones(batch_size, n_dim) 48 | shear = torch.zeros(batch_size, n_dim) 49 | affine = compose_affine(translation, rotation, zoom, shear) 50 | id_affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]]) 51 | self.assertTrue(torch.equal(affine, id_affine)) 52 | 53 | 54 | def synthetic_image(batch_size, n_dim, shift): 55 | shape = [batch_size, 1, 7, 7, 7][:2 + n_dim] 56 | x = torch.zeros(*shape) 57 | if n_dim == 3: 58 | x[:, :, 2 - shift:5 - shift, 2:5, 2:5] = 1 59 | else: 60 | x[:, :, 2 - shift:5 - shift, 2:5] = 1 61 | return x 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchreg 2 | 3 | torchreg is a tiny (~300 lines) PyTorch-based library for 2D and 3D image registration. 4 | 5 |

6 | 7 | 8 | 9 |

10 | 11 | ## Usage 12 | Affine Registration of two image tensors is done via: 13 | ```python 14 | from torchreg import AffineRegistration 15 | 16 | # Load images as torch Tensors 17 | big_alice = ... # Tensor with shape [1, 3 (color channel), 1024 (pixel), 1024 (pixel)] 18 | small_alice = ... # Tensor with shape [1, 3 (color channel), 1024 (pixel), 1024 (pixel)] 19 | # Intialize AffineRegistration 20 | reg = AffineRegistration(is_3d=False) 21 | # Run it! 22 | moved_alice = reg(big_alice, small_alice) 23 | ``` 24 | 25 | ## Features 26 | 27 | Multiresolution approach to save compute (per default 1/4 + 1/2 of original resolution for 500 + 100 iterations) 28 | ```python 29 | reg = AffineRegistration(scales=(4, 2), iterations=(500, 100)) 30 | ``` 31 | Choosing which operations (translation, rotation, zoom, shear) to optimize 32 | ```python 33 | reg = AffineRegistration(with_zoom=False, with_shear=False) 34 | ``` 35 | Custom initial parameters 36 | ```python 37 | reg = AffineRegistration(zoom=torch.Tensor([[1.5, 2.]])) 38 | ``` 39 | Custom dissimilarity functions and optimizers 40 | ```python 41 | def dice_loss(x1, x2): 42 | dim = [2, 3, 4] if len(x2.shape) == 5 else [2, 3] 43 | inter = torch.sum(x1 * x2, dim=dim) 44 | union = torch.sum(x1 + x2, dim=dim) 45 | return 1 - (2. * inter / union).mean() 46 | 47 | reg = AffineRegistration(dissimilarity_function=dice_loss, optimizer=torch.optim.Adam) 48 | ``` 49 | CUDA support (NVIDIA GPU) 50 | ```python 51 | moved_alice = reg(moving=big_alice.cuda(), static=small_alice.cuda()) 52 | ``` 53 | 54 | After the registration is run, you can apply it to new images (coregistration) 55 | ```python 56 | another_moved_alice = reg.transform(another_alice, shape=(256, 256)) 57 | ``` 58 | with desired output shape. 59 | 60 | You can access the affine 61 | ```python 62 | affine = reg.get_affine() 63 | ``` 64 | and the four parameters (translation, rotation, zoom, shear) 65 | ```python 66 | translation = reg.parameters[0] 67 | rotation = reg.parameters[1] 68 | zoom = reg.parameters[2] 69 | shear = reg.parameters[3] 70 | ``` 71 | 72 | ## Installation 73 | ```bash 74 | pip install torchreg 75 | ``` 76 | 77 | ## Examples/Tutorials 78 | 79 | There are three example notebooks: 80 | 81 | - [examples/basics.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/basic.ipynb) shows the basics by using small cubes/squares as image data 82 | - [examples/images.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/image.ipynb) shows how to register alice_big.jpg to alice_small.jpg 83 | - [examples/mri.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/mri.ipynb) shows how to register MR images (Nifti files) including co-, parallel and multimodal registration 84 | 85 | ## Background 86 | 87 | If you want to know how the core of this package works, read [the blog post](https://codingfisch.github.io/2023/08/09/affine-registration-in-12-lines-of-code.html)! 88 | 89 | ## TODO 90 | - [ ] Add 2D support to SyN, NCC and LinearElasticity 91 | - [ ] Add tests for SyN 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /torchreg/affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | 5 | 6 | class AffineRegistration: 7 | def __init__(self, scales=(4, 2), iterations=(500, 100), is_3d=True, learning_rate=1e-2, 8 | verbose=True, dissimilarity_function=torch.nn.MSELoss(), optimizer=torch.optim.Adam, 9 | init_translation=None, init_rotation=None, init_zoom=None, init_shear=None, 10 | with_translation=True, with_rotation=True, with_zoom=True, with_shear=False, 11 | align_corners=True, interp_mode=None, padding_mode='border'): 12 | self.scales = scales 13 | self.iterations = iterations[:len(scales)] 14 | self.is_3d = is_3d 15 | self.learning_rate = learning_rate 16 | self.verbose = verbose 17 | self.dissimilarity_function = dissimilarity_function 18 | self.optimizer = optimizer 19 | self.inits = (init_translation, init_rotation, init_zoom, init_shear) 20 | self.withs = (with_translation, with_rotation, with_zoom, with_shear) 21 | self.align_corners = align_corners 22 | self.interp_mode = 'trilinear' if is_3d else 'bilinear' if interp_mode is None else interp_mode 23 | self.padding_mode = padding_mode 24 | self._parameters = None 25 | self._loss = None 26 | 27 | def __call__(self, moving, static, return_moved=True): 28 | if len(moving.shape) - 4 != self.is_3d or len(static.shape) - 4 != self.is_3d: 29 | raise ValueError(f'Expected moving and static to be {4 + self.is_3d}D Tensors (2 + Spatial Dims.). ' 30 | f'Got size {moving.shape} and {static.shape}.') 31 | self._parameters = init_parameters(self.is_3d, len(static), static.device, *self.withs, *self.inits) 32 | interp_kwargs = {'mode': self.interp_mode, 'align_corners': self.align_corners} 33 | moving_ = F.interpolate(moving, static.shape[2:], **interp_kwargs) 34 | for scale, iters in zip(self.scales, self.iterations): 35 | moving_small = F.interpolate(moving_, scale_factor=1 / scale, **interp_kwargs) 36 | static_small = F.interpolate(static, scale_factor=1 / scale, **interp_kwargs) 37 | self._loss = self._fit(moving_small, static_small, iters) 38 | return self.transform(moving, static.shape[2:]).detach() if return_moved else None 39 | 40 | def _fit(self, moving, static, iterations): 41 | optimizer = self.optimizer(self._parameters, self.learning_rate) 42 | progress_bar = tqdm(range(iterations), disable=not self.verbose) 43 | for self.iter in progress_bar: 44 | optimizer.zero_grad() 45 | moved = self.transform(moving, static.shape[2:], with_grad=True) 46 | loss = self.dissimilarity_function(moved, static) 47 | progress_bar.set_description(f'Shape: {[*static.shape]}; Dissimiliarity: {loss.item()}') 48 | loss.backward() 49 | optimizer.step() 50 | return loss.item() 51 | 52 | def transform(self, moving, shape=None, with_grad=False): 53 | affine = self.get_affine(with_grad).type(moving.dtype) 54 | return affine_transform(moving, affine, shape, self.interp_mode, self.padding_mode, self.align_corners) 55 | 56 | def get_affine(self, with_grad=False): 57 | affine = compose_affine(*self._parameters) 58 | return affine if with_grad else affine.detach() 59 | 60 | 61 | def affine_transform(x, affine, shape=None, mode='bilinear', padding_mode='border', align_corners=True): 62 | shape = x.shape[2:] if shape is None else shape 63 | grid = F.affine_grid(affine, [len(x), len(shape), *shape], align_corners) 64 | sample_mode = 'bilinear' if mode == 'trilinear' else mode # grid_sample converts 'bi-' to 'trilinear' internally 65 | return F.grid_sample(x, grid, sample_mode, padding_mode, align_corners) 66 | 67 | 68 | def init_parameters(is_3d=True, batch_size=1, device='cpu', with_translation=True, with_rotation=True, with_zoom=True, 69 | with_shear=True, init_translation=None, init_rotation=None, init_zoom=None, init_shear=None): 70 | _check_parameter_shapes(init_translation, init_rotation, init_zoom, init_shear, is_3d, batch_size) 71 | n_dim = 2 + is_3d 72 | translation = torch.zeros(batch_size, n_dim).to(device) if init_translation is None else init_translation 73 | rotation = torch.stack(batch_size * [torch.eye(n_dim)]).to(device) if init_rotation is None else init_rotation 74 | zoom = torch.ones(batch_size, n_dim).to(device) if init_zoom is None else init_zoom 75 | shear = torch.zeros(batch_size, n_dim).to(device) if init_shear is None else init_shear 76 | params = [translation, rotation, zoom, shear] 77 | with_grad = [with_translation, with_rotation, with_zoom, with_shear] 78 | return [torch.nn.Parameter(param, requires_grad=grad) for param, grad in zip(params, with_grad)] 79 | 80 | 81 | def compose_affine(translation, rotation, zoom, shear): 82 | _check_parameter_shapes(translation, rotation, zoom, shear, zoom.shape[-1] == 3, zoom.shape[0]) 83 | square_matrix = torch.diag_embed(zoom) 84 | if zoom.shape[-1] == 3: 85 | square_matrix[..., 0, 1:] = shear[..., :2] 86 | square_matrix[..., 1, 2] = shear[..., 2] 87 | else: 88 | square_matrix[..., 0, 1] = shear[..., 0] 89 | square_matrix = rotation @ square_matrix 90 | return torch.cat([square_matrix, translation[:, :, None]], dim=-1) 91 | 92 | 93 | def _check_parameter_shapes(translation, rotation, zoom, shear, is_3d=True, batch_size=1): 94 | n_dim = 2 + is_3d 95 | params = {'translation': translation, 'rotation': rotation, 'zoom': zoom, 'shear': shear} 96 | for name, param in params.items(): 97 | if param is not None: 98 | desired_shape = (batch_size, n_dim, n_dim) if name == 'rotation' else (batch_size, n_dim) 99 | if param.shape != desired_shape: 100 | raise ValueError(f'Expected {name} to be size {desired_shape} since batch_size is {batch_size} ' 101 | f'and is_3d is {is_3d} -> {2 + is_3d} dimensions. Got size {param.shape}.') 102 | -------------------------------------------------------------------------------- /torchreg/syn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | 6 | from .metrics import LinearElasticity 7 | from .utils import INTERP_KWARGS, create_grid, smooth_kernel 8 | LIN_ELAST_FUNC = lambda x: LinearElasticity(mu=2., lam=1.)(x) 9 | 10 | 11 | class SyNBase: 12 | def __init__(self, time_steps=7): 13 | self.time_steps = time_steps 14 | self._grid = None 15 | 16 | def apply_flows(self, x, y, v_xy, v_yx): 17 | half_flows = self.diffeomorphic_transform(torch.cat([v_xy, v_yx, -v_xy, -v_yx])) 18 | half_images = self.spatial_transform(torch.cat([x, y]), half_flows[:2]) 19 | full_flows = self.composition_transform(half_flows[:2], half_flows[2:].flip(0)) 20 | full_images = self.spatial_transform(torch.cat([x, y]), full_flows) 21 | images = {'xy_half': half_images[:1], 'yx_half': half_images[1:2], 22 | 'xy_full': full_images[:1], 'yx_full': full_images[1:2]} 23 | flows = {'xy_half': half_flows[:1], 'yx_half': half_flows[1:2], 24 | 'xy_full': full_flows[:1], 'yx_full': full_flows[1:2]} 25 | flows = {k: flow.permute(0, 2, 3, 4, 1) for k, flow in flows.items()} 26 | return images, flows 27 | 28 | def diffeomorphic_transform(self, v): 29 | v = v / (2 ** self.time_steps) 30 | for i in range(self.time_steps): 31 | v = v + self.spatial_transform(v, v) 32 | return v 33 | 34 | def composition_transform(self, v1, v2): 35 | return v2 + self.spatial_transform(v1, v2) 36 | 37 | def spatial_transform(self, x, v): 38 | if self._grid is None: 39 | self._grid = create_grid(v.shape[2:], x.device) 40 | return F.grid_sample(x, self._grid + v.permute(0, 2, 3, 4, 1), align_corners=True, padding_mode='reflection') 41 | 42 | 43 | class SyNRegistration(SyNBase): 44 | def __init__(self, scales=(4, 2, 1), iterations=(30, 30, 10), learning_rate=1e-2, verbose=True, 45 | dissimilarity_function=torch.nn.MSELoss(), regularization_function=LIN_ELAST_FUNC, 46 | optimizer=torch.optim.Adam, sigma_img=.2, sigma_flow=.2, lambda_=2e-5, time_steps=7): 47 | super().__init__(time_steps=time_steps) 48 | self.scales = scales 49 | self.iterations = iterations 50 | self.learning_rates = [learning_rate] * len(scales) if isinstance(learning_rate, float) else learning_rate 51 | self.verbose = verbose 52 | self.dissimilarity_function = dissimilarity_function 53 | self.regularization_function = regularization_function 54 | self.optimizer = optimizer 55 | self.sigma_img = sigma_img 56 | self.sigma_flow = sigma_flow 57 | self.lambda_ = lambda_ 58 | self.v_xy = None 59 | self.v_yx = None 60 | self._grid = None 61 | 62 | def __call__(self, moving, static, v_xy=None, v_yx=None, return_moved=True): 63 | if v_xy is None: 64 | v_xy = torch.zeros((moving.shape[0], 3, *moving.shape[2:])) 65 | if v_yx is None: 66 | v_yx = torch.zeros((static.shape[0], 3, *static.shape[2:])) 67 | self.v_xy = v_xy.type(static.dtype).to(static.device) 68 | self.v_yx = v_yx.type(static.dtype).to(static.device) 69 | for scale, iters, lr in zip(self.scales, self.iterations, self.learning_rates): 70 | moving_shape, static_shape = [s for s in moving.shape[2:]], [s for s in static.shape[2:]] 71 | shape = [int(round(s / scale)) for s in static_shape] 72 | self._grid = create_grid(shape, static.device) 73 | x = F.interpolate(moving, shape, **INTERP_KWARGS) if shape != moving_shape else moving.clone() 74 | y = F.interpolate(static, shape, **INTERP_KWARGS) if shape != static_shape else static.clone() 75 | if self.sigma_img: 76 | sigma_img = self.sigma_img * 200 / torch.tensor(shape).int() 77 | x = gauss_smoothing(x, sigma_img) 78 | y = gauss_smoothing(y, sigma_img) 79 | self.fit(x, y, iters, lr) 80 | self._grid = create_grid(static.shape[2:], static.device) 81 | if return_moved: 82 | images, flows = self.apply_flows(moving, static, self.v_xy, self.v_yx) 83 | return images['xy_full'], images['yx_full'], flows['xy_full'], flows['yx_full'] 84 | 85 | def fit(self, x, y, iterations, learning_rate): 86 | v_xy = F.interpolate(self.v_xy, x.shape[2:], **INTERP_KWARGS) 87 | v_xy = torch.nn.Parameter(v_xy, requires_grad=True) 88 | v_yx = F.interpolate(self.v_yx, x.shape[2:], **INTERP_KWARGS) 89 | v_yx = torch.nn.Parameter(v_yx, requires_grad=True) 90 | sigma_flow = self.sigma_flow * torch.ones(3) 91 | optimizer = self.optimizer([v_xy, v_yx], learning_rate) 92 | progress_bar = tqdm(range(iterations), disable=not self.verbose) 93 | for _ in progress_bar: 94 | optimizer.zero_grad() 95 | images, flows = self.apply_flows(x, y, gauss_smoothing(v_xy, sigma_flow), gauss_smoothing(v_yx, sigma_flow)) 96 | dissimilarity = (self.dissimilarity_function(x, images['yx_full']) + 97 | self.dissimilarity_function(y, images['xy_full']) + 98 | self.dissimilarity_function(images['yx_half'], images['xy_half'])) 99 | regularization = (self.regularization_function(flows['yx_full']) + 100 | self.regularization_function(flows['xy_full'])) 101 | loss = dissimilarity + self.lambda_ * regularization 102 | progress_bar.set_description(f'Loss: {loss.item()}, ' 103 | f'Dissimilarity: {dissimilarity.item()}, ' 104 | f'Regularization: {regularization.item()}') 105 | loss.backward() 106 | optimizer.step() 107 | v_xy, v_yx = v_xy.detach(), v_yx.detach() 108 | v_xy, v_yx = gauss_smoothing(v_xy, sigma_flow), gauss_smoothing(v_yx, sigma_flow) 109 | self.v_xy = F.interpolate(v_xy, self.v_xy.shape[2:], **INTERP_KWARGS) if self.v_xy.shape != v_xy.shape else v_xy 110 | self.v_yx = F.interpolate(v_yx, self.v_yx.shape[2:], **INTERP_KWARGS) if self.v_yx.shape != v_yx.shape else v_yx 111 | 112 | 113 | def gauss_smoothing(x, sigma): 114 | half_kernel_size = np.array(x.shape[2:]) // 50 115 | kernel_size = 1 + 2 * half_kernel_size.clip(min=1) 116 | kernel = smooth_kernel(kernel_size.tolist(), sigma).to(x.device) 117 | kernel = kernel[None, None].repeat(x.shape[1], 1, 1, 1, 1) 118 | x = F.pad(x, (kernel_size.repeat(2)[::-1] // 2).tolist(), mode='replicate') 119 | return F.conv3d(x.type(torch.float32), kernel, groups=x.shape[1]).type(x.dtype) 120 | -------------------------------------------------------------------------------- /examples/basic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Basic Square/Cube Examples\n", 7 | "\n", 8 | "In the following we will register two little squares/cubes to get familiar with the usage of torchreg." 9 | ], 10 | "metadata": { 11 | "collapsed": false 12 | } 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "sys.path.append('..')\n", 21 | "import torch\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from torchreg import AffineRegistration" 24 | ], 25 | "metadata": { 26 | "collapsed": false 27 | } 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "Creating squares" 33 | ], 34 | "metadata": { 35 | "collapsed": false 36 | } 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "outputs": [], 42 | "source": [ 43 | "# Non-centered 3² square in 7² tensor\n", 44 | "moving = torch.zeros(7, 7)\n", 45 | "moving[1:4, 2:5] = 1\n", 46 | "# Centered 3² square in 7² tensor\n", 47 | "static = torch.zeros(7, 7)\n", 48 | "static[2:5, 2:5] = 1" 49 | ], 50 | "metadata": { 51 | "collapsed": false 52 | } 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "source": [ 57 | "Plot before registration" 58 | ], 59 | "metadata": { 60 | "collapsed": false 61 | } 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": "
", 70 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVs0lEQVR4nO3dfWyVhd3/8W+lcFBsq6AgDZURNT4h6KhzgG4+jTv8lGiWOV3UkT38wYIP2Jg59A/dk3V/bNHF2axscSOLYpYNZYmALBNwcWyAEgkaxWFCpzKicW3tH0fB6/fHfa/33aHM0/bbw+ler+RKPCfXyfU5CeHt1dOWuqIoigCAYXZUtQcAMDoJDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKSoH+kLfvDBB/HGG29EQ0ND1NXVjfTlARiCoiiit7c3mpub46ijDn+PMuKBeeONN6KlpWWkLwvAMOrq6opp06Yd9pwRD0xDQ0NERFwY/y/qY+xIXx6AITgQ78cf48n+v8sPZ8QD888vi9XH2KivExiAmvI/v73y43zE4UN+AFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUgwrMQw89FDNmzIjx48fHnDlz4plnnhnuXQDUuIoD89hjj8WyZcvirrvuiueffz4uuuiiWLhwYezduzdjHwA1quLA/OhHP4qvfe1r8fWvfz3OPPPMuP/++6OlpSU6Ojoy9gFQoyoKzHvvvRfbt2+PBQsWDHh+wYIF8eyzz37oa8rlcvT09Aw4ABj9KgrMW2+9FQcPHowpU6YMeH7KlCmxb9++D31Ne3t7NDU19R8tLS2DXwtAzRjUh/x1dXUDHhdFcchz/7R8+fLo7u7uP7q6ugZzSQBqTH0lJ59wwgkxZsyYQ+5W9u/ff8hdzT+VSqUolUqDXwhATaroDmbcuHExZ86c2LBhw4DnN2zYEPPmzRvWYQDUtoruYCIi2tra4sYbb4zW1taYO3dudHZ2xt69e2PJkiUZ+wCoURUH5tprr4233347vvOd78Sbb74ZM2fOjCeffDKmT5+esQ+AGlVXFEUxkhfs6emJpqamuDiuivq6sSN5aQCG6EDxfmyMJ6K7uzsaGxsPe67fRQZACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkKK+2gM4cq1/Y0e1J/Af6L+az632BIaJOxgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkKLiwGzevDkWLVoUzc3NUVdXF48//njCLABqXcWB6evri9mzZ8eDDz6YsQeAUaK+0hcsXLgwFi5cmLEFgFGk4sBUqlwuR7lc7n/c09OTfUkAjgDpH/K3t7dHU1NT/9HS0pJ9SQCOAOmBWb58eXR3d/cfXV1d2ZcE4AiQ/iWyUqkUpVIp+zIAHGH8HAwAKSq+g3n33Xfj1Vdf7X/82muvxY4dO2LixIlx8sknD+s4AGpXxYHZtm1bXHLJJf2P29raIiJi8eLF8Ytf/GLYhgFQ2yoOzMUXXxxFUWRsAWAU8RkMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQVBaa9vT3OP//8aGhoiMmTJ8fVV18dL7/8ctY2AGpYRYHZtGlTLF26NLZs2RIbNmyIAwcOxIIFC6Kvry9rHwA1qr6Sk9etWzfg8cMPPxyTJ0+O7du3x2c+85lhHQZAbasoMP+qu7s7IiImTpz4keeUy+Uol8v9j3t6eoZySQBqxKA/5C+KItra2uLCCy+MmTNnfuR57e3t0dTU1H+0tLQM9pIA1JBBB+amm26KF154IR599NHDnrd8+fLo7u7uP7q6ugZ7SQBqyKC+RHbzzTfHmjVrYvPmzTFt2rTDnlsqlaJUKg1qHAC1q6LAFEURN998c6xevTo2btwYM2bMyNoFQI2rKDBLly6NRx55JJ544oloaGiIffv2RUREU1NTHH300SkDAahNFX0G09HREd3d3XHxxRfH1KlT+4/HHnssax8ANariL5EBwMfhd5EBkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQoqLAdHR0xKxZs6KxsTEaGxtj7ty5sXbt2qxtANSwigIzbdq0uO+++2Lbtm2xbdu2uPTSS+Oqq66KXbt2Ze0DoEbVV3LyokWLBjz+/ve/Hx0dHbFly5Y4++yzh3UYALWtosD8XwcPHoxf//rX0dfXF3Pnzv3I88rlcpTL5f7HPT09g70kADWk4g/5d+7cGccee2yUSqVYsmRJrF69Os4666yPPL+9vT2ampr6j5aWliENBqA2VByY008/PXbs2BFbtmyJb3zjG7F48eJ48cUXP/L85cuXR3d3d//R1dU1pMEA1IaKv0Q2bty4OPXUUyMiorW1NbZu3RoPPPBA/PSnP/3Q80ulUpRKpaGtBKDmDPnnYIqiGPAZCwBEVHgHc+edd8bChQujpaUlent7Y9WqVbFx48ZYt25d1j4AalRFgfn73/8eN954Y7z55pvR1NQUs2bNinXr1sXnPve5rH0A1KiKAvPzn/88awcAo4zfRQZACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkKK+2gM4cv1X87nVngDUMHcwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASDFkALT3t4edXV1sWzZsmGaA8BoMejAbN26NTo7O2PWrFnDuQeAUWJQgXn33Xfj+uuvjxUrVsTxxx8/3JsAGAUGFZilS5fGFVdcEZdffvm/PbdcLkdPT8+AA4DRr77SF6xatSqee+652Lp168c6v729Pb797W9XPAyA2lbRHUxXV1fceuut8atf/SrGjx//sV6zfPny6O7u7j+6uroGNRSA2lLRHcz27dtj//79MWfOnP7nDh48GJs3b44HH3wwyuVyjBkzZsBrSqVSlEql4VkLQM2oKDCXXXZZ7Ny5c8BzX/nKV+KMM86IO+6445C4APCfq6LANDQ0xMyZMwc8N2HChJg0adIhzwPwn81P8gOQouLvIvtXGzduHIYZAIw27mAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApKgoMPfcc0/U1dUNOE466aSsbQDUsPpKX3D22WfH73//+/7HY8aMGdZBAIwOFQemvr7eXQsA/1bFn8Hs3r07mpubY8aMGXHdddfFnj17Dnt+uVyOnp6eAQcAo19Fgbngggti5cqVsX79+lixYkXs27cv5s2bF2+//fZHvqa9vT2ampr6j5aWliGPBuDIV1cURTHYF/f19cUpp5wS3/zmN6Otre1DzymXy1Eul/sf9/T0REtLS1wcV0V93djBXhqAKjhQvB8b44no7u6OxsbGw55b8Wcw/9eECRPinHPOid27d3/kOaVSKUql0lAuA0ANGtLPwZTL5XjppZdi6tSpw7UHgFGiosDcfvvtsWnTpnjttdfiz3/+c3zhC1+Inp6eWLx4cdY+AGpURV8i+9vf/hZf+tKX4q233ooTTzwxPv3pT8eWLVti+vTpWfsAqFEVBWbVqlVZOwAYZfwuMgBSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApKg7M66+/HjfccENMmjQpjjnmmDj33HNj+/btGdsAqGH1lZz8zjvvxPz58+OSSy6JtWvXxuTJk+Ovf/1rHHfccUnzAKhVFQXmBz/4QbS0tMTDDz/c/9wnPvGJ4d4EwChQ0ZfI1qxZE62trXHNNdfE5MmT47zzzosVK1Yc9jXlcjl6enoGHACMfhUFZs+ePdHR0RGnnXZarF+/PpYsWRK33HJLrFy58iNf097eHk1NTf1HS0vLkEcDcOSrK4qi+Lgnjxs3LlpbW+PZZ5/tf+6WW26JrVu3xp/+9KcPfU25XI5yudz/uKenJ1paWuLiuCrq68YOYToAI+1A8X5sjCeiu7s7GhsbD3tuRXcwU6dOjbPOOmvAc2eeeWbs3bv3I19TKpWisbFxwAHA6FdRYObPnx8vv/zygOdeeeWVmD59+rCOAqD2VRSY2267LbZs2RL33ntvvPrqq/HII49EZ2dnLF26NGsfADWqosCcf/75sXr16nj00Udj5syZ8d3vfjfuv//+uP7667P2AVCjKvo5mIiIK6+8Mq688sqMLQCMIn4XGQApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUFf+TyUNVFEVERByI9yOKkb46AENxIN6PiP/9u/xwRjwwvb29ERHxx3hypC8NwDDp7e2Npqamw55TV3ycDA2jDz74IN54441oaGiIurq6tOv09PRES0tLdHV1RWNjY9p1RpL3dOQbbe8nwnuqFSP1noqiiN7e3mhubo6jjjr8pywjfgdz1FFHxbRp00bseo2NjaPmD9A/eU9HvtH2fiK8p1oxEu/p3925/JMP+QFIITAApBi1gSmVSnH33XdHqVSq9pRh4z0d+Ubb+4nwnmrFkfieRvxDfgD+M4zaOxgAqktgAEghMACkEBgAUozKwDz00EMxY8aMGD9+fMyZMyeeeeaZak8aks2bN8eiRYuiubk56urq4vHHH6/2pCFpb2+P888/PxoaGmLy5Mlx9dVXx8svv1ztWUPS0dERs2bN6v8ht7lz58batWurPWvYtLe3R11dXSxbtqzaU4bknnvuibq6ugHHSSedVO1ZQ/L666/HDTfcEJMmTYpjjjkmzj333Ni+fXu1Z0XEKAzMY489FsuWLYu77rornn/++bjoooti4cKFsXfv3mpPG7S+vr6YPXt2PPjgg9WeMiw2bdoUS5cujS1btsSGDRviwIEDsWDBgujr66v2tEGbNm1a3HfffbFt27bYtm1bXHrppXHVVVfFrl27qj1tyLZu3RqdnZ0xa9asak8ZFmeffXa8+eab/cfOnTurPWnQ3nnnnZg/f36MHTs21q5dGy+++GL88Ic/jOOOO67a0/5bMcp86lOfKpYsWTLguTPOOKP41re+VaVFwysiitWrV1d7xrDav39/ERHFpk2bqj1lWB1//PHFz372s2rPGJLe3t7itNNOKzZs2FB89rOfLW699dZqTxqSu+++u5g9e3a1ZwybO+64o7jwwgurPeMjjao7mPfeey+2b98eCxYsGPD8ggUL4tlnn63SKv6d7u7uiIiYOHFilZcMj4MHD8aqVauir68v5s6dW+05Q7J06dK44oor4vLLL6/2lGGze/fuaG5ujhkzZsR1110Xe/bsqfakQVuzZk20trbGNddcE5MnT47zzjsvVqxYUe1Z/UZVYN566604ePBgTJkyZcDzU6ZMiX379lVpFYdTFEW0tbXFhRdeGDNnzqz2nCHZuXNnHHvssVEqlWLJkiWxevXqOOuss6o9a9BWrVoVzz33XLS3t1d7yrC54IILYuXKlbF+/fpYsWJF7Nu3L+bNmxdvv/12tacNyp49e6KjoyNOO+20WL9+fSxZsiRuueWWWLlyZbWnRUQVfpvySPjXfwagKIrUfxqAwbvpppvihRdeiD/+8Y/VnjJkp59+euzYsSP+8Y9/xG9+85tYvHhxbNq0qSYj09XVFbfeems89dRTMX78+GrPGTYLFy7s/+9zzjkn5s6dG6ecckr88pe/jLa2tiouG5wPPvggWltb4957742IiPPOOy927doVHR0d8eUvf7nK60bZHcwJJ5wQY8aMOeRuZf/+/Yfc1VB9N998c6xZsyaefvrpEf0nHLKMGzcuTj311GhtbY329vaYPXt2PPDAA9WeNSjbt2+P/fv3x5w5c6K+vj7q6+tj06ZN8eMf/zjq6+vj4MGD1Z44LCZMmBDnnHNO7N69u9pTBmXq1KmH/A/MmWeeecR8U9OoCsy4ceNizpw5sWHDhgHPb9iwIebNm1elVfyroijipptuit/+9rfxhz/8IWbMmFHtSSmKoohyuVztGYNy2WWXxc6dO2PHjh39R2tra1x//fWxY8eOGDNmTLUnDotyuRwvvfRSTJ06tdpTBmX+/PmHfIv/K6+8EtOnT6/SooFG3ZfI2tra4sYbb4zW1taYO3dudHZ2xt69e2PJkiXVnjZo7777brz66qv9j1977bXYsWNHTJw4MU4++eQqLhucpUuXxiOPPBJPPPFENDQ09N9xNjU1xdFHH13ldYNz5513xsKFC6OlpSV6e3tj1apVsXHjxli3bl21pw1KQ0PDIZ+JTZgwISZNmlTTn5XdfvvtsWjRojj55JNj//798b3vfS96enpi8eLF1Z42KLfddlvMmzcv7r333vjiF78Yf/nLX6KzszM6OzurPe2/Vfeb2HL85Cc/KaZPn16MGzeu+OQnP1nz3/769NNPFxFxyLF48eJqTxuUD3svEVE8/PDD1Z42aF/96lf7/8ydeOKJxWWXXVY89dRT1Z41rEbDtylfe+21xdSpU4uxY8cWzc3Nxec///li165d1Z41JL/73e+KmTNnFqVSqTjjjDOKzs7Oak/q59f1A5BiVH0GA8CRQ2AASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUvx/JU/30Wjw9isAAAAASUVORK5CYII=\n" 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "plt.imshow(moving);" 78 | ], 79 | "metadata": { 80 | "collapsed": false 81 | } 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": "
", 90 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVs0lEQVR4nO3dfWyVhd3/8W+lcFBsq6AgDZURNT4h6KhzgG4+jTv8lGiWOV3UkT38wYIP2Jg59A/dk3V/bNHF2axscSOLYpYNZYmALBNwcWyAEgkaxWFCpzKicW3tH0fB6/fHfa/33aHM0/bbw+ler+RKPCfXyfU5CeHt1dOWuqIoigCAYXZUtQcAMDoJDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKSoH+kLfvDBB/HGG29EQ0ND1NXVjfTlARiCoiiit7c3mpub46ijDn+PMuKBeeONN6KlpWWkLwvAMOrq6opp06Yd9pwRD0xDQ0NERFwY/y/qY+xIXx6AITgQ78cf48n+v8sPZ8QD888vi9XH2KivExiAmvI/v73y43zE4UN+AFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUgwrMQw89FDNmzIjx48fHnDlz4plnnhnuXQDUuIoD89hjj8WyZcvirrvuiueffz4uuuiiWLhwYezduzdjHwA1quLA/OhHP4qvfe1r8fWvfz3OPPPMuP/++6OlpSU6Ojoy9gFQoyoKzHvvvRfbt2+PBQsWDHh+wYIF8eyzz37oa8rlcvT09Aw4ABj9KgrMW2+9FQcPHowpU6YMeH7KlCmxb9++D31Ne3t7NDU19R8tLS2DXwtAzRjUh/x1dXUDHhdFcchz/7R8+fLo7u7uP7q6ugZzSQBqTH0lJ59wwgkxZsyYQ+5W9u/ff8hdzT+VSqUolUqDXwhATaroDmbcuHExZ86c2LBhw4DnN2zYEPPmzRvWYQDUtoruYCIi2tra4sYbb4zW1taYO3dudHZ2xt69e2PJkiUZ+wCoURUH5tprr4233347vvOd78Sbb74ZM2fOjCeffDKmT5+esQ+AGlVXFEUxkhfs6emJpqamuDiuivq6sSN5aQCG6EDxfmyMJ6K7uzsaGxsPe67fRQZACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEgRcWB2bx5cyxatCiam5ujrq4uHn/88YRZANS6igPT19cXs2fPjgcffDBjDwCjRH2lL1i4cGEsXLgwYwsAo0jFgalUuVyOcrnc/7inpyf7kgAcAdI/5G9vb4+mpqb+o6WlJfuSABwB0gOzfPny6O7u7j+6urqyLwnAESD9S2SlUilKpVL2ZQA4wvg5GABSVHwH8+6778arr77a//i1116LHTt2xMSJE+Pkk08e1nEA1K6KA7Nt27a45JJL+h+3tbVFRMTixYvjF7/4xbANA6C2VRyYiy++OIqiyNgCwCjiMxgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBT11R7AkWv9GzuqPYH/QP/VfG61JzBM3MEAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQVBaa9vT3OP//8aGhoiMmTJ8fVV18dL7/8ctY2AGpYRYHZtGlTLF26NLZs2RIbNmyIAwcOxIIFC6Kvry9rHwA1qr6Sk9etWzfg8cMPPxyTJ0+O7du3x2c+85lhHQZAbasoMP+qu7s7IiImTpz4keeUy+Uol8v9j3t6eoZySQBqxKA/5C+KItra2uLCCy+MmTNnfuR57e3t0dTU1H+0tLQM9pIA1JBBB+amm26KF154IR599NHDnrd8+fLo7u7uP7q6ugZ7SQBqyKC+RHbzzTfHmjVrYvPmzTFt2rTDnlsqlaJUKg1qHAC1q6LAFEURN998c6xevTo2btwYM2bMyNoFQI2rKDBLly6NRx55JJ544oloaGiIffv2RUREU1NTHH300SkDAahNFX0G09HREd3d3XHxxRfH1KlT+4/HHnssax8ANariL5EBwMfhd5EBkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQoqLAdHR0xKxZs6KxsTEaGxtj7ty5sXbt2qxtANSwigIzbdq0uO+++2Lbtm2xbdu2uPTSS+Oqq66KXbt2Ze0DoEbVV3LyokWLBjz+/ve/Hx0dHbFly5Y4++yzh3UYALWtosD8XwcPHoxf//rX0dfXF3Pnzv3I88rlcpTL5f7HPT09g70kADWk4g/5d+7cGccee2yUSqVYsmRJrF69Os4666yPPL+9vT2ampr6j5aWliENBqA2VByY008/PXbs2BFbtmyJb3zjG7F48eJ48cUXP/L85cuXR3d3d//R1dU1pMEA1IaKv0Q2bty4OPXUUyMiorW1NbZu3RoPPPBA/PSnP/3Q80ulUpRKpaGtBKDmDPnnYIqiGPAZCwBEVHgHc+edd8bChQujpaUlent7Y9WqVbFx48ZYt25d1j4AalRFgfn73/8eN954Y7z55pvR1NQUs2bNinXr1sXnPve5rH0A1KiKAvPzn/88awcAo4zfRQZACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEgxZAC097eHnV1dbFs2bJhmgPAaDHowGzdujU6Oztj1qxZw7kHgFFiUIF599134/rrr48VK1bE8ccfP9ybABgFBhWYpUuXxhVXXBGXX375vz23XC5HT0/PgAOA0a++0hesWrUqnnvuudi6devHOr+9vT2+/e1vVzwMgNpW0R1MV1dX3HrrrfGrX/0qxo8f/7Fes3z58uju7u4/urq6BjUUgNpS0R3M9u3bY//+/TFnzpz+5w4ePBibN2+OBx98MMrlcowZM2bAa0qlUpRKpeFZC0DNqCgwl112WezcuXPAc1/5ylfijDPOiDvuuOOQuADwn6uiwDQ0NMTMmTMHPDdhwoSYNGnSIc8D8J/NT/IDkKLi7yL7Vxs3bhyGGQCMNu5gAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABS1Fd7AEeu/2o+t9oTgBrmDgaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApKgoMPfcc0/U1dUNOE466aSsbQDUsPpKX3D22WfH73//+/7HY8aMGdZBAIwOFQemvr7eXQsA/1bFn8Hs3r07mpubY8aMGXHdddfFnj17Dnt+uVyOnp6eAQcAo19Fgbngggti5cqVsX79+lixYkXs27cv5s2bF2+//fZHvqa9vT2ampr6j5aWliGPBuDIV1cURTHYF/f19cUpp5wS3/zmN6Otre1DzymXy1Eul/sf9/T0REtLS1wcV0V93djBXhqAKjhQvB8b44no7u6OxsbGw55b8Wcw/9eECRPinHPOid27d3/kOaVSKUql0lAuA0ANGtLPwZTL5XjppZdi6tSpw7UHgFGiosDcfvvtsWnTpnjttdfiz3/+c3zhC1+Inp6eWLx4cdY+AGpURV8i+9vf/hZf+tKX4q233ooTTzwxPv3pT8eWLVti+vTpWfsAqFEVBWbVqlVZOwAYZfwuMgBSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApKg7M66+/HjfccENMmjQpjjnmmDj33HNj+/btGdsAqGH1lZz8zjvvxPz58+OSSy6JtWvXxuTJk+Ovf/1rHHfccUnzAKhVFQXmBz/4QbS0tMTDDz/c/9wnPvGJ4d4EwChQ0ZfI1qxZE62trXHNNdfE5MmT47zzzosVK1Yc9jXlcjl6enoGHACMfhUFZs+ePdHR0RGnnXZarF+/PpYsWRK33HJLrFy58iNf097eHk1NTf1HS0vLkEcDcOSrK4qi+Lgnjxs3LlpbW+PZZ5/tf+6WW26JrVu3xp/+9KcPfU25XI5yudz/uKenJ1paWuLiuCrq68YOYToAI+1A8X5sjCeiu7s7GhsbD3tuRXcwU6dOjbPOOmvAc2eeeWbs3bv3I19TKpWisbFxwAHA6FdRYObPnx8vv/zygOdeeeWVmD59+rCOAqD2VRSY2267LbZs2RL33ntvvPrqq/HII49EZ2dnLF26NGsfADWqosCcf/75sXr16nj00Udj5syZ8d3vfjfuv//+uP7667P2AVCjKvo5mIiIK6+8Mq688sqMLQCMIn4XGQApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUFf+TyUNVFEVERByI9yOKkb46AENxIN6PiP/9u/xwRjwwvb29ERHxx3hypC8NwDDp7e2Npqamw55TV3ycDA2jDz74IN54441oaGiIurq6tOv09PRES0tLdHV1RWNjY9p1RpL3dOQbbe8nwnuqFSP1noqiiN7e3mhubo6jjjr8pywjfgdz1FFHxbRp00bseo2NjaPmD9A/eU9HvtH2fiK8p1oxEu/p3925/JMP+QFIITAApBi1gSmVSnH33XdHqVSq9pRh4z0d+Ubb+4nwnmrFkfieRvxDfgD+M4zaOxgAqktgAEghMACkEBgAUozKwDz00EMxY8aMGD9+fMyZMyeeeeaZak8aks2bN8eiRYuiubk56urq4vHHH6/2pCFpb2+P888/PxoaGmLy5Mlx9dVXx8svv1ztWUPS0dERs2bN6v8ht7lz58batWurPWvYtLe3R11dXSxbtqzaU4bknnvuibq6ugHHSSedVO1ZQ/L666/HDTfcEJMmTYpjjjkmzj333Ni+fXu1Z0XEKAzMY489FsuWLYu77rornn/++bjoooti4cKFsXfv3mpPG7S+vr6YPXt2PPjgg9WeMiw2bdoUS5cujS1btsSGDRviwIEDsWDBgujr66v2tEGbNm1a3HfffbFt27bYtm1bXHrppXHVVVfFrl27qj1tyLZu3RqdnZ0xa9asak8ZFmeffXa8+eab/cfOnTurPWnQ3nnnnZg/f36MHTs21q5dGy+++GL88Ic/jOOOO67a0/5bMcp86lOfKpYsWTLguTPOOKP41re+VaVFwysiitWrV1d7xrDav39/ERHFpk2bqj1lWB1//PHFz372s2rPGJLe3t7itNNOKzZs2FB89rOfLW699dZqTxqSu+++u5g9e3a1ZwybO+64o7jwwgurPeMjjao7mPfeey+2b98eCxYsGPD8ggUL4tlnn63SKv6d7u7uiIiYOHFilZcMj4MHD8aqVauir68v5s6dW+05Q7J06dK44oor4vLLL6/2lGGze/fuaG5ujhkzZsR1110Xe/bsqfakQVuzZk20trbGNddcE5MnT47zzjsvVqxYUe1Z/UZVYN566604ePBgTJkyZcDzU6ZMiX379lVpFYdTFEW0tbXFhRdeGDNnzqz2nCHZuXNnHHvssVEqlWLJkiWxevXqOOuss6o9a9BWrVoVzz33XLS3t1d7yrC54IILYuXKlbF+/fpYsWJF7Nu3L+bNmxdvv/12tacNyp49e6KjoyNOO+20WL9+fSxZsiRuueWWWLlyZbWnRUQVfpvySPjXfwagKIrUfxqAwbvpppvihRdeiD/+8Y/VnjJkp59+euzYsSP+8Y9/xG9+85tYvHhxbNq0qSYj09XVFbfeems89dRTMX78+GrPGTYLFy7s/+9zzjkn5s6dG6ecckr88pe/jLa2tiouG5wPPvggWltb4957742IiPPOOy927doVHR0d8eUvf7nK60bZHcwJJ5wQY8aMOeRuZf/+/Yfc1VB9N998c6xZsyaefvrpEf0nHLKMGzcuTj311GhtbY329vaYPXt2PPDAA9WeNSjbt2+P/fv3x5w5c6K+vj7q6+tj06ZN8eMf/zjq6+vj4MGD1Z44LCZMmBDnnHNO7N69u9pTBmXq1KmH/A/MmWeeecR8U9OoCsy4ceNizpw5sWHDhgHPb9iwIebNm1elVfyroijipptuit/+9rfxhz/8IWbMmFHtSSmKoohyuVztGYNy2WWXxc6dO2PHjh39R2tra1x//fWxY8eOGDNmTLUnDotyuRwvvfRSTJ06tdpTBmX+/PmHfIv/K6+8EtOnT6/SooFG3ZfI2tra4sYbb4zW1taYO3dudHZ2xt69e2PJkiXVnjZo7777brz66qv9j1977bXYsWNHTJw4MU4++eQqLhucpUuXxiOPPBJPPPFENDQ09N9xNjU1xdFHH13ldYNz5513xsKFC6OlpSV6e3tj1apVsXHjxli3bl21pw1KQ0PDIZ+JTZgwISZNmlTTn5XdfvvtsWjRojj55JNj//798b3vfS96enpi8eLF1Z42KLfddlvMmzcv7r333vjiF78Yf/nLX6KzszM6OzurPe2/Vfeb2HL85Cc/KaZPn16MGzeu+OQnP1nz3/769NNPFxFxyLF48eJqTxuUD3svEVE8/PDD1Z42aF/96lf7/8ydeOKJxWWXXVY89dRT1Z41rEbDtylfe+21xdSpU4uxY8cWzc3Nxec///li165d1Z41JL/73e+KmTNnFqVSqTjjjDOKzs7Oak/q59f1A5BiVH0GA8CRQ2AASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUvx/2RL30ephvEsAAAAASUVORK5CYII=\n" 91 | }, 92 | "metadata": {}, 93 | "output_type": "display_data" 94 | } 95 | ], 96 | "source": [ 97 | "plt.imshow(static);" 98 | ], 99 | "metadata": { 100 | "collapsed": false 101 | } 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "source": [ 106 | "Initialize AffineRegistration with is_3d=False since we have 2D images.\n", 107 | "\n", 108 | "Change scales from default (4, 2) to (1,) since downscaling would not make sense for 7x7 pixel resolution." 109 | ], 110 | "metadata": { 111 | "collapsed": false 112 | } 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 5, 117 | "outputs": [], 118 | "source": [ 119 | "reg = AffineRegistration(is_3d=False, scales=(1,), learning_rate=1e-1)" 120 | ], 121 | "metadata": { 122 | "collapsed": false 123 | } 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "source": [ 128 | "Run registration with two dimensions added (via `[None, None]`)...\n", 129 | "\n", 130 | "(...since it is torch convention for images to have **batch + channel** dimension prior to spatial dimensions x + y (+ z))" 131 | ], 132 | "metadata": { 133 | "collapsed": false 134 | } 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "outputs": [ 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "Shape: [1, 1, 7, 7]; Dissimiliarity: 1.1252674817332012e-13: 100%|██████████| 500/500 [00:00<00:00, 1205.78it/s]\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "moved = reg(moving[None, None], static[None, None])\n", 150 | "moved = moved[0, 0]" 151 | ], 152 | "metadata": { 153 | "collapsed": false 154 | } 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "source": [ 159 | "Runs fast and dissimilarity approaches 0 ✔️\n", 160 | "\n", 161 | "Let's look at the moved image!" 162 | ], 163 | "metadata": { 164 | "collapsed": false 165 | } 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 7, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": "
", 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVs0lEQVR4nO3dfWyVhd3/8W+lcFBsq6AgDZURNT4h6KhzgG4+jTv8lGiWOV3UkT38wYIP2Jg59A/dk3V/bNHF2axscSOLYpYNZYmALBNwcWyAEgkaxWFCpzKicW3tH0fB6/fHfa/33aHM0/bbw+ler+RKPCfXyfU5CeHt1dOWuqIoigCAYXZUtQcAMDoJDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKSoH+kLfvDBB/HGG29EQ0ND1NXVjfTlARiCoiiit7c3mpub46ijDn+PMuKBeeONN6KlpWWkLwvAMOrq6opp06Yd9pwRD0xDQ0NERFwY/y/qY+xIXx6AITgQ78cf48n+v8sPZ8QD888vi9XH2KivExiAmvI/v73y43zE4UN+AFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUgwrMQw89FDNmzIjx48fHnDlz4plnnhnuXQDUuIoD89hjj8WyZcvirrvuiueffz4uuuiiWLhwYezduzdjHwA1quLA/OhHP4qvfe1r8fWvfz3OPPPMuP/++6OlpSU6Ojoy9gFQoyoKzHvvvRfbt2+PBQsWDHh+wYIF8eyzz37oa8rlcvT09Aw4ABj9KgrMW2+9FQcPHowpU6YMeH7KlCmxb9++D31Ne3t7NDU19R8tLS2DXwtAzRjUh/x1dXUDHhdFcchz/7R8+fLo7u7uP7q6ugZzSQBqTH0lJ59wwgkxZsyYQ+5W9u/ff8hdzT+VSqUolUqDXwhATaroDmbcuHExZ86c2LBhw4DnN2zYEPPmzRvWYQDUtoruYCIi2tra4sYbb4zW1taYO3dudHZ2xt69e2PJkiUZ+wCoURUH5tprr4233347vvOd78Sbb74ZM2fOjCeffDKmT5+esQ+AGlVXFEUxkhfs6emJpqamuDiuivq6sSN5aQCG6EDxfmyMJ6K7uzsaGxsPe67fRQZACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEgRcWB2bx5cyxatCiam5ujrq4uHn/88YRZANS6igPT19cXs2fPjgcffDBjDwCjRH2lL1i4cGEsXLgwYwsAo0jFgalUuVyOcrnc/7inpyf7kgAcAdI/5G9vb4+mpqb+o6WlJfuSABwB0gOzfPny6O7u7j+6urqyLwnAESD9S2SlUilKpVL2ZQA4wvg5GABSVHwH8+6778arr77a//i1116LHTt2xMSJE+Pkk08e1nEA1K6KA7Nt27a45JJL+h+3tbVFRMTixYvjF7/4xbANA6C2VRyYiy++OIqiyNgCwCjiMxgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBT11R7AkWv9GzuqPYH/QP/VfG61JzBM3MEAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQVBaa9vT3OP//8aGhoiMmTJ8fVV18dL7/8ctY2AGpYRYHZtGlTLF26NLZs2RIbNmyIAwcOxIIFC6Kvry9rHwA1qr6Sk9etWzfg8cMPPxyTJ0+O7du3x2c+85lhHQZAbasoMP+qu7s7IiImTpz4keeUy+Uol8v9j3t6eoZySQBqxKA/5C+KItra2uLCCy+MmTNnfuR57e3t0dTU1H+0tLQM9pIA1JBBB+amm26KF154IR599NHDnrd8+fLo7u7uP7q6ugZ7SQBqyKC+RHbzzTfHmjVrYvPmzTFt2rTDnlsqlaJUKg1qHAC1q6LAFEURN998c6xevTo2btwYM2bMyNoFQI2rKDBLly6NRx55JJ544oloaGiIffv2RUREU1NTHH300SkDAahNFX0G09HREd3d3XHxxRfH1KlT+4/HHnssax8ANariL5EBwMfhd5EBkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQoqLAdHR0xKxZs6KxsTEaGxtj7ty5sXbt2qxtANSwigIzbdq0uO+++2Lbtm2xbdu2uPTSS+Oqq66KXbt2Ze0DoEbVV3LyokWLBjz+/ve/Hx0dHbFly5Y4++yzh3UYALWtosD8XwcPHoxf//rX0dfXF3Pnzv3I88rlcpTL5f7HPT09g70kADWk4g/5d+7cGccee2yUSqVYsmRJrF69Os4666yPPL+9vT2ampr6j5aWliENBqA2VByY008/PXbs2BFbtmyJb3zjG7F48eJ48cUXP/L85cuXR3d3d//R1dU1pMEA1IaKv0Q2bty4OPXUUyMiorW1NbZu3RoPPPBA/PSnP/3Q80ulUpRKpaGtBKDmDPnnYIqiGPAZCwBEVHgHc+edd8bChQujpaUlent7Y9WqVbFx48ZYt25d1j4AalRFgfn73/8eN954Y7z55pvR1NQUs2bNinXr1sXnPve5rH0A1KiKAvPzn/88awcAo4zfRQZACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEgxZAC097eHnV1dbFs2bJhmgPAaDHowGzdujU6Oztj1qxZw7kHgFFiUIF599134/rrr48VK1bE8ccfP9ybABgFBhWYpUuXxhVXXBGXX375vz23XC5HT0/PgAOA0a++0hesWrUqnnvuudi6devHOr+9vT2+/e1vVzwMgNpW0R1MV1dX3HrrrfGrX/0qxo8f/7Fes3z58uju7u4/urq6BjUUgNpS0R3M9u3bY//+/TFnzpz+5w4ePBibN2+OBx98MMrlcowZM2bAa0qlUpRKpeFZC0DNqCgwl112WezcuXPAc1/5ylfijDPOiDvuuOOQuADwn6uiwDQ0NMTMmTMHPDdhwoSYNGnSIc8D8J/NT/IDkKLi7yL7Vxs3bhyGGQCMNu5gAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABS1Fd7AEeu/2o+t9oTgBrmDgaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApKgoMPfcc0/U1dUNOE466aSsbQDUsPpKX3D22WfH73//+/7HY8aMGdZBAIwOFQemvr7eXQsA/1bFn8Hs3r07mpubY8aMGXHdddfFnj17Dnt+uVyOnp6eAQcAo19Fgbngggti5cqVsX79+lixYkXs27cv5s2bF2+//fZHvqa9vT2ampr6j5aWliGPBuDIV1cURTHYF/f19cUpp5wS3/zmN6Otre1DzymXy1Eul/sf9/T0REtLS1wcV0V93djBXhqAKjhQvB8b44no7u6OxsbGw55b8Wcw/9eECRPinHPOid27d3/kOaVSKUql0lAuA0ANGtLPwZTL5XjppZdi6tSpw7UHgFGiosDcfvvtsWnTpnjttdfiz3/+c3zhC1+Inp6eWLx4cdY+AGpURV8i+9vf/hZf+tKX4q233ooTTzwxPv3pT8eWLVti+vTpWfsAqFEVBWbVqlVZOwAYZfwuMgBSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApKg7M66+/HjfccENMmjQpjjnmmDj33HNj+/btGdsAqGH1lZz8zjvvxPz58+OSSy6JtWvXxuTJk+Ovf/1rHHfccUnzAKhVFQXmBz/4QbS0tMTDDz/c/9wnPvGJ4d4EwChQ0ZfI1qxZE62trXHNNdfE5MmT47zzzosVK1Yc9jXlcjl6enoGHACMfhUFZs+ePdHR0RGnnXZarF+/PpYsWRK33HJLrFy58iNf097eHk1NTf1HS0vLkEcDcOSrK4qi+Lgnjxs3LlpbW+PZZ5/tf+6WW26JrVu3xp/+9KcPfU25XI5yudz/uKenJ1paWuLiuCrq68YOYToAI+1A8X5sjCeiu7s7GhsbD3tuRXcwU6dOjbPOOmvAc2eeeWbs3bv3I19TKpWisbFxwAHA6FdRYObPnx8vv/zygOdeeeWVmD59+rCOAqD2VRSY2267LbZs2RL33ntvvPrqq/HII49EZ2dnLF26NGsfADWqosCcf/75sXr16nj00Udj5syZ8d3vfjfuv//+uP7667P2AVCjKvo5mIiIK6+8Mq688sqMLQCMIn4XGQApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUFf+TyUNVFEVERByI9yOKkb46AENxIN6PiP/9u/xwRjwwvb29ERHxx3hypC8NwDDp7e2Npqamw55TV3ycDA2jDz74IN54441oaGiIurq6tOv09PRES0tLdHV1RWNjY9p1RpL3dOQbbe8nwnuqFSP1noqiiN7e3mhubo6jjjr8pywjfgdz1FFHxbRp00bseo2NjaPmD9A/eU9HvtH2fiK8p1oxEu/p3925/JMP+QFIITAApBi1gSmVSnH33XdHqVSq9pRh4z0d+Ubb+4nwnmrFkfieRvxDfgD+M4zaOxgAqktgAEghMACkEBgAUozKwDz00EMxY8aMGD9+fMyZMyeeeeaZak8aks2bN8eiRYuiubk56urq4vHHH6/2pCFpb2+P888/PxoaGmLy5Mlx9dVXx8svv1ztWUPS0dERs2bN6v8ht7lz58batWurPWvYtLe3R11dXSxbtqzaU4bknnvuibq6ugHHSSedVO1ZQ/L666/HDTfcEJMmTYpjjjkmzj333Ni+fXu1Z0XEKAzMY489FsuWLYu77rornn/++bjoooti4cKFsXfv3mpPG7S+vr6YPXt2PPjgg9WeMiw2bdoUS5cujS1btsSGDRviwIEDsWDBgujr66v2tEGbNm1a3HfffbFt27bYtm1bXHrppXHVVVfFrl27qj1tyLZu3RqdnZ0xa9asak8ZFmeffXa8+eab/cfOnTurPWnQ3nnnnZg/f36MHTs21q5dGy+++GL88Ic/jOOOO67a0/5bMcp86lOfKpYsWTLguTPOOKP41re+VaVFwysiitWrV1d7xrDav39/ERHFpk2bqj1lWB1//PHFz372s2rPGJLe3t7itNNOKzZs2FB89rOfLW699dZqTxqSu+++u5g9e3a1ZwybO+64o7jwwgurPeMjjao7mPfeey+2b98eCxYsGPD8ggUL4tlnn63SKv6d7u7uiIiYOHFilZcMj4MHD8aqVauir68v5s6dW+05Q7J06dK44oor4vLLL6/2lGGze/fuaG5ujhkzZsR1110Xe/bsqfakQVuzZk20trbGNddcE5MnT47zzjsvVqxYUe1Z/UZVYN566604ePBgTJkyZcDzU6ZMiX379lVpFYdTFEW0tbXFhRdeGDNnzqz2nCHZuXNnHHvssVEqlWLJkiWxevXqOOuss6o9a9BWrVoVzz33XLS3t1d7yrC54IILYuXKlbF+/fpYsWJF7Nu3L+bNmxdvv/12tacNyp49e6KjoyNOO+20WL9+fSxZsiRuueWWWLlyZbWnRUQVfpvySPjXfwagKIrUfxqAwbvpppvihRdeiD/+8Y/VnjJkp59+euzYsSP+8Y9/xG9+85tYvHhxbNq0qSYj09XVFbfeems89dRTMX78+GrPGTYLFy7s/+9zzjkn5s6dG6ecckr88pe/jLa2tiouG5wPPvggWltb4957742IiPPOOy927doVHR0d8eUvf7nK60bZHcwJJ5wQY8aMOeRuZf/+/Yfc1VB9N998c6xZsyaefvrpEf0nHLKMGzcuTj311GhtbY329vaYPXt2PPDAA9WeNSjbt2+P/fv3x5w5c6K+vj7q6+tj06ZN8eMf/zjq6+vj4MGD1Z44LCZMmBDnnHNO7N69u9pTBmXq1KmH/A/MmWeeecR8U9OoCsy4ceNizpw5sWHDhgHPb9iwIebNm1elVfyroijipptuit/+9rfxhz/8IWbMmFHtSSmKoohyuVztGYNy2WWXxc6dO2PHjh39R2tra1x//fWxY8eOGDNmTLUnDotyuRwvvfRSTJ06tdpTBmX+/PmHfIv/K6+8EtOnT6/SooFG3ZfI2tra4sYbb4zW1taYO3dudHZ2xt69e2PJkiXVnjZo7777brz66qv9j1977bXYsWNHTJw4MU4++eQqLhucpUuXxiOPPBJPPPFENDQ09N9xNjU1xdFHH13ldYNz5513xsKFC6OlpSV6e3tj1apVsXHjxli3bl21pw1KQ0PDIZ+JTZgwISZNmlTTn5XdfvvtsWjRojj55JNj//798b3vfS96enpi8eLF1Z42KLfddlvMmzcv7r333vjiF78Yf/nLX6KzszM6OzurPe2/Vfeb2HL85Cc/KaZPn16MGzeu+OQnP1nz3/769NNPFxFxyLF48eJqTxuUD3svEVE8/PDD1Z42aF/96lf7/8ydeOKJxWWXXVY89dRT1Z41rEbDtylfe+21xdSpU4uxY8cWzc3Nxec///li165d1Z41JL/73e+KmTNnFqVSqTjjjDOKzs7Oak/q59f1A5BiVH0GA8CRQ2AASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUvx/2RL30ephvEsAAAAASUVORK5CYII=\n" 175 | }, 176 | "metadata": {}, 177 | "output_type": "display_data" 178 | } 179 | ], 180 | "source": [ 181 | "plt.imshow(moved);" 182 | ], 183 | "metadata": { 184 | "collapsed": false 185 | } 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "Nice, the formally non-centered square (see moving-plot before registration) is now aligned to the static (see static-plot before registration) and therefore centered!\n", 191 | "\n", 192 | "Next, extract the affine and the four parameters which were optimized in the background." 193 | ], 194 | "metadata": { 195 | "collapsed": false 196 | } 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "outputs": [], 202 | "source": [ 203 | "affine = reg.get_affine()\n", 204 | "translation = reg._parameters[0]\n", 205 | "rotation = reg._parameters[1]\n", 206 | "zoom = reg._parameters[2]\n", 207 | "shear = reg._parameters[3]" 208 | ], 209 | "metadata": { 210 | "collapsed": false 211 | } 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 9, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "tensor([[[ 1.0000e+00, 4.2938e-08, 3.4270e-08],\n", 222 | " [-4.6447e-08, 1.0000e+00, -3.3333e-01]]])\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "print(affine)" 228 | ], 229 | "metadata": { 230 | "collapsed": false 231 | } 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 10, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "Parameter containing:\n", 242 | "tensor([[ 3.4270e-08, -3.3333e-01]], requires_grad=True) Parameter containing:\n", 243 | "tensor([[[ 1.0000e+00, 4.2938e-08],\n", 244 | " [-4.6447e-08, 1.0000e+00]]], requires_grad=True) Parameter containing:\n", 245 | "tensor([[1.0000, 1.0000]], requires_grad=True) Parameter containing:\n", 246 | "tensor([[0., 0.]])\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "print(translation, rotation, zoom, shear)" 252 | ], 253 | "metadata": { 254 | "collapsed": false 255 | } 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "source": [ 260 | "As expected (we are aligning two identical squares which are shifted in one dimension), all parameter values except of one are near 0 or 1.\n", 261 | "\n", 262 | "The \"one value\" `-3.3333e-01` is in the `translation` parameter as expected ✔️\n", 263 | "\n", 264 | "## Important side note!!!\n", 265 | "\n", 266 | "Unexpectedly, the `-3.3333e-01` value is in the **second** dimension of `translation` despite the misalignment being in the **first** dimension (`moving[1:4, 2:5] = 1` vs `static[2:5, 2:5] = 1`. That's because of the [torch convention of coordinate grids](https://discuss.pytorch.org/t/surprising-convention-for-grid-sample-coordinates/79997).\n", 267 | "\n", 268 | "This convention results in:\n", 269 | "- 2D: **X and Y coordinate are in the order `[Y, X]`**\n", 270 | "- 3D: **X, Y and Z coordinate are in the order `[Z, Y, X]`**\n", 271 | "\n", 272 | "A **workaround** if you don't like that is to flip the dimensions of the registrations input tensors via `.permute(1, 0)` (or `.permute(2, 1, 0)` in 3D):" 273 | ], 274 | "metadata": { 275 | "collapsed": false 276 | } 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 11, 281 | "outputs": [ 282 | { 283 | "name": "stderr", 284 | "output_type": "stream", 285 | "text": [ 286 | "Shape: [1, 1, 7, 7]; Dissimiliarity: 3.978223503509071e-06: 100%|██████████| 500/500 [00:00<00:00, 1186.11it/s] " 287 | ] 288 | }, 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "Parameter containing:\n", 294 | "tensor([[-3.3333e-01, -1.8018e-06]], requires_grad=True)\n" 295 | ] 296 | }, 297 | { 298 | "name": "stderr", 299 | "output_type": "stream", 300 | "text": [ 301 | "\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "reg = AffineRegistration(is_3d=False, scales=(1,), learning_rate=1e-2)\n", 307 | "moved = reg(moving.permute(1, 0)[None, None], static.permute(1, 0)[None, None])\n", 308 | "translation = reg._parameters[0]\n", 309 | "print(translation)" 310 | ], 311 | "metadata": { 312 | "collapsed": false 313 | } 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "source": [ 318 | "## Cube registration (3D) without progress bar" 319 | ], 320 | "metadata": { 321 | "collapsed": false 322 | } 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 12, 327 | "outputs": [ 328 | { 329 | "name": "stdout", 330 | "output_type": "stream", 331 | "text": [ 332 | "Parameter containing:\n", 333 | "tensor([[-4.7301e-08, -4.7325e-08, -3.2846e-01]], requires_grad=True) Parameter containing:\n", 334 | "tensor([[[ 1.0000e+00, -1.7097e-08, 6.9077e-07],\n", 335 | " [-1.7084e-08, 1.0000e+00, 6.9076e-07],\n", 336 | " [ 1.5955e-09, -1.2442e-09, 1.0118e+00]]], requires_grad=True) Parameter containing:\n", 337 | "tensor([[1.0000, 1.0000, 1.0118]], requires_grad=True) Parameter containing:\n", 338 | "tensor([[0., 0., 0.]])\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "# Non-centered 3³ cube in 7³ tensor\n", 344 | "moving_cube = torch.zeros(7, 7, 7)\n", 345 | "moving_cube[1:4, 2:5, 2:5] = 1\n", 346 | "# Centered 3³ cube in 7³ tensor\n", 347 | "static_cube = torch.zeros(7, 7, 7)\n", 348 | "static_cube[2:5, 2:5, 2:5] = 1\n", 349 | "\n", 350 | "reg = AffineRegistration(is_3d=True, scales=(1,),\n", 351 | " verbose=False) # verbose=False for hidden progress bar\n", 352 | "\n", 353 | "moved_cube = reg(moving_cube[None, None], static_cube[None, None])\n", 354 | "translation = reg._parameters[0]\n", 355 | "rotation = reg._parameters[1]\n", 356 | "zoom = reg._parameters[2]\n", 357 | "shear = reg._parameters[3]\n", 358 | "print(translation, rotation, zoom, shear)" 359 | ], 360 | "metadata": { 361 | "collapsed": false 362 | } 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "source": [ 367 | "## Translation-only registration" 368 | ], 369 | "metadata": { 370 | "collapsed": false 371 | } 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 13, 376 | "outputs": [ 377 | { 378 | "name": "stderr", 379 | "output_type": "stream", 380 | "text": [ 381 | "Shape: [1, 1, 7, 7]; Dissimiliarity: 2.7626772397537636e-11: 100%|██████████| 500/500 [00:00<00:00, 1899.94it/s]\n" 382 | ] 383 | }, 384 | { 385 | "name": "stdout", 386 | "output_type": "stream", 387 | "text": [ 388 | "Parameter containing:\n", 389 | "tensor([[ 4.7112e-09, -3.3333e-01]], requires_grad=True) Parameter containing:\n", 390 | "tensor([[[1., 0.],\n", 391 | " [0., 1.]]]) Parameter containing:\n", 392 | "tensor([[1., 1.]]) Parameter containing:\n", 393 | "tensor([[0., 0.]])\n" 394 | ] 395 | } 396 | ], 397 | "source": [ 398 | "reg = AffineRegistration(is_3d=False, scales=(1,),\n", 399 | " with_rotation=False, with_zoom=False, with_shear=False)\n", 400 | "\n", 401 | "moved = reg(moving[None, None], static[None, None])\n", 402 | "\n", 403 | "translation = reg._parameters[0]\n", 404 | "rotation = reg._parameters[1]\n", 405 | "zoom = reg._parameters[2]\n", 406 | "shear = reg._parameters[3]\n", 407 | "print(translation, rotation, zoom, shear)" 408 | ], 409 | "metadata": { 410 | "collapsed": false 411 | } 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "source": [ 416 | "## Translation-only registration with initial parameter" 417 | ], 418 | "metadata": { 419 | "collapsed": false 420 | } 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 14, 425 | "outputs": [ 426 | { 427 | "name": "stderr", 428 | "output_type": "stream", 429 | "text": [ 430 | "Shape: [1, 1, 7, 7]; Dissimiliarity: 8.85713030696067e-12: 100%|██████████| 500/500 [00:00<00:00, 1974.00it/s] \n" 431 | ] 432 | }, 433 | { 434 | "name": "stdout", 435 | "output_type": "stream", 436 | "text": [ 437 | "Parameter containing:\n", 438 | "tensor([[-2.2464e-07, -3.3333e-01]], requires_grad=True) Parameter containing:\n", 439 | "tensor([[[1., 0.],\n", 440 | " [0., 1.]]]) Parameter containing:\n", 441 | "tensor([[1., 1.]]) Parameter containing:\n", 442 | "tensor([[0., 0.]])\n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "reg = AffineRegistration(is_3d=False, scales=(1,), init_translation=torch.Tensor([[-3e-1, 0.]]),\n", 448 | " with_rotation=False, with_zoom=False, with_shear=False)\n", 449 | "\n", 450 | "moved = reg(moving[None, None], static[None, None])\n", 451 | "\n", 452 | "translation = reg._parameters[0]\n", 453 | "rotation = reg._parameters[1]\n", 454 | "zoom = reg._parameters[2]\n", 455 | "shear = reg._parameters[3]\n", 456 | "print(translation, rotation, zoom, shear)" 457 | ], 458 | "metadata": { 459 | "collapsed": false 460 | } 461 | } 462 | ], 463 | "metadata": { 464 | "kernelspec": { 465 | "display_name": "Python 3", 466 | "language": "python", 467 | "name": "python3" 468 | }, 469 | "language_info": { 470 | "codemirror_mode": { 471 | "name": "ipython", 472 | "version": 2 473 | }, 474 | "file_extension": ".py", 475 | "mimetype": "text/x-python", 476 | "name": "python", 477 | "nbconvert_exporter": "python", 478 | "pygments_lexer": "ipython2", 479 | "version": "2.7.6" 480 | } 481 | }, 482 | "nbformat": 4, 483 | "nbformat_minor": 0 484 | } 485 | --------------------------------------------------------------------------------