├── spit ├── __init__.py ├── utils │ ├── __init__.py │ ├── cuda │ │ ├── concom.cu │ │ ├── scatterhist.cu │ │ └── cossimargmax.cu │ ├── scatter.py │ ├── indexing.py │ ├── concom.py │ ├── scatterhist.py │ └── cossimargmax.py ├── tokenizer │ ├── __init__.py │ ├── densesp.py │ ├── voronoi.py │ ├── proc.py │ └── tokenizer.py └── nn.py ├── .gitattributes ├── assets ├── fig1.png ├── fig2.png ├── fig3.png └── fig1_dark.png ├── setup.py ├── LICENSE ├── .gitignore ├── README.md └── hubconf.py /spit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /spit/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-generated=true -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsb-ifi/SPiT/HEAD/assets/fig1.png -------------------------------------------------------------------------------- /assets/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsb-ifi/SPiT/HEAD/assets/fig2.png -------------------------------------------------------------------------------- /assets/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsb-ifi/SPiT/HEAD/assets/fig3.png -------------------------------------------------------------------------------- /assets/fig1_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsb-ifi/SPiT/HEAD/assets/fig1_dark.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='spit', 5 | version='0.1', 6 | description='A Spitting Image: Modular Superpixel Tokenization in Vision Transformers', 7 | author='Marius Aasan ', 8 | licence='MIT', 9 | packages=['spit'], 10 | install_requires = [ 11 | 'torch >= 2.1.0', 12 | 'torchvision >= 0.16.0', 13 | 'scipy >= 1.10.1', 14 | 'cupy >= 13.0.0', 15 | 'numba >= 0.59.0', 16 | ] 17 | ) -------------------------------------------------------------------------------- /spit/utils/cuda/concom.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ 2 | void dpcc_roots_kernel(long long n, long long* P) { 3 | const int thread_id = blockDim.x * blockIdx.x + threadIdx.x; 4 | const int thread_count = gridDim.x * blockDim.x; 5 | bool flag = true; 6 | while (flag) { 7 | flag = false; 8 | for (long long v = thread_id; v < n; v += thread_count) { 9 | long long root = P[v]; 10 | while (root != P[root]) { 11 | root = P[root]; 12 | } 13 | P[v] = root; 14 | if (P[v] != P[P[v]]) { 15 | flag = true; 16 | } 17 | } 18 | __syncthreads(); 19 | } 20 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Marius Aasan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /spit/tokenizer/densesp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .tokenizer import diffmap_from_seg, concatenate_diffmap, superpixel_tokenizer 5 | 6 | class DenseSPEdgeEmbedder(nn.Module): 7 | 8 | '''Embeds superpixel segmentation maps into dense edge maps. 9 | 10 | This allows for embedding superpixel information into image representations 11 | without losing the original image resolution. 12 | 13 | Example: 14 | ``` 15 | >>> import torch 16 | >>> from spit.tokenizer.densesp import DenseSPEdgeEmbedder 17 | >>> img = ... # (1, 3, 224, 224) input image 18 | >>> tokenizer = DenseSPEdgeEmbedder() 19 | >>> out = tokenizer(img) 20 | >>> out.shape 21 | torch.Size([1, 5, 224, 224]) 22 | # Out can then be fed to a 5-channel Conv2d layer 23 | ``` 24 | ''' 25 | 26 | def __init__(self, maxlvl:int=4, drop_delta:bool=False, bbox_reg:bool=True, learn_vrange:bool=True, return_seg:bool=False): 27 | super().__init__() 28 | self._lgrad = 27.8 29 | self._lcol = 10. 30 | self._maxlvl = 4 31 | self.drop_delta = drop_delta 32 | self.bbox_reg = bbox_reg 33 | self.return_seg = return_seg 34 | if learn_vrange: 35 | self.vmin = nn.Parameter(torch.tensor(-torch.pi/2)) 36 | self.vmax = nn.Parameter(torch.tensor(torch.pi/2)) 37 | else: 38 | self.register_buffer('vmin', torch.tensor(-torch.pi/2)) 39 | self.register_buffer('vmax', torch.tensor(torch.pi/2)) 40 | 41 | def forward(self, img:torch.Tensor): 42 | _, seg, _ = superpixel_tokenizer( 43 | img, self._lgrad, self._lcol, self.drop_delta, self.bbox_reg, 44 | False, self._maxlvl, 45 | ) 46 | img = concatenate_diffmap(img, seg, vmin=self.vmin, vmax=self.vmax) # type: ignore 47 | if self.return_seg: 48 | return img, seg 49 | return img 50 | -------------------------------------------------------------------------------- /spit/utils/scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def scatter_mean_2d(src:torch.Tensor, idx:torch.Tensor) -> torch.Tensor: 4 | '''Computes scatter mean with 2d source and 1d index over first dimension. 5 | 6 | Args: 7 | src (torch.Tensor): Source tensor. 8 | idx (torch.Tensor): Index tensor. 9 | 10 | Returns: 11 | torch.Tensor: Output tensor. 12 | ''' 13 | assert src.ndim == 2 14 | assert len(src) == len(idx) 15 | if idx.ndim == 1: 16 | idx = idx.unsqueeze(1).expand(*src.shape) 17 | out = src.new_empty(idx.max()+1, src.shape[1]) # type: ignore 18 | return out.scatter_reduce_(0, idx, src, 'mean', include_self=False) 19 | 20 | 21 | def scatter_add_1d(src:torch.Tensor, idx:torch.Tensor, n:int) -> torch.Tensor: 22 | '''Computes scatter add with 1d source and 1d index. 23 | 24 | Args: 25 | src (torch.Tensor): Source tensor. 26 | idx (torch.Tensor): Index tensor. 27 | n (int): No. outputs. 28 | 29 | Returns: 30 | torch.Tensor: Output tensor. 31 | ''' 32 | assert src.ndim == 1 33 | assert len(src) == len(idx) 34 | out = src.new_zeros(n) 35 | return out.scatter_add_(0, idx, src) 36 | 37 | 38 | def scatter_range_2d(src:torch.Tensor, idx:torch.Tensor) -> torch.Tensor: 39 | '''Computes scattered range (max-min) with 2d source and 1d index. 40 | 41 | Args: 42 | src (torch.Tensor): Source tensor. 43 | idx (torch.Tensor): Index tensor. 44 | 45 | Returns 46 | torch.Tensor: Output tensor. 47 | ''' 48 | if idx.ndim == 1: 49 | idx = idx.unsqueeze(1).expand(*src.shape) 50 | mx = src.new_empty(idx.max()+1, src.shape[1]) # type: ignore 51 | mn = src.new_empty(idx.max()+1, src.shape[1]) # type: ignore 52 | mx.scatter_reduce_(0, idx, src, 'amax', include_self=False) 53 | mn.scatter_reduce_(0, idx, src, 'amin', include_self=False) 54 | return mx - mn 55 | 56 | 57 | def scatter_cov_2d(src:torch.Tensor, idx:torch.Tensor) -> torch.Tensor: 58 | '''Scatter covariance reduction. 59 | 60 | NOTE: Runs two passes. 61 | 62 | Args: 63 | src (torch.Tensor): Source tensor. 64 | idx (torch.Tensor): Index tensor. 65 | 66 | Returns: 67 | torch.Tensor: Output tensor. 68 | ''' 69 | d = src.shape[-1] 70 | mu = scatter_mean_2d(src, idx) 71 | diff = (src - mu[idx]) 72 | return scatter_mean_2d( 73 | (diff.unsqueeze(-1) @ diff.unsqueeze(-2)).view(-1,d**2), 74 | idx, 75 | ) 76 | -------------------------------------------------------------------------------- /spit/utils/indexing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from typing import Union, Sequence 5 | 6 | def unravel_index( 7 | indices:torch.Tensor, shape:Union[Sequence[int], torch.Tensor] 8 | ) -> torch.Tensor: 9 | '''Converts a tensor of flat indices into a tensor of coordinate vectors. 10 | 11 | Args: 12 | index (torch.Tensor): Indices to unravel. 13 | shape (tuple[int]): Shape of tensor. 14 | 15 | Returns: 16 | torch.Tensor: Tensor (long) of unraveled indices. 17 | ''' 18 | try: 19 | shape = indices.new_tensor(torch.Size(shape))[:,None] # type: ignore 20 | except Exception: 21 | pass 22 | shape = F.pad(shape, (0,0,0,1), value=1) # type: ignore 23 | coefs = shape[1:].flipud().cumprod(dim=0).flipud() 24 | return torch.div(indices[None], coefs, rounding_mode='trunc') % shape[:-1] 25 | 26 | 27 | def fast_uidx_1d(ar:torch.Tensor) -> torch.Tensor: 28 | '''Pretty fast unique index calculation for 1d tensors. 29 | 30 | Args: 31 | ar (torch.Tensor): Tensor to compute unique indices for. 32 | 33 | Returns: 34 | torch.Tensor: Tensor (long) of indices. 35 | ''' 36 | assert ar.ndim == 1, f'Need dim of 1, got: {ar.ndim}!' 37 | perm = ar.argsort() 38 | aux = ar[perm] 39 | mask = ar.new_zeros(aux.shape[0], dtype=torch.bool) 40 | mask[:1] = True 41 | mask[1:] = aux[1:] != aux[:-1] 42 | return perm[mask] 43 | 44 | 45 | def fast_uidx_long2d(ar:torch.Tensor) -> torch.Tensor: 46 | '''Pretty fast unique index calculation for 2d long tensors (row wise). 47 | 48 | Args: 49 | ar (torch.Tensor): Tensor to compute unique indices for. 50 | 51 | Returns: 52 | torch.Tensor: Tensor (long) of indices. 53 | ''' 54 | assert ar.ndim == 2, f'Need dim of 2, got: {ar.ndim}!' 55 | m = ar.max() + 1 56 | r, c = ar 57 | cons = r*m + c 58 | return fast_uidx_1d(cons) 59 | 60 | 61 | def lexsort(*tensors:torch.Tensor) -> torch.Tensor: 62 | '''Lexicographical sort of multidimensional tensor. 63 | 64 | Args: 65 | src (torch.Tensor): Input tensor. 66 | dim (int): Dimension to sort over, defaults to -1. 67 | 68 | Returns: 69 | torch.Tensor: Sorting indices for multidimensional tensor. 70 | ''' 71 | numel = tensors[0].numel() 72 | assert all([t.ndim == 1 for t in tensors]) 73 | assert all([t.numel() == numel for t in tensors[1:]]) 74 | idx = tensors[0].argsort(dim=0, stable=True) 75 | for k in tensors[1:]: 76 | idx = idx.gather(0, k.gather(0, idx).argsort(dim=0, stable=True)) 77 | return idx -------------------------------------------------------------------------------- /spit/utils/cuda/scatterhist.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ void flatnorm_scatterhist_kernel( 2 | const float* features, 3 | float* output, 4 | const long long* indices, 5 | const float* bins, 6 | const float* sigmaptr, 7 | const long long num_pixels, 8 | const long long num_features, 9 | const long long num_bins 10 | ) { 11 | 12 | long long thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 13 | long long thread_cnt = gridDim.x * blockDim.x; 14 | const float sigma = sigmaptr[0]; 15 | 16 | for (long long pixel_idx = thread_idx; pixel_idx < num_pixels; pixel_idx += thread_cnt) { 17 | long long output_idx_base = indices[pixel_idx] * num_bins * num_features; 18 | 19 | for (long long feature = 0; feature < num_features; ++feature) { 20 | float feature_val = features[pixel_idx * num_features + feature]; 21 | for (long long bin = 0; bin < num_bins; ++bin) { 22 | float bin_val = bins[bin]; 23 | float z = (feature_val - bin_val) / sigma; 24 | float hist_val = exp(-0.5 * z * z); 25 | 26 | // Calculate the output index 27 | long long output_idx = output_idx_base + feature * num_bins + bin; 28 | 29 | // Atomic add to the output 30 | atomicAdd(&output[output_idx], hist_val); 31 | } 32 | } 33 | } 34 | } 35 | 36 | extern "C" __global__ 37 | void scatter_joint_hist( 38 | const long long* seg, 39 | const float* feats, 40 | const float* mesh_y, 41 | const float* mesh_x, 42 | const long long* featcombs, 43 | float* output, 44 | float* sigmaptr, 45 | const long long n, 46 | const long long nbins, 47 | const long long nfeats, 48 | const long long feat_dim 49 | ) { 50 | long long idx = blockDim.x * blockIdx.x + threadIdx.x; 51 | long long nbins2 = nbins * nbins; 52 | float sigma = sigmaptr[0]; 53 | 54 | if (idx < n) { 55 | long long s_idx = seg[idx]; 56 | float y; 57 | float x; 58 | float z1; 59 | float z2; 60 | float value; 61 | long long j_y; 62 | long long j_x; 63 | 64 | for (long long j = 0; j < nfeats; j++){ 65 | j_y = featcombs[2*j]; 66 | j_x = featcombs[2*j+1]; 67 | y = feats[idx*feat_dim + j_y]; 68 | x = feats[idx*feat_dim + j_x]; 69 | 70 | for (long long i = 0; i < nbins2; i++) { 71 | z1 = (y - mesh_y[i]) / sigma; 72 | z2 = (x - mesh_x[i]) / sigma; 73 | value = exp(-0.5 * (z1 * z1 + z2 * z2)); 74 | atomicAdd(&output[s_idx * nfeats * nbins2 + j * nbins2 + i], value); 75 | } 76 | } 77 | } 78 | } 79 | 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # Additional stuff 158 | .vscode/ 159 | testing.ipynb 160 | .trash/ 161 | -------------------------------------------------------------------------------- /spit/utils/concom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import cupy 4 | from warnings import warn 5 | from scipy.sparse import coo_matrix as cpu_coo_matrix 6 | from scipy.sparse.csgraph import connected_components as cpu_concom 7 | from typing import Optional 8 | 9 | 10 | __location__ = os.path.realpath( 11 | os.path.join(os.getcwd(), os.path.dirname(__file__)) 12 | ) 13 | 14 | with open(os.path.join(__location__, 'cuda', 'concom.cu'), 'r') as cc_file: 15 | _dpcc_roots_kernel_code = cc_file.read() 16 | 17 | _dpcc_roots_kernel = cupy.RawKernel( 18 | _dpcc_roots_kernel_code, 'dpcc_roots_kernel' 19 | ) 20 | 21 | def dpcc_recursive( 22 | n:int, u:torch.Tensor, v:torch.Tensor, labels:torch.Tensor, 23 | it:int, maxit:int=50, tpb:int=128, bpg:Optional[int]=None 24 | ): 25 | '''Recursive loop for parallel connected components for CUDA. 26 | 27 | Args: 28 | n (int): Number of vertices in graph. 29 | u (torch.Tensor): Source edges. 30 | v (torch.Tensor): Target edges. 31 | labels (torch.Tensor): Labels of nodes. 32 | it (int): Current iteration. 33 | maxit (int, optional): Max number of iterations. 34 | tpb (int, optional): Threads per block. 35 | bpg (int, optional): Blocks per grid (def: get from n, tpg). 36 | ''' 37 | assert labels.device == u.device 38 | assert labels.device == v.device 39 | 40 | # Get current number of edges 41 | m = len(u) 42 | 43 | # Calculate default blocks per grid 44 | if bpg is None: 45 | bpg = (n + (tpb - 1)) // tpb 46 | 47 | # Convergence, end recursion 48 | if m == 0: 49 | return 50 | 51 | # Maximum iterations 52 | if it > maxit: 53 | msg = f'DPCC recursion limit - curit: {it} > maxit: {maxit}.' 54 | warn(msg, RuntimeWarning) 55 | return 56 | 57 | # Low->High vs. High->Low idxcount 58 | l2h = (u < v).sum() 59 | h2l = m - l2h 60 | 61 | # Pick largest for maximum graph reduction 62 | if l2h >= h2l: 63 | mask = u < v 64 | else: 65 | mask = u > v 66 | 67 | # Contract labels 68 | labels[u[mask]] = v[mask] 69 | 70 | # Compute roots 71 | with cupy.cuda.Device(labels.device.index): 72 | _dpcc_roots_kernel((bpg,), (tpb,), (n, cupy.from_dlpack(labels))) 73 | 74 | # Compute new edges 75 | mask = labels[u] != labels[v] 76 | uprime = labels[u[mask]] 77 | vprime = labels[v[mask]] 78 | 79 | # Recurse 80 | dpcc_recursive(n, uprime, vprime, labels, it+1, maxit, tpb, bpg) 81 | 82 | 83 | def cc_gpu(src:torch.Tensor, tgt:torch.Tensor, n:int, tpb:int=128) -> torch.Tensor: 84 | '''Parallel connected components algorithm on CUDA. 85 | 86 | Args: 87 | src (int): Source edges. 88 | tgt (int): Target edges. 89 | n (int): Number of vertices in graph. 90 | 91 | Returns: 92 | torch.Tensor: Connected components of graph. 93 | ''' 94 | # Init labels 95 | device = src.device 96 | labels = torch.arange(n, device=device) 97 | 98 | # Connected Components 99 | dpcc_recursive(n, src, tgt, labels, 0, tpb=tpb) 100 | 101 | # Return unique inverse 102 | return labels.unique(return_inverse=True)[1] 103 | 104 | 105 | def cc_cpu(src:torch.Tensor, tgt:torch.Tensor, n:int) -> torch.Tensor: 106 | '''Computes connected components using SciPy / CPU 107 | 108 | Args: 109 | src (int): Source edges. 110 | tgt (int): Target edges. 111 | n (int): Number of vertices in graph. 112 | 113 | Returns: 114 | torch.Tensor: Connected components of graph. 115 | ''' 116 | ones = torch.ones_like(src, device='cpu').numpy() 117 | edges = (src.numpy(), tgt.numpy()) 118 | csr = cpu_coo_matrix((ones, edges), shape=(n,n)).tocsr() 119 | return src.new_tensor(cpu_concom(csr)[1]) 120 | 121 | 122 | def connected_components(src:torch.Tensor, tgt:torch.Tensor, n:int, tpb=1024) -> torch.Tensor: 123 | '''Connected components algorithm (device agnostic). 124 | 125 | Args: 126 | src (int): Source edges. 127 | tgt (int): Target edges. 128 | n (int): Number of vertices in graph. 129 | 130 | Returns: 131 | torch.Tensor: Connected components of graph. 132 | ''' 133 | assert src.shape == tgt.shape 134 | if src.device.type == 'cpu': 135 | return cc_cpu(src, tgt, n) 136 | return cc_gpu(src, tgt, n, tpb=tpb) -------------------------------------------------------------------------------- /spit/utils/cuda/cossimargmax.cu: -------------------------------------------------------------------------------- 1 | __device__ __forceinline__ unsigned long long to_packed_ull(float sim, unsigned int index) { 2 | unsigned short score = static_cast(round(sim * 65535)); 3 | unsigned long long packed_value; 4 | packed_value = (static_cast(score) << 48) | static_cast(index); 5 | return packed_value; 6 | } 7 | 8 | extern "C" __global__ 9 | void argmax_cosine_kernel( 10 | const unsigned long long m, 11 | const unsigned long long d, 12 | const unsigned long long n, 13 | const float *vertices, 14 | const unsigned long long *u, 15 | const unsigned long long *v, 16 | unsigned long long *packed, 17 | const float* muptr, 18 | const float* stdptr, 19 | const float *size 20 | ) { 21 | const unsigned long long thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 22 | const unsigned long long thread_cnt = gridDim.x * blockDim.x; 23 | const float mu = muptr[0]; 24 | const float std = stdptr[0]; 25 | 26 | for (long long tid = thread_idx; tid < m; tid += thread_cnt) { 27 | unsigned long long i = u[tid]; 28 | unsigned long long j = v[tid]; 29 | float sim = 0; 30 | float norm_i = 0; 31 | float norm_j = 0; 32 | 33 | if (i == j) { 34 | sim = (size[i] - mu) / std; 35 | sim = sim < -.75 ? -.75 : sim > .75 ? .75 : sim; 36 | } 37 | else { 38 | for (unsigned long long k = 0; k < d; k++) { 39 | unsigned long long idx_i = i * d + k; 40 | unsigned long long idx_j = j * d + k; 41 | 42 | sim += vertices[idx_i] * vertices[idx_j]; 43 | norm_i += vertices[idx_i] * vertices[idx_i]; 44 | norm_j += vertices[idx_j] * vertices[idx_j]; 45 | } 46 | sim = sim / (sqrtf(norm_i) * sqrtf(norm_j)); 47 | } 48 | sim = (sim + 1.0f) / 2; 49 | atomicMax((unsigned long long *)&packed[i], to_packed_ull(sim, j)); 50 | atomicMax((unsigned long long *)&packed[j], to_packed_ull(sim, i)); 51 | } 52 | } 53 | 54 | extern "C" __global__ 55 | void argmax_cosine_kernel_bbox( 56 | const unsigned long long m, 57 | const unsigned long long d, 58 | const unsigned long long n, 59 | const float *vertices, 60 | const unsigned long long *u, 61 | const unsigned long long *v, 62 | unsigned long long *packed, 63 | const float* muptr, 64 | const float* stdptr, 65 | const float* cmixptr, 66 | const float *size, 67 | const float *ymin, 68 | const float *xmin, 69 | const float *ymax, 70 | const float *xmax 71 | ) { 72 | const unsigned long long thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 73 | const unsigned long long thread_cnt = gridDim.x * blockDim.x; 74 | const float mu = muptr[0]; 75 | const float std = stdptr[0]; 76 | const float cmix = cmixptr[0]; 77 | 78 | for (long long tid = thread_idx; tid < m; tid += thread_cnt) { 79 | unsigned long long i = u[tid]; 80 | unsigned long long j = v[tid]; 81 | float sim = 0; 82 | float cpw = 0; 83 | float per = 0; 84 | float norm_i = 0; 85 | float norm_j = 0; 86 | 87 | if (i == j) { 88 | sim = (size[i] - mu) / std; 89 | sim = sim < -0.75f ? -0.75f : sim > 0.75f ? 0.75f : sim; 90 | per = ymax[i] - ymin[i] + xmax[i] - xmin[i] + 2.0f; 91 | cpw = 4.0f * size[i] / (per * per); 92 | } 93 | else { 94 | per = ( 95 | (max(ymax[i], ymax[j]) - min(ymin[i], ymin[j])) + 96 | (max(xmax[i], xmax[j]) - min(xmin[i], xmin[j])) + 2.0f 97 | ); 98 | cpw = 4.0f * (size[i] + size[j]) / (per * per); 99 | 100 | for (unsigned long long k = 0; k < d; k++) { 101 | unsigned long long idx_i = i * d + k; 102 | unsigned long long idx_j = j * d + k; 103 | sim += vertices[idx_i] * vertices[idx_j]; 104 | norm_i += vertices[idx_i] * vertices[idx_i]; 105 | norm_j += vertices[idx_j] * vertices[idx_j]; 106 | } 107 | sim = sim / (sqrtf(norm_i) * sqrtf(norm_j)); 108 | } 109 | cpw = cpw < 0 ? 0 : cpw > 1 ? 1 : cpw; 110 | sim = cmix * cpw + (1.0f - cmix) * (sim + 1.0f) / 2.0f; 111 | atomicMax((unsigned long long *)&packed[i], to_packed_ull(sim, j)); 112 | atomicMax((unsigned long long *)&packed[j], to_packed_ull(sim, i)); 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /spit/tokenizer/voronoi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import Optional 4 | from ..utils.indexing import fast_uidx_1d 5 | from ..utils.scatter import scatter_mean_2d, scatter_cov_2d 6 | 7 | 8 | def voronoi(img: torch.Tensor, num_cells:int) -> torch.Tensor: 9 | '''Constructs a Voronoi partition. 10 | 11 | NOTE: In the paper, the RViT models were trained 12 | using pregenerated Voronoi partitions, as on-line 13 | computations are quite memory intensive. 14 | 15 | Args: 16 | img (torch.Tensor): An image. 17 | num_cells (int): Desired number of cells for the partition. 18 | 19 | Returns: 20 | Voronoi partitioning. 21 | ''' 22 | B,_,H,W = img.shape 23 | N = B*H*W 24 | C = num_cells 25 | dev = img.device 26 | shape = torch.tensor([B,H,W,1], device=dev)[:,None] 27 | coefs = shape[1:].flipud().cumprod(dim=0).flipud() 28 | byx = torch.div(torch.arange(N, device=dev)[None], coefs, rounding_mode='trunc') % shape[:-1] 29 | y, x = byx[1:] / torch.tensor([H,W], device=byx.device)[:,None] 30 | b = byx[0] 31 | sy,sx = torch.rand(2,B,C,device=byx.device) 32 | std = C/(H*W)**.5 33 | 34 | # One liner gaussian kernels 35 | def gauss1d(x): return x.div(std).pow_(2).neg_().exp_() 36 | def gauss2d(x, y): return (gauss1d(x) + gauss1d(y)) / 2 37 | 38 | out = gauss2d(y[:,None] - sy[b], x[:,None] - sx[b]).argmax(-1).view(B,H,W) 39 | return out 40 | 41 | 42 | def chunked_voronoi(img: torch.Tensor, num_cells:int, chunks:int=16) -> torch.Tensor: 43 | '''Uses chunking to lower memory overhead of Voronoi. 44 | 45 | NOTE: In the paper, the RViT models were trained 46 | using pregenerated Voronoi partitions, as on-line 47 | computations are quite memory intensive. 48 | 49 | Args: 50 | img (torch.Tensor): An image. 51 | num_cells (int): Desired number of cells for the partition. 52 | chunks (int): Number of chunks to use. 53 | 54 | Returns: 55 | Voronoi partitioning. 56 | ''' 57 | B, _, H, W = img.shape 58 | outs = [] 59 | cums = 0 60 | for c, chunk in enumerate(img.chunk(chunks, 0)): 61 | vor = voronoi(chunk, num_cells) + cums 62 | cums = vor.max().item() 63 | outs.append(vor) 64 | 65 | out = torch.cat(outs, 0) 66 | return out.view(-1).unique(return_inverse=True)[1].view(B,H,W) 67 | 68 | 69 | def _init_sobol_spatial_centroids( 70 | img:torch.Tensor, num_cells:int 71 | ) -> tuple[ 72 | torch.Tensor, 73 | torch.Tensor, 74 | torch.Tensor 75 | ]: 76 | B, _, H, W = img.shape 77 | device = img.device 78 | engine = torch.quasirandom.SobolEngine(2, True) 79 | eps = 1e-7 80 | centroids = ( 81 | engine.draw(B*num_cells).clip(0,1-1e-7) * 82 | torch.tensor([[H,W]]) 83 | ).round().long().to(device) 84 | init_index = torch.arange(B, device=device).repeat_interleave(num_cells) 85 | sizes = torch.full((B,), num_cells, dtype=torch.long, device=device) 86 | return centroids, init_index, sizes 87 | 88 | def _pcatree( 89 | points:torch.Tensor, cur_idx:torch.Tensor, 90 | cur_sizes:Optional[torch.Tensor]=None, steps:Optional[int]=None 91 | ) -> tuple[ 92 | list[torch.Tensor], 93 | list[torch.Tensor], 94 | list[torch.Tensor], 95 | ]: 96 | if cur_sizes is None: 97 | cur_sizes = cur_idx.bincount() 98 | if steps is None: 99 | num_cells = points.shape[0] // (cur_idx.max().item() + 1) 100 | steps = int(torch.log2(torch.tensor(num_cells)).ceil().long().item()) 101 | 102 | all_weights = [] 103 | all_mus = [] 104 | indices = [cur_idx] 105 | 106 | for _ in range(steps): 107 | cov = scatter_cov_2d(points, cur_idx).view(-1,2,2) 108 | mu = scatter_mean_2d(points, cur_idx) 109 | clus_rng = torch.arange(len(cov)) 110 | 111 | eigval, eigvec = torch.linalg.eigh(cov) 112 | weights = eigvec[clus_rng, eigval.argmax(-1)] 113 | cpoints = points - mu[cur_idx] 114 | dot_products = (cpoints * weights[cur_idx]).sum(-1) 115 | split = dot_products >= 0 116 | cur_idx = 2*cur_idx + split 117 | all_weights.append(weights) 118 | all_mus.append(mu) 119 | indices.append(cur_idx) 120 | 121 | return indices, all_weights, all_mus 122 | 123 | 124 | def fast_pseudo_voronoi( 125 | img:torch.Tensor, num_cells:int 126 | ) -> torch.Tensor: 127 | '''Computes a fast pseudo Voronoi tesselation. 128 | 129 | NOTE: This method uses PCA Trees first proposed by Sproull, 1991. 130 | While not strictly Voronoi tesselations, they are sufficiently 131 | similar, and samples faster O(n log n). 132 | https://doi.org/10.1007/BF01759061 133 | 134 | Args: 135 | img (torch.Tensor): Input image of shape B,C.H,W. 136 | num_cells (int): Number of cells to compute in tree. Should ideally be a power of 2. 137 | 138 | Returns: 139 | torch.Tensor: Pseudo Voronoi partitioning. 140 | ''' 141 | B, _, H, W = img.shape 142 | device = img.device 143 | points, cur_idx, sizes = _init_sobol_spatial_centroids(img, num_cells) 144 | points = points.float() 145 | 146 | _, weights, mus = _pcatree(points, cur_idx, sizes) 147 | 148 | byx = torch.stack(torch.meshgrid( 149 | torch.arange(B, device=device), 150 | torch.arange(H, device=device), 151 | torch.arange(W, device=device), 152 | indexing='ij' 153 | ), 0).view(3,-1).mT 154 | 155 | idx, cur_idx, yx = torch.arange(B*H*W), byx[:,0], byx[:,1:] 156 | yx = yx.float() 157 | steps = len(weights) 158 | 159 | for step in range(steps): 160 | split = ( 161 | (yx - mus[step][cur_idx]) * weights[step][cur_idx] 162 | ).sum(-1) >= 0 163 | cur_idx = 2*cur_idx + split 164 | 165 | return cur_idx.unique(return_inverse=True)[1].view(B,H,W) 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /spit/utils/scatterhist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy, cupy 4 | import os 5 | 6 | try: 7 | import numba 8 | 9 | @numba.jit(nopython=True, parallel=True) 10 | def _flatnorm_scatterhist_kernel_cpu( # type: ignore 11 | features, output, indices, bins, 12 | sigma, num_pixels, num_features, num_bins 13 | ): 14 | for idx in numba.prange(num_pixels): 15 | output_idx = indices[idx] 16 | for feature in range(num_features): 17 | feat_val = features[idx][feature] 18 | 19 | for bin in range(num_bins): 20 | bin_val = bins[bin] 21 | diff = feat_val - bin_val 22 | hist_val = numpy.exp(-0.5 * (diff / sigma) ** 2) 23 | output[output_idx][feature * num_bins + bin] += hist_val 24 | 25 | except ImportError: 26 | numba = None 27 | def _flatnorm_scatterhist_kernel_cpu( 28 | features, output, indices, bins, 29 | sigma, num_pixels, num_features, num_bins 30 | ): 31 | raise NotImplementedError('Numba not installed, only CUDA available.') 32 | 33 | 34 | __location__ = os.path.realpath( 35 | os.path.join(os.getcwd(), os.path.dirname(__file__)) 36 | ) 37 | 38 | with open(os.path.join(__location__, 'cuda', 'scatterhist.cu'), 'r') as sch_file: 39 | _scatterhist_kernel_code = sch_file.read() 40 | 41 | _flatnorm_scatterhist_kernel = cupy.RawKernel( 42 | _scatterhist_kernel_code, 'flatnorm_scatterhist_kernel' 43 | ) 44 | 45 | _scatter_joint_hist_kernel = cupy.RawKernel( 46 | _scatterhist_kernel_code, 'scatter_joint_hist' 47 | ) 48 | 49 | def scatter_hist( 50 | mapping:torch.Tensor, features:torch.Tensor, num_bins:int, 51 | sigma:float=0.025, low:float=-1, high:float=1, tpb=128 52 | ) -> torch.Tensor: 53 | '''Scattered histogram computation. 54 | 55 | Effectively computes KDE histograms with a Gaussian kernel using scatter operations. 56 | Args: 57 | mapping (torch.Tensor): Hierograph mapping, precomputed. 58 | features (torch.Tensor): Base pixel features. 59 | num_bins (int): Level to compute. 60 | low (int, optional): Min val of bins. 61 | high (int, optional): Max val of bins. 62 | tpb (int, optional): Threads per block for GPU. 63 | ''' 64 | device = features.device 65 | dtype = features.dtype 66 | delta = 1/num_bins 67 | bins = torch.linspace(low+delta, high-delta, num_bins, dtype=dtype, device=device) 68 | num_pixels, num_features = features.shape 69 | output = features.new_zeros( 70 | mapping.max()+1, # type: ignore 71 | num_features*num_bins, 72 | dtype=torch.float # Use float as output for portability with 6.1 devices. 73 | ) 74 | 75 | if device.type == 'cpu': 76 | _tonp = lambda x: x.numpy() 77 | _flatnorm_scatterhist_kernel_cpu( 78 | _tonp(features), 79 | _tonp(output), 80 | _tonp(mapping), 81 | _tonp(bins), 82 | sigma, 83 | num_pixels, 84 | num_features, 85 | num_bins, 86 | ) 87 | return output 88 | 89 | sigma = features.new_tensor([sigma]) # type: ignore 90 | bpg = (num_pixels + (tpb - 1)) // tpb 91 | _todl = cupy.from_dlpack 92 | with cupy.cuda.Device(device.index): 93 | if dtype == torch.half: 94 | raise NotImplementedError() 95 | elif dtype == torch.float: 96 | kernel = _flatnorm_scatterhist_kernel 97 | else: 98 | raise TypeError(f'No support for dtype:{dtype}') 99 | kernel( 100 | (bpg,), (tpb,), ( 101 | _todl(features), 102 | _todl(output), 103 | _todl(mapping), 104 | _todl(bins), 105 | _todl(sigma), 106 | num_pixels, 107 | num_features, 108 | num_bins, 109 | ) 110 | ) 111 | 112 | return output.to(dtype=dtype) 113 | 114 | 115 | def scatter_joint_hist( 116 | seg:torch.Tensor, feats:torch.Tensor, num_seg, num_bins, featcombs, 117 | sigma=0.025, low=-1, high=1, 118 | tpb=1024, 119 | ): 120 | '''Scattered histogram computation. 121 | 122 | Effectively computes KDE 2d histograms with a Gaussian kernel using scatter operations. 123 | Args: 124 | seg (torch.Tensor): Segmentation. 125 | feats (torch.Tensor): Base pixel features. 126 | num_seg (int): Number of superpixels. 127 | num_bins (int): Number of bins in each dimension. 128 | featcombs (list[tuple[int, int]]): Index of joint features for histogram. 129 | sigma (float): Bandwith of KDE kernel. 130 | low (int, optional): Min val of bins. 131 | high (int, optional): Max val of bins. 132 | tpb (int, optional): Threads per block for GPU. 133 | 134 | Returns: 135 | Tensor: 2D KDE histogram of features. 136 | ''' 137 | n, feat_dim = feats.shape 138 | delta = 1/num_bins 139 | featcombs = seg.new_tensor(featcombs) 140 | num_feats = len(featcombs) 141 | 142 | assert n == len(seg) 143 | assert featcombs.max() < feat_dim, f'{featcombs.max().item()=}>={feat_dim=}' 144 | assert featcombs.min() >= 0 145 | assert feats.dtype == torch.float 146 | 147 | bins1d = torch.linspace(low+delta, high-delta, num_bins, device=seg.device) 148 | mesh_y, mesh_x = [mesh.flatten() for mesh in torch.meshgrid(bins1d, bins1d, indexing='ij')] # type:ignore 149 | output = bins1d.new_zeros(num_seg, num_feats, num_bins**2) 150 | sigmaptr = bins1d.new_tensor([sigma]) 151 | 152 | bpg = (n + tpb - 1) // tpb 153 | _todl = cupy.from_dlpack 154 | with cupy.cuda.Device(seg.device.index) as cpdev: 155 | _scatter_joint_hist_kernel( 156 | (bpg,), (tpb,), ( 157 | _todl(seg), 158 | _todl(feats), 159 | _todl(mesh_y), 160 | _todl(mesh_x), 161 | _todl(featcombs), 162 | _todl(output), 163 | _todl(sigmaptr), 164 | n, num_bins, num_feats, feat_dim 165 | ) 166 | ) 167 | return output.view(num_seg, -1) 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # A Spitting Image: Modular Superpixel Tokenization in Vision Transformers 4 | 5 | **[Marius Aasan](https://www.mn.uio.no/ifi/english/people/aca/mariuaas/), [Odd Kolbjørnsen](https://www.mn.uio.no/math/english/people/aca/oddkol/), [Anne Schistad Solberg](https://www.mn.uio.no/ifi/english/people/aca/anne/), [Adín Ramírez Rivera](https://www.mn.uio.no/ifi/english/people/aca/adinr/)**
6 | 7 | 8 | **[DSB @ IFI @ UiO](https://www.mn.uio.no/ifi/english/research/groups/dsb/)**
9 | 10 | [![Website](https://img.shields.io/badge/Website-green)](https://dsb-ifi.github.io/SPiT/) 11 | [![PaperArxiv](https://img.shields.io/badge/Paper-arXiv-red)](https://arxiv.org/abs/2408.07680) 12 | [![PaperECCVW](https://img.shields.io/badge/Paper-ECCVW_2024-blue)](https://doi.org/10.1007/978-3-031-93806-1_11) 13 | [![NotebookExample](https://img.shields.io/badge/Notebook-Example-orange)](https://nbviewer.jupyter.org/github/dsb-ifi/SPiT/blob/main/notebooks/eval_in1k.ipynb)
14 | 15 | ![SPiT Figure 1](/assets/fig1.png#gh-light-mode-only "Examples of feature maps from SPiT-B16") 16 | ![SPiT Figure 1](/assets/fig1_dark.png#gh-dark-mode-only "Examples of feature maps from SPiT-B16") 17 | 18 |
19 | 20 | ## SPiT: Superpixel Transformers 21 | 22 | This repo contains code and weights for **A Spitting Image: Modular Superpixel Tokenization in Vision Transformers**, accepted for MELEX, ECCVW 2024. 23 | 24 | For an introduction to our work, visit the [project webpage](https://dsb-ifi.github.io/SPiT/). 25 | 26 | ## Installation 27 | 28 | The package can currently be installed via: 29 | 30 | ```bash 31 | # HTTPS 32 | pip install git+https://github.com/dsb-ifi/SPiT.git 33 | 34 | # SSH 35 | pip install git+ssh://git@github.com/dsb-ifi/SPiT.git 36 | ``` 37 | 38 | ## Loading models 39 | 40 | You can load the Superpixel Transformer model easily via `torch.hub`: 41 | 42 | ```python 43 | model = torch.hub.load( 44 | 'dsb-ifi/spit', 45 | 'spit_base_16', 46 | pretrained=True, 47 | source='github', 48 | ) 49 | ``` 50 | 51 | This will load the model and downloaded the pretrained weights, stored in your local `torch.hub` directory. 52 | If you prefer downloading weights manually, feel free to use: 53 | 54 | | Model | Link | MD5 | 55 | |-|-|-| 56 | | SPiT-S16 | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EZ57Sad2uf9Dizwm3VYhvw4BVdHOxsEJcgyf4vgKsdmgZg) |8e899c846a75c51e1c18538db92efddf| 57 | | SPiT-S16 (w. grad.) | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/Eb9FViSwap5JqYe1mtlC3jQBE-nAMG88MfJfmypT_J8r0Q) |e49be7009c639c0ccda4bd68ed34e5af| 58 | | SPiT-B16 | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EXhsshO-DvlIii87kyyEVtoBRFbZaTp8SqTgDJhQ1iQIBw) |9d3483a4c6fdaf603ee6528824d48803| 59 | | SPiT-B16 (w. grad.) | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EcahlrAzXZ5Bsozrqs4dWLABHFX-V5VH8jQR5ygHhZH30A) |9394072a5d488977b1af05c02aa0d13c| 60 | | ViT-S16 | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EWqHDQvY5V5PjKkMmO5fcFEBKuN6WTfr4a99u8vpNT67WQ) |73af132e4bb1405b510a5eb2ea74cf22| 61 | | ViT-S16 (w. grad.) | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EenEECYQaQZFl_GeU2N9q7YB-XOHNyaJXHnC74qREU3cSQ) |b8e4f1f219c3baef47fc465eaef9e0d4| 62 | | ViT-B16 | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EUWJM_RY9IRPvM9dsp2Zzi8B6ZOnhQ_C666TMESzmAQ0sQ) |ce45dcbec70d61d1c9f944e1899247f1| 63 | | ViT-B16 (w. grad.) | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EdGx5GaXRshPpOh0gsCHU4cBeZ0FxexzuBm7vTtm67nuTw) |1caa683ecd885347208b0db58118bf40| 64 | | RViT-B16 | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/Ed9R0bQOmslLiPnFX_P0hRoBUf_zQ4pfHXZ3BpQ4iW8JYA) |18c13af67d10f407c3321eb1ca5eb568| 65 | | RViT-B16 (w. grad.) | [Manual Download](https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EflpV7TP04RKmxg1qfiNovUBo149q0P9j4tmoOTQ-NkV-Q) |50d25403adfd5a12d7cb07f7ebfced97| 66 | 67 | 68 | ## More Examples 69 | 70 | We provide a [Jupyter notebook](https://nbviewer.jupyter.org/github/dsb-ifi/SPiT/blob/main/notebooks/eval_in1k.ipynb) as a sandbox for loading, evaluating, and extracting segmentations for the models. 71 | 72 | ## Notes: 73 | 74 | ### RViT and On-Line Voronoi Tesselation 75 | 76 | Currently the code features some slight modifications to streamline use of the RViT models. The original RViT models sampled partitions from a dataset of pre-computed Voronoi tesselations for training and evaluation. This is impractical for deployment, and we have yet to implement a CUDA kernel for computing Voronoi with lower memory overhead. 77 | 78 | However, we have developed a fast implementation for generating fast tesselations with PCA trees [1], which mimic Voronoi tesselations relatively well, and can be computed on-the-fly. There are, however still some minor issues with the small capacity RViT models. Consequently, the RViT-B16 models will perform marginally different than the reported results in the paper. *We appreciate the readers patience with regard to this matter.* 79 | 80 | Note that the RViT models are inherently stochastic so that different runs can yield different results. Also, it is worth mentioning that SPiT models can yield slightly different results for each run, due to nondeterministic behaviours in CUDA kernels. 81 | 82 | 83 | [1] Refinements to nearest-neighbor searching in $k$-dimensional trees [(Sproull, 1991)](https://doi.org/10.1007/BF01759061) 84 | 85 | ## Progress and Current Todo's: 86 | 87 | - [X] Include foundational code and model weights. 88 | - [X] Add manual links with MD5 hash for manual weight download. 89 | - [X] Add module for loading models, and provide example notebook. 90 | - [X] Create temporary solution to on-line Voronoi tesselation. 91 | - [X] Add `hubconf.py` for PyTorch Hub compatability. 92 | - [ ] Add example for extracting attribution maps with Att.Flow and Proto.PCA. 93 | - [ ] Add example for computing sufficiency and comprehensiveness. 94 | - [ ] Add assets for computed attribution maps for XAI experiments. 95 | - [ ] Add code and examples for salient segmentation. 96 | 97 | ## Citation 98 | 99 | If you find our work useful, please consider citing our paper. 100 | 101 | ``` 102 | @inproceedings{Aasan2024, 103 | title={A Spitting Image: Modular Superpixel Tokenization in Vision Transformers}, 104 | author={Aasan, Marius and Kolbj\o{}rnsen, Odd and Schistad Solberg, Anne and Ram\'irez Rivera, Ad\'in}, 105 | booktitle={{CVF/ECCV} Computer Vision -- {ECCVW} 2024 -- {MELEX}}, 106 | year={2024} 107 | doi="https://doi.org/10.1007/978-3-031-93806-1_11", 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /spit/utils/cossimargmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cupy 3 | import os 4 | from typing import Optional 5 | 6 | 7 | __location__ = os.path.realpath( 8 | os.path.join(os.getcwd(), os.path.dirname(__file__)) 9 | ) 10 | 11 | with open(os.path.join(__location__, 'cuda', 'cossimargmax.cu'), 'r') as csam_file: 12 | _argmax_cosine_kernel_code = csam_file.read() 13 | 14 | _argmax_cosine_kernel = cupy.RawKernel( 15 | _argmax_cosine_kernel_code, 'argmax_cosine_kernel' 16 | ) 17 | 18 | _argmax_cosine_kernel_bbox = cupy.RawKernel( 19 | _argmax_cosine_kernel_code, 'argmax_cosine_kernel_bbox' 20 | ) 21 | 22 | def argmax_cosine_gpu( 23 | vfeat:torch.Tensor, edges:torch.Tensor, sizes:torch.Tensor, nnz:int, 24 | tpb:int=1024, lvl:Optional[int]=None, bbox:Optional[torch.Tensor]=None, 25 | cmix:float=0.1 26 | ): 27 | '''Computes argmax over cosine similarity in a graph using custom CUDA kernel. 28 | 29 | Secures parallel execution by packing similarity scores and argmax indices 30 | in single unsigned long long where the highest bits encode the similarity. 31 | This ensures that we can use atomicMax calls to update the values on both 32 | edges, which eliminates race conditions. Since we do not have to initialise 33 | multiple potentially long arrays, this saves on memory consumption, while 34 | ensuring good performance. Faster than reduction, sacrificing some precision 35 | in the similarities. 36 | 37 | The packing is done by encoding the similarities as uint16. We bitshift by 48 38 | and then the indices of the array is `or`'ed into the remaining bits. After applying 39 | atomicMax, the result is `and`'ed with a 48-bit value to retrieve the indices. 40 | 41 | Args: 42 | vfeat (torch.Tensor): Vertex features. 43 | edges (torch.Tensor): Edge indices. 44 | sizes (torch.Tensor): Current node sizes. 45 | tpb (int, optional): Threads per block. 46 | lvl (int, optional): Current level. 47 | ''' 48 | dtype = vfeat.dtype 49 | u, v = edges[0].contiguous(), edges[1].contiguous() 50 | vfeat = vfeat.contiguous() 51 | sfl = sizes.contiguous().to(dtype=dtype) 52 | if lvl is None: 53 | mu = sfl.mean()[None] * 2 54 | else: 55 | # Cheaper to use expected value than compute mean size 56 | # but less adaptive to current samples. 57 | mu = sfl.new_tensor([4**(lvl-1)]) * 2 58 | std = sfl.std().clip(min=1e-6)[None] 59 | cmixptr = sfl.new_tensor([cmix]) 60 | m = len(u) 61 | d = vfeat.shape[-1] 62 | packed = edges.new_zeros(nnz) 63 | bpg = (m + tpb - 1) // tpb 64 | 65 | _todl = cupy.from_dlpack 66 | with cupy.cuda.Device(vfeat.device.index) as cpdev: 67 | if dtype == torch.float: 68 | if bbox is None: 69 | kernel = _argmax_cosine_kernel 70 | else: 71 | kernel = _argmax_cosine_kernel_bbox 72 | else: 73 | raise TypeError(f'No support for dtype:{dtype}') 74 | 75 | if bbox is None: 76 | kernel( 77 | (bpg,), (tpb,), ( 78 | m, d, nnz, 79 | _todl(vfeat), _todl(u), _todl(v), 80 | _todl(packed), _todl(mu), 81 | _todl(std), _todl(sfl) 82 | ) 83 | ) 84 | else: 85 | ymin, xmin, ymax, xmax = bbox.to(dtype=dtype) 86 | kernel( 87 | (bpg,), (tpb,), ( 88 | m, d, nnz, 89 | _todl(vfeat), _todl(u), _todl(v), 90 | _todl(packed), _todl(mu), _todl(std), _todl(cmixptr), 91 | _todl(sfl), _todl(ymin), _todl(xmin), _todl(ymax), _todl(xmax) 92 | ) 93 | ) 94 | 95 | return packed & 0xFFFFFFFFFFFF 96 | 97 | 98 | def packed_scatter_argmax(src:torch.Tensor, idx:torch.Tensor, n:int) -> torch.Tensor: 99 | '''Computes scatter argmax with 1d source and 1d index. 100 | 101 | Uses packing, i.e., reduced precision over the tensors to retrieve the argmax. 102 | Assumes inputs in range [-1, 1]. Could be generalized by scaling but, meh. 103 | Packs the src tensor as a virtual uint16, and bitshifts by 47, avoiding 104 | the sign bit of the int64. The indices of the array is or'ed into the remaining 105 | bits. The atomicMax operation takes the max over the src. The result is then 106 | and'ed with the 47-bit representation to retrieve the indices. 107 | 108 | NOTE: This is almost how the CUDA kernel works, except by considering the inputs 109 | as unsigned long longs, we squeeze a little more headroom for large indices. 110 | 111 | Args: 112 | src (torch.Tensor): Source tensor in range [-1, 1]. 113 | idx (torch.Tensor): Index tensor. 114 | n (int): No. outputs. 115 | 116 | Returns: 117 | torch.Tensor: Output tensor. 118 | ''' 119 | assert src.ndim == 1 120 | assert len(src) == len(idx) 121 | assert (len(src) & 0x7FFF800000000000) == 0 122 | 123 | shorts = (src.clip(-1, 1).add(1).div(2) * (2**16 - 1)).long() 124 | packed = (shorts << 47) | torch.arange(len(src), device=src.device) 125 | out = packed.new_zeros(n) 126 | out.scatter_reduce_(0, idx, packed, 'amax', include_self=False) 127 | return out & 0x7FFFFFFFFFFF 128 | 129 | 130 | def argmax_cosine_pytorch(vfeat, edges, sizes, nnz, lvl=None): 131 | '''Computes argmax over cosine similarity in a graph using pytorch. 132 | 133 | This is device agnostic, but requires much higher memory overhead. 134 | 135 | Args: 136 | V (torch.Tensor): Vertex features. 137 | E (torch.Tensor): Edge indices. 138 | s (torch.Tensor): Current node sizes. 139 | m (float): Mean of current sizes. 140 | ''' 141 | u, v = edges 142 | sfl = sizes.float() 143 | if lvl is None: 144 | mu = sfl.mean() 145 | else: 146 | mu = sfl.new_tensor(4**(lvl-1)) 147 | stdwt = (sfl - mu) / sfl.std().clip(min=1e-6) 148 | weight = torch.where(u == v, stdwt[u].clip(-.75, .75), 1.0) 149 | sim = torch.cosine_similarity(vfeat[u], vfeat[v]) * weight 150 | udir, vdir = torch.cat([edges, edges.flip(0)], 1) 151 | simdir = sim[None,:].expand(2,-1).reshape(-1) 152 | argmax_idx = packed_scatter_argmax(simdir, udir, nnz) 153 | return vdir[argmax_idx] 154 | 155 | 156 | def cosine_similarity_argmax(vfeat, edges, sizes, nnz, force_pt=False, lvl:Optional[int]=None, bbox:Optional[torch.Tensor]=None, tpb=1024): 157 | '''Computes argmax over cosine similarity in a graph using pytorch. 158 | 159 | Checks if tensors are located on CPU or GPU, preferring the custom CUDA 160 | implementation to PyTorch for improved memory footprint. 161 | 162 | Args: 163 | vfeat (torch.Tensor): Vertex features. 164 | edges (torch.Tensor): Edge indices. 165 | sizes (torch.Tensor): Current node sizes. 166 | nnz (float): Number of vertices. 167 | ''' 168 | assert vfeat.device == edges.device 169 | assert vfeat.device == sizes.device 170 | if vfeat.device.type == 'cpu' or force_pt: 171 | return argmax_cosine_pytorch(vfeat, edges, sizes, nnz, lvl=lvl) 172 | return argmax_cosine_gpu(vfeat, edges, sizes, nnz, tpb=tpb, lvl=lvl, bbox=bbox) 173 | -------------------------------------------------------------------------------- /spit/tokenizer/proc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | _in1k_mean = torch.tensor([0.485, 0.456, 0.406]) 5 | _in1k_std = torch.tensor([0.229, 0.224, 0.225]) 6 | 7 | def _in1k_norm(tensor, dim=-1): 8 | shape = [1] * tensor.ndim 9 | shape[dim] = -1 10 | mean = _in1k_mean.view(shape).to(tensor.device) 11 | std = _in1k_std.reshape(shape).to(tensor.device) 12 | return (tensor - mean) / std 13 | 14 | def _in1k_unnorm(tensor, dim=-1): 15 | shape = [1] * tensor.ndim 16 | shape[dim] = -1 17 | mean = _in1k_mean.view(shape).to(tensor.device) 18 | std = _in1k_std.reshape(shape).to(tensor.device) 19 | return tensor * std + mean 20 | 21 | def kuma_contrast1(x:torch.Tensor, mu:float, lambda_:float): 22 | '''Contrast adjustment with Kumaraswamy in range -1, 1. 23 | 24 | Args: 25 | x (torch.Tensor): Feature tensor. 26 | mu (float): Mean for contrast. 27 | lambda_ (float): Shape parameter for contrast. 28 | 29 | Returns: 30 | torch.Tensor: Adjusted features. 31 | ''' 32 | x = x.clip(0,1) 33 | m, a = x.new_tensor(mu).clip_(0, 1), x.new_tensor(lambda_).clip_(0) 34 | b = -(x.new_tensor(2)).log_() / (1-m**a).log_() 35 | return 1 - (1 - x**a)**b 36 | 37 | 38 | def kuma_contrast1_(x:torch.Tensor, mu:float, lambda_:float): 39 | '''Contrast adjustment with Kumaraswamy in range -1, 1. 40 | 41 | Args: 42 | x (torch.Tensor): Feature tensor. 43 | mu (float): Mean for contrast. 44 | lambda_ (float): Shape parameter for contrast. 45 | 46 | Returns: 47 | torch.Tensor: Adjusted features. 48 | ''' 49 | x.clip_(0,1) 50 | m, a = x.new_tensor(mu).clip_(0, 1), x.new_tensor(lambda_).clip_(0) 51 | b = -(x.new_tensor(2)).log_() / (1-m**a).log_() 52 | return x.pow_(a).mul_(-1).add_(1).pow_(b).mul_(-1).add_(1) 53 | 54 | 55 | def asinh_contrast(features:torch.Tensor, lambda_:float) -> torch.Tensor: 56 | '''Contrast adjustment with Arcsinh. 57 | 58 | Args: 59 | features (torch.Tensor): Feature tensor. 60 | lambda_ (float): Multiplier for contrast. 61 | 62 | Returns: 63 | torch.Tensor: Adjusted features. 64 | ''' 65 | if lambda_ == 0: 66 | return features 67 | tmul = features.new_tensor(lambda_) 68 | m, d = tmul, torch.arcsinh(tmul) 69 | if lambda_ > 0: 70 | return features.mul(m).arcsinh().div(d) 71 | return features.mul(d).sinh().div(m) 72 | 73 | 74 | def asinh_contrast_(features:torch.Tensor, lambda_:float) -> torch.Tensor: 75 | '''In-place contrast adjustment with Arcsinh. 76 | 77 | Args: 78 | features (torch.Tensor): Feature tensor. 79 | lambda_ (float): Multiplier for contrast. 80 | 81 | Returns: 82 | torch.Tensor: Adjusted features. 83 | ''' 84 | if lambda_ == 0: 85 | return features 86 | tmul = features.new_tensor(lambda_) 87 | m, d = tmul, torch.arcsinh(tmul) 88 | if lambda_ > 0: 89 | return features.mul_(m).arcsinh_().div_(d) 90 | return features.mul_(d).sinh_().div_(m) 91 | 92 | 93 | def pthroot_contrast(features:torch.Tensor, lambda_:float) -> torch.Tensor: 94 | '''Hard contrast adjustment with pth-root. 95 | 96 | Args: 97 | features (torch.Tensor): Feature tensor. 98 | lambda_ (float): Pth root for contrast. 99 | 100 | Returns: 101 | torch.Tensor: Adjusted features. 102 | ''' 103 | mul = features.sign().div_(features.new_tensor(1.0).exp_().sub_(1).pow_(1/lambda_)) 104 | return features.abs().exp_().sub_(1).pow_(1/lambda_).mul_(mul) 105 | 106 | 107 | def pthroot_contrast_(features:torch.Tensor, lambda_:float) -> torch.Tensor: 108 | '''In-place hard contrast adjustment with pth-root. 109 | 110 | Args: 111 | features (torch.Tensor): Feature tensor. 112 | lambda_ (float): Pth root for contrast. 113 | 114 | Returns: 115 | torch.Tensor: Adjusted features. 116 | ''' 117 | mul = features.sign().div_(features.new_tensor(1.0).exp_().sub_(1).pow_(1/lambda_)) 118 | features.abs_().exp_().sub_(1).pow_(1/lambda_).mul_(mul) 119 | return features 120 | 121 | 122 | def scharr_features(img:torch.Tensor, lambda_:float) -> torch.Tensor: 123 | '''Computes constrast enchanced Scharr featrues of an image. 124 | 125 | Args: 126 | img (torch.Tensor): Image tensor. 127 | contrast (float): Multiplier for contrast. 128 | 129 | Returns: 130 | torch.Tensor: Adjusted features. 131 | ''' 132 | img = img.mean(1, keepdim=True) 133 | kernel = img.new_tensor([[[[-3.,-10,-3.],[0.,0.,0.],[3.,10,3.]]]]) 134 | kernel = torch.cat([kernel, kernel.mT], dim=0) 135 | out = F.conv2d( 136 | F.pad(img, 4*[1], mode='replicate'), 137 | kernel, 138 | stride=1 139 | ).div_(16) 140 | return asinh_contrast_(out, lambda_) 141 | 142 | 143 | def adjust_saturation(rgb:torch.Tensor, mul:float): 144 | '''Adjusts saturation via interpolation / extrapolation. 145 | 146 | Args: 147 | rgb (torch.Tensor): An input tensor of shape (..., 3) representing the RGB values of an image. 148 | mul (float): Saturation adjustment factor. A value of 1.0 will keep the saturation unchanged. 149 | 150 | Returns: 151 | torch.Tensor: A tensor of the same shape as the input, with adjusted saturation. 152 | """ 153 | ''' 154 | weights = rgb.new_tensor([0.299, 0.587, 0.114]) 155 | grayscale = torch.matmul(rgb, weights).unsqueeze(dim=-1).expand_as(rgb).to(dtype=rgb.dtype) 156 | return torch.lerp(grayscale, rgb, mul).clip(0,1) 157 | 158 | 159 | def peronamalik1(img, niter=5, kappa=0.0275, gamma=0.275): 160 | """Anisotropic diffusion. 161 | 162 | Perona-Malik anisotropic diffusion type 1, which favours high contrast 163 | edges over low contrast ones. 164 | 165 | `kappa` controls conduction as a function of gradient. If kappa is low 166 | small intensity gradients are able to block conduction and hence diffusion 167 | across step edges. A large value reduces the influence of intensity 168 | gradients on conduction. 169 | 170 | Reference: 171 | P. Perona and J. Malik. 172 | Scale-space and edge detection using ansotropic diffusion. 173 | IEEE Transactions on Pattern Analysis and Machine Intelligence, 174 | 12(7):629-639, July 1990. 175 | 176 | Args: 177 | img (torch.Tensor): input image 178 | niter (int): number of iterations 179 | kappa (float): conduction coefficient. 180 | gamma (float): controls speed of diffusion (generally max 0.25) 181 | 182 | Returns: 183 | Diffused image. 184 | """ 185 | 186 | deltaS, deltaE = img.new_zeros(2, *img.shape) 187 | 188 | for _ in range(niter): 189 | deltaS[...,:-1,:] = torch.diff(img, dim=-2) 190 | deltaE[...,:,:-1] = torch.diff(img, dim=-1) 191 | 192 | gS = torch.exp(-(deltaS/kappa)**2.) 193 | gE = torch.exp(-(deltaE/kappa)**2.) 194 | 195 | S, E = gS*deltaS, gE*deltaE 196 | 197 | S[...,1:,:] = S.diff(dim=-2) 198 | E[...,:,1:] = E.diff(dim=-1) 199 | img = img + gamma*(S+E) 200 | 201 | return img 202 | 203 | 204 | def rgb_to_ycbcr(feat: torch.Tensor, dim=-1) -> torch.Tensor: 205 | r"""Convert RGB features to YCbCr. 206 | 207 | Args: 208 | feat (torch.Tensor): Pixels to be converted YCbCr. 209 | 210 | Returns: 211 | torch.Tensor: YCbCr converted features. 212 | """ 213 | r,g,b = feat.unbind(dim) 214 | y = 0.299 * r + 0.587 * g + 0.114 * b 215 | delta = 0.5 216 | cb = (b - y) * 0.564 + delta 217 | cr = (r - y) * 0.713 + delta 218 | return torch.stack([y, cb, cr], dim) 219 | 220 | 221 | def apply_color_transform( 222 | colfeat:torch.Tensor, shape:tuple[int,...], lambda_col:float, 223 | ) -> torch.Tensor: 224 | ''' Applies anisotropic diffusion and color enhancement. 225 | 226 | Args: 227 | colfeat (torch.Tensor): Color features. 228 | shape (tuple[int]): Image shape. 229 | lambda_col (int): Color contrast to use. 230 | ''' 231 | b, _ , h, w = shape 232 | c = colfeat.shape[-1] 233 | f = adjust_saturation(colfeat.add(1).div_(2), 2.718) 234 | 235 | f = rgb_to_ycbcr(f, -1).mul_(2).sub_(1) 236 | asinh_contrast_(f, lambda_col) 237 | f = peronamalik1( 238 | f.view(b, h, w, c).permute(0,3,1,2), 239 | 4, 240 | 0.1, 241 | 0.5 242 | ).permute(0,2,3,1).view(-1, c).clip_(-1,1) 243 | return f 244 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import warnings 4 | import requests 5 | from tqdm import tqdm 6 | from typing import Union 7 | from collections import OrderedDict 8 | from spit.nn import SPiT 9 | 10 | dependencies = ['torch', 'torchvision', 'scipy', 'cupy', 'numba', 'requests', 'tqdm'] 11 | 12 | _architecture_cfg = { 13 | 'S': OrderedDict(depth=12, emb_dim= 384, heads= 6, dop_path=0), 14 | 'B': OrderedDict(depth=12, emb_dim= 768, heads=12, dop_path=0.2), 15 | 'L': OrderedDict(depth=24, emb_dim=1024, heads=16, dop_path=0.2), 16 | } 17 | 18 | _std_cfg:dict[str, Union[str, int, float, bool]] = dict( 19 | classes = 1000, keep_k = 256, extractor = 'interpolate' 20 | ) 21 | 22 | _modelweights_url = dict( 23 | SPiT_S16 = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EZ57Sad2uf9Dizwm3VYhvw4BVdHOxsEJcgyf4vgKsdmgZg', 24 | SPiT_S16_grad = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/Eb9FViSwap5JqYe1mtlC3jQBE-nAMG88MfJfmypT_J8r0Q', 25 | SPiT_B16 = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EXhsshO-DvlIii87kyyEVtoBRFbZaTp8SqTgDJhQ1iQIBw', 26 | SPiT_B16_grad = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EcahlrAzXZ5Bsozrqs4dWLABHFX-V5VH8jQR5ygHhZH30A', 27 | ViT_S16 = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EWqHDQvY5V5PjKkMmO5fcFEBKuN6WTfr4a99u8vpNT67WQ', 28 | ViT_S16_grad = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EenEECYQaQZFl_GeU2N9q7YB-XOHNyaJXHnC74qREU3cSQ', 29 | ViT_B16 = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EUWJM_RY9IRPvM9dsp2Zzi8B6ZOnhQ_C666TMESzmAQ0sQ', 30 | ViT_B16_grad = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EdGx5GaXRshPpOh0gsCHU4cBeZ0FxexzuBm7vTtm67nuTw', 31 | RViT_B16 = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/Ed9R0bQOmslLiPnFX_P0hRoBUf_zQ4pfHXZ3BpQ4iW8JYA', 32 | RViT_B16_grad = 'https://uio-my.sharepoint.com/:u:/g/personal/mariuaas_uio_no/EflpV7TP04RKmxg1qfiNovUBo149q0P9j4tmoOTQ-NkV-Q', 33 | ) 34 | 35 | def _download_model_weights(model: str, grad: bool = False) -> str: 36 | model_full = f'{model}_grad' if grad else model 37 | if model_full not in _modelweights_url: 38 | raise KeyError(f'Invalid model: {model_full}') 39 | 40 | hub_dir = torch.hub.get_dir() 41 | local_path = os.path.join(hub_dir, 'checkpoints', f'{model_full}.pth') 42 | url = _modelweights_url[model_full] 43 | 44 | if url == '': 45 | raise NotImplementedError('Sorry! Weights for this model have not been uploaded yet!') 46 | 47 | url += '?download=1' 48 | 49 | if not os.path.exists(local_path): 50 | print(f'Downloading pretrained weights for {model_full} to {local_path}...') 51 | response = requests.get(url, stream=True) 52 | if response.status_code == 200: 53 | total_size = int(response.headers.get('content-length', 0)) 54 | with open(local_path, 'wb') as f, tqdm( 55 | desc=model_full, 56 | total=total_size, 57 | unit='iB', 58 | unit_scale=True, 59 | unit_divisor=1024, 60 | ) as pbar: 61 | for chunk in response.iter_content(chunk_size=8192): 62 | size = f.write(chunk) 63 | pbar.update(size) 64 | print(f'Weights downloaded to: {local_path}') 65 | else: 66 | raise ConnectionError( 67 | f'Failed to download weights: HTTP status code {response.status_code}' 68 | ) 69 | 70 | return local_path 71 | 72 | def _get_pretrained_weights(model: str, grad: bool = False, **kwargs): 73 | '''Torch Hub does not like SharePoint URLs, so we download the weights manually.''' 74 | _prefix = 'SPiT_model_' 75 | _suffix = '_grad' if grad else '' 76 | hub_dir = torch.hub.get_dir() 77 | local_path = f'{hub_dir}/checkpoints/{_prefix}{model}{_suffix}.pth' 78 | if not os.path.isfile(local_path): 79 | _download_model_weights(model, grad) 80 | sd = torch.load(local_path, map_location='cpu', weights_only=True) 81 | return sd 82 | 83 | # def _get_pretrained_weights(model:str, grad:bool=True, **kwargs): 84 | # model_full = f'{model}_grad' if grad else model 85 | # url = f'{_modelweights_url.get(model_full, "")}?download=1' 86 | # return torch.hub.load_state_dict_from_url( 87 | # url=url, 88 | # map_location="cpu", 89 | # weights_only=True, 90 | # **kwargs.get('torch_hub_kwargs', {}) 91 | # ) 92 | 93 | def spit_small_16(grad:bool=True, pretrained=False, **kwargs) -> SPiT: 94 | kwargs = {**_architecture_cfg['S'], **_std_cfg} 95 | kwargs['num_bins'] = 16 96 | kwargs['tokenizer'] = 'superpixel' 97 | kwargs['drop_delta'] = not grad 98 | kwargs['sigma2d'] = 0.025 99 | kwargs['bbox_reg'] = False 100 | if not grad: 101 | kwargs['bbox_reg'] = True 102 | 103 | model = SPiT(**kwargs) 104 | 105 | if pretrained: 106 | warnings.warn('Note that S16 weights are not fine tuned.') 107 | sd = _get_pretrained_weights('SPiT_S16', grad, **kwargs) 108 | model.load_state_dict(sd, strict=False) 109 | return model.eval() 110 | 111 | return model.eval() 112 | 113 | def spit_base_16(grad:bool=True, pretrained=False, **kwargs) -> SPiT: 114 | kwargs = {**_architecture_cfg['B'], **_std_cfg} 115 | kwargs['num_bins'] = 16 116 | kwargs['tokenizer'] = 'superpixel' 117 | kwargs['drop_delta'] = not grad 118 | kwargs['sigma2d'] = 0.025 119 | kwargs['bbox_reg'] = False 120 | 121 | if grad: 122 | kwargs['sigma2d'] = 0.05 123 | else: 124 | kwargs['bbox_reg'] = True 125 | 126 | model = SPiT(**kwargs) # type: ignore 127 | 128 | if pretrained: 129 | sd = _get_pretrained_weights('SPiT_B16', grad, **kwargs) 130 | model.load_state_dict(sd) 131 | return model.eval() 132 | 133 | return model.eval() 134 | 135 | def vit_small_16(grad:bool=True, pretrained=False, **kwargs) -> SPiT: 136 | kwargs = {**_architecture_cfg['S'], **_std_cfg} 137 | kwargs['num_bins'] = 16 138 | kwargs['tokenizer'] = 'default' 139 | kwargs['mode'] = 'nearest' 140 | kwargs['drop_delta'] = not grad 141 | kwargs['sigma2d'] = 0.025 142 | 143 | model = SPiT(**kwargs) # type: ignore 144 | 145 | if pretrained: 146 | warnings.warn('Note that S16 weights are not fine tuned.') 147 | sd = _get_pretrained_weights('ViT_S16', grad, **kwargs) 148 | model.load_state_dict(sd) 149 | return model.eval() 150 | 151 | return model.eval() 152 | 153 | def vit_base_16(grad:bool=True, pretrained=False, **kwargs) -> SPiT: 154 | kwargs = {**_architecture_cfg['B'], **_std_cfg} 155 | kwargs['num_bins'] = 16 156 | kwargs['tokenizer'] = 'default' 157 | kwargs['mode'] = 'nearest' 158 | kwargs['drop_delta'] = not grad 159 | kwargs['sigma2d'] = 0.025 160 | 161 | model = SPiT(**kwargs) # type: ignore 162 | 163 | if pretrained: 164 | sd = _get_pretrained_weights('ViT_B16', grad, **kwargs) 165 | model.load_state_dict(sd) 166 | return model.eval() 167 | 168 | return model.eval() 169 | 170 | def rvit_small_16(grad:bool=True, pretrained=False, **kwargs) -> SPiT: 171 | kwargs = {**_architecture_cfg['S'], **_std_cfg} 172 | kwargs['num_bins'] = 16 173 | kwargs['tokenizer'] = 'default' 174 | kwargs['mode'] = 'bilinear' 175 | kwargs['prvt'] = True 176 | kwargs['drop_delta'] = not grad 177 | kwargs['sigma2d'] = 0.025 178 | 179 | model = SPiT(**kwargs) # type: ignore 180 | 181 | if pretrained: 182 | raise ValueError('RViT_S16 does not have pretrained weights.') 183 | 184 | return model.eval() 185 | 186 | def rvit_base_16(grad:bool=True, pretrained=False, **kwargs) -> SPiT: 187 | kwargs = {**_architecture_cfg['B'], **_std_cfg} 188 | kwargs['num_bins'] = 16 189 | kwargs['tokenizer'] = 'default' 190 | kwargs['mode'] = 'bilinear' 191 | kwargs['prvt'] = True 192 | kwargs['drop_delta'] = not grad 193 | kwargs['sigma2d'] = 0.025 194 | 195 | model = SPiT(**kwargs) # type: ignore 196 | 197 | if pretrained: 198 | sd = _get_pretrained_weights('RViT_B16', grad, **kwargs) 199 | model.load_state_dict(sd) 200 | return model.eval() 201 | 202 | return model.eval() 203 | 204 | __all__ = [ 205 | 'spit_small_16', 206 | 'spit_base_16', 207 | 'vit_small_16', 208 | 'vit_base_16', 209 | 'rvit_small_16', 210 | 'rvit_base_16', 211 | ] 212 | 213 | -------------------------------------------------------------------------------- /spit/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.utils.checkpoint import checkpoint 6 | from torch.distributions import Beta 7 | from typing import Optional, Tuple 8 | 9 | from .tokenizer.tokenizer import ( 10 | superpixel_tokenizer, preprocess_features, preprocess_segmentation, 11 | img_coords, random_rectangular_partitions, bbox_interpolate, histogram_2d, 12 | postprocess_for_attention 13 | ) 14 | from .tokenizer.voronoi import fast_pseudo_voronoi, chunked_voronoi 15 | 16 | class SuperpixelTokenizer(nn.Module): 17 | 18 | def __init__(self, drop_delta:bool=False, bbox_reg=False, **kwargs): 19 | super().__init__() 20 | self.drop_delta = drop_delta 21 | self._lgrad = 27.8 22 | self._lcol = 10. 23 | self._maxlvl = 4 24 | self._bbox = bbox_reg 25 | self._final_th = 0.0 26 | 27 | def forward(self, feat:torch.Tensor, *args): 28 | return superpixel_tokenizer( 29 | feat, self._lgrad, self._lcol, self.drop_delta, 30 | bbox_reg=self._bbox, maxlvl=self._maxlvl, 31 | final_th=self._final_th 32 | ) 33 | 34 | def enable_bbox(self, lgrad:float=27.8, lcol:float=10.0): 35 | self._bbox = True 36 | self._lgrad = lgrad 37 | self._lcol = lcol 38 | 39 | def set_final_th(self, th): 40 | self._final_th = th 41 | 42 | 43 | class DefaultTokenizer(nn.Module): 44 | 45 | def __init__( 46 | self, num_bins:int, p_low:int=12, p_high:int=32, roll:bool=False, 47 | drop_delta:bool=False, mode:str='bilinear', rvt:bool=False, prvt:bool=False, **kwargs 48 | ): 49 | super().__init__() 50 | self.p_low = p_low 51 | self.p_high = p_high 52 | self.roll = roll 53 | self.drop_delta = drop_delta 54 | self.prvt = prvt 55 | self.rvt = rvt 56 | self.num_bins = num_bins 57 | if mode == 'nearest': 58 | self.p_low = num_bins 59 | self.p_high = num_bins 60 | self.roll = False 61 | 62 | def forward(self, img:torch.Tensor, seg:Optional[torch.Tensor]): 63 | nb, _, h, w = img.shape 64 | coords = img_coords(img) 65 | feat = preprocess_features(img, coords, 27.8, self.drop_delta) 66 | if seg is None: 67 | if self.rvt: 68 | num_cells = int(round((h*w)**.5)) 69 | seg = chunked_voronoi(img, num_cells) 70 | 71 | elif self.prvt: # Use pseudo random voronoi tesselation 72 | num_cells = int(round((h*w)**.5 * 1.1)) 73 | seg = fast_pseudo_voronoi(img, num_cells) 74 | 75 | else: 76 | seg = random_rectangular_partitions( 77 | nb, h, w, self.p_low, self.p_high, self.roll, device=feat.device 78 | ) 79 | else: 80 | seg = preprocess_segmentation(seg, coords) 81 | 82 | return feat, seg, coords 83 | 84 | def enable_bbox(self, *args): 85 | raise AttributeError('Cannot use compactness regularization for ViT / RViT!') 86 | 87 | 88 | class InterpolationExtractor(nn.Module): 89 | 90 | def __init__( 91 | self, num_bins:int, in_channels:int=3, sigma2d:float=0.025, 92 | drop_delta:bool=True, tpb:int=1024, mode:str='bilinear', **kwargs 93 | ): 94 | super().__init__() 95 | assert mode in ['nearest', 'bilinear'] 96 | self.num_bins = num_bins 97 | self.in_channels = in_channels 98 | self._intp_dims = tuple(range(in_channels)) 99 | self.sigma2d = sigma2d 100 | self.tpb = tpb 101 | self.mode = mode 102 | if drop_delta: 103 | self._2d_dims = ((in_channels, in_channels+1),) 104 | else: 105 | self._2d_dims = ((in_channels, in_channels+1), (in_channels+2, in_channels+3)) 106 | self._half2d = self.num_bins**2 107 | self.drop_delta = drop_delta 108 | 109 | def forward(self, feat:torch.Tensor, seg:torch.Tensor, coords:torch.Tensor, *args): 110 | tokens_intp = bbox_interpolate( 111 | feat, seg, coords, self.num_bins, self._intp_dims, self.mode 112 | ).view(-1, self._half2d, self.in_channels) 113 | 114 | tokens_2d = histogram_2d( 115 | feat, seg, self.num_bins, self._2d_dims, self.sigma2d, self.tpb 116 | ).view(-1, len(self._2d_dims), self._half2d).mT 117 | 118 | return torch.cat((tokens_intp, tokens_2d), -1).view(tokens_intp.shape[0], -1) 119 | 120 | def _forward_old(self, feat:torch.Tensor, seg:torch.Tensor, coords:torch.Tensor, *args): 121 | tokens_intp = bbox_interpolate( 122 | feat, seg, coords, self.num_bins, self._intp_dims, self.mode 123 | ) 124 | 125 | tokens_2d = histogram_2d( 126 | feat, seg, self.num_bins, self._2d_dims, self.sigma2d, self.tpb 127 | ) 128 | 129 | # Balance normalization of gradients for 2d 130 | if not self.drop_delta: 131 | tokens_2d.mul_(4)[:,self._half2d:].mul_(0.2) 132 | 133 | return torch.cat((tokens_intp, tokens_2d), -1) 134 | 135 | def interpolate_patch_size(self, new_bins): 136 | self.num_bins = new_bins 137 | self._half2d = self.num_bins**2 138 | self.mode = 'bilinear' 139 | 140 | 141 | class MaskedLinear(nn.Module): 142 | 143 | def __init__(self, in_feat, out_feat, bias=True, activation=None): 144 | super().__init__() 145 | self.in_feat = in_feat 146 | self.out_feat = out_feat 147 | self.linear = nn.Linear(in_feat, out_feat, bias=bias) 148 | if activation is None: 149 | self.act = nn.Identity() 150 | else: 151 | self.act = activation 152 | if bias: 153 | self.linear.bias.data.mul_(1e-3) 154 | 155 | def forward(self, x, amask): 156 | assert x.ndim == 2 157 | masked_output = self.act(self.linear(x[amask.view(-1)])) 158 | out = torch.zeros(x.shape[0], self.out_feat, dtype=masked_output.dtype, device=x.device) 159 | out[amask.view(-1)] = masked_output 160 | return out 161 | 162 | 163 | class TokenEmbedder(nn.Module): 164 | 165 | def __init__( 166 | self, num_bins:int, emb_dim:int, keep_k:int, extractor:str, 167 | in_channels:int=3, drop_delta:bool=False, **kwargs 168 | ): 169 | super().__init__() 170 | self.num_bins = num_bins 171 | self.emb_dim = emb_dim 172 | self.in_channels = in_channels 173 | self.drop_delta = drop_delta 174 | self.keep_k = keep_k 175 | self._extractorstr = extractor 176 | 177 | if extractor == 'histogram': 178 | self.token_dim = ( 179 | in_channels * (num_bins**2 // in_channels) + num_bins**2 * (2 - drop_delta) 180 | ) 181 | elif extractor == 'interpolate': 182 | self.token_dim = num_bins**2 * (in_channels + 2 - drop_delta) 183 | else: 184 | raise ValueError(f'Invalid extractor: {extractor:=}') 185 | 186 | self.embedder = MaskedLinear(self.token_dim, self.emb_dim) 187 | self.cls_token = nn.Parameter(torch.zeros(self.emb_dim) + 1e-5) 188 | 189 | def budget_dropout(self, amask:torch.Tensor) -> torch.Tensor: 190 | '''Randomly drops features keeping max token size of k, leaving the first elements untouched. 191 | 192 | Args: 193 | amask (torch.Tensor): Attention mask 194 | 195 | Returns: 196 | torch.Tensor: Indices of tokens to keep after budget dropout. 197 | ''' 198 | b, t = amask.shape 199 | first = torch.arange(b, device=amask.device) * t 200 | keep = torch.multinomial(amask[:, 1:].float(), t - 1) + 1 201 | b_idx = torch.arange(b, device=keep.device).view(-1, 1).expand_as(keep) * t 202 | keep = torch.cat((first.view(-1,1), (b_idx + keep)[:, :self.keep_k - 1]), dim=1) 203 | return keep.reshape(-1).sort().values 204 | 205 | def forward(self, feat:torch.Tensor, seg:torch.Tensor, coords:torch.Tensor, *args): 206 | nb = seg.shape[0] 207 | feat, amask, g_idx, b_idx = postprocess_for_attention(feat, seg, coords) 208 | 209 | if self.keep_k > 0 and self.training: 210 | keep = self.budget_dropout(amask) 211 | feat = feat[keep] 212 | amask = amask[:,:self.keep_k].contiguous() 213 | b_idx = b_idx[keep].contiguous() 214 | g_idx = torch.arange(0, nb*self.keep_k, self.keep_k, device=g_idx.device) 215 | 216 | feat = self.embedder(feat, amask) 217 | feat[g_idx] = self.cls_token.view(1, -1).expand(nb, -1) 218 | return feat, amask, g_idx, b_idx 219 | 220 | def random_sample_embed(self, feat, amask, g_idx, b_idx, keep_k, drop=True): 221 | nb = amask.shape[0] 222 | keep = None 223 | if drop: 224 | # Replace keep_k 225 | old_keep = self.keep_k 226 | self.keep_k = keep_k 227 | 228 | # Drop non_keeps 229 | keep = self.budget_dropout(amask) 230 | feat = feat[keep].clone() 231 | amask = amask[:,:self.keep_k].clone() 232 | b_idx = b_idx[keep].clone() 233 | g_idx = torch.arange(0, nb*self.keep_k, self.keep_k, device=g_idx.device) 234 | 235 | # Reset original keep_k 236 | self.keep_k = old_keep 237 | 238 | # Compute current outputs 239 | feat = self.embedder(feat, amask) 240 | feat[g_idx] = self.cls_token.view(1, -1).expand(nb, -1) 241 | return feat, amask, g_idx, b_idx, keep 242 | 243 | 244 | class MaskedMSA(nn.Module): 245 | 246 | def __init__( 247 | self, embed_dim:int, heads:int, dop_att:float=0.0, dop_proj:float=0.0, 248 | qkv_bias:bool=False, lnqk:bool=False, **kwargs 249 | ): 250 | super().__init__() 251 | assert embed_dim % heads == 0, f'Invalid args: embed_dim % heads != 0.' 252 | self.embed_dim = embed_dim 253 | self.heads = heads 254 | self.head_dim = embed_dim // heads 255 | self.scale = self.head_dim ** -.5 256 | self.dop_att = dop_att 257 | self.dop_proj = dop_proj 258 | self.qkv = MaskedLinear(embed_dim, 3*embed_dim, bias=qkv_bias) 259 | self.proj = MaskedLinear(embed_dim, embed_dim) 260 | if lnqk: 261 | self.ln_k = nn.LayerNorm(self.head_dim, eps=1e-6) 262 | self.ln_q = nn.LayerNorm(self.head_dim, eps=1e-6) 263 | else: 264 | self.ln_k = nn.Identity() 265 | self.ln_q = nn.Identity() 266 | 267 | def doo(self, x): 268 | return F.dropout(x, self.dop_proj, training=self.training) 269 | 270 | def doa(self, x): 271 | return F.dropout(x, self.dop_att, training=self.training) 272 | 273 | def expand_mask(self, amask): 274 | m, n = amask.shape 275 | return amask.view(m,1,1,n).expand(m, self.heads, n, n) 276 | 277 | def forward(self, feats, amask, store_att=False, pre_softmax=False): 278 | b, t = amask.shape 279 | h, d = self.heads, self.head_dim 280 | n, c = feats.shape 281 | m = n - b*t 282 | 283 | out = torch.zeros_like(feats) 284 | q, k, v = ( 285 | self.qkv(feats[m:], amask) 286 | .view(b, t, 3, h, d) 287 | .permute(2,0,3,1,4) 288 | ) 289 | if not store_att: 290 | out[m:] = self.proj( 291 | F.scaled_dot_product_attention( 292 | self.ln_q(q), self.ln_k(k), v 293 | ).transpose(1,2).reshape(-1, c), 294 | amask 295 | ) 296 | return self.doo(out) 297 | 298 | else: 299 | out[m:] = self.proj( 300 | self._manual_att( 301 | self.ln_q(q), self.ln_k(k), v, amask, pre_softmax 302 | ).transpose(1,2).reshape(-1, c), 303 | amask 304 | ) 305 | return self.doo(out) 306 | 307 | 308 | class MaskedMLP(nn.Module): 309 | 310 | def __init__(self, embed_dim:int, hid_dim:int, **kwargs): 311 | super().__init__() 312 | self.embed_dim = embed_dim 313 | self.hid_dim = hid_dim 314 | self.L1 = MaskedLinear(embed_dim, hid_dim, activation=nn.GELU()) 315 | self.L2 = MaskedLinear(hid_dim, embed_dim) 316 | 317 | def forward(self, x:torch.Tensor, amask:torch.Tensor): 318 | x = self.L1(x, amask) 319 | return self.L2(x, amask) 320 | 321 | 322 | class LayerScale(nn.Module): 323 | 324 | def __init__(self, embed_dim:int, init_val:float=1e-5): 325 | super().__init__() 326 | self.lambda_ = nn.Parameter(torch.full((embed_dim,), init_val)) 327 | 328 | def forward(self, x): 329 | return x * self.lambda_ 330 | 331 | 332 | class DropPath(nn.Module): 333 | 334 | def __init__(self, p:float, scale_by_keep:bool=True): 335 | super().__init__() 336 | self.p = p 337 | self.q = 1 - p 338 | self.scale_by_keep = scale_by_keep 339 | 340 | 341 | def forward(self, x:torch.Tensor, batch_idx:Optional[torch.Tensor]=None): 342 | if self.p == 0 or not self.training: 343 | return x 344 | 345 | if batch_idx is None: 346 | shape = (x.size(0), *((1,)*(x.ndim-1))) 347 | drops = x.new_empty(*shape).bernoulli_(self.q) 348 | 349 | if self.q > 0. and self.scale_by_keep: 350 | drops.div_(self.q) 351 | 352 | return x * drops 353 | 354 | nb = (batch_idx.max() + 1).item() 355 | drops = x.new_empty(nb).bernoulli_(self.q) # type: ignore 356 | 357 | if self.q > 0. and self.scale_by_keep: 358 | drops.div_(self.q) 359 | 360 | return x * drops.gather(0, batch_idx)[:,None] 361 | 362 | 363 | class MaskedViTBlock(nn.Module): 364 | 365 | def __init__( 366 | self, embed_dim, heads, mlp_ratio=4.0, dop_path:float=0.0, use_cp=False, **kwargs 367 | ): 368 | super().__init__() 369 | self.use_cp = use_cp 370 | self.norm1 = nn.LayerNorm(embed_dim, eps=1e-6) 371 | self.norm2 = nn.LayerNorm(embed_dim, eps=1e-6) 372 | self.ls1 = LayerScale(embed_dim) 373 | self.ls2 = LayerScale(embed_dim) 374 | self.dop1 = DropPath(dop_path) 375 | self.dop2 = DropPath(dop_path) 376 | hid_dim = int(embed_dim * mlp_ratio) 377 | self.att = MaskedMSA(embed_dim, heads, **kwargs) 378 | self.mlp = MaskedMLP(embed_dim, hid_dim) 379 | 380 | def _fwd(self, x, amask, batch_idx, store_att=False, pre_softmax=False): 381 | x = x + self.dop1(self.ls1(self.att(self.norm1(x), amask, store_att=store_att, pre_softmax=pre_softmax)), batch_idx) 382 | x = x + self.dop2(self.ls2(self.mlp(self.norm2(x), amask)), batch_idx) 383 | return x 384 | 385 | def forward(self, x, amask, batch_idx, store_att=False, pre_softmax=False): 386 | if self.use_cp and self.training: 387 | return checkpoint( 388 | self._fwd, x, amask, batch_idx, store_att=store_att, pre_softmax=pre_softmax 389 | ) 390 | return self._fwd(x, amask, batch_idx, store_att=store_att, pre_softmax=pre_softmax) 391 | 392 | 393 | class SPiT(nn.Module): 394 | 395 | def __init__( 396 | self, emb_dim:int, num_bins:int, heads:int, depth:int, classes:int, keep_k:int, 397 | extractor:str, tokenizer:str, in_channels:int=3, dop_input:float=0.0, **kwargs 398 | ): 399 | super().__init__() 400 | 401 | # Initialize tokenizer 402 | if tokenizer == 'default': 403 | self.tokenizer = DefaultTokenizer(num_bins, **kwargs) 404 | elif tokenizer == 'superpixel': 405 | self.tokenizer = SuperpixelTokenizer(**kwargs) 406 | else: 407 | raise ValueError(f'Invalid argument: {tokenizer=}') 408 | 409 | # Initialize extractor 410 | if extractor == 'interpolate': 411 | self.extractor = InterpolationExtractor(num_bins, in_channels, **kwargs) 412 | else: 413 | raise ValueError(f'Invalid argument: {extractor=}') 414 | 415 | self.embedder = TokenEmbedder(num_bins, emb_dim, keep_k, extractor, in_channels, **kwargs) 416 | self.dop_input = dop_input 417 | self.blocks = nn.ModuleList([ 418 | MaskedViTBlock( 419 | emb_dim, 420 | heads, 421 | **kwargs 422 | ) 423 | for _ in range(depth) 424 | ]) 425 | self.norm = nn.LayerNorm(emb_dim, eps=1e-6) 426 | self.head = nn.Linear(emb_dim, classes) 427 | self.id = nn.Identity() 428 | self.segmodel = False 429 | if 'segmodel' in kwargs: 430 | self.segmodel = kwargs['segmodel'] 431 | 432 | def doi(self, x): 433 | return F.dropout(x, self.dop_input, training=self.training) 434 | 435 | def avg_noncls_tokens(self, feat:torch.Tensor, amask:torch.Tensor, g_idx:torch.Tensor): 436 | b_idx = torch.where(amask)[0].unsqueeze(1).expand(-1, feat.shape[1]) 437 | out = -feat[g_idx] 438 | out.scatter_add_( 439 | 0, b_idx, feat[amask.view(-1)] 440 | ).div_(amask.sum(1).view(-1, 1) - 1) 441 | return out 442 | 443 | def to_segmodel(self, seg_classes:int): 444 | self.to_newclasses(seg_classes) 445 | self.segmodel = True 446 | 447 | def to_newclasses(self, new_classes:int): 448 | emb_dim = self.head.in_features 449 | device = self.head.weight.device 450 | self.head = nn.Linear(emb_dim, new_classes, device=device) 451 | 452 | @staticmethod 453 | def maxmin(x, eps=1e-6): 454 | mx, mn = x[:,1:].max(-1, keepdim=True).values, x[:,1:].min(-1, keepdim=True).values 455 | return (x - mn) / (mx - mn + eps) 456 | 457 | def normalized_attmap(self, attn): 458 | attmap = attn[:,0].log() 459 | std_dev = torch.std(attmap, dim=-1, keepdim=True) 460 | mean = torch.mean(attmap, dim=-1, keepdim=True) 461 | attmap = (attmap - mean) / std_dev 462 | return attmap 463 | 464 | def normalized_pca(self, feat, prototypes, num_pc=1): 465 | if prototypes is not None: 466 | feat = torch.cat([prototypes, feat], 1) 467 | mean = torch.mean(feat, dim=1, keepdim=True) 468 | feat_centered = feat - mean 469 | if prototypes is not None: 470 | feat_centered = feat_centered[:,prototypes.shape[1]:] 471 | _, _, V = torch.pca_lowrank(feat_centered, num_pc) 472 | proj = torch.einsum('bnd,bdp->bnp', feat_centered, V).max(-1).values 473 | std_dev = torch.std(proj[:,1:], dim=-1, keepdim=True) 474 | mean = torch.mean(proj[:,1:], dim=-1, keepdim=True) 475 | proj[:,1:] = (proj[:,1:] - mean) / std_dev 476 | return proj 477 | 478 | def forward( 479 | self, feat:torch.Tensor, seg:Optional[torch.Tensor] = None, 480 | headless=False, return_seg=None, return_attn=False, return_pca=False, 481 | prototypes=None, do_lazy=False, newtok=False, parttok=False, 482 | save_feats=None, **kwargs 483 | ): 484 | if do_lazy: 485 | return torch.inverse(torch.ones((0, 0), device="cuda:0")) 486 | if not newtok: 487 | if parttok: 488 | feat, seg, coords, sizes, nnz, bbox = self.tokenizer(feat) 489 | feat, seg, coords, nnz = self.extractor(feat, seg, coords, sizes, nnz, bbox) 490 | feat = feat.flatten(1,-1) 491 | else: 492 | feat, seg, coords = self.tokenizer(feat, seg) 493 | feat = self.extractor(feat, seg, coords) 494 | feat, amask, g_idx, b_idx = self.embedder(feat, seg, coords) 495 | else: 496 | feat, seg, coords, sizes, nnz, bbox = self.tokenizer(feat) 497 | feat, seg, coords, nnz = self.extractor(feat, seg, coords, sizes, nnz, bbox) 498 | feat, seg, amask = self.embedder(feat, seg, coords, nnz) 499 | amcp = amask.clone() 500 | amcp[:,1:] = 0 501 | g_idx = amcp.view(-1) 502 | b_idx = None 503 | attn = None 504 | pca = None 505 | out = [] 506 | eye = torch.eye(amask.shape[-1], device=feat.device) 507 | head_fn = self.id if headless else self.head 508 | 509 | if return_seg is None: 510 | return_seg = self.segmodel 511 | 512 | feat = self.doi(feat) 513 | store_feat = [] 514 | save_feats = [len(self.blocks)-1] if save_feats is None else save_feats 515 | psm = kwargs.get('pre_softmax', False) 516 | for _i, block in enumerate(self.blocks): 517 | feat = block(feat, amask, b_idx, store_att=return_attn, pre_softmax=psm) 518 | if _i in save_feats: 519 | store_feat.append(feat) 520 | if return_attn: 521 | if attn is None: 522 | attn = 0.9*block.att._attn + 0.1*eye 523 | else: 524 | attn = attn.clip(1e-8, 1) @ (0.9*block.att._attn + 0.1*eye) # Preclip 0.1 525 | block.att._attn = None 526 | feat = torch.cat([self.norm(f.view(-1, f.shape[-1])) for f in store_feat], -1) 527 | 528 | if not return_seg: 529 | out.append(head_fn(feat[g_idx])) 530 | 531 | else: 532 | assert seg is not None 533 | seg = seg - seg.view(seg.shape[0], -1).min(-1).values[:,None,None] + 1 534 | out += [head_fn(feat[amask.view(-1)]), seg] 535 | 536 | if attn is not None: 537 | if kwargs.get('attn_as_matrix', False): 538 | out.append(attn) 539 | else: 540 | attn = self.normalized_attmap(attn.max(1).values) 541 | out.append(self.maxmin(attn).view(-1)[amask.view(-1)]) 542 | 543 | if return_pca: 544 | pca = self.normalized_pca(feat.view(*amask.shape, -1), prototypes) 545 | out.append(self.maxmin(pca).view(-1)[amask.view(-1)]) 546 | 547 | if len(out) > 1: 548 | return tuple(out) 549 | 550 | return out[0] 551 | 552 | 553 | def explain(self, img, seg, label, explanations=512, keep_k_bounds=(0.1, 0.3)): 554 | assert len(label) == len(img) 555 | assert explanations > 0 556 | assert not self.training 557 | 558 | with torch.no_grad(): 559 | feat, seg, coords = self.tokenizer(img, seg) 560 | feat = self.extractor(feat, seg, coords) 561 | ofeat, oamask, og_idx, ob_idx = postprocess_for_attention(feat, seg, coords) 562 | scores = feat.new_zeros(explanations, oamask.shape[0]) 563 | covariates = feat.new_zeros(explanations, ofeat.shape[0]) 564 | attn = None 565 | rng = torch.arange(len(label), device=label.device) 566 | _sc, _sh = max(keep_k_bounds) - min(keep_k_bounds), min(keep_k_bounds) 567 | for _ex in range(explanations): 568 | keep_k_ratio = torch.rand(1).item() * _sc + _sh 569 | keep_k = int(keep_k_ratio * oamask.shape[-1]) 570 | feat, amask, g_idx, b_idx, keep = self.embedder.random_sample_embed(ofeat, oamask, og_idx, ob_idx, keep_k) 571 | covariates[_ex,keep] = 1 572 | for block in self.blocks: 573 | feat = block(feat, amask, b_idx) 574 | scores[_ex] = self.head(self.norm(feat)[g_idx]).softmax(-1)[rng, label] 575 | 576 | feat, amask, g_idx, b_idx, keep = self.embedder.random_sample_embed(ofeat, oamask, og_idx, ob_idx, 0, drop=False) 577 | eye = torch.eye(amask.shape[-1], device=feat.device) 578 | for block in self.blocks: 579 | assert isinstance(block, MaskedViTBlock) 580 | ofeat = block(feat, amask, b_idx, store_att=True) 581 | if attn is None: 582 | attn = block.att._attn.max(1).values + eye 583 | else: 584 | attn = attn @ (block.att._attn.max(1).values + eye) 585 | block.att._attn = None # type: ignore 586 | 587 | 588 | covariates = covariates.view(explanations, amask.shape[0], -1).permute(1,0,2) 589 | 590 | return self.norm(feat), attn, covariates, scores.mT 591 | -------------------------------------------------------------------------------- /spit/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from typing import Sequence, Optional 5 | 6 | from .proc import ( 7 | kuma_contrast1_, asinh_contrast_, 8 | apply_color_transform, scharr_features, 9 | _in1k_unnorm 10 | ) 11 | from ..utils.indexing import unravel_index, fast_uidx_long2d 12 | from ..utils.scatter import scatter_mean_2d, scatter_add_1d 13 | from ..utils.scatterhist import scatter_hist, scatter_joint_hist 14 | from ..utils.cossimargmax import cosine_similarity_argmax 15 | from ..utils.concom import connected_components 16 | 17 | 18 | def getmem(d) -> float: 19 | '''Returns reserved memory on device. 20 | 21 | Args: 22 | d (torch.device): A torch device. 23 | 24 | Returns: 25 | float: Currently reserved memory on device. 26 | ''' 27 | if d.type == 'cpu': 28 | return 0 29 | a, b = torch.cuda.mem_get_info(d) 30 | return (b-a) / (1024**2) 31 | 32 | 33 | def init_segedges(img:torch.Tensor) -> tuple[torch.Tensor,torch.Tensor,int]: 34 | '''Computes edges and initial segmentation for an image. 35 | 36 | Args: 37 | img (torch.Tensor): Input image of shape [B,C,H,W]. 38 | 39 | Returns: 40 | tuple[torch.Tensor]: tuple of segmentation, edges, and number of elements. 41 | ''' 42 | nb, _, h, w = img.shape 43 | nnz = nb*h*w 44 | seg = torch.arange(nnz, device=img.device).view(nb, h, w) 45 | lr = seg.unfold(-1, 2, 1).reshape(-1, 2).mT 46 | ud = seg.unfold(-2, 2, 1).reshape(-1, 2).mT 47 | edges = torch.cat([lr, ud], -1) 48 | return seg.view(-1), edges, nnz 49 | 50 | 51 | def img_coords(img:torch.Tensor): 52 | '''Returns image coordinates. 53 | 54 | Args: 55 | img (torch.Tensor): Input image. 56 | 57 | Returns: 58 | torch.Tensor: Image coordinates of shape `(b,h,w)`. 59 | ''' 60 | nb, _, h, w = img.shape 61 | return unravel_index( 62 | torch.arange(nb*h*w, device=img.device), # type:ignore 63 | (nb,h,w) 64 | ) 65 | 66 | 67 | def get_hierarchy_sublevel( 68 | hierograph:Sequence[torch.Tensor], level1:int, level2:int 69 | ) -> torch.Tensor: 70 | '''Retrieves superpixel mapping between level1 to level2. 71 | 72 | Args: 73 | hierograph (List[torch.Tensor]): List of mapping tensors. 74 | level1 (int): Level to compute mapping from. 75 | level2 (int): Level to compute mapping to. 76 | 77 | Returns: 78 | torch.Tensor: Mapping of indices from level1 to level2. 79 | ''' 80 | nlvl = len(hierograph) 81 | if not (0 <= level1 < nlvl and 0 <= level2 < nlvl): 82 | raise ValueError("Invalid hierarchy levels") 83 | 84 | if level1 == level2: 85 | return hierograph[level1] 86 | 87 | min_level, max_level = min(level1, level2), max(level1, level2) 88 | 89 | segmentation = hierograph[min_level] 90 | for i in range(min_level + 1, max_level + 1): 91 | segmentation = hierograph[i][segmentation] 92 | 93 | return segmentation 94 | 95 | 96 | def get_hierarchy_level(hierograph:Sequence[torch.Tensor], level:int) -> torch.Tensor: 97 | '''Retrieves superpixel mapping from initial pixels to level. 98 | 99 | Args: 100 | hierograph (Iterable[torch.Tensor]): List of mapping tensors. 101 | level (int): Level to compute mapping to. 102 | 103 | Returns: 104 | torch.Tensor: Mapping of indices up to level. 105 | ''' 106 | return get_hierarchy_sublevel(hierograph, 0, level) 107 | 108 | 109 | def preprocess_features( 110 | img:torch.Tensor, coords:torch.Tensor, lambda_delta:float, drop_delta:bool=False 111 | ) -> torch.Tensor: 112 | '''Preprocesses image for feature extraction. 113 | 114 | Args: 115 | img (torch.Tensor): Input image. 116 | lamdba_delta (float): Hyperparameter for gradient contrast adjustment. 117 | drop_delta (bool, optional): Flag for skipping gradient features. 118 | 119 | Returns: 120 | torch.Tensor: Preprocessed features. 121 | ''' 122 | shape_proc = lambda x: x.permute(1,0,2,3).reshape(x.shape[1], -1).unbind(0) 123 | 124 | nb, _, h, w = img.shape 125 | den = max(h, w) 126 | _, y, x = coords.to(img.dtype).mul(2/den).sub_(1).to(img.dtype) 127 | r, g, b = shape_proc(img) 128 | kuma_contrast1_(r, .485, .539).mul_(2).sub_(1) 129 | kuma_contrast1_(g, .456, .507).mul_(2).sub_(1) 130 | kuma_contrast1_(b, .406, .404).mul_(2).sub_(1) 131 | features = [r,g,b,y,x] 132 | 133 | if not drop_delta: 134 | gy, gx = shape_proc(scharr_features(img, lambda_delta)) 135 | features = [*features, gy, gx] 136 | 137 | return torch.stack(features, -1) 138 | 139 | 140 | def preprocess_segmentation(seg:torch.Tensor, coords:torch.Tensor) -> torch.Tensor: 141 | '''Preprocesses image segmentation. 142 | 143 | Applied when segmentation indices are not unique across batches. 144 | 145 | Args: 146 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 147 | coords (torch.Tensor): Tensor of shape (3,B*H*W) of pixel coords. 148 | 149 | Returns: 150 | torch.Tensor: Segmentation with unique indices across batches. 151 | ''' 152 | nb, h, w = seg.shape 153 | shifts = torch.arange(nb, device=seg.device).mul_(h*w).view(-1, 1, 1) 154 | return (seg + shifts).unique(return_inverse=True)[1] 155 | 156 | 157 | def histogram_1d( 158 | X:torch.Tensor, seg:torch.Tensor, num_bins:int, dims:Sequence[int], 159 | sigma:float, tpb:int=1024 160 | ) -> torch.Tensor: 161 | '''Computes 1d histogram features for selected dimensions. 162 | 163 | Args: 164 | X (torch.Tensor): Features of shape (B*H*W,D). 165 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 166 | num_bins (int): Number of bins. 167 | dims (Iterable[int]): Feature dimensions to compute histograms for. 168 | sigma (float): Sigma for KDE. 169 | tpb (int): Threads per block. 170 | 171 | Returns: 172 | torch.Tensor: 1d histogram features. 173 | ''' 174 | seg = seg.view(-1) 175 | tdims = seg.new_tensor(dims) if not torch.is_tensor(dims) else dims 176 | den = seg.bincount().to(dtype=X.dtype).unsqueeze(-1).mul_(len(tdims)/16) 177 | out = scatter_hist(seg, X[:,tdims].clone(), num_bins, sigma=sigma, tpb=tpb) 178 | return out / den 179 | 180 | 181 | def histogram_2d( 182 | X:torch.Tensor, seg:torch.Tensor, num_bins:int, dims:Sequence[tuple[int,int]], 183 | sigma:float, tpb:int=1024 184 | ) -> torch.Tensor: 185 | '''Computes 2d histogram features for selected dimensions. 186 | 187 | NOTE: The dimensions are pairs of dimensions we want the joint histograms over. 188 | 189 | Args: 190 | X (torch.Tensor): Features of shape (B*H*W,D). 191 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 192 | num_bins (int): Number of bins. 193 | dims (Iterable[tuple[int,int]]): Pairs of dimensions for computing histograms. 194 | sigma (float): Sigma for KDE. 195 | tpb (int): Threads per block. 196 | 197 | Returns: 198 | torch.Tensor: 2d histogram features. 199 | ''' 200 | seg = seg.view(-1) 201 | m = int(seg.max().item())+1 202 | den = seg.bincount().to(dtype=X.dtype).unsqueeze(-1).mul_(1/4*(num_bins/16)**2) 203 | out = scatter_joint_hist(seg, X, m, num_bins, dims, sigma=sigma, tpb=tpb) 204 | return out / den 205 | 206 | 207 | def bbox_coords(seg:torch.Tensor, coords:torch.Tensor) -> torch.Tensor: 208 | '''Calculates the bounding box coordinates for each partition in segmentation maps. 209 | 210 | NOTE: Uses ordering convention ymin, xmin, ymax, xmax 211 | 212 | Args: 213 | seg (torch.Tensor): Segmentation map of shape (B, H, W). 214 | coords (torch.Tensor): Tensor of shape (3,B*H*W) of pixel coords. 215 | 216 | Returns: 217 | torch.Tensor: bbox coordinates (ymin, xmin, ymax, xmax) for each partition. 218 | ''' 219 | nb, h, w = seg.shape 220 | _, y, x = coords 221 | bbox = seg.new_zeros(4, int(seg.max().item()) + 1) 222 | bbox[0].scatter_reduce_(0, seg.view(-1), y, 'amin', include_self=False) 223 | bbox[1].scatter_reduce_(0, seg.view(-1), x, 'amin', include_self=False) 224 | bbox[2].scatter_reduce_(0, seg.view(-1), y, 'amax', include_self=False) 225 | bbox[3].scatter_reduce_(0, seg.view(-1), x, 'amax', include_self=False) 226 | return bbox 227 | 228 | 229 | def bbox_interpolate( 230 | feat:torch.Tensor, seg:torch.Tensor, coords:torch.Tensor, 231 | num_bins:int, dims:Sequence[int], mode:str='bilinear' 232 | ) -> torch.Tensor: 233 | '''Interpolation of partition to fixed square size. 234 | 235 | This function assumes that the dimensions specified in dims are the last two dimensions 236 | of the input tensor X. It uses bilinear or nearest neighbour interpolation and generates 237 | a mask for each bounding box to ensure the interpolation only affects the pixels within 238 | the bounding box. 239 | 240 | Args: 241 | X (torch.Tensor): Features of shape (B*H*W,D). 242 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 243 | coords (torch.Tensor): Tensor of shape (3,B*H*W) of pixel coords. 244 | num_bins (int): Dimension of the square interpolation. 245 | dims (Iterable[int]): An iterable of dimensions / channels to interpolate. 246 | mode (str): The interpolation mode, either `nearest` or `bilinear`. 247 | 248 | Returns: 249 | torch.Tensor: A tensor of square bilinearly interpolated features. 250 | ''' 251 | assert mode in ['bilinear', 'nearest'] 252 | nb, h, w = seg.shape 253 | b, y, x = coords 254 | dims = seg.new_tensor(dims) if not torch.is_tensor(dims) else dims # type: ignore 255 | 256 | # Construct the batch indices of the segmentation 257 | b_idx = seg.view(-1).mul(nb).add(b).unique() % nb 258 | 259 | # Construct image and bbox coordinates 260 | img = feat[:,dims].view(nb, h, w, -1) 261 | ymin, xmin, ymax, xmax = bbox_coords(seg, coords).view(4, -1, 1, 1) 262 | 263 | # Construct the grid 264 | grid_base = torch.linspace(0, 1, num_bins, device=feat.device, dtype=feat.dtype) 265 | ygrid, xgrid = torch.meshgrid(grid_base, grid_base, indexing='ij') 266 | ygrid, xgrid = ygrid.reshape(-1, num_bins**2, 1), xgrid.reshape(-1, num_bins**2, 1) 267 | 268 | # Get coordinates and indices for batch / channel dimensions 269 | h_pos = ygrid * (ymax - ymin) + ymin 270 | w_pos = xgrid * (xmax - xmin) + xmin 271 | b_idx = b_idx.view(-1, 1, 1).expand(-1, num_bins**2, -1) 272 | c_idx = dims.view(1,1,-1).expand(*b_idx.shape[:2], -1) # type: ignore 273 | 274 | if mode == 'bilinear': 275 | 276 | # Construct lower and upper bounds 277 | h_floor = h_pos.floor().long().clamp(0, h-1) 278 | w_floor = w_pos.floor().long().clamp(0, w-1) 279 | h_ceil = (h_floor + 1).clamp(0, h-1) 280 | w_ceil = (w_floor + 1).clamp(0, w-1) 281 | 282 | # Construct fractional parts of bilinear coordinates 283 | Uh, Uw = h_pos - h_floor, w_pos - w_floor 284 | Lh, Lw = 1 - Uh, 1 - Uw 285 | hfwf, hfwc, hcwf, hcwc = Lh*Lw, Lh*Uw, Uh*Lw, Uh*Uw 286 | 287 | # Get interpolated features 288 | bilinear = ( 289 | img[b_idx, h_floor, w_floor, c_idx] * hfwf + 290 | img[b_idx, h_floor, w_ceil, c_idx] * hfwc + 291 | img[b_idx, h_ceil, w_floor, c_idx] * hcwf + 292 | img[b_idx, h_ceil, w_ceil, c_idx] * hcwc 293 | ) 294 | 295 | # Get masks 296 | srange = torch.arange(b_idx.shape[0], device=feat.device).view(-1,1) 297 | masks = ( 298 | (seg[b_idx[:,:,0], h_floor[:,:,0], w_floor[:,:,0]] == srange).unsqueeze(-1) * hfwf + 299 | (seg[b_idx[:,:,0], h_floor[:,:,0], w_ceil[:,:,0]] == srange).unsqueeze(-1) * hfwc + 300 | (seg[b_idx[:,:,0], h_ceil[:,:,0], w_floor[:,:,0]] == srange).unsqueeze(-1) * hcwf + 301 | (seg[b_idx[:,:,0], h_ceil[:,:,0], w_ceil[:,:,0]] == srange).unsqueeze(-1) * hcwc 302 | ) 303 | 304 | return (bilinear * masks).view(bilinear.shape[0], -1) 305 | 306 | elif mode == 'nearest': 307 | # Construct lower and upper bounds 308 | h_pos = h_pos.round().long().clamp(0, h-1) 309 | w_pos = w_pos.round().long().clamp(0, w-1) 310 | 311 | # Get interpolated features 312 | nearest = img[b_idx, h_pos, w_pos, c_idx] 313 | 314 | # Get masks 315 | srange = torch.arange(b_idx.shape[0], device=feat.device).view(-1,1) 316 | mask = (seg[b_idx[:,:,0], h_pos[:,:,0], w_pos[:,:,0]] == srange).unsqueeze(-1) 317 | 318 | return (nearest * mask).view(nearest.shape[0], -1) 319 | 320 | raise ValueError(f'Invalid interpolation mode: {mode=}') 321 | 322 | 323 | def postprocess_for_attention( 324 | feat:torch.Tensor, seg:torch.Tensor, coords:torch.Tensor 325 | ) -> tuple[torch.Tensor,...]: 326 | '''Postprocess features for self-attention operators. 327 | 328 | Essentially computes an attention mask for non-fixed numbers of patches and pads 329 | the features to accept a global class token. 330 | 331 | Args: 332 | feat (torch.Tensor): Features of shape (B*H*W,D). 333 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 334 | coords (torch.Tensor): Tensor of shape (3,B*H*W) of pixel coords. 335 | 336 | Returns: 337 | tuple[torch.Tensor]: Output features, attention mask, global- and batch indices. 338 | ''' 339 | nb, b = seg.shape[0], coords[0] 340 | b_idx = seg.view(-1).mul(nb).add(b).unique() % nb 341 | bc = b_idx.bincount() 342 | maxdim = bc.max() + 1 343 | idx = ( 344 | torch.arange(len(b_idx), device=b_idx.device) - 345 | (bc.cumsum(-1) - bc).repeat_interleave(bc) 346 | ) 347 | 348 | amask = feat.new_zeros(nb, maxdim, dtype=torch.bool) 349 | outfeat = feat.new_zeros(nb, maxdim, feat.shape[-1]) 350 | amask[b_idx, idx+1] = True 351 | amask[:,0] = True 352 | outfeat[b_idx, idx+1] = feat 353 | g_idx = torch.arange(0, nb*maxdim, maxdim, device=idx.device) 354 | b_idx = torch.arange(nb*maxdim, device=b_idx.device) // maxdim 355 | 356 | return outfeat.view(-1, feat.shape[-1]), amask, g_idx, b_idx 357 | 358 | 359 | def init_img_graph( 360 | img:torch.Tensor, lambda_delta:float, lambda_col:float, drop_delta:bool=False 361 | ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], 362 | torch.Tensor, torch.Tensor, int, torch.Tensor]: 363 | '''Initialises image graph and base level features. 364 | 365 | Args: 366 | img (torch.Tensor): Image tensor. 367 | grad_contrast (float): Parameter for gradient contrast. 368 | col_contrast (float): Parameter for color contrast. 369 | drop_delta (bool, optional): Drop computation of discrete gradient features. Default: False 370 | 371 | Returns: 372 | tuple[torch.Tensor]: tuple with image graph features. 373 | ''' 374 | nb, _, h, w = img.shape 375 | lab, edges, nnz = init_segedges(img) 376 | coords = unravel_index(lab, (nb, h, w)) 377 | feat = preprocess_features(img, coords, lambda_delta, False) # Keep delta for vfeat 378 | sizes = torch.ones_like(lab) 379 | maxval = asinh_contrast_(img.new_tensor(13/16), lambda_delta).mul_(2**.5) 380 | vfeat = torch.cat([ 381 | apply_color_transform(feat[:,:3], img.shape, lambda_col), 382 | feat[:,-2:].norm(2, dim=1, keepdim=True).div_(maxval).mul_(2).sub_(1), 383 | ], -1).float() 384 | 385 | if drop_delta: 386 | feat = feat[:,:5] 387 | 388 | return ( 389 | lab, edges, vfeat, feat, sizes, nnz, coords 390 | ) 391 | 392 | 393 | def spit_step( 394 | lab:torch.Tensor, edges:torch.Tensor, vfeat:Optional[torch.Tensor], 395 | sizes:torch.Tensor, nnz:int, lvl:int, bbox:Optional[torch.Tensor]=None, tpb:int=1024, 396 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor, Optional[torch.Tensor]]: 397 | '''Computes a Superpixel Hierarchy step. 398 | 399 | Args: 400 | lab (torch.Tensor): Labels, segmentation. 401 | edges (torch.Tensor): Edges. 402 | vfeat (torch.Tensor): Superpixel features. 403 | sizes (torch.Tensor): Superpixel sizes. 404 | nnz (int): No. vertices. 405 | lvl (int): Level of superpixel hierarchy. 406 | tpb (int, optional): CUDA threads per block. Default: 1024 407 | 408 | Returns: 409 | tuple[torch.Tensor]: Updated parameters. 410 | ''' 411 | assert vfeat is not None 412 | sim = cosine_similarity_argmax(vfeat, edges, sizes, nnz, tpb=tpb, lvl=None, bbox=bbox) 413 | cc = connected_components(lab, sim, nnz, tpb=tpb) 414 | vfeat_new = scatter_mean_2d(vfeat, cc) 415 | edges_new = cc[edges].contiguous() 416 | edges_new = edges_new[:, fast_uidx_long2d(edges_new)] 417 | lab_new = cc.unique() 418 | nnz_new = len(lab_new) 419 | sizes_new = scatter_add_1d(sizes, cc, nnz_new) 420 | bbox_new = None 421 | if bbox is not None: 422 | bbox_new = bbox.new_zeros(4, nnz_new) 423 | bbox_new[0].scatter_reduce_(0, cc, bbox[0], 'amin', include_self=False) 424 | bbox_new[1].scatter_reduce_(0, cc, bbox[1], 'amin', include_self=False) 425 | bbox_new[2].scatter_reduce_(0, cc, bbox[2], 'amax', include_self=False) 426 | bbox_new[3].scatter_reduce_(0, cc, bbox[3], 'amax', include_self=False) 427 | 428 | return lab_new, edges_new, vfeat_new, sizes_new, nnz_new, cc, bbox_new 429 | 430 | 431 | def _aggstep(cc, edges, vfeat, sizes, nnz, bbox=None, do_feat=True): 432 | vfeat_new = None 433 | if do_feat: 434 | vfeat_new = scatter_mean_2d(vfeat, cc) 435 | edges_new = cc[edges] 436 | edges_new = edges_new[:, fast_uidx_long2d(edges_new)].contiguous() 437 | nnz_new = cc.max().item() + 1 438 | lab_new = torch.arange(nnz_new, device=cc.device) 439 | sizes_new = scatter_add_1d(sizes, cc, nnz_new) 440 | bbox_new = None 441 | if bbox is not None: 442 | bbox_new = bbox.new_zeros(4, nnz_new) 443 | bbox_new[0].scatter_reduce_(0, cc, bbox[0], 'amin', include_self=False) 444 | bbox_new[1].scatter_reduce_(0, cc, bbox[1], 'amin', include_self=False) 445 | bbox_new[2].scatter_reduce_(0, cc, bbox[2], 'amax', include_self=False) 446 | bbox_new[3].scatter_reduce_(0, cc, bbox[3], 'amax', include_self=False) 447 | return lab_new, edges_new, vfeat_new, sizes_new, nnz_new, cc, bbox_new 448 | 449 | 450 | def _finalstep(th, edges, vfeat, sizes, nnz, bbox, tpb=1024): 451 | u, v = edges 452 | diff = vfeat[u] - vfeat[v] 453 | mask = (diff.max(-1).values - diff.min(-1).values).abs() < th 454 | cc = connected_components(u[mask], v[mask], nnz, tpb=tpb) 455 | return _aggstep(cc, edges, vfeat, sizes, nnz, bbox) 456 | 457 | 458 | def superpixel_tokenizer( 459 | img:torch.Tensor, lambda_grad:float, lambda_col:float, drop_delta:bool, 460 | bbox_reg:bool=False, debug:bool=False, maxlvl:int=4, tpb:int=1024, 461 | final_th:float=0.0, deactivate_in1k_unnorm:bool=False 462 | ): 463 | '''Superpixel tokenizer and feature preprocessor. 464 | 465 | Args: 466 | img (torch.Tensor): Image of shape [B, 3, H, W]. 467 | lambda_grad (float): Lambda for gradient. 468 | lambda_col (float): Lambda for color. 469 | drop_delta (float): Drop discrete gradient features. 470 | bbox_reg: (bool): Whether to use bounding box compactness regularization. 471 | debug (bool): Print debug info for superpixel iterations. 472 | maxlvl (int): Max number of levels. Defaults to 4. 473 | tbp (int): Threads per block for cuda computations. 474 | final_th (float): Use final thresholding. 475 | deactivate_in1k_unnorm (bool): Force tokenizer to ignore normalization. Defaults to False. 476 | ''' 477 | if not deactivate_in1k_unnorm: 478 | if (img < 0).any(): 479 | img = _in1k_unnorm(img, 1) 480 | 481 | # Assert 3 channels 482 | assert img.shape[1] == 3 483 | 484 | # Init variables 485 | device = img.device 486 | batch_size, _, height, width = img.shape 487 | lab, edges, vfeat, feat, sizes, nnz, coords = init_img_graph( 488 | img, lambda_grad, lambda_col, drop_delta 489 | ) 490 | bbox = None 491 | if bbox_reg: 492 | bbox = torch.stack([coords[1], coords[2], coords[1], coords[2]], 0) 493 | 494 | hierograph = [lab] 495 | lvl = 0 496 | 497 | # If debug, check before main loop 498 | if debug: 499 | print(f"lvl:{lvl:3} nnz:{nnz:12} mu:{sizes.float().mean().item():8.9f} mem:{getmem(device):8}") 500 | 501 | # Main loop 502 | while lvl < maxlvl: 503 | 504 | lvl += 1 505 | lab, edges, vfeat, sizes, nnz, cc, bbox = spit_step( 506 | lab, edges, vfeat, sizes, nnz, lvl, bbox, tpb 507 | ) 508 | hierograph.append(cc) 509 | if debug: 510 | print(f"lvl:{lvl:3} nnz:{nnz:12} mu:{sizes.float().mean().item():8.9f} mem:{getmem(device):8}") 511 | 512 | if final_th > 0: 513 | 514 | lvl += 1 515 | lab, edges, vfeat, sizes, nnz, cc, bbox = _finalstep( 516 | final_th, edges, vfeat, sizes, nnz, bbox, tpb 517 | ) 518 | hierograph.append(cc) 519 | 520 | # Compile segmentation from hierarchy 521 | seg = get_hierarchy_level(hierograph, lvl) 522 | return feat, seg.view(batch_size, height, width), coords 523 | 524 | 525 | def random_rectangular_partitions( 526 | b:int, h:int, w:int, p_low:int, p_high:int, 527 | roll:bool=True, square:bool=True, 528 | device:torch.device=torch.device('cpu') 529 | ) -> torch.Tensor: 530 | '''Generates random square partitions of dimension BxHxW. 531 | 532 | Args: 533 | b (int): Batch size. 534 | h (int): Raster height. 535 | w (int): Raster width. 536 | p_low (int): Minimum partition size. 537 | p_high (int): Maximum partition size (inclusive). 538 | roll (bool): Flag for randomized rolling of rows and columns. 539 | square (bool): Flag for enforcing square partitions. 540 | device (torch.device): Output device. 541 | 542 | Returns: 543 | torch.Tensor: Square partition. 544 | 545 | ''' 546 | def _quickroll(A:torch.Tensor, shifts:tuple[torch.Tensor,torch.Tensor]): 547 | shape = A.shape 548 | for i in range(2): 549 | A = A.mT.reshape(-1, shape[-2 + i]) 550 | rng = torch.arange(shape[-2 + i], device=A.device).view(1,-1).expand_as(A) 551 | idx = (rng + shifts[i].view(-1,1)) % shape[-2 + i] 552 | A = A.gather(1, idx).view(*shape) 553 | return A 554 | 555 | ps_h, ps_w = torch.randint(p_low, p_high+1, (2,b), device=device) 556 | ps_w = ps_h if square else ps_w 557 | ceil_h, ceil_w = -(-h//ps_h), -(-w//ps_w) 558 | ceil_hw = ceil_h * ceil_w 559 | 560 | partition_ids = torch.arange(ceil_hw.sum().item(), device=device) 561 | bs = torch.arange(b, device=device).repeat_interleave(ceil_hw) 562 | idx = partition_ids - (ceil_hw.cumsum(-1) - ceil_hw)[bs] 563 | y, x = idx // ceil_w[bs], idx % ceil_w[bs] 564 | batched_x = x + (ceil_w.cumsum(-1) - ceil_w)[bs] 565 | 566 | psize_h = ( 567 | torch.stack([ps_h, h - (ceil_h - 1)*ps_h], -1) 568 | .view(-1) 569 | .repeat_interleave( 570 | torch.stack([ceil_h-1, ceil_h.new_ones(b)], -1) 571 | .view(-1) 572 | ) 573 | ) 574 | psize_w = ( 575 | torch.stack([ps_w, w - (ceil_w - 1)*ps_w], -1) 576 | .view(-1) 577 | .repeat_interleave( 578 | torch.stack([ceil_w-1, ceil_w.new_ones(b)], -1) 579 | .view(-1) 580 | ) 581 | ) 582 | 583 | out = ( 584 | partition_ids 585 | .repeat_interleave(psize_w[batched_x]) 586 | .view(-1, w) 587 | .repeat_interleave(psize_h, dim=0) 588 | .view(b, h, w) 589 | ) 590 | 591 | if not roll: 592 | return out 593 | 594 | roll_h = ( 595 | torch.rand_like(ceil_h.float()) 596 | .mul_(ceil_h) 597 | .long() 598 | .mul_(ps_h) 599 | .repeat_interleave(h) 600 | ) 601 | roll_w = ( 602 | torch.rand_like(ceil_w.float()) 603 | .mul_(ceil_w) 604 | .long() 605 | .mul_(ps_w) 606 | .repeat_interleave(w) 607 | ) 608 | return _quickroll(out, (roll_h, roll_w)) 609 | 610 | 611 | def diffmap_from_seg( 612 | seg:torch.Tensor, dim:int=1, dtype:torch.dtype=torch.float, 613 | vmin:float=-torch.pi/2, vmax:float=torch.pi/2 614 | ) -> torch.Tensor: 615 | '''Computes a difference map from a segmentation map. 616 | 617 | Args: 618 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 619 | dim (int): Dimension to stack differences along. 620 | dtype (torch.dtype): Output data type. 621 | vmin (float): Minimum value for normalization. 622 | vmax (float): Maximum value for normalization. 623 | 624 | Returns: 625 | torch.Tensor: Difference map of shape (B,2,H,W) for dim=1. 626 | ''' 627 | B,H,W = seg.shape 628 | dw = seg.diff(1,-1,prepend=seg.new_zeros(*([1]*seg.ndim)).expand(B,H,1)) != 0 629 | dh = seg.diff(1,-2,prepend=seg.new_zeros(*([1]*seg.ndim)).expand(B,1,W)) != 0 630 | d = torch.stack([dw,dh], dim=dim).to(dtype) 631 | if vmin != 0.0 or vmax != 1.0: 632 | return d.mul_(vmax-vmin).add_(vmin) 633 | return d 634 | 635 | 636 | def concatenate_diffmap( 637 | img:torch.Tensor, seg:torch.Tensor, vmin:float=-torch.pi/2, vmax:float=torch.pi/2 638 | ) -> torch.Tensor: 639 | '''Concatenates difference map to image. 640 | 641 | Args: 642 | img (torch.Tensor): Image of shape (B,C,H,W). 643 | seg (torch.Tensor): Segmentation map of shape (B,H,W). 644 | vmin (float): Minimum value for normalization. 645 | vmax (float): Maximum value for normalization. 646 | 647 | Returns: 648 | torch.Tensor: Image with difference map concatenated along channel dimension. 649 | ''' 650 | diffmap = diffmap_from_seg(seg, dim=1, dtype=img.dtype, vmin=vmin, vmax=vmax) 651 | return torch.cat([img, diffmap], 1) 652 | --------------------------------------------------------------------------------