├── .gitignore ├── AUTHORS ├── ChangeLog ├── LICENSE ├── MANIFEST.in ├── README.md ├── requirements.txt ├── setup.cfg ├── setup.py ├── test └── test_voting.py └── torchfields ├── __init__.py ├── fields.py ├── inversion.py ├── utils.py └── voting.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### From standard python .gitignore template ### 2 | align_env/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | logs/* 10 | logs_old/* 11 | dump/* 12 | data/images/* 13 | data/labels/* 14 | data/masks/* 15 | data/tf/* 16 | data/slices/* 17 | data/* 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | # vi 120 | *.swp 121 | 122 | ### Custom for this project ### 123 | 124 | 125 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Barak Nehoran 2 | Dodam Ih 3 | Nico Kemnitz 4 | Thomas Macrina -------------------------------------------------------------------------------- /ChangeLog: -------------------------------------------------------------------------------- 1 | CHANGES 2 | ======= 3 | 4 | v0.1.2 5 | ------ 6 | 7 | * get_vote_weights_*: use softmax to avoid overflow issues 8 | 9 | v0.1.1 10 | ------ 11 | 12 | * Priority Vote: Adjusting kernel size for priority vote operations will now affect blurring 13 | * Priority Vote: Consensus threshold can be 0 14 | 15 | v0.1.0 16 | ------ 17 | 18 | * Add priority\_vote 19 | * Add vote\_with\_distances 20 | * Add voting\_with\_variances to include prior weights 21 | 22 | v0.0.6 23 | ------ 24 | 25 | * feat(use\_identity\_mapping\_cache): allow caching identity\_mapping() results 26 | * fix: 0.7.0 compatibility (#8) 27 | * Update license in setup.cfg 28 | 29 | v0.0.5 30 | ------ 31 | 32 | * Update documentation for sample() 33 | * Allow non-square displacement fields 34 | * Clean up identity mapping code 35 | * Remove caching of identity mappings 36 | * Use align\_corners=False from PyTorch 1.3.0 Will no longer work with earlier PyTorch versions 37 | * Convert to using MIT License 38 | 39 | v0.0.4 40 | ------ 41 | 42 | * Prevent producing NaN in inverse backward pass 43 | * Allow accessing field type as torchfields.Field 44 | * Ensure contiguous gradients in inversion backward pass 45 | * Allow padding to be explicitly given in \_pad() 46 | * Bump required pytorch version to 1.1.0 47 | * Factor out voting, inversion, and util functions 48 | * Update README.md 49 | * Update README.md 50 | 51 | v0.0.3 52 | ------ 53 | 54 | * Change ndim to ndimension() to support wider range of PyTorch versions 55 | 56 | v0.0.2 57 | ------ 58 | 59 | * [Fix] affine\_field incorrect dimensions bug 60 | 61 | v0.0.1 62 | ------ 63 | 64 | * [Fix] inverse not working on cpu 65 | * Add setup.py, setup.cfg, README.md, .gitignore, and requirements.txt 66 | * DisplacementField: minor bug fixes, refactors, and comments 67 | * Safe division to avoid NaNs during backward pass 68 | * Use winding number to test inclusion rather than bounding i,j 69 | * Epsilon for comparison to zero 70 | * Autopad functionality in left inverse 71 | * More efficient left inverse using sparse tensors 72 | * Mean finite vector function 73 | * Inefficient implementation of left inverse 74 | * Add DisplacementField class to abstract displacement field operations 75 | * Initial commit 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Barak Nehoran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude AUTHORS 2 | exclude ChangeLog 3 | exclude MANIFEST.in 4 | recursive-exclude *.egg-info/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchfields 2 | A [PyTorch](https://github.com/pytorch/pytorch) add-on for working with image mappings and displacement fields, including Spatial Transformers 3 | 4 | Torchfields provides an abstraction that neatly encapsulates the functionality of displacement fields 5 | as used in [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) and [Optical Flow Estimation](https://en.wikipedia.org/wiki/Optical_flow). 6 | 7 | Fields can be treated as normal PyTorch tensors for most 8 | purposes, and also include additional functionality for composing 9 | displacements and sampling from tensors. 10 | 11 | ### Installation 12 | 13 | To install torchfields simply do 14 | 15 | ``` 16 | pip install torchfields 17 | ``` 18 | 19 | 20 | ### Introduction 21 | 22 | A **displacement field** represents a *mapping* or *flow* that indicates how an image should be warped. 23 | 24 | It is essentially a spatial tensor containing displacement vectors at each pixel, where each displacement vector indicates the displacement distance and direction at that pixel. 25 | 26 | 27 | #### Displacement field conventions 28 | 29 | ##### Units 30 | 31 | The standard unit of displacement is a **half-image**, so a displacement vector of magnitude 2 means that the displacement distance is equal to the side length of the displaced image. 32 | 33 | **Note**: *This convention originates from the original [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) paper where such fields were presented as mappings, with -1 representing the left or top edge of the image, and +1 representing the right or bottom edge.* 34 | 35 | `torchfields` also supports seamlessly converting to and from units of **pixels** using the `pixels()` and `from_pixels()` functions. 36 | 37 | ##### Displacement direction 38 | 39 | The most common way to warp an image by a displacement field is by sampling from it at the points pointed to by the field vectors. 40 | This is often referred to as the **Eulerian** or **pull** convention, since the vectors in the field point to the locations from which the image should be *pulled*. 41 | This is achieved by calling the `sample()` function (which in fact wraps PyTorch's built-in `grid_sample()`, while converting the conventions as necessary). 42 | 43 | An alternative way to warp an image by a displacement field is by sending each pixel of the image along the corresponding displacement vector to its new location. This is referred to as the **Lagrangian** or **push** convention, since the vectors of the field indicate where an image pixel should be *pushed* to. This direction, while seemingly intuitive, is much less straight-forward to implement, since there is no definitive way to handle the discretization (for instance, what to do when the destinations are not whole pixel coordinates, when two sources map to the same destination, and when nothing maps into a destination pixel). 44 | The solution for warping in the Lagrangian direction is to **first invert the field** using `inverse()`, and then warp the image normally using `sample()`. 45 | 46 | *To read more about the two ways to describe flow fields, see the [Wikipedia article](https://en.wikipedia.org/wiki/Lagrangian_and_Eulerian_specification_of_the_flow_field) on the subject.* 47 | 48 | 49 | #### Relationship to PyTorch tensors 50 | 51 | Displacement fields inherit from `torch.Tensor`, so all functionality from [PyTorch](https://github.com/pytorch/pytorch) tensors also works with displacement fields. That is, any PyTorch function that accepts a `torch.Tensor` type will also implicitly accept a `torchfields` displacement field. 52 | 53 | Furthermore, the module installs itself (through monkey patching) as 54 | 55 | ```python 56 | torch.Field 57 | ``` 58 | 59 | mirroring the `torch.Tensor` module, and all the functionality of the `torchfields` package can be conveniently accessed through that shortcut. This shortcut gets activated at the first import (using `import torchfields`). 60 | 61 | Note, however, that the `torchfields` package is neither endorsed by nor maintained by the PyTorch developer community, and is instead a separate project maintained by researchers at Princeton University. 62 | 63 | 64 | 65 | ### Tutorial 66 | 67 | To learn more and get started with using `torchfields` check out the [tutorial](https://colab.research.google.com/drive/1KrUjFbWjwwnsyNFTpNCZjjIJyMUP8eFx). 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | setuptools>=34.0.0 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = torchfields 3 | url = https://github.com/seung-lab/torchfields 4 | summary = A PyTorch add-on for working with image mappings and displacement fields, including Spatial Transformers 5 | description-content-type = text/markdown 6 | description-file = README.md 7 | author = Barak Nehoran, Nico Kemnitz 8 | author_email = bnehoran@users.noreply.github.com, nkemnitz@users.noreply.github.com 9 | home-page = https://github.com/seung-lab/torchfields 10 | license = MIT 11 | classifiers = 12 | Intended Audience :: Developers 13 | Development Status :: 4 - Beta 14 | Programming Language :: Python :: 3 15 | Programming Language :: Python :: 3.7 16 | Programming Language :: Python :: 3.8 17 | Programming Language :: Python :: 3.9 18 | Programming Language :: Python :: 3.10 19 | Topic :: Scientific/Engineering 20 | License :: OSI Approved :: MIT License 21 | 22 | [global] 23 | setup-hooks = pbr.hooks.setup_hook 24 | 25 | [files] 26 | packages = torchfields 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name='torchfields', 8 | version='0.1.2', 9 | author='Barak Nehoran, Nico Kemnitz', 10 | author_email='bnehoran@users.noreply.github.com, nkemnitz@users.noreply.github.com', 11 | description='A PyTorch add-on for working with image mappings and displacement fields, including Spatial Transformers', 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | packages=find_packages(), 15 | scripts=[], 16 | url="https://github.com/seung-lab/torchfields", 17 | setup_requires=[ 18 | 'pbr', 19 | ], 20 | pbr=True, 21 | ) 22 | -------------------------------------------------------------------------------- /test/test_voting.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torchfields 4 | 5 | # def test_vote_shape(): 6 | # assert False 7 | # 8 | # 9 | # def test_voting_subsets(): 10 | # assert False 11 | # 12 | # 13 | # def test_voting_weights(): 14 | # assert False 15 | # 16 | # 17 | # def test_vote(): 18 | # assert False 19 | 20 | 21 | def test_vote_with_variances(): 22 | f = torch.zeros((2, 2, 1, 1)).field() 23 | f[1] = 1 24 | v = torch.zeros((2, 1, 1, 1)) 25 | vf = f.vote_with_variances(var=v, softmin_temp=1, blur_sigma=0, subset_size=1) 26 | tf = torch.ones((1, 2, 1, 1)).field() / 2.0 27 | assert torch.equal(tf, vf) 28 | v[1] = 1 29 | # softmin_temp root of (e^(-sqrt(2)/x))/(e^(-sqrt(2)/x)+e^(0/x))-0.25=0 30 | vf = f.vote_with_variances(var=v, softmin_temp=1.28727, blur_sigma=0, subset_size=1) 31 | tf = torch.ones((1, 2, 1, 1)).field() / 4.0 32 | assert torch.allclose(tf, vf) 33 | 34 | 35 | def test_vote_with_distances(): 36 | f = torch.zeros((2, 2, 1, 1)).field() 37 | f[1] = 1 38 | d = torch.ones((2, 1, 1)) 39 | d[1] = 2 40 | df = f.vote_with_distances(distances=d, softmin_temp=1, blur_sigma=0, subset_size=1) 41 | tf = torch.ones((1, 2, 1, 1)).field() / 2.0 42 | assert torch.equal(tf, df) 43 | f = torch.zeros((2, 2, 1, 1)).field() 44 | f[1] = 1 45 | d = torch.ones((2, 1, 1)) 46 | d[1] = 2 47 | df = f.get_vote_weights_with_distances( 48 | distances=d, softmin_temp=1, blur_sigma=0, subset_size=2 49 | ) 50 | tf = torch.ones((2, 1, 1)) / 3.0 51 | tf[0] *= 2.0 52 | assert torch.equal(tf, df) 53 | f = torch.zeros((3, 2, 1, 1)).field() 54 | f[1] = 1 55 | f[2] = 1.2 56 | d = torch.ones((3, 1, 1)) 57 | d[1] = 2 58 | d[2] = 3 59 | df = f.vote_with_distances( 60 | distances=d, softmin_temp=0.01, blur_sigma=0, subset_size=2 61 | ) 62 | tf = ( 63 | torch.ones((1, 2, 1, 1)).field() * 3 / 5.0 64 | + torch.ones((1, 2, 1, 1)).field() * 1.2 * 2 / 5.0 65 | ) 66 | assert torch.allclose(tf, df) 67 | 68 | 69 | def test_priority_vote(): 70 | # if subsets are only 1, then return weights that identify highest priority 71 | f = torch.zeros((2, 2, 1, 1)).field() 72 | p = torch.ones((2, 1, 1)) 73 | p[1] = 0 74 | vfw = f.get_priority_vote_weights(priorities=p, subset_size=1) 75 | tfw = torch.ones((2, 1, 1)) 76 | tfw[1] = 0 77 | assert torch.equal(tfw, vfw) 78 | 79 | # v1 in consensus with v3, so return weights for v1 80 | f = torch.zeros((3, 2, 1, 1)).field() 81 | f[1] = 1 82 | p = torch.full((3, 1, 1), fill_value=3) 83 | p[1] = 2 84 | p[2] = 1 85 | vfw = f.get_priority_vote_weights( 86 | priorities=p, consensus_threshold=1, subset_size=2 87 | ) 88 | tfw = torch.ones((3, 1, 1)) 89 | tfw[1] = 0 90 | tfw[2] = 0 91 | assert torch.equal(tfw, vfw) 92 | 93 | # regardless of whether the consensus_threshold marks v2 as in consensus 94 | vfw = f.get_priority_vote_weights( 95 | priorities=p, consensus_threshold=2, subset_size=2 96 | ) 97 | tfw = torch.ones((3, 1, 1)) 98 | tfw[1] = 0 99 | tfw[2] = 0 100 | assert torch.equal(tfw, vfw) 101 | 102 | # v1 in consensus with v3, then v2 in consensus in v3 103 | f = torch.zeros((3, 2, 2, 1)).field() 104 | f[0, :, 1, :] = 1 105 | f[1, :, 0, :] = 1 106 | p = torch.full((3, 2, 1), fill_value=3) 107 | p[1] = 2 108 | p[2] = 1 109 | vfw = f.get_priority_vote_weights( 110 | priorities=p, consensus_threshold=1, subset_size=2 111 | ) 112 | tfw = torch.ones((3, 2, 1)) 113 | tfw[0, 1, :] = 0 114 | tfw[1, 0, :] = 0 115 | tfw[2] = 0 116 | assert torch.equal(tfw, vfw) 117 | 118 | # increasing the consensus_threshold makes v1 always in consensus 119 | vfw = f.get_priority_vote_weights( 120 | priorities=p, consensus_threshold=2, subset_size=2 121 | ) 122 | tfw = torch.ones((3, 2, 1)) 123 | tfw[1] = 0 124 | tfw[2] = 0 125 | assert torch.equal(tfw, vfw) 126 | 127 | # increasing the consensus_threshold makes v1 always in consensus 128 | vf = f.priority_vote(priorities=p, consensus_threshold=2, subset_size=2) 129 | tf = torch.zeros((1, 2, 2, 1)) 130 | tf[0, :, 1, :] = 1 131 | assert torch.equal(tf, vf) 132 | 133 | # just test that blurring doesn't throw an error 134 | f = torch.ones((3, 2, 4, 4)).field() 135 | p = torch.ones((3, 4, 4)) 136 | vf = f.priority_vote(priorities=p, consensus_threshold=2, subset_size=2) 137 | assert torch.allclose(f, vf) 138 | 139 | # consensus_threshold=0 should return v2 140 | f = torch.zeros((3, 2, 1, 1)).field() 141 | f[0] = 1 142 | p = torch.full((3, 1, 1), fill_value=3) 143 | p[1] = 2 144 | p[2] = 1 145 | vfw = f.get_priority_vote_weights( 146 | priorities=p, consensus_threshold=0, subset_size=2 147 | ) 148 | tfw = torch.zeros((3, 1, 1)) 149 | tfw[1] = 1 150 | assert torch.equal(tfw, vfw) 151 | 152 | # no negative consensus_threshold 153 | with pytest.raises(ValueError): 154 | vfw = f.get_priority_vote_weights( 155 | priorities=p, consensus_threshold=-1, subset_size=2 156 | ) 157 | 158 | 159 | def test_gaussian_blur(): 160 | f = torch.ones((3, 1, 4, 4)) 161 | gf = torchfields.voting.gaussian_blur(data=f) 162 | assert torch.allclose(gf, f) 163 | -------------------------------------------------------------------------------- /torchfields/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .fields import DisplacementField as Field 3 | from .fields import set_identity_mapping_cache 4 | 5 | torch.Field = Field 6 | -------------------------------------------------------------------------------- /torchfields/fields.py: -------------------------------------------------------------------------------- 1 | """PyTorch tensor type for working with displacement vector fields 2 | """ 3 | from typing import Any 4 | import torch 5 | import torch.nn.functional as F 6 | from functools import wraps 7 | 8 | from .utils import permute_input, permute_output, ensure_dimensions 9 | from . import inversion 10 | from . import voting 11 | 12 | 13 | #################################### 14 | # DisplacementField Class Definition 15 | #################################### 16 | 17 | 18 | class DisplacementField(torch.Tensor): 19 | """An abstraction that encapsulates functionality of displacement fields 20 | as used in Spatial Transformer Networks. 21 | 22 | DisplacementFields can be treated as normal PyTorch tensors for most 23 | purposes, and also include additional functionality for composing 24 | displacements and sampling from tensors. 25 | """ 26 | 27 | @classmethod 28 | def __torch_function__(cls, func, types, args=(), kwargs=None): 29 | if kwargs is None: 30 | kwargs = {} 31 | 32 | if not all(issubclass(cls, t) for t in types): 33 | return NotImplemented 34 | 35 | return super().__torch_function__(func, types, args, kwargs) 36 | 37 | def __new__(cls, *args, **kwargs): 38 | return super().__new__(cls, *args, **kwargs) 39 | 40 | def __init__(self, *args, **kwargs): 41 | if len(self.shape) < 3: 42 | raise ValueError( 43 | "The displacement field must have a components " 44 | "dimension. Only {} dimensions are present.".format(len(self.shape)) 45 | ) 46 | if self.shape[-3] != 2: 47 | raise ValueError( 48 | "The displacement field must have exactly 2 " 49 | "components, not {}.".format(self.shape[-3]) 50 | ) 51 | 52 | def __repr__(self, *args, **kwargs): 53 | out = super().__repr__(*args, **kwargs) 54 | return out.replace("tensor", "field", 1).replace("\n ", "\n") 55 | 56 | _cache_identities = False 57 | _identities = {} 58 | 59 | @classmethod 60 | def _set_identity_mapping_cache(cls, mode: bool, clear_cache: bool = False) -> None: 61 | cls._cache_identities = mode 62 | if clear_cache: 63 | cls._identities = {} 64 | 65 | @classmethod 66 | def is_identity_mapping_cache_enabled(cls) -> bool: 67 | """``True`` if identity_mapping() calls are currently cached, ``else False``.""" 68 | return cls._cache_identities 69 | 70 | # Conversion to and from torch.Tensor 71 | 72 | def field_(self, *args, **kwargs): 73 | """Converts a `torch.Tensor` to a `DisplacementField` 74 | 75 | Note: This does not make a copy, but rather modifies it in place. 76 | Because of this, nothing is added to the computation graph. 77 | To produce a new `DisplacementField` from a tensor and/or add a 78 | step to the computation graph, instead use `field()`, 79 | the not-in-place version. 80 | """ 81 | allowed_types = DisplacementField.__bases__ 82 | if not isinstance(self, allowed_types): 83 | raise TypeError( 84 | "'{}' cannot be converted to '{}'. Valid options are: {}".format( 85 | type(self).__name__, 86 | DisplacementField.__name__, 87 | [base.__module__ + "." + base.__name__ for base in allowed_types], 88 | ) 89 | ) 90 | if len(self.shape) < 3: 91 | raise ValueError( 92 | "The displacement field must have a components " 93 | "dimension. Only {} dimensions are present.".format(len(self.shape)) 94 | ) 95 | if self.shape[-3] != 2: 96 | raise ValueError( 97 | "The displacement field must have exactly 2 " 98 | "components, not {}.".format(self.shape[-3]) 99 | ) 100 | self.__class__ = DisplacementField 101 | self.__init__(*args, **kwargs) # in case future __init__ is nonempty 102 | return self 103 | 104 | torch.Tensor.field_ = field_ # adds conversion to torch.Tensor superclass 105 | _from_superclass = field_ # for use in `return_subclass_type()` 106 | 107 | def field(data, *args, **kwargs): 108 | """Converts a `torch.Tensor` to a `DisplacementField` 109 | """ 110 | if isinstance(data, torch.Tensor): 111 | return DisplacementField.field_(data.clone(), *args, **kwargs) 112 | else: 113 | return DisplacementField.field_(torch.tensor(data, *args, **kwargs).float()) 114 | 115 | torch.Tensor.field = field # adds conversion to torch.Tensor superclass 116 | torch.field = field 117 | 118 | def tensor_(self): 119 | """Converts the `DisplacementField` to a standard `torch.Tensor` 120 | in-place 121 | 122 | Note: This does not make a copy, but rather modifies it in place. 123 | Because of this, nothing is added to the computation graph. 124 | To produce a new `torch.Tensor` from a `DisplacementField` and/or 125 | add a copy step to the computation graph, instead use `tensor()`, 126 | the not-in-place version. 127 | """ 128 | self.__class__ = torch.Tensor 129 | return self 130 | 131 | def tensor(self): 132 | """Converts the `DisplacementField` to a standard `torch.Tensor` 133 | """ 134 | return self.clone().tensor_() 135 | 136 | # Constuctors for typical displacent fields 137 | 138 | def identity(*args, **kwargs): 139 | """Returns an identity displacement field (containing all zero vectors) 140 | 141 | See :func:`torch.zeros` 142 | """ 143 | if len(args) > 0 and isinstance(args[0], torch.Tensor): 144 | tensor_like, *args = args 145 | if "device" not in kwargs or kwargs["device"] is None: 146 | kwargs["device"] = tensor_like.device 147 | if "size" not in kwargs or kwargs["size"] is None: 148 | kwargs["size"] = tensor_like.shape 149 | if "dtype" not in kwargs or kwargs["dtype"] is None: 150 | kwargs["dtype"] = tensor_like.dtype 151 | return torch.zeros(*args, **kwargs).field_() 152 | 153 | zeros_like = zeros = identity 154 | 155 | def ones(*args, **kwargs): 156 | """Returns a displacement field type tensor of all ones. 157 | 158 | The result is a translation field of half the image in all coordinates, 159 | which is not usually a useful field on its own, but can be multiplied 160 | by a factor to get different translations. 161 | 162 | See :func:`torch.ones` 163 | """ 164 | if len(args) > 0 and isinstance(args[0], torch.Tensor): 165 | tensor_like, *args = args 166 | if "device" not in kwargs or kwargs["device"] is None: 167 | kwargs["device"] = tensor_like.device 168 | if "size" not in kwargs or kwargs["size"] is None: 169 | kwargs["size"] = tensor_like.shape 170 | if "dtype" not in kwargs or kwargs["dtype"] is None: 171 | kwargs["dtype"] = tensor_like.dtype 172 | return torch.ones(*args, **kwargs).field_() 173 | 174 | ones_like = ones 175 | 176 | def rand(*args, **kwargs): 177 | """Returns a displacement field type tensor with each vector 178 | component randomly sampled from the uniform distribution on [0, 1). 179 | 180 | See :func:`torch.rand` 181 | """ 182 | if len(args) > 0 and isinstance(args[0], torch.Tensor): 183 | tensor_like, *args = args 184 | if "device" not in kwargs or kwargs["device"] is None: 185 | kwargs["device"] = tensor_like.device 186 | if "size" not in kwargs or kwargs["size"] is None: 187 | kwargs["size"] = tensor_like.shape 188 | if "dtype" not in kwargs or kwargs["dtype"] is None: 189 | kwargs["dtype"] = tensor_like.dtype 190 | return torch.rand(*args, **kwargs).field_() 191 | 192 | rand_like = rand 193 | 194 | @torch.no_grad() 195 | def rand_in_bounds(*args, **kwargs): 196 | """Returns a displacement field where each displacement 197 | vector samples from a uniformly random location from within the 198 | bounds of the sampled tensor (when called with `sample()` or 199 | `compose()`). 200 | 201 | See :func:`torch.rand` for the function signature. 202 | """ 203 | rand_tensor = DisplacementField.rand(*args, **kwargs) 204 | if not isinstance(rand_tensor, DisplacementField): 205 | # if incompatible, fail with the proper error 206 | rand_tensor = DisplacementField._from_superclass(rand_tensor) 207 | field = rand_tensor * 2 - 1 # rescale to [-1, 1) 208 | field = field - field.identity_mapping() 209 | return field.requires_grad_(rand_tensor.requires_grad) 210 | 211 | rand_in_bounds_like = rand_in_bounds 212 | 213 | def _get_parameters(tensor, shape=None, device=None, dtype=None, override=False): 214 | """Auxiliary function to deduce the right set of parameters to a tensor 215 | function. 216 | In particular, if `tensor` is a `torch.Tensor`, it uses those values. 217 | Otherwise, if the values are not explicitly specified, returns the 218 | default values. 219 | If `override` is set to `True`, then the parameters passed override 220 | those of the tensor unless they are None. 221 | """ 222 | if isinstance(tensor, torch.Tensor): 223 | shape = shape if override and (shape is not None) else tensor.shape 224 | device = device if override and (device is not None) else tensor.device 225 | dtype = dtype if override and (dtype is not None) else tensor.dtype 226 | else: 227 | if device is None: 228 | try: 229 | device = torch.cuda.current_device() 230 | except AssertionError: 231 | device = "cpu" 232 | if dtype is None: 233 | dtype = torch.float 234 | if isinstance(shape, tuple): 235 | batch_dim = shape[0] if len(shape) > 3 else 1 236 | if len(shape) < 2: 237 | raise ValueError( 238 | "The shape must have at least two spatial " 239 | "dimensions. Recieved shape {}.".format(shape) 240 | ) 241 | while len(shape) < 4: 242 | shape = (1,) + shape 243 | else: 244 | try: 245 | shape = torch.Size((1, 2, shape, shape)) 246 | batch_dim = 1 247 | except TypeError: 248 | raise TypeError( 249 | "'shape' must be an 'int', 'tuple', or " 250 | "'torch.Size'. Received '{}'".format(type(shape).__qualname__) 251 | ) 252 | device = torch.device(device) 253 | if dtype == torch.double: 254 | tensor_type = ( 255 | torch.DoubleTensor if device.type == "cpu" else torch.cuda.DoubleTensor 256 | ) 257 | elif dtype == torch.float: 258 | tensor_type = ( 259 | torch.FloatTensor if device.type == "cpu" else torch.cuda.FloatTensor 260 | ) 261 | else: 262 | raise ValueError( 263 | "The data type must be either torch.float or " 264 | "torch.double. Recieved {}.".format(dtype) 265 | ) 266 | return { 267 | "shape": shape, 268 | "batch_dim": batch_dim, 269 | "device": device, 270 | "dtype": dtype, 271 | "tensor_type": tensor_type, 272 | } 273 | 274 | @torch.no_grad() 275 | def identity_mapping(size, device=None, dtype=None): 276 | """Returns an identity mapping with -1 and +1 at the corners of the 277 | image (not the centers of the border pixels as in PyTorch 1.1). 278 | 279 | Note that this is NOT an identity displacement field, and therefore 280 | sampling with it will not return the input. 281 | To get the identity displacement field, use `identity()`. 282 | Instead, this creates a mapping that maps each coordinate to its 283 | own coordinate vector (in the [-1, +1] space). 284 | 285 | Args: 286 | size: either an `int` or a `torch.Size` of the form `(N, C, H, W)`. 287 | `C` is ignored. 288 | device (torch.device): the device (cpu/cuda) on which to create 289 | the mapping 290 | dtype (torch.dtype): the data type of resulting mapping. Can be 291 | `torch.float` or `torch.double`, specifying either double 292 | or single precision floating points 293 | 294 | Returns: 295 | DisplacementField of size `(N, 2, H, W)`, or `(1, 2, H, W)` if 296 | `size` is given as an `int` 297 | 298 | If called on an instance of `torch.Tensor` or `DisplacementField`, the 299 | `size`, `device`, and `dtype` of that instance are used. 300 | For example 301 | 302 | df = DisplacementField(1,1,10,10) 303 | ident = df.identity_mapping() # uses df.shape and df.device 304 | 305 | NOTE: If `use_identity_mapping_cache` is enabled, the returned field will 306 | be a reference to the field in cache. Use `clone()` on the returned 307 | field if you plan to perform inplace modifications and do not want 308 | to alter the cached version. 309 | """ 310 | # find the right set of parameters 311 | params = DisplacementField._get_parameters(size, size, device, dtype) 312 | shape, batch_dim, device, tensor_type = [ 313 | params[key] for key in ("shape", "batch_dim", "device", "tensor_type") 314 | ] 315 | 316 | # look in the cache and create from scratch if not there 317 | if ( 318 | DisplacementField._cache_identities == True 319 | and (shape, device, tensor_type) in DisplacementField._identities 320 | ): 321 | Id = DisplacementField._identities[shape, device, tensor_type] 322 | else: 323 | id_theta = tensor_type([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], device=device) 324 | id_theta = id_theta.expand(batch_dim, *id_theta.shape[1:]) 325 | Id = F.affine_grid(id_theta, shape, align_corners=False) 326 | Id = Id.permute(0, 3, 1, 2).field_() # move the components to 2nd position 327 | 328 | if DisplacementField._cache_identities == True: 329 | DisplacementField._identities[shape, device, tensor_type] = Id 330 | 331 | return Id 332 | 333 | @classmethod 334 | def affine_field(cls, aff, size, offset=(0.0, 0.0), device=None, dtype=None): 335 | """Returns a displacement field for an affine transform within a bbox 336 | 337 | Args: 338 | aff: 2x3 ndarray or torch.Tensor. The affine matrix defining the 339 | affine transform 340 | offset: tuple with (x-offset, y-offset) 341 | size: an `int`, a `tuple` or a `torch.Size` of the form 342 | `(N, C, H, W)`. `C` is ignored. 343 | 344 | Returns: 345 | DisplacementField for the given affine transform of size 346 | `(N, 2, H, W)`, or `(1, 2, H, W)` if `size` is given as an `int` 347 | 348 | Note: 349 | the affine matrix defines the transformation that warps the 350 | destination to the source, such that, 351 | ``` 352 | \vec{x_s} = A \vec{x_d} 353 | ``` 354 | where x_s is a point in the source image, x_d a point in the 355 | destination image, and A is the affine matrix. The field returned 356 | will be defined over the destination image. So the matrix A should 357 | define the location in the source image that contribute to a pixel 358 | in the destination image. 359 | """ 360 | params = DisplacementField._get_parameters( 361 | aff, size, device, dtype, override=True 362 | ) 363 | device, dtype, tensor_type, size, batch_dim = [ 364 | params[key] 365 | for key in ("device", "dtype", "tensor_type", "shape", "batch_dim") 366 | ] 367 | if isinstance(aff, list): 368 | aff = tensor_type(aff, device=device) 369 | if aff.ndimension() == 2: 370 | aff.unsqueeze_(0) 371 | N = 1 372 | elif aff.ndimension() == 3: 373 | N = aff.shape[0] 374 | else: 375 | raise ValueError( 376 | "Expected 2 or 3-dimensional affine matrix. " 377 | "Received shape {}.".format(aff.shape) 378 | ) 379 | if N == 1 and batch_dim > 1: 380 | aff = aff.expand(batch_dim, *aff.shape[1:]) 381 | N = batch_dim 382 | if offset[0] != 0 or offset[1] != 0: 383 | z = tensor_type([[0.0, 0.0, 1.0]], device=device) 384 | z = z.expand(N, *z.shape) 385 | A = torch.cat([aff, z], 1) 386 | B = tensor_type( 387 | [[1.0, 0.0, offset[0]], [0.0, 1.0, offset[1]], [0.0, 0.0, 1.0]], 388 | device=device, 389 | ) 390 | B = B.expand(N, *B.shape) 391 | Bi = tensor_type( 392 | [[1.0, 0.0, -offset[0]], [0.0, 1.0, -offset[1]], [0.0, 0.0, 1.0]], 393 | device=device, 394 | ) 395 | Bi = Bi.expand(N, *Bi.shape) 396 | aff = torch.mm(Bi, torch.mm(A, B))[:, :2] 397 | M = F.affine_grid(aff, size, align_corners=False) 398 | # Id is an identity mapping without the overhead of `identity_mapping` 399 | id_aff = tensor_type([[1, 0, 0], [0, 1, 0]], device=device) 400 | id_aff = id_aff.expand(N, *id_aff.shape) 401 | Id = F.affine_grid(id_aff, size, align_corners=False) 402 | M = M - Id 403 | M = M.permute(0, 3, 1, 2).field_() # move the components to 2nd position 404 | return M 405 | 406 | # Basic vector field properties 407 | 408 | def is_identity(self, eps=None, magn_eps=None): 409 | """Checks if this is the identity displacement field, up to some 410 | tolerance `eps`, which is 0 by default. 411 | 412 | Args: 413 | eps: can either be a floating point number or a tensor of the same 414 | shape, in which case each location in the field can have a 415 | different tolerance. 416 | magn_eps: similar to eps, except bounds the magnitude of each 417 | vector instead of the components. 418 | 419 | If neither `eps` nor `magn_eps` are specified, the default is zero 420 | tolerance. 421 | 422 | Note that this does NOT check for identity mappings created by 423 | `identity_mapping()`. To check for that, subtract 424 | `self.identity_mapping()` first. 425 | 426 | This function is called and negated by `__bool__()`, which makes 427 | the following equivalent: 428 | 429 | if df: 430 | do_something() 431 | 432 | and 433 | 434 | if not df.is_identity(): 435 | do_something() 436 | 437 | since `df.is_identity()` is equivalent to `not df`. 438 | """ 439 | if eps is None and magn_eps is None: 440 | return (self == 0.0).all() 441 | else: 442 | is_id = True 443 | if eps is not None: 444 | is_id = is_id and (self >= -eps).all() and (self <= eps).all() 445 | if magn_eps is not None: 446 | is_id = is_id and (self.magnitude(True) <= magn_eps).all() 447 | return is_id 448 | 449 | def __bool__(self): 450 | return not self.is_identity().tensor_() 451 | 452 | __nonzero__ = __bool__ 453 | 454 | def magnitude(self, keepdim=False): 455 | """Computes the magnitude of the displacement at each location in the 456 | displacement field 457 | 458 | Args: 459 | self: `DisplacementField` of shape `(N, 2, H, W)` 460 | 461 | Returns: 462 | `torch.Tensor` of shape `(N, H, W)` or `(N, 1, H, W)` if 463 | `keepdim` is `True`, containing the magnitude of the displacement 464 | """ 465 | return self.tensor().pow(2).sum(dim=-3, keepdim=keepdim).sqrt() 466 | 467 | def distance(self, other, keepdim=False) -> torch.Tensor: 468 | """Compute the pointwise Euclidean distance between two displacement 469 | fields 470 | 471 | Args: 472 | self, other: DisplacementFields of the same shape `(N, 2, H, W)` 473 | 474 | Returns: 475 | `torch.Tensor` of shape `(N, H, W)` or `(N, 1, H, W)` if 476 | `keepdim` is `True`, containing the distance at each location in 477 | the displacement fields 478 | """ 479 | return (self - other).magnitude(keepdim=keepdim) 480 | 481 | def mean_vector(self, keepdim=False): 482 | """Compute the mean displacement vector of each field in a batch 483 | 484 | Args: 485 | self: DisplacementFields of shape `(N, 2, H, W)` 486 | keepdim: if `True`, retains the spatial dimensions in the output 487 | 488 | Returns: 489 | `torch.Tensor` of shape `(N, 2)` or `DisplacementField` of shape 490 | `(N, 2, 1, 1)` if `keepdim` is `True`, containing the mean vector 491 | of each field 492 | """ 493 | if keepdim: 494 | return self.mean(-1, keepdim=keepdim).mean(-2, keepdim=keepdim) 495 | else: 496 | return self.mean(-1).mean(-1) 497 | 498 | def mean_finite_vector(self, keepdim=False): 499 | """Compute the mean displacement vector of the finite elements in 500 | each field in a batch 501 | 502 | Args: 503 | self: DisplacementFields of shape `(N, 2, H, W)` 504 | keepdim: if `True`, retains the spatial dimensions in the output 505 | 506 | Returns: 507 | `torch.Tensor` of shape `(N, 2)` or `DisplacementField` of shape 508 | `(N, 2, 1, 1)` if `keepdim` is `True`, containing the mean finite 509 | vector of each field 510 | """ 511 | mask = torch.isfinite(self).all(-3, keepdim=True) 512 | self = self.where(mask, torch.tensor(0).to(self)) 513 | if keepdim: 514 | sum = self.sum(-1, keepdim=keepdim).sum(-2, keepdim=keepdim) 515 | count = mask.sum(-1, keepdim=keepdim).sum(-2, keepdim=keepdim) 516 | else: 517 | sum = self.sum(-1).sum(-1) 518 | count = mask.sum(-1).sum(-1) 519 | return sum / count.clamp(min=1).float() 520 | 521 | def mean_nonzero_vector(self, keepdim=False): 522 | """Compute the mean displacement vector of the nonzero elements in 523 | each field in a batch 524 | 525 | Note: to get the mean displacement vector of all elements, run 526 | 527 | field.mean(-1).mean(-1) 528 | 529 | Args: 530 | self: DisplacementFields of shape `(N, 2, H, W)` 531 | keepdim: if `True`, retains the spatial dimensions in the output 532 | 533 | Returns: 534 | `torch.Tensor` of shape `(N, 2)` or `DisplacementField` of shape 535 | `(N, 2, 1, 1)` if `keepdim` is `True`, containing the mean nonzero 536 | vector of each field 537 | """ 538 | mask = self.magnitude(keepdim=True) > 0 539 | if keepdim: 540 | sum = self.sum(-1, keepdim=keepdim).sum(-2, keepdim=keepdim) 541 | count = mask.sum(-1, keepdim=keepdim).sum(-2, keepdim=keepdim) 542 | else: 543 | sum = self.sum(-1).sum(-1) 544 | count = mask.sum(-1).sum(-1) 545 | return sum / count.clamp(min=1).float() 546 | 547 | def min_vector(self, keepdim=False): 548 | """Compute the minimum displacement vector of each field in a batch 549 | 550 | Args: 551 | self: DisplacementFields of shape `(N, 2, H, W)` 552 | keepdim: if `True`, retains the spatial dimensions in the output 553 | 554 | Returns: 555 | `torch.Tensor` of shape `(N, 2)` or `DisplacementField` of shape 556 | `(N, 2, 1, 1)` if `keepdim` is `True`, containing the minimum 557 | vector of each field 558 | """ 559 | if keepdim: 560 | return self.min(-1, keepdim=keepdim).values.min(-2, keepdim=keepdim).values 561 | else: 562 | return self.min(-1).values.min(-1).values 563 | 564 | def max_vector(self, keepdim=False): 565 | """Compute the maximum displacement vector of each field in a batch 566 | 567 | Args: 568 | self: DisplacementFields of shape `(N, 2, H, W)` 569 | keepdim: if `True`, retains the spatial dimensions in the output 570 | 571 | Returns: 572 | `torch.Tensor` of shape `(N, 2)` or `DisplacementField` of shape 573 | `(N, 2, 1, 1)` if `keepdim` is `True`, containing the maximum 574 | vector of each field 575 | """ 576 | if keepdim: 577 | return self.max(-1, keepdim=keepdim).values.max(-2, keepdim=keepdim).values 578 | else: 579 | return self.max(-1).values.max(-1).values 580 | 581 | # Conversions to and from other representations of the displacement field 582 | 583 | def pixels(self, size=None): 584 | """Convert the displacement distances to units of pixels from the 585 | standard [-1, 1] distance convention. 586 | 587 | Note that while out of convenience, the type of 588 | the result is `DisplacementField`, many `DisplacementField` 589 | operations on it will produce incorrect results, since it will 590 | be in the wrong units. 591 | 592 | Args: 593 | self (DisplacementField): the field to convert 594 | size (int or torch.Size): the size, in pixels, of the tensor to be 595 | sampled. Used to calculate the pixel size. If not specified 596 | the size is assumed to be the size of the displacement field. 597 | 598 | Returns: 599 | a `DisplacementField` type tensor containing displacements in 600 | units of pixels 601 | """ 602 | if size is None: 603 | size = self.shape 604 | if isinstance(size, tuple): 605 | size = size[-1] 606 | return self * (size / 2) 607 | 608 | def from_pixels(self, size=None): 609 | """Convert the displacement distances from units of pixels to the 610 | standard [-1, 1] distance convention. 611 | 612 | This reverses the operation of `pixels()` 613 | 614 | Args: 615 | self (DisplacementField): the field to convert 616 | size (int or torch.Size): the size, in pixels, of the tensor to be 617 | sampled. Used to calculate the pixel size. If not specified 618 | the size is assumed to be the size of the displacement field. 619 | 620 | Returns: 621 | a `DisplacementField` type tensor containing displacements in 622 | units of pixels 623 | """ 624 | if size is None: 625 | size = self.shape 626 | if isinstance(size, tuple): 627 | size = size[-1] 628 | return self / (size / 2) 629 | 630 | def mapping(self): 631 | """Convert the displacement field to a mapping, where each location 632 | contains the coordinates of another location to which it maps. 633 | 634 | Note that while out of convenience, the type of 635 | the result is `DisplacementField`, many `DisplacementField` 636 | operations on it will produce incorrect results, since it will 637 | be in the wrong units. 638 | 639 | The units of the mapping will be in the standard [-1, 1] convention. 640 | 641 | Args: 642 | self (DisplacementField): the field to convert 643 | 644 | Returns: 645 | a `DisplacementField` type tensor containing the same field 646 | represented as a mapping 647 | """ 648 | return self + self.identity_mapping() 649 | 650 | def from_mapping(self): 651 | """Convert a mapping to a displacement field which contains the 652 | displacement at each location. 653 | 654 | The units of the mapping should be in the standard [-1, 1] convention. 655 | 656 | Args: 657 | self (DisplacementField): the mapping to convert 658 | 659 | Returns: 660 | a `DisplacementField` containing the mapping represented 661 | as a displacement field 662 | """ 663 | return self - self.identity_mapping() 664 | 665 | def pixel_mapping(self, size=None): 666 | """Convert the displacement field to a pixel mapping, where each pixel 667 | contains the coordinates of another pixel to which it maps. 668 | 669 | Note that while out of convenience, the type of 670 | the result is `DisplacementField`, many `DisplacementField` 671 | operations on it will produce incorrect results, since it will 672 | be in the wrong units. 673 | 674 | The units of the mapping will be in pixels in the range [0, size-1]. 675 | 676 | Args: 677 | self (DisplacementField): the field to convert 678 | size (int or torch.Size): the size, in pixels, of the tensor to be 679 | sampled. Used to calculate the pixel size. If not specified 680 | the size is assumed to be the size of the displacement field. 681 | 682 | Returns: 683 | a `DisplacementField` type tensor containing the same field 684 | represented as a pixel mapping 685 | """ 686 | if size is None: 687 | size = self.shape 688 | if isinstance(size, tuple): 689 | size = size[-1] 690 | return self.mapping().pixels(size) + (size - 1) / 2 691 | 692 | def from_pixel_mapping(self, size=None): 693 | """Convert a mapping to a displacement field which contains the 694 | displacement at each location. 695 | 696 | The units of the mapping should be in pixels in the range [0, size-1]. 697 | 698 | Args: 699 | self (DisplacementField): the pixel mapping to convert 700 | size (int or torch.Size): the size, in pixels, of the tensor to be 701 | sampled. Used to calculate the pixel size. If not specified 702 | the size is assumed to be the size of the displacement field. 703 | 704 | Returns: 705 | a `DisplacementField` containing the pixel mapping represented 706 | as a displacement field 707 | """ 708 | if size is None: 709 | size = self.shape 710 | if isinstance(size, tuple): 711 | size = size[-1] 712 | return (self - (size - 1) / 2).from_pixels(size).from_mapping() 713 | 714 | # Aliases for the components of the displacent vectors 715 | 716 | @property 717 | def x(self): 718 | """The column component of the displacent field 719 | """ 720 | return self[..., 0:1, :, :] 721 | 722 | @x.setter 723 | def x(self, value): 724 | self[..., 0:1, :, :] = value 725 | 726 | j = x # j & x are both aliases for the column component of the displacent 727 | 728 | @property 729 | def y(self): 730 | """The row component of the displacent field 731 | """ 732 | return self[..., 1:2, :, :] 733 | 734 | @y.setter 735 | def y(self, value): 736 | self[..., 1:2, :, :] = value 737 | 738 | i = y # i & y are both aliases for the row component of the displacent 739 | 740 | # Functions for sampling, composing, mapping, warping 741 | 742 | @ensure_dimensions(ndimensions=4, arg_indices=(1, 0), reverse=True) 743 | def sample(self, input, mode="bilinear", padding_mode="zeros"): 744 | r"""A wrapper for the PyTorch grid sampler to sample or warp and image 745 | by a displacent field. 746 | 747 | The displacement vector field encodes relative displacements from 748 | which to pull from the input, where vectors with values -1 or +1 749 | reference a displacement equal to the distance from the center point 750 | to the edges of the input. 751 | 752 | Args: 753 | `input` (Tensor): should be a PyTorch Tensor or DisplacementField 754 | on the same GPU or CPU as `self`, with `input` having 755 | dimensions :math:`(N, C, H_in, W_in)`, whenever `self` has 756 | dimensions :math:`(N, 2, H_out, W_out)`. 757 | The shape of the output will be :math:`(N, C, H_out, W_out)`. 758 | `mode` (str): 'bilinear' or 'nearest' 759 | `padding_mode` (str): determines the value sampled when a 760 | displacement vector's source falls outside of the input. 761 | Options are: 762 | * "zeros" : produce the value zero (okay for sampling images 763 | with zero as background, but potentially 764 | problematic for sampling masks and terrible for 765 | sampling from other displacement vector fields) 766 | * "border" : produces the value at the nearest inbounds pixel 767 | (great for sampling from masks and from other 768 | residual displacement fields) 769 | * "reflection" : reflects any sampling points that lie out 770 | of bounds until they fall inside the 771 | sampling range 772 | 773 | Returns: 774 | `output` (Tensor): the input after being warped by `self`, 775 | having shape :math:`(N, C, H_out, W_out)` 776 | 777 | See the PyTorch documentation of the underlying function for additional 778 | details: 779 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample 780 | but note that the conventions used there are different. 781 | """ 782 | field = self + self.identity_mapping() 783 | field = field.permute(0, 2, 3, 1) # move components to last position 784 | out = F.grid_sample( 785 | input, field, mode=mode, padding_mode=padding_mode, align_corners=False 786 | ) 787 | if not isinstance(input, DisplacementField): 788 | out.tensor_() 789 | return out 790 | 791 | def compose_with(self, other, mode="bilinear"): 792 | r"""Compose this displacement field with another displacement field. 793 | If `f = self` and `g = other`, then this computes 794 | `f⚬g` such that `(f⚬g)(x) ~= f(g(x))` for any tensor `x`. 795 | 796 | Returns: 797 | a displacement field such that when it is used to sample a tensor, 798 | it is the (approximate) equivalent of sampling with `other` 799 | and then with `self`. 800 | 801 | The reason this is only an approximate equivalence is because when 802 | sampling twice, information is inevitably lost in the intermediate 803 | stage. Sampling with the composed field is therefore more precise. 804 | """ 805 | return self + self.sample(other, padding_mode="border") 806 | 807 | def __call__(self, x, mode="bilinear"): 808 | """Syntactic sugar for `compose_with()` or `sample()`, depending on 809 | the type of the sampled tensor. 810 | 811 | Be careful when using this that the sampled tensor is of the correct 812 | type for the desired outcome. 813 | For better assurance, it can be safer to call the functions explicitly. 814 | """ 815 | if isinstance(x, DisplacementField): 816 | return self.compose_with(x, mode=mode) 817 | else: 818 | return self.sample(x, mode=mode) 819 | 820 | def multicompose(self, *others): 821 | """Composes multiple displacement fields with one another. 822 | This takes a list of fields :math:`f_0, f_1, ..., f_n` 823 | and composes them to get 824 | :math:`f_0 ⚬ f_1 ⚬ ... ⚬ f_n ~= f_0(f_1(...(f_n)))` 825 | 826 | Use of this function is not always recommended because of the 827 | potential for boundary effects when composing multiple displacements. 828 | Specifically, whenever a vector samples from out of bounds, the 829 | nearest vector is used, which may not be the desired behavior and can 830 | become a worse approximation of it as more displacement fields are 831 | composed together. 832 | """ 833 | f = self 834 | for g in others: 835 | f = (f)(g) 836 | return f 837 | 838 | @ensure_dimensions(ndimensions=4, arg_indices=(0), reverse=True) 839 | def up(self, mips=None, scale_factor=2): 840 | """Upsamples by `mips` mip levels or by a factor of `scale_factor`, 841 | whichever one is specified. 842 | If neither are specified explicitly, upsamples by a factor of two, or 843 | in other words, one mip level. 844 | """ 845 | if mips is not None: 846 | scale_factor = 2 ** mips 847 | if scale_factor == 1: 848 | return self 849 | return F.interpolate( 850 | self, scale_factor=scale_factor, mode="bilinear", align_corners=False 851 | ) 852 | 853 | @ensure_dimensions(ndimensions=4, arg_indices=(0), reverse=True) 854 | def down(self, mips=None, scale_factor=2): 855 | """Downsample by `mips` mip levels or by a factor of `scale_factor`, 856 | whichever one is specified. 857 | If neither are specified explicitly, downsamples by a factor of two, or 858 | in other words, one mip level. 859 | """ 860 | if mips is not None: 861 | scale_factor = 2 ** mips 862 | if scale_factor == 1: 863 | return self 864 | return F.interpolate( 865 | self, scale_factor=1.0 / scale_factor, mode="bilinear", align_corners=False 866 | ) 867 | 868 | # Displacement Field Inverses 869 | 870 | def inverse(self, *args, **kwargs): 871 | """Return a symmetric inverse approximation for the displacement field 872 | 873 | Given a displacement field `f`, its symmetric inverse is a displacement 874 | field `f_inv` such that 875 | `f(f_inv) ~= identity ~= f_inv(f)` 876 | 877 | In other words 878 | :math:`f_{inv} = \argmin_{g} |f(g)|^2 + |g(f)|^2` 879 | 880 | Note that this is an approximation for the symmetric inverse. 881 | In cases for which only one inverse direction is desired, a better 882 | one-sided approximation can be achieved using `linverse()` or 883 | `rinverse()`. 884 | 885 | Also note that this overrides the `inverse()` method of `torch.Tensor`, 886 | but this definition cannot conflict, since `torch.Tensor.inverse` is 887 | only able to accept 2-dimensional tensors, and a `DisplacementField` 888 | is always at least 3-dimensional (2 spatial + 1 component dimension). 889 | """ 890 | # TODO: Implement symmetric inverse. Currently using left inverse. 891 | return self.linverse(*args, **kwargs) 892 | 893 | def __invert__(self, *args, **kwargs): 894 | """Return a symmetric inverse approximation for the displacement field 895 | 896 | Given a displacement field `f`, its symmetric inverse is a displacement 897 | field `f_inv` such that 898 | `f(f_inv) ~= identity ~= f_inv(f)` 899 | 900 | In other words 901 | :math:`f_{inv} = \argmin_{g} |f(g)|^2 + |g(f)|^2` 902 | 903 | Note that this is an approximation for the symmetric inverse. 904 | In cases for which only one inverse direction is desired, a better 905 | approximation can be achieved using `linverse()` and `rinverse()`. 906 | 907 | This is syntactic sugar for `inverse()`, and allows the symmetric 908 | inverse to be called as `~f` rather than `f.inverse()`. 909 | """ 910 | return self.inverse(*args, **kwargs) 911 | 912 | @wraps(inversion.linverse) 913 | def linverse(self, autopad=True): 914 | return inversion.linverse(self, autopad=True) 915 | 916 | @wraps(inversion.rinverse) 917 | def rinverse(self, *args, **kwargs): 918 | return inversion.rinverse(self, autopad=True) 919 | 920 | # Adapting functions inherited from torch.Tensor 921 | 922 | @permute_output 923 | @permute_input 924 | def fft(self, *args, **kwargs): 925 | return super(type(self), self).fft(*args, **kwargs) 926 | 927 | @permute_output 928 | @permute_input 929 | def ifft(self, *args, **kwargs): 930 | return super(type(self), self).ifft(*args, **kwargs) 931 | 932 | @permute_output 933 | def rfft(self, *args, **kwargs): 934 | # Present for completeness, but cannot be called on a DisplacementField 935 | return super(type(self), self).rfft(*args, **kwargs) 936 | 937 | @permute_input 938 | def irfft(self, *args, **kwargs): 939 | return super(type(self), self).irfft(*args, **kwargs) 940 | 941 | def __rpow__(self, other): 942 | # defined explicitly since pytorch default gives infinite recursion 943 | return self.new_tensor(other).__pow__(self) 944 | 945 | # Vector Voting 946 | 947 | @wraps(voting.gaussian_blur) 948 | def gaussian_blur(self, sigma=1, kernel_size=5): 949 | return voting.gaussian_blur(self, sigma, kernel_size) 950 | 951 | @wraps(voting.vote) 952 | def get_vote_shape(self): 953 | return voting.get_vote_shape(self) 954 | 955 | @wraps(voting.vote) 956 | def get_subset_size(self, subset_size=None): 957 | return voting.get_subset_size(self, subset_size) 958 | 959 | @wraps(voting.vote) 960 | def get_vote_subsets(self, subset_size=None): 961 | return voting.get_vote_subsets(self, subset_size) 962 | 963 | @wraps(voting.vote) 964 | def linear_combination(self, weights): 965 | return voting.linear_combination(self, weights=weights) 966 | 967 | @wraps(voting.vote) 968 | def smoothed_combination(self, weights, blur_sigma=2., kernel_size=5): 969 | return voting.smoothed_combination(self, 970 | weights=weights, 971 | blur_sigma=blur_sigma, 972 | kernel_size=kernel_size) 973 | 974 | @wraps(voting.vote) 975 | def get_vote_weights(self, softmin_temp=1, blur_sigma=1, subset_size=None): 976 | return voting.get_vote_weights(self, softmin_temp, blur_sigma, subset_size) 977 | 978 | @wraps(voting.vote) 979 | def vote(self, softmin_temp=1, blur_sigma=1, subset_size=None): 980 | return voting.vote(self, softmin_temp, blur_sigma, subset_size) 981 | 982 | @wraps(voting.vote) 983 | def get_vote_weights_with_variances( 984 | self, var, softmin_temp=1, blur_sigma=1, subset_size=None 985 | ): 986 | return voting.get_vote_weights_with_variances( 987 | self, var, softmin_temp, blur_sigma, subset_size 988 | ) 989 | 990 | @wraps(voting.vote) 991 | def vote_with_variances(self, var, softmin_temp=1, blur_sigma=1, subset_size=None): 992 | return voting.vote_with_variances( 993 | self, var, softmin_temp, blur_sigma, subset_size 994 | ) 995 | 996 | @wraps(voting.vote) 997 | def get_vote_weights_with_distances( 998 | self, distances, softmin_temp=1, blur_sigma=1, subset_size=None 999 | ): 1000 | return voting.get_vote_weights_with_distances( 1001 | self, distances, softmin_temp, blur_sigma, subset_size 1002 | ) 1003 | 1004 | @wraps(voting.vote) 1005 | def vote_with_distances( 1006 | self, distances, softmin_temp=1, blur_sigma=1, subset_size=None 1007 | ): 1008 | return voting.vote_with_distances( 1009 | self, distances, softmin_temp, blur_sigma, subset_size 1010 | ) 1011 | 1012 | @wraps(voting.vote) 1013 | def get_priority_vote_weights( 1014 | self, priorities, consensus_threshold=2, subset_size=None 1015 | ): 1016 | return voting.get_priority_vote_weights( 1017 | self, 1018 | priorities, 1019 | consensus_threshold=consensus_threshold, 1020 | subset_size=subset_size, 1021 | ) 1022 | 1023 | @wraps(voting.vote) 1024 | def priority_vote( 1025 | self, priorities, consensus_threshold=2, blur_sigma=1, subset_size=None 1026 | ): 1027 | return voting.priority_vote( 1028 | self, 1029 | priorities, 1030 | consensus_threshold=consensus_threshold, 1031 | blur_sigma=blur_sigma, 1032 | subset_size=subset_size, 1033 | ) 1034 | 1035 | class set_identity_mapping_cache(): 1036 | """Context-manager that controls caching of identity_mapping() results. 1037 | 1038 | ``set_identity_mapping_cache`` will enable or disable the cache (:attr: `mode`). 1039 | It can be used as a context-manager or as a function. 1040 | 1041 | If enabled, cache results of identity_mapping() calls based on 1042 | (shape, device, dtype) for faster recall. 1043 | 1044 | This may also improve repeated calls to other torchfields methods, 1045 | such as `sample()`, `rand_in_bounds()`, `pixel_mapping()`, etc. 1046 | 1047 | The trade-off is a higher burden on CPU/GPU memory, therefore 1048 | caching is disabled by default. 1049 | 1050 | Args: 1051 | mode (bool): Flag whether to enable cache (``True``), or disable (``False``). 1052 | clear_cache (bool): Optional flag whether or not to empty any existing 1053 | cached results. Default is ``False``. 1054 | 1055 | Note: For performance reasons, the returned field from identity_mapping() 1056 | will be the cached, *mutable* field. Use `clone()` on the returned field 1057 | if you plan to perform in-place modifications and do not want to alter the 1058 | cached version. 1059 | """ 1060 | 1061 | def __init__(self, mode: bool, clear_cache: bool = False) -> None: 1062 | self.prev = DisplacementField.is_identity_mapping_cache_enabled() 1063 | DisplacementField._set_identity_mapping_cache(mode, clear_cache) 1064 | 1065 | def __enter__(self) -> None: 1066 | pass 1067 | 1068 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 1069 | DisplacementField._set_identity_mapping_cache(self.prev) 1070 | -------------------------------------------------------------------------------- /torchfields/inversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .utils import ensure_dimensions 5 | 6 | 7 | # Inversion helper functions 8 | 9 | def _tensor_min(*args): 10 | """Elementwise minimum of a sequence of tensors""" 11 | minimum, *rest = args 12 | for arg in rest: 13 | minimum = minimum.min(arg) 14 | return minimum 15 | 16 | 17 | def _tensor_max(*args): 18 | """Elementwise maximum of a sequence of tensors""" 19 | maximum, *rest = args 20 | for arg in rest: 21 | maximum = maximum.max(arg) 22 | return maximum 23 | 24 | 25 | class _BackContig(torch.autograd.Function): 26 | """Ensure that the gradient is contiguous in the backward pass""" 27 | @staticmethod 28 | def forward(ctx, input): 29 | return input 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | return grad_output.contiguous() 34 | 35 | 36 | _back_contig = _BackContig.apply 37 | 38 | 39 | def _pad(inp, padding=None): 40 | """Pads the field just enough to eliminate border effects""" 41 | if padding is None: 42 | with torch.no_grad(): 43 | *_, H, W = inp.shape 44 | mapping = inp.pixel_mapping() 45 | pad_yl = mapping.y[..., 0, :].max().ceil().int().item() 46 | pad_yh = (H-1-mapping.y[..., -1, :].min()).ceil().int().item() 47 | pad_xl = mapping.x[..., :, 0].max().ceil().int().item() 48 | pad_xh = (W-1-mapping.x[..., :, -1].min()).ceil().int().item() 49 | pad_yl = max(pad_yl, 0) + 1 50 | pad_yh = max(pad_yh, 0) + 1 51 | pad_xl = max(pad_xl, 0) + 1 52 | pad_xh = max(pad_xh, 0) + 1 53 | # ensure that the new field is square (that is, newH = newW) 54 | newH, newW = pad_yl + H + pad_yh, pad_xl + W + pad_xh 55 | if newH > newW: 56 | pad_xh += newH - newW 57 | elif newW > newH: 58 | pad_yh += newW - newH 59 | padding = (pad_xl, pad_xh, pad_yl, pad_yh) 60 | return (F.pad(inp.pixels(), padding, mode='replicate').field() 61 | .from_pixels(), padding) 62 | 63 | 64 | def _unpad(inp, padding): 65 | """Crops the field back to its original size""" 66 | p_xl, p_xh, p_yl, p_yh = padding 67 | p_xh = inp.shape[-1] - p_xh 68 | p_yh = inp.shape[-2] - p_yh 69 | return inp.pixels()[..., p_yl:p_yh, p_xl:p_xh].from_pixels() 70 | 71 | 72 | def _fold(inp): 73 | """Collapse the matrix at each pixel onto the local neighborhood 74 | 75 | The input to this function is a spatial tensor in which the 76 | entry contained in every spatial pixel is itself a small matrix. 77 | The values of this matrix correspond to the neighborhood of that 78 | pixel. For instance, the value at the center of the matrix 79 | corresponds to the pixel itself, whereas the value above and to 80 | the left of the center corresponds to the pixel's upper left 81 | neighbor, and so on. 82 | 83 | This function collapses this into a spatial tensor with scalar 84 | values by summing the respective values corresponding to each 85 | pixel. 86 | """ 87 | pad = (0, (inp.shape[2] + 1) % 2, # last dimension 88 | 0, (inp.shape[1] + 1) % 2) # second to last dimension 89 | res = F.pad(_back_contig(inp), pad) 90 | res = F.fold( 91 | res.view(1, res.shape[0]*res.shape[1]*res.shape[2], -1).contiguous(), 92 | output_size=inp.shape[3:], kernel_size=inp.shape[1:3], 93 | padding=((inp.shape[1])//2, (inp.shape[2])//2)) 94 | return res 95 | 96 | 97 | @torch.no_grad() 98 | def _winding_number(Px, Py, v00, v01, v10, v11, eps3): 99 | """Gives the winding number of a quadrilateral around a grid point 100 | When the winding number is non-zero, the quadrilateral contains 101 | the grid point. More specifically, only positive winding numbers 102 | are relevant in this context, since inverted quadrilaterals are 103 | not considered. 104 | For edge cases, this is a more accurate measure than checking 105 | whether i and j fall within the range [0,1). 106 | Based on http://geomalgorithms.com/a03-_inclusion.html 107 | """ 108 | try: 109 | # vertices in counterclockwise order (viewing y as up) 110 | V = torch.stack([v00, v01, v11, v10, v00], dim=0).field_() 111 | v00 = v01 = v10 = v11 = None # del v00, v01, v10, v11 112 | # initial and final vertex for each edge 113 | Vi, Vf = V[:-1], V[1:] 114 | V = None # del V 115 | # sign of cross product indicates direction around grid point 116 | cross = ((Vf.x - Vi.x)*(Py - Vi.y) - (Px - Vi.x)*(Vf.y - Vi.y)) 117 | # upward crossing of rightward ray from grid point 118 | upward = (Vi.y <= Py) & (Vf.y > Py) & (cross > -eps3) 119 | # downward crossing of rightward ray from grid point 120 | downward = (Vi.y > Py) & (Vf.y <= Py) & (cross < -eps3) 121 | Vi = Vf = Px = Py = cross = None # del Vi, Vf, Px, Py, cross 122 | # winding number = diff between number of up and down crossings 123 | return (upward.int() - downward.int()).sum(dim=0) 124 | except RuntimeError: 125 | # In case this is an out-of-memory error, clear temp tensors 126 | v00 = v01 = v10 = v11 = Px = Py = None 127 | V = Vi = Vf = cross = upward = downward = None 128 | raise 129 | 130 | 131 | ############################# 132 | # Displacement Field Inverses 133 | ############################# 134 | 135 | @ensure_dimensions(ndimensions=4, arg_indices=(0), reverse=True) 136 | def linverse(self, autopad=True): 137 | r"""Return a left inverse approximation for the displacement field 138 | 139 | Given a displacement field `f`, its left inverse is a displacement 140 | field `g` such that 141 | `g(f) ~= identity` 142 | 143 | In other words 144 | :math:`f_{inv} = \argmin_{g} |g(f)|^2` 145 | """ 146 | if len(self.shape) != 4 or self.shape[0] > 1: 147 | raise NotImplementedError('Left inverse is currently implemented ' 148 | 'only for single-batch fields. ' 149 | 'Received batch size {}' 150 | .format(','.join( 151 | str(n) for n in self.shape[:-3]))) 152 | # comparison to 0 153 | eps1 = 2**(-51) if self.dtype is torch.double else 2**(-23) 154 | # denominator fudge factor to avoid dividing by 0 155 | eps2 = eps1 * 2**(-10) 156 | # tolarance for point containment in a quadrilateral 157 | eps3 = 2**-16 158 | 159 | try: 160 | # pad the field 161 | if autopad: 162 | field, padding = _pad(self) 163 | else: 164 | field = self 165 | padding = None 166 | 167 | # vectors at the four corners of each pixel's quadrilateral 168 | mapping = field.pixel_mapping() 169 | v00 = mapping[..., :-1, :-1] 170 | v01 = mapping[..., :-1, 1:] 171 | v10 = mapping[..., 1:, :-1] 172 | v11 = mapping[..., 1:, 1:] 173 | mapping = None # del mapping 174 | 175 | with torch.no_grad(): 176 | # find each quadrilateral's (set of 4 vectors) span, in pixels 177 | v_min = _tensor_min(v00, v01, v10, v11).floor() 178 | v_min.y.clamp_(0, field.shape[-2] - 1) 179 | v_min.x.clamp_(0, field.shape[-1] - 1) 180 | v_max = _tensor_max(v00, v01, v10, v11).floor() + 1 181 | v_max.y.clamp_(0, field.shape[-2] - 1) 182 | v_max.x.clamp_(0, field.shape[-1] - 1) 183 | # d_x and d_y are the largest spans in x and y 184 | d = (v_max - v_min).max_vector().max(0)[0].long() 185 | v_max = None # del v_max 186 | d_x, d_y = list(d.cpu().numpy()) 187 | d = ((d//2).unsqueeze(-1).unsqueeze(-1) # center of the span 188 | .unsqueeze(-1).unsqueeze(-1)).to(v_min) 189 | v_min.y.clamp_(0, field.shape[-2] - 1 - d_y) 190 | v_min.x.clamp_(0, field.shape[-1] - 1 - d_x) 191 | # u is an identity pixel mapping of a d_y by d_x neighborhood 192 | u = field.identity().pixel_mapping()[..., :d_y, :d_x].round() 193 | ux = u.x.unsqueeze(-1).unsqueeze(-1) 194 | uy = u.y.unsqueeze(-1).unsqueeze(-1) 195 | u = None # del u 196 | 197 | # subtract out v_min to bring all quadrilaterals near zero 198 | v00 = (v00 - v_min).unsqueeze(-4).unsqueeze(-4) 199 | v01 = (v01 - v_min).unsqueeze(-4).unsqueeze(-4) 200 | v10 = (v10 - v_min).unsqueeze(-4).unsqueeze(-4) 201 | v11 = (v11 - v_min).unsqueeze(-4).unsqueeze(-4) 202 | 203 | # quadratic coefficients in gridsample solution `a*j^2+b*j+c=0` 204 | a = ((v00.x - v01.x) * (v00.y - v01.y - v10.y + v11.y) 205 | - (v00.x - v01.x - v10.x + v11.x) * (v00.y - v01.y)) 206 | b = ((ux - v00.x) * (v00.y - v01.y - v10.y + v11.y) 207 | + (v00.x - v01.x) * (-v00.y + v10.y) 208 | - (-v00.x + v10.x) * (v00.y - v01.y) 209 | - (v00.x - v01.x - v10.x + v11.x) * (uy - v00.y)) 210 | c = (ux - v00.x)*(-v00.y + v10.y) - (-v00.x + v10.x)*(uy - v00.y) 211 | # quadratic formula solution (note positive root is always invalid) 212 | j_temp = ((b + (b.pow(2) - 4*a*c).clamp(min=eps2).sqrt()).abs() 213 | / (2*a).abs().clamp(min=eps2)) 214 | # corner case when a == 0 (reduces to `b*j + c = 0`) 215 | j_temp = j_temp.where(a.abs() > eps1, c.abs()/b.abs().clamp(min=eps2)) 216 | a = b = c = None # del a, b, c 217 | # get i from j_temp 218 | i = ((uy - v00.y + (v00.y - v01.y) * j_temp).abs() 219 | / (-v00.y + v10.y + (v00.y - v01.y - v10.y + v11.y) * j_temp) 220 | .abs().clamp(min=eps2)) 221 | j_temp = None # del j_temp 222 | # j has significantly smaller rounding error for near-trapezoids 223 | j = ((ux - v00.x + (v00.x - v10.x) * i).abs() 224 | / (-v00.x + v01.x + (v00.x - v10.x - v01.x + v11.x) * i) 225 | .abs().clamp(min=eps2)) 226 | # winding_number > 0 means point is contained in the quadrilateral 227 | wn = _winding_number(ux, uy, v00, v01, v10, v11, eps3) 228 | ux = uy = None # del ux, uy 229 | v00 = v01 = v10 = v11 = None # del v00, v01, v10, v11 230 | 231 | # negative of the bilinear interpolation to produce inverse vector 232 | v00 = field[..., :-1, :-1].unsqueeze(-3).unsqueeze(-3) 233 | v01 = field[..., :-1, 1:].unsqueeze(-3).unsqueeze(-3) 234 | v10 = field[..., 1:, :-1].unsqueeze(-3).unsqueeze(-3) 235 | v11 = field[..., 1:, 1:].unsqueeze(-3).unsqueeze(-3) 236 | inv = -((1-i)*(1-j)*v00 + (1-i)*j*v01 + i*(1-j)*v10 + i*j*v11) 237 | v00 = v01 = v10 = v11 = None # del v00, v01, v10, v11 238 | 239 | # mask out inverse vectors at points outside the quadrilaterals 240 | mask = (wn > 0) & torch.isfinite(i) & torch.isfinite(j) 241 | i = j = wn = None # del i, j, wn 242 | inv = inv.where(mask, torch.tensor(0.).to(inv)) 243 | # append mask to keep track of how many contributions to each pixel 244 | inv = torch.cat((inv, mask.to(inv)), 1) 245 | mask = None 246 | 247 | # indices at which to place each inverse vector in a sparse tensor 248 | indices = ((v_min.unsqueeze(-3).unsqueeze(-3) + d) 249 | .view(2, -1).flip(0).round().long().contiguous()) 250 | v_min = d = None 251 | inv = inv.view(3, d_y, d_x, -1).permute(3, 0, 1, 2).contiguous() 252 | # construct sparse tensor and use `to_dense` to arrange vectors 253 | SparseTensor = (torch.cuda.sparse.FloatTensor if self.is_cuda 254 | else torch.sparse.FloatTensor) 255 | inv = SparseTensor(indices, inv, (*field.shape[-2:], 3, d_y, d_x), 256 | device=inv.device) 257 | inv = inv.to_dense().permute(2, 3, 4, 0, 1) 258 | # fold the d_y by d_x neighborhoods by summing the overlaps 259 | inv = _fold(inv) 260 | # divide each pixel by number of contributions to get an average 261 | inv = inv[:, :2] / inv[:, 2:].clamp(min=1.) 262 | 263 | # crop back to original shape 264 | if autopad: 265 | inv = _unpad(inv.field(), padding) 266 | return inv 267 | except RuntimeError: 268 | # In case this is an out-of-memory error, clear temporary tensors 269 | self = field = mapping = v_min = v_max = d = d_x = d_y = None 270 | u = ux = uy = v00 = v01 = v10 = v11 = wn = None 271 | a = b = c = j_temp = i = j = None 272 | mask = inv = indices = None 273 | raise 274 | 275 | 276 | def rinverse(self, *args, **kwargs): 277 | r"""Return a right inverse approximation for the displacement field 278 | 279 | Given a displacement field `f`, its right inverse is a displacement 280 | field `g` such that 281 | `f(g) ~= identity` 282 | 283 | In other words 284 | :math:`f_{inv} = \argmin_{g} |f(g)|^2` 285 | """ 286 | raise NotImplementedError 287 | -------------------------------------------------------------------------------- /torchfields/utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | ############################################# 5 | # Decorators for enforcing return value types 6 | ############################################# 7 | 8 | def return_subclass_type(cls): 9 | """Class decorator for a subclass to encourage it to return its own 10 | subclass type whenever its inherited functions would otherwise return 11 | the superclass type. 12 | 13 | This works by attempting to convert any return values of the superclass 14 | type to the subclass type, and then defaulting back to the original 15 | return value on any errors during conversion. 16 | 17 | If running the subclass constructor has undesired side effects, 18 | the class can define a `_from_superclass()` function that casts 19 | to the subclass type more directly. 20 | This function should raise an exception if the type is not compatible. 21 | If `_from_superclass` is not defined, the class constructor is called 22 | by default. 23 | """ 24 | def decorator(f): 25 | @wraps(f) 26 | def f_decorated(*args, **kwargs): 27 | out = f(*args, **kwargs) 28 | try: 29 | if not isinstance(out, cls) and isinstance(out, cls.__bases__): 30 | return cls._from_superclass(out) 31 | except Exception: 32 | pass 33 | # Result cannot be returned as subclass type 34 | return out 35 | return f_decorated 36 | 37 | # fall back to constructor if _from_superclass not defined 38 | try: 39 | cls._from_superclass 40 | except AttributeError: 41 | cls._from_superclass = cls 42 | 43 | for name in dir(cls): 44 | attr = getattr(cls, name) 45 | if name not in dir(object) and callable(attr): 46 | try: 47 | # check if this attribute is flagged to keep its return type 48 | if attr._keep_type: 49 | continue 50 | except AttributeError: 51 | pass 52 | setattr(cls, name, decorator(attr)) 53 | return cls 54 | 55 | 56 | def dec_keep_type(keep=True): 57 | """Function decorator that adds a flag to tell `return_subclass_type()` 58 | to leave the function's return type as is. 59 | 60 | This is useful for functions that intentionally return a value of 61 | superclass type. 62 | 63 | If a boolean argument is passed to the decorator as 64 | 65 | @dec_keep_type(True) 66 | def func(): 67 | pass 68 | 69 | then that agument determines whether to enable the flag. If no argument 70 | is passed, the flag is enabled as if `True` were passed. 71 | 72 | @dec_keep_type 73 | def func(): 74 | pass 75 | 76 | """ 77 | def _dec_keep_type(keep_type): 78 | def _set_flag(f): 79 | f._keep_type = keep_type 80 | return f 81 | return _set_flag 82 | if isinstance(keep, bool): # boolean argument passed 83 | return _dec_keep_type(keep) 84 | else: # the argument is actually the function itself 85 | func = keep 86 | return _dec_keep_type(True)(func) 87 | 88 | 89 | ########################################################################### 90 | # Decorators to convert the inputs and outputs of DisplacementField methods 91 | ########################################################################### 92 | 93 | def permute_input(f): 94 | """Function decorator to permute the input dimensions from the 95 | DisplacementField convention `(N, 2, H, W)` to the standard PyTorch 96 | field convention `(N, H, W, 2)` before passing it into the function. 97 | """ 98 | @wraps(f) 99 | def f_new(self, *args, **kwargs): 100 | ndims = self.ndimension() 101 | perm = self.permute(*range(ndims-3), -2, -1, -3) 102 | return f(perm, *args, **kwargs) 103 | return f_new 104 | 105 | 106 | def permute_output(f): 107 | """Function decorator to permute the dimensions of the function output 108 | from the standard PyTorch field convention `(N, H, W, 2)` to the 109 | DisplacementField convention `(N, 2, H, W)` before returning it. 110 | """ 111 | @wraps(f) 112 | def f_new(self, *args, **kwargs): 113 | out = f(self, *args, **kwargs) 114 | ndims = out.ndimension() 115 | return out.permute(*range(ndims-3), -1, -3, -2) 116 | return f_new 117 | 118 | 119 | def ensure_dimensions(ndimensions=4, arg_indices=(0,), reverse=False): 120 | """Function decorator to ensure that the the input has the 121 | approprate number of dimensions 122 | 123 | If it has too few dimensions, it pads the input with dummy dimensions. 124 | 125 | Args: 126 | ndimensions (int): number of dimensions to pad to 127 | arg_indices (int or List[int]): the indices of inputs to pad 128 | Note: Currently, this only works on arguments passed by 129 | position. Those inputs must be a torch.Tensor or 130 | DisplacementField. 131 | reverse (bool): if `True`, it then also removes the added dummy 132 | dimensions from the output, down to the number of dimensions 133 | of arg[arg_indices[0]] 134 | """ 135 | if callable(ndimensions): # it was called directly on a function 136 | func = ndimensions 137 | ndimensions = 4 138 | else: 139 | func = None 140 | if isinstance(arg_indices, int): 141 | arg_indices = (arg_indices,) 142 | assert(len(arg_indices) > 0) 143 | 144 | def decorator(f): 145 | @wraps(f) 146 | def f_decorated(*args, **kwargs): 147 | args = list(args) 148 | original_ndims = len(args[arg_indices[0]].shape) 149 | for i in arg_indices: 150 | if i >= len(args): 151 | continue 152 | while args[i].ndimension() < ndimensions: 153 | args[i] = args[i].unsqueeze(0) 154 | out = f(*args, **kwargs) 155 | while reverse and out.ndimension() > original_ndims: 156 | new_out = out.squeeze(0) 157 | if new_out.ndimension() == out.ndimension(): 158 | break # no progress made; nothing left to squeeze 159 | out = new_out 160 | return out 161 | return f_decorated 162 | 163 | if func is None: # parameters were passed to the decorator 164 | return decorator 165 | else: # the function itself was passed to the decorator 166 | return decorator(func) 167 | -------------------------------------------------------------------------------- /torchfields/voting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | ############################ 5 | # Vector vote implemtation 6 | ############################ 7 | 8 | 9 | def get_padding(kernel_size): 10 | pad = (kernel_size - 1) // 2 11 | if kernel_size % 2 == 0: 12 | pad = (pad, pad + 1, pad, pad + 1) 13 | else: 14 | pad = (pad,) * 4 15 | return pad 16 | 17 | 18 | def gaussian_blur(data, sigma=1, kernel_size=5): 19 | """Gausssian blur the displacement field to reduce any unsmoothness 20 | Adapted from https://bit.ly/2JO7CCP 21 | 22 | Args: 23 | data (tensor): NxCxWxH 24 | """ 25 | import math 26 | 27 | if sigma == 0: 28 | return data.clone() 29 | pad = get_padding(kernel_size) 30 | padded = F.pad(data, pad, mode="reflect") 31 | mu = (kernel_size - 1) / 2 32 | x = torch.stack(torch.meshgrid([torch.arange(kernel_size).to(data)] * 2)) 33 | kernel = torch.exp((-(((x - mu) / sigma) ** 2)) / 2) 34 | kernel = kernel.prod(dim=0) / (2 * math.pi * sigma ** 2) 35 | kernel = kernel / kernel.sum() # renormalize to get unit product 36 | kernel = kernel.expand(2, 1, *kernel.shape) 37 | groups = 2 38 | if data.shape[1] == 1: 39 | groups = 1 40 | return F.conv2d(padded, weight=kernel, groups=groups) 41 | 42 | 43 | def get_vote_shape(self): 44 | """Consistently split shape into N fields & shape of field""" 45 | n, _, *shape = self.shape 46 | return n, shape 47 | 48 | 49 | def get_subset_size(self, subset_size=None): 50 | """Compute smallest majority of self is subset_size not set""" 51 | n, _ = self.get_vote_shape() 52 | m = (n + 1) // 2 # smallest number that constututes a majority 53 | if subset_size is not None: 54 | m = subset_size 55 | return m 56 | 57 | 58 | def get_vote_subsets(self, subset_size=None): 59 | """Compute list of majority subsets to use in vote""" 60 | n, _ = self.get_vote_shape() 61 | m = self.get_subset_size(subset_size=subset_size) 62 | from itertools import combinations 63 | 64 | subset_tuples = list(combinations(range(n), m)) 65 | return subset_tuples 66 | 67 | 68 | def linear_combination(self, weights): 69 | """Create a single field from a set of fields given a set of weights 70 | 71 | Args: 72 | weights (tensor): (N, 1, W, H) or (N, W, H) 73 | 74 | Returns: 75 | DisplacementField with shape (1, 2, W, H) 76 | """ 77 | if len(weights.shape) == 3: 78 | weights = weights.unsqueeze(-3) 79 | return (self * weights).sum(dim=0, keepdim=True) 80 | 81 | 82 | def smoothed_combination(self, weights, blur_sigma=2.0, kernel_size=5): 83 | """Create a single field from a set of fields, given a set of weights. 84 | The weights will be spaitally smooth with a Gaussian kernel of std blur_sigma. 85 | 86 | Args: 87 | weights (tensor): (N, W, H) 88 | blur_sigma (float) 89 | kernel_size (int) 90 | 91 | Returns: 92 | DisplacementField with shape (1, 2, W, H) 93 | """ 94 | weights = weights.unsqueeze(-3) 95 | _, shape = self.get_vote_shape() 96 | # need to blur with reflected padding, which requires minimums for dimensions 97 | max_pad = max(get_padding(kernel_size)) 98 | if (shape[-1] > max_pad) and (shape[-2] > max_pad): 99 | weights = gaussian_blur(data=weights, sigma=blur_sigma, kernel_size=kernel_size) 100 | return self.linear_combination(weights) 101 | 102 | 103 | def get_vote_weights(self, softmin_temp=1, blur_sigma=1, subset_size=None): 104 | """Calculate per field weights for batch of displacement fields, indicating 105 | which fields should be considered consensus. 106 | 107 | Args: 108 | self: DisplacementField of shape (N, 2, H, W) 109 | softmin_temp (float): temperature of softmin to use 110 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 111 | the softmin inputs. Note that the outputs are not blurred. 112 | None or 0 means no blurring. 113 | subset_size (int): number of members to each set for comparison 114 | 115 | Returns: 116 | per field weight (torch.Tensor): (N, 1, H, W) 117 | """ 118 | from itertools import combinations 119 | 120 | if self.ndimension() != 4: 121 | raise ValueError( 122 | "Vector vote is only implemented on " 123 | "displacement fields with 4 dimensions. " 124 | "The input has {}.".format(self.ndimension()) 125 | ) 126 | n, shape = self.get_vote_shape() 127 | subset_size = self.get_subset_size(subset_size=subset_size) 128 | if n == 1: 129 | return torch.ones((1, *shape)).to(self) 130 | if n == subset_size: 131 | return torch.ones((1, *shape)).to(self) / n 132 | # elif n % 2 == 0: 133 | # raise ValueError('Cannot vetor vote on an even number of ' 134 | # 'displacement fields: {}'.format(n)) 135 | blurred = self.gaussian_blur(sigma=blur_sigma) if blur_sigma else self 136 | mtuples = self.get_vote_subsets(subset_size=subset_size) 137 | 138 | # compute distances for all pairs of fields 139 | dists = torch.zeros((n, n, *shape)).to(device=blurred.device) 140 | for i in range(n): 141 | for j in range(i): 142 | dists[i, j] = dists[j, i] = blurred[i].distance(blurred[j]) 143 | 144 | # compute mean distance for all majority tuples 145 | mtuple_avg = [] 146 | for mtuple in mtuples: 147 | delta = torch.stack([dists[i, j] for i, j in combinations(mtuple, 2)]).mean( 148 | dim=0 149 | ) 150 | mtuple_avg.append(delta) 151 | mavg = torch.stack(mtuple_avg) 152 | 153 | # compute weights for mtuples: smaller mean distance -> higher weight 154 | mt_weights = (-mavg / softmin_temp).softmax(dim=0) 155 | 156 | # assign mtuple weights back to individual fields 157 | field_weights = torch.zeros((n, *shape)).to(device=mt_weights.device) 158 | for i, mtuple in enumerate(mtuples): 159 | for j in mtuple: 160 | field_weights[j] += mt_weights[i] 161 | 162 | # rather than use m, prefer sum for sum precision 163 | elements_per_subset = field_weights.sum(dim=0, keepdim=True) 164 | field_weights = field_weights / elements_per_subset 165 | return field_weights 166 | 167 | 168 | def vote(self, softmin_temp=1, blur_sigma=1, subset_size=None): 169 | """Produce a single, consensus displacement field from a batch of 170 | displacement fields 171 | 172 | The resulting displacement field represents displacements that are 173 | closest to the most consistent majority of the fields. 174 | This effectively allows the fields to differentiably vote on the 175 | displacement that is most likely to be correct. 176 | 177 | Args: 178 | self: DisplacementField of shape (N, 2, H, W) 179 | softmin_temp (float): temperature of softmin to use 180 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 181 | the softmin inputs. Note that the outputs are not blurred. 182 | None or 0 means no blurring. 183 | subset_size (int): number of members to each subset 184 | 185 | Returns: 186 | DisplacementField of shape (1, 2, H, W) containing the vector 187 | vote result 188 | """ 189 | weights = self.get_vote_weights( 190 | softmin_temp=softmin_temp, blur_sigma=blur_sigma, subset_size=subset_size 191 | ) 192 | return self.linear_combination(weights=weights) 193 | 194 | 195 | def get_vote_weights_with_distances( 196 | self, distances, softmin_temp=1, blur_sigma=1, subset_size=None 197 | ): 198 | """Calculate consensus field from batch of displacement fields along with distances. 199 | Voting proceeds as normal, until it comes time to distribute the weight of each 200 | subset amongst its constitute vectors. The distribution is now based on the 201 | distances of each vector, with distances further away making a vector contribute 202 | less to consensus than nearer distance vectors in the subset. 203 | 204 | Weights should be proportional to get_vote_weights if distances are identical. 205 | 206 | Args: 207 | self: DisplacementField of shape (N, 2, H, W) 208 | distances: Tensor of shape (N, H, W) 209 | softmin_temp (float): temperature of softmin to use 210 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 211 | the softmin inputs. Note that the outputs are not blurred. 212 | None or 0 means no blurring. 213 | subset_size (int): number of members to each set for comparison 214 | 215 | Returns: 216 | per field weight (torch.Tensor): (N, H, W) 217 | """ 218 | from itertools import combinations 219 | 220 | if self.ndimension() != 4: 221 | raise ValueError( 222 | "Vector vote is only implemented on " 223 | "displacement fields with 4 dimensions. " 224 | "The input has {}.".format(self.ndimension()) 225 | ) 226 | n, shape = self.get_vote_shape() 227 | subset_size = self.get_subset_size(subset_size=subset_size) 228 | if n == 1: 229 | return torch.ones((1, *shape)).to(self) 230 | blurred = self.gaussian_blur(sigma=blur_sigma) if blur_sigma else self 231 | # distances = distances.gaussian_blur(sigma=blur_sigma) if blur_sigma else distances 232 | subset_tuples = self.get_vote_subsets(subset_size=subset_size) 233 | 234 | # compute mean of mixture distribution for all subset tuples 235 | subset_avg = {} 236 | for subset in subset_tuples: 237 | s_avg = torch.stack([blurred[i] for i in subset]).mean(dim=0) 238 | subset_avg[subset] = s_avg 239 | 240 | # compute standard deviations of mixture distribution for all subset tuples with var=0 241 | subset_std = [] 242 | for subset in subset_tuples: 243 | s_moment_sum = torch.stack([blurred[i].pow(2) for i in subset]) 244 | s_var = (s_moment_sum - subset_avg[subset].pow(2)).mean(dim=0) 245 | subset_std.append(s_var.abs().sqrt()) 246 | subset_std_dist = torch.stack(subset_std).pow(2).sum(dim=-3).sqrt() 247 | 248 | # compute weights for subset_tuples: smaller variance -> higher weight 249 | subset_weights = (-subset_std_dist / softmin_temp).softmax(dim=0) 250 | 251 | # assign subset weights back to individual fields 252 | # use distances to partition the weights: larger distance -> less weight 253 | field_weights = torch.zeros((n, *shape)).to(device=subset_weights.device) 254 | for i, subset in enumerate(subset_tuples): 255 | dists = distances[subset, ...] 256 | weights = (1.0 / dists) * (1.0 / (1.0 / dists).sum(dim=0)) 257 | for k, j in enumerate(subset): 258 | field_weights[j] += subset_weights[i] * weights[k] 259 | return field_weights 260 | 261 | 262 | def vote_with_distances( 263 | self, distances, softmin_temp=1, blur_sigma=1, subset_size=None 264 | ): 265 | """Produce a single, consensus displacement field from a batch of 266 | displacement fields along with a distance measure that weights further 267 | fields less in the consensus. 268 | 269 | Args: 270 | self: DisplacementField of shape (N, 2, H, W) 271 | distances: Tensor of shape (N, 1, H, W) 272 | softmin_temp (float): temperature of softmin to use 273 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 274 | the softmin inputs. Note that the outputs are not blurred. 275 | None or 0 means no blurring. 276 | subset_size (int): number of members to each subset 277 | 278 | Returns: 279 | DisplacementField of shape (1, 2, H, W) containing the vector 280 | vote result 281 | """ 282 | weights = self.get_vote_weights_with_distances( 283 | softmin_temp=softmin_temp, 284 | distances=distances, 285 | blur_sigma=blur_sigma, 286 | subset_size=subset_size, 287 | ) 288 | return self.linear_combination(weights=weights) 289 | 290 | 291 | def get_vote_weights_with_variances( 292 | self, var, softmin_temp=1, blur_sigma=1, subset_size=None 293 | ): 294 | """Calculate consensus field from batch of displacement fields along with variances. 295 | Each vector within self is treated as the mean of a distribution with isotropic 296 | variance for the corresponding location in variances. A subset of vectors is 297 | considered a mixture distribution. We assign higher weight to mixture distributions 298 | with lower variances. 299 | 300 | Weights should be proportional to get_vote_weights if variances are zero. 301 | 302 | Args: 303 | self: DisplacementField of shape (N, 2, H, W) 304 | var: Tensor of shape (N, 1, H, W) 305 | softmin_temp (float): temperature of softmin to use 306 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 307 | the softmin inputs. Note that the outputs are not blurred. 308 | None or 0 means no blurring. 309 | subset_size (int): number of members to each set for comparison 310 | 311 | Returns: 312 | per field weight (torch.Tensor): (N, H, W) 313 | """ 314 | from itertools import combinations 315 | 316 | if self.ndimension() != 4: 317 | raise ValueError( 318 | "Vector vote is only implemented on " 319 | "displacement fields with 4 dimensions. " 320 | "The input has {}.".format(self.ndimension()) 321 | ) 322 | n, shape = self.get_vote_shape() 323 | subset_size = self.get_subset_size(subset_size=subset_size) 324 | if n == 1: 325 | return torch.ones((1, *shape)).to(self) 326 | if n == subset_size: 327 | return torch.ones((1, *shape)).to(self) / n 328 | blurred = self.gaussian_blur(sigma=blur_sigma) if blur_sigma else self 329 | variances = torch.cat([var, var], dim=1).field() 330 | variances = variances.gaussian_blur(sigma=blur_sigma) if blur_sigma else variances 331 | subset_tuples = self.get_vote_subsets(subset_size=subset_size) 332 | 333 | # compute mean of mixture distribution for all subset tuples 334 | subset_avg = {} 335 | for subset in subset_tuples: 336 | s_avg = torch.stack([blurred[i] for i in subset]).mean(dim=0) 337 | subset_avg[subset] = s_avg 338 | 339 | # compute standard deviations of mixture distribution for all subset tuples 340 | subset_std = [] 341 | for subset in subset_tuples: 342 | s_moment_sum = torch.stack([variances[i] + blurred[i].pow(2) for i in subset]) 343 | s_var = (s_moment_sum - subset_avg[subset].pow(2)).mean(dim=0) 344 | subset_std.append(s_var.abs().sqrt()) 345 | subset_std_dist = torch.stack(subset_std).pow(2).sum(dim=-3).sqrt() 346 | 347 | # compute weights for subset_tuples: smaller variance -> higher weight 348 | subset_weights = (-subset_std_dist / softmin_temp).softmax(dim=0) 349 | 350 | # assign subset weights back to individual fields 351 | field_weights = torch.zeros((n, *shape)).to(device=subset_weights.device) 352 | for i, subset in enumerate(subset_tuples): 353 | for j in subset: 354 | field_weights[j] += subset_weights[i] 355 | 356 | # rather than use subset_size, prefer sum for sum precision 357 | elements_per_subset = field_weights.sum(dim=0, keepdim=True) 358 | field_weights = field_weights / elements_per_subset 359 | return field_weights 360 | 361 | 362 | def vote_with_variances(self, var, softmin_temp=1, blur_sigma=1, subset_size=None): 363 | """Produce a single, consensus displacement field from a batch of 364 | distributions, with displacement fields as mean and variances. 365 | 366 | Args: 367 | self: DisplacementField of shape (N, 2, H, W) 368 | var: Tensor of shape (N, 1, H, W) 369 | softmin_temp (float): temperature of softmin to use 370 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 371 | the softmin inputs. Note that the outputs are not blurred. 372 | None or 0 means no blurring. 373 | subset_size (int): number of members to each subset 374 | 375 | Returns: 376 | DisplacementField of shape (1, 2, H, W) containing the vector 377 | vote result 378 | """ 379 | weights = self.get_vote_weights_with_variances( 380 | softmin_temp=softmin_temp, 381 | var=var, 382 | blur_sigma=blur_sigma, 383 | subset_size=subset_size, 384 | ) 385 | return self.linear_combination(weights=weights) 386 | 387 | 388 | def get_priority_vote_weights( 389 | self, priorities, consensus_threshold=2, subset_size=None 390 | ): 391 | """Calculate weights to produce near-median vector with highest priority. 392 | This method differs from other voting approaches by favoring a single 393 | vector as much as possible, rather than averaging over any subset. 394 | 395 | Args: 396 | self: DisplacementField of shape (N, 2, H, W) 397 | priorities: Tensor of shape (N, H, W). Larger means higher priority. 398 | consensus_threshold (float): maximum distance from lowest score that will 399 | consider subset part of consensus 400 | subset_size (int): number of members to each set for comparison 401 | 402 | Returns: 403 | per field weight (torch.Tensor): (N, H, W) 404 | """ 405 | from itertools import combinations 406 | 407 | if consensus_threshold < 0.0: 408 | raise ValueError( 409 | "Expected non-negative value for consensus_threshold, but received {}.".format( 410 | consensus_threshold 411 | ) 412 | ) 413 | 414 | if self.ndimension() != 4: 415 | raise ValueError( 416 | "Vector vote is only implemented on " 417 | "displacement fields with 4 dimensions. " 418 | "The input has {}.".format(self.ndimension()) 419 | ) 420 | n, shape = self.get_vote_shape() 421 | subset_size = self.get_subset_size(subset_size=subset_size) 422 | if subset_size == 1: 423 | return (priorities == torch.max(priorities, dim=0, keepdim=True)[0]).float() 424 | 425 | # mtuple: majority tuples 426 | mtuples = self.get_vote_subsets(subset_size=subset_size) 427 | 428 | # compute distances for all pairs of fields 429 | dists = torch.zeros((n, n, *shape)).to(device=self.device) 430 | for i in range(n): 431 | for j in range(i): 432 | dists[i, j] = dists[j, i] = self[i].distance(self[j]) 433 | 434 | # compute mean distance for all majority tuples 435 | mtuple_avg = [] 436 | mtuple_priority = [] 437 | for mtuple in mtuples: 438 | delta = torch.stack([dists[i, j] for i, j in combinations(mtuple, 2)]).mean( 439 | dim=0 440 | ) 441 | mtuple_avg.append(delta) 442 | mtuple_priorities = torch.stack([priorities[i] for i in mtuple]) 443 | mtuple_priority.append(torch.max(mtuple_priorities, dim=0)[0]) 444 | mavg = torch.stack(mtuple_avg) 445 | # best priority for each mtuple 446 | mpriority = torch.stack(mtuple_priority) 447 | 448 | # identify vectors that participate in consensus, find their priority 449 | relative_score = mavg - torch.min(mavg, dim=0)[0] 450 | consensus_indicator = relative_score <= consensus_threshold 451 | consensus_priorities = torch.where( 452 | consensus_indicator, mpriority, torch.zeros_like(mpriority) 453 | ) 454 | consensus_priority = torch.max(consensus_priorities, dim=0, keepdim=True)[0] 455 | weights = (priorities == consensus_priority).float() 456 | return weights / weights.sum(dim=0) 457 | 458 | 459 | def priority_vote( 460 | self, 461 | priorities, 462 | consensus_threshold=2, 463 | blur_sigma=2, 464 | kernel_size=5, 465 | subset_size=None, 466 | ): 467 | """Produce a single, consensus displacement field from a batch of 468 | distributions, with displacement fields as mean and variances. 469 | 470 | Args: 471 | self: DisplacementField of shape (N, 2, H, W) 472 | priorities: Tensor of shape (N, 1, H, W). Larger means higher priority. 473 | consensus_threshold (float): maximum distance from lowest score that will 474 | consider subset part of consensus 475 | blur_sigma (float): std dev of the Gaussian kernel by which to blur 476 | the weight outputs. None or 0 means no blurring. 477 | subset_size (int): number of members to each subset 478 | 479 | Returns: 480 | DisplacementField of shape (1, 2, H, W) containing the vector 481 | vote result 482 | """ 483 | weights = self.get_priority_vote_weights( 484 | priorities=priorities, 485 | consensus_threshold=consensus_threshold, 486 | subset_size=subset_size, 487 | ) 488 | return self.smoothed_combination( 489 | weights=weights, blur_sigma=blur_sigma, kernel_size=kernel_size 490 | ) 491 | --------------------------------------------------------------------------------