├── .gitignore ├── LICENSE ├── Pipfile ├── README.md ├── setup.py ├── sgcn ├── __init__.py ├── masked │ ├── __init__.py │ ├── functions.py │ └── tensor.py └── nn │ ├── __init__.py │ ├── affinity.py │ ├── attention.py │ └── normalization.py └── tests ├── conftest.py ├── masked ├── test_funtions.py └── test_tensor.py └── nn ├── test_affinity.py └── test_attention.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | sgcn = {editable = true, path = "."} 8 | 9 | [dev-packages] 10 | pytest = "*" 11 | mock = "*" 12 | 13 | [requires] 14 | python_version = "3.7" 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse graph convolution network 2 | 3 | Attention model is based on: 4 | 5 | Vaswani, Ashish, et al. "[Attention is all you need.](http://papers.nips.cc/paper/7181-attention-is-all-you-need)" Advances in Neural Information Processing Systems. 2017. 6 | 7 | Graph Attention Network: 8 | 9 | Veličković, Petar, et al. "[Graph Attention Networks.](https://arxiv.org/abs/1710.10903)" arXiv preprint arXiv:1710.10903 (2017). 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """Setup file. 4 | 5 | Installs the modules. 6 | """ 7 | 8 | import os 9 | from setuptools import setup, find_packages 10 | 11 | 12 | __currdir__ = os.getcwd() 13 | __readme__ = os.path.join(__currdir__, "README.md") 14 | 15 | 16 | install_requires = [ 17 | "attrs", 18 | "numpy", 19 | "torch" 20 | ] 21 | 22 | setup( 23 | name="sgcn", 24 | version="0.1", 25 | packages=find_packages(), 26 | author="CERC DS4DM", 27 | license="MIT", 28 | description="Sparse graph neural networks in PyTorch", 29 | long_description=open(__readme__).read(), 30 | install_requires=install_requires 31 | ) 32 | -------------------------------------------------------------------------------- /sgcn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /sgcn/masked/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ds4dm/sparse-gcn/801bb5fc3290ba64376416765e5b0b6c45b9a330/sgcn/masked/__init__.py -------------------------------------------------------------------------------- /sgcn/masked/functions.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """Functions for automatic differentiation of masked tensors.""" 4 | 5 | from typing import Tuple, Optional 6 | 7 | import torch 8 | from torch.autograd import Function 9 | 10 | 11 | class MatMul(Function): 12 | """See matmul().""" 13 | 14 | @staticmethod 15 | def forward( 16 | ctx, 17 | indices: torch.Tensor, 18 | values: torch.Tensor, 19 | size: torch.Size, 20 | B: torch.Tensor 21 | ) -> torch.Tensor: 22 | """Forward computation. 23 | 24 | Computes the matrix multiplication of `A` and `B`, where `A` is defined 25 | by the sparse tensor (`indices`, `values`, `shape`) 26 | 27 | Parameters 28 | ---------- 29 | indices 30 | Indices to define `A` according to Pytorch sparse tensor. 31 | values: 32 | Values to define `A` according to Pytorch sparse tensor. 33 | size: 34 | Size to define `A` according to Pytorch sparse tensor. 35 | B: 36 | Other matrix to multiply by. 37 | 38 | Returns 39 | ------- 40 | output 41 | The result of the sparse multiplication of `A` and `B`. 42 | 43 | """ 44 | A = torch.sparse_coo_tensor(indices, values, size) 45 | # Save indices because we don't want to rely on `A._indices()` 46 | ctx.save_for_backward(A, indices, B) 47 | return A @ B 48 | 49 | @ staticmethod 50 | def backward( 51 | ctx, grad_output: torch.Tensor 52 | ) -> Tuple[None, Optional[torch.Tensor], None, Optional[torch.Tensor]]: 53 | """Backward computation. 54 | 55 | Parameters 56 | ---------- 57 | grad_output: 58 | The gradient of some quantity with regard to the output of this 59 | function. 60 | 61 | Returns 62 | ------- 63 | outputs: 64 | The gradient wrt to each of the inputs of forward. None if does 65 | not exist or not required. 66 | 67 | """ 68 | A, indices, B = ctx.saved_tensors 69 | grad_A = grad_B = None 70 | 71 | if ctx.needs_input_grad[1]: 72 | grad_A = matmulmasked(grad_output, B.t(), indices) 73 | if ctx.needs_input_grad[3]: 74 | grad_B = A.t() @ grad_output 75 | 76 | return None, grad_A, None, grad_B 77 | 78 | 79 | # define function aliases, useful since Function.apply() 80 | # does not support named arguments 81 | def matmul( 82 | indices: torch.Tensor, 83 | values: torch.Tensor, 84 | size: torch.Size, 85 | B: torch.Tensor 86 | ) -> torch.Tensor: 87 | """Matrix multiplication with a sparse tensor. 88 | 89 | Computes the matrix multiplication of `A` and `B`, where `A` is defined 90 | by the sparse tensor (`indices`, `values`, `shape`) 91 | 92 | Parameters 93 | ---------- 94 | indices 95 | Indices to define `A` according to Pytorch sparse tensor. 96 | values: 97 | Values to define `A` according to Pytorch sparse tensor. 98 | size: 99 | Size to define `A` according to Pytorch sparse tensor. 100 | B: 101 | Other matrix to multiply by. 102 | 103 | Returns 104 | ------- 105 | output 106 | The result of the sparse multiplication of `A` and `B`. 107 | 108 | """ 109 | return MatMul.apply(indices, values, size, B) 110 | 111 | 112 | def matmulmasked( 113 | A: torch.Tensor, B: torch.Tensor, indices: torch.Tensor 114 | ) -> torch.Tensor: 115 | """Matrix multiplication and mask. 116 | 117 | This function computes `(A @ B) * m` in a masked way. Only the values with 118 | a positive mask are computed by levraging sparse computations. 119 | Note that this function yields a small numeric difference from its 120 | equivalent dense version. 121 | 122 | Parameters 123 | ---------- 124 | A 125 | Tensor of size (n, p) 126 | B 127 | Tensor of size (p, m) 128 | indices 129 | The mask defining the computation to do. Has to be given according to 130 | Pytorch indices convention for sparse tensors. 131 | 132 | Returns 133 | ------- 134 | values 135 | The values of `A @ B` associated with the `indices` passed. 136 | 137 | """ 138 | idx, jdx = indices 139 | A_values = A.index_select(0, idx) 140 | B_values = B.t().index_select(0, jdx) 141 | 142 | AB_values = torch.bmm( 143 | A_values.view(A_values.size(0), 1, -1), 144 | B_values.view(B_values.size(0), -1, 1) 145 | ).view(-1) 146 | 147 | return AB_values 148 | -------------------------------------------------------------------------------- /sgcn/masked/tensor.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | """Module for the MaskedTensor class.""" 5 | 6 | # FIXME: Pyflakes flaging unquoted forward references 7 | from __future__ import annotations 8 | 9 | from typing import Optional, Union, Callable 10 | 11 | import attr 12 | import torch 13 | 14 | from . import functions as func 15 | 16 | 17 | @attr.s(auto_attribs=True, frozen=True, cmp=False) 18 | class MaskedTensor: 19 | """MaskedTensor class. 20 | 21 | A class define a sparse Tensor where the sparsity pattern represent 22 | the sparsity patter represent the only existing element. 23 | This would be similar to have a mask (all the indices not provided) of 24 | zeros applied after each operations. 25 | 26 | For instance, the gradient with reagrd to this tensor would not have the 27 | same sparsity pattern, even though the gradient wrt a sparse tensor isn't 28 | necessarily sparse. 29 | 30 | This class is represented as a hybrid COO sparse tensor. Sparse dimesnions 31 | are the firs ones and dense dimensions are the last ones. 32 | 33 | The tensor must always be caolesced, i.e. it should not contain duplicate 34 | coordinates, otherwise many opertations would not give the right results. 35 | This is not enforsed explicitly. 36 | 37 | Why not differentiate through PyTorch sparse tensors? 38 | As of PyTorch version 0.4.1, it is difficlut to enforce that the values 39 | stay in the same order. More especially, coalescing doesn't always put the 40 | indices in the same order (example on cpu, using `transpose` then 41 | `coalesce`). Being transparent on the indices ensure that no permuation is 42 | done without explicit knowledge. 43 | """ 44 | 45 | indices: torch.Tensor = attr.ib(converter=lambda x: x.long()) 46 | values: torch.Tensor = attr.ib() 47 | shape: torch.Size = attr.ib(converter=torch.Size) 48 | dtype: torch.dtype = attr.ib() 49 | device: torch.device = attr.ib(converter=torch.device) 50 | 51 | @shape.default 52 | def _shape_default(self): 53 | sparse_dims, _ = self.indices.max(1) 54 | dense_dims = self.values.size()[1:] 55 | return torch.Size(1 + sparse_dims) + dense_dims 56 | 57 | @dtype.default 58 | def _dtype_default(self): 59 | return self.values.dtype 60 | 61 | @device.default 62 | def _device_default(self): 63 | return self.values.device 64 | 65 | @indices.validator 66 | def _check_indices(self, attribute, val): 67 | if len(val.size()) != 2: 68 | raise ValueError("Indices must have two dimensions.") 69 | 70 | @values.validator 71 | def _check_values(self, attribute, val): 72 | if self.indices.size(1) != val.size(0): 73 | raise ValueError("Indices and values must have same nnz.") 74 | 75 | def __attrs_post_init__(self): 76 | """Initialize after attr. 77 | 78 | Moves `indices` and `values` to the correct `device` and convert 79 | `values` to the desired type. 80 | """ 81 | # Actualize the device and dtype of indices and values 82 | indices = self.indices.to(device=self.device) 83 | values = self.values.to(device=self.device, dtype=self.dtype) 84 | # Necessary to use this to overcome the frozen aspect of the class 85 | object.__setattr__(self, "indices", indices) 86 | object.__setattr__(self, "values", values) 87 | 88 | @classmethod 89 | def from_sparse(ctx, tensor: torch.Tensor) -> "MaskedTensor": 90 | """Build from a torch sparse tensor. 91 | 92 | This function is not diffrentiable. 93 | """ 94 | return ctx( 95 | indices=tensor._indices(), 96 | values=tensor._values(), 97 | shape=tensor.shape, 98 | dtype=tensor.dtype, 99 | device=tensor.device 100 | ) 101 | 102 | def to_sparse(self) -> torch.Tensor: 103 | """Return the object as a torch sparse Tensor. 104 | 105 | This method is not diffrentiable. 106 | """ 107 | return torch.sparse_coo_tensor( 108 | indices=self.indices, 109 | values=self.values, 110 | size=self.shape, 111 | dtype=self.dtype, 112 | device=self.device 113 | ) 114 | 115 | def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: 116 | """Shape of the MaskedTensor, or one of the specic dimension.""" 117 | if dim is None: 118 | return self.shape 119 | else: 120 | return self.shape[dim] 121 | 122 | @property 123 | def sparse_dims(self) -> int: 124 | """Return the number of sparse dimensions.""" 125 | return self.indices.size(0) 126 | 127 | @property 128 | def dense_dims(self) -> int: 129 | """Return the number of dense dimensions.""" 130 | return len(self.values.shape) - 1 131 | 132 | @property 133 | def dims(self) -> int: 134 | """Return the total number of dimensions.""" 135 | return len(self.shape) 136 | 137 | def with_values(self, values: torch.Tensor) -> "MaskedTensor": 138 | """Return a new MaskedTensor with different values. 139 | 140 | The sparsity pattern, device, and dtype are preserved. 141 | 142 | Parameters 143 | ---------- 144 | values: 145 | New values to use. Must have the same dimensions across the 146 | first dimension as the previous values. 147 | 148 | Returns 149 | ------- 150 | output: 151 | The output of the Tensor is wrapped in a new MaskedTensor, hence 152 | preserving the sparsity pattern. 153 | 154 | """ 155 | # We have to compute the shape here because new values may not have the 156 | # same dimensions, and sparse dimensions might have been set to 157 | # something else than the default. 158 | # Shape is shape sparse + new shae dense. 159 | shape = self.shape[:self.sparse_dims] + values.size()[1:] 160 | return MaskedTensor( 161 | indices=self.indices, 162 | values=values, 163 | shape=shape, 164 | dtype=self.dtype, 165 | device=self.device 166 | ) 167 | 168 | def apply( 169 | self, func: Callable[[torch.Tensor], torch.Tensor] 170 | ) -> "MaskedTensor": 171 | """Apply a function on all the values. 172 | 173 | Does not change the sparsity pattern. 174 | 175 | Parameters 176 | ---------- 177 | func: 178 | Takes as input the values and returns a new Tensor. The function 179 | must preseve the first axis. 180 | 181 | Returns 182 | ------- 183 | output: 184 | The output of the Tensor is wrapped in a new MaskedTensor, hence 185 | preserving the sparsity pattern. 186 | 187 | """ 188 | return self.with_values(func(self.values)) 189 | 190 | def transpose(self, dim1: int, dim2: int) -> "MaskedTensor": 191 | """Transpose two dimensions.""" 192 | # Transposing two sparse dimensions by permuting axis 193 | if dim1 < self.sparse_dims and dim2 < self.sparse_dims: 194 | idx = torch.arange(self.indices.size(0)) 195 | # Using item() because this gets a view of the tensor. 196 | idx[dim1], idx[dim2] = idx[dim2].item(), idx[dim1].item() 197 | # Using tensor to permute indexes of size 198 | shape = torch.Size(torch.tensor(self.shape).index_select(0, idx)) 199 | shape += self.shape[self.sparse_dims:] 200 | return MaskedTensor( 201 | indices=self.indices.index_select(0, idx.to(self.device)), 202 | values=self.values, 203 | shape=shape, 204 | dtype=self.dtype, 205 | device=self.device 206 | ) 207 | 208 | # Transposing two dense dimensions by simple transposition 209 | elif dim1 >= self.sparse_dims and dim2 >= self.sparse_dims: 210 | offset = self.sparse_dims - 1 211 | values = self.values.transpose(dim1 - offset, dim2 - offset) 212 | return self.with_values(values) 213 | 214 | else: 215 | raise RuntimeError("Transposing sparse and dense dims.") 216 | 217 | def t(self) -> "MaskedTensor": 218 | """Transpose a matrix.""" 219 | if self.dims == 2: 220 | return self.transpose(0, 1) 221 | else: 222 | raise RuntimeError("t() is valid for matrices only.") 223 | 224 | def sum( 225 | self, dim: Optional[int] = None, keepdim: bool = False 226 | ) -> Union[torch.Tensor, "MaskedTensor"]: 227 | """Sum across one or all dimensions. 228 | 229 | Parameters 230 | ---------- 231 | dim: 232 | Dimension to sum across, or None for all dimensions. 233 | keepdim: 234 | Whether or not to keep the original number of dimensions (the 235 | dimension summed over gets length one). Does nothing if no 236 | axis is specified. 237 | 238 | Returns 239 | ------- 240 | output 241 | A new tensor, dense or masked depending on the sumation dimensions. 242 | 243 | """ 244 | if dim is None: 245 | return self.values.sum() 246 | elif dim < self.sparse_dims: 247 | # Cannot remove dim and use coalesce because it may permutes 248 | # the values 249 | raise NotImplementedError("Summation over sparse dimension.") 250 | else: 251 | return self.with_values( 252 | self.values.sum(dim - self.sparse_dims + 1, keepdim=keepdim)) 253 | 254 | def mm(self, other: torch.Tensor) -> torch.Tensor: 255 | """Perform matrix matrix multiplication. 256 | 257 | Supports self as sparse matrix (2 dense dims) and other as dense matrix 258 | (2 dims) with matching shapes. 259 | 260 | Parameters 261 | ---------- 262 | other: 263 | Dense matrix left multiplied by this tensor. 264 | 265 | Returns 266 | ------- 267 | output: 268 | The dense result of the `self` `@` `other`. 269 | 270 | """ 271 | return func.matmul(self.indices, self.values, self.shape, other) 272 | 273 | def mv(self, other: torch.Tensor) -> torch.Tensor: 274 | """Perform matrix vector multiplication. 275 | 276 | Supports self as sparse matrix (2 dense dims) and other as dense vector 277 | (1 dim) with matching shapes. 278 | 279 | Parameters 280 | ---------- 281 | other: 282 | Dense vector left multiplied by this tensor. 283 | 284 | Returns 285 | ------- 286 | output: 287 | The dense result of the `self` `@` `other`. 288 | 289 | """ 290 | return self.mm(other.unsqueeze(-1)).squeeze(-1) 291 | 292 | def mask_mm(self, A: torch.Tensor, B: torch.Tensor) -> "MaskedTensor": 293 | """Compute the masked matrix multiplication. 294 | 295 | Computes `A @ B` on;y for the indices provided in the current 296 | MaskedTensor. 297 | 298 | Parameters 299 | ---------- 300 | A: 301 | Left dense matrix (two dimensions only). 302 | B: 303 | Right dense matrix (two dimensions only). 304 | 305 | Returns 306 | ------- 307 | ouput: 308 | The masked result of `A @ B`. 309 | 310 | """ 311 | return self.with_values(func.matmulmasked(A, B, self.indices)) 312 | -------------------------------------------------------------------------------- /sgcn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | from .attention import Attention, MultiHeadAttention 5 | from .affinity import Affinity, DotProduct 6 | from .normalization import Normalization, NoNorm 7 | 8 | 9 | __all__ = [ 10 | "Attention", "MultiHeadAttention", 11 | "Affinity", "DotProduct", 12 | "Normalization", "NoNorm" 13 | ] 14 | -------------------------------------------------------------------------------- /sgcn/nn/affinity.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """Affinities classes.""" 4 | 5 | from typing import Optional, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from math import sqrt 10 | 11 | from sgcn.masked.tensor import MaskedTensor 12 | 13 | 14 | class Affinity(nn.Module): 15 | """Affinity base class. 16 | 17 | An affinity class implements a function `forward` with three parameters: 18 | the attention keys, the attention queries and some optional query-keys 19 | specifics (e.g a mask). 20 | """ 21 | 22 | pass 23 | 24 | 25 | class DotProduct(Affinity): 26 | """Dot product attention. 27 | 28 | The attetnion between a key and a value are computed using their dot 29 | product. 30 | """ 31 | 32 | def __init__(self, scaled: bool = True) -> None: 33 | """Initialize affinity. 34 | 35 | Parameters 36 | ---------- 37 | scaled: 38 | Wether or nor to scale the attention weights by the inverse 39 | of the square root of the key dimension. 40 | 41 | """ 42 | super().__init__() 43 | self.scaled = scaled 44 | 45 | def forward( 46 | self, 47 | Q: torch.Tensor, 48 | K: torch.Tensor, 49 | m: Optional[Union[torch.Tensor, MaskedTensor]] = None 50 | ) -> Union[torch.Tensor, MaskedTensor]: 51 | """Compute dot-product affinity. 52 | 53 | Parameters 54 | ---------- 55 | Q 56 | Queries tensor with dimension (n_queries, feat_dim). 57 | K 58 | Keys tensor with dimensions (n_keys, feat_dim). 59 | m 60 | Optional mask. If the mask is given in dense form it is applied 61 | by elementwise multiplication. If given in sparse form, only 62 | the non zero values will be computed. 63 | 64 | Returns 65 | ------- 66 | affinity 67 | The matrice of attention weights first dimension is query 68 | index, second dimension is key index. If a mask is given by the 69 | type MaskedTensor, the results is also MaskedTensor, otherwise 70 | the result is dense. 71 | 72 | """ 73 | if isinstance(m, MaskedTensor): 74 | QKt = m.mask_mm(Q, K.t()) 75 | if self.scaled: 76 | QKt = QKt.apply(lambda x: x / sqrt(K.size(1))) 77 | 78 | else: 79 | QKt = Q @ K.t() 80 | QKt = QKt if m is None else QKt * m 81 | if self.scaled: 82 | QKt = QKt / sqrt(K.size(1)) 83 | 84 | return QKt 85 | -------------------------------------------------------------------------------- /sgcn/nn/attention.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """Attention module. 4 | 5 | Attention as defined in Attention is All you Need. 6 | https://arxiv.org/abs/1706.03762 7 | """ 8 | 9 | from typing import Union, Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from sgcn.masked.tensor import MaskedTensor 15 | from . import affinity as aff 16 | from . import normalization as norm 17 | 18 | 19 | class Attention(nn.Module): 20 | """Attention. 21 | 22 | TODO: if necessary make a batched version of this. Note batch of different 23 | sizes can already be done using a block diagonal mask. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | affinity: aff.Affinity, 29 | normalization: norm.Normalization 30 | ) -> None: 31 | """Initialize the Attention. 32 | 33 | Parameters 34 | ---------- 35 | affinity 36 | Object of type Affinity to compute the affinity between keys 37 | and attentions queries. 38 | normalization 39 | Object of type Normalization to apply a correction to the 40 | attention weights. 41 | 42 | """ 43 | super().__init__() 44 | self.affinity = affinity 45 | self.normalization = normalization 46 | 47 | def forward( 48 | self, 49 | K: torch.Tensor, 50 | V: torch.Tensor, 51 | Q: torch.Tensor, 52 | m: Optional[Union[torch.Tensor, MaskedTensor]] = None 53 | ) -> torch.Tensor: 54 | """Compute attention. 55 | 56 | Accoring to _Attention is All you Need_: 57 | > An attention function can be described as mapping a query and a set 58 | > of key-value pairs to an output, where the query, keys, values, and 59 | > output are all vectors. The output is computed as a weighted sum of 60 | > the values, where the weight assigned to each value is computed by a 61 | > compatibility function of the query with the corresponding key. 62 | https://arxiv.org/abs/1706.03762 63 | 64 | Parameters 65 | ---------- 66 | K: 67 | Attention keys. First dimension is key index, other are feature 68 | values. 69 | V: 70 | Attention values. First dimension is the value index. There 71 | should be as many attention values as their are keys. 72 | Q: 73 | Queries to make on attention keys. 74 | m: 75 | A matrix of dimension number of queries per number of keys. 76 | Passed to the affinity function. Can be used to make a mask 77 | or to pass additional queries data (e.g. edge information for 78 | a graph). 79 | 80 | Returns 81 | ------- 82 | attention: 83 | First dimension is align with queries indexes. Other dimensions are 84 | similar to the value ones. 85 | 86 | """ 87 | QKt = self.affinity(Q, K, m) 88 | QKt_n = self.normalization(QKt) 89 | if isinstance(QKt_n, MaskedTensor): 90 | return QKt_n.mm(V) 91 | else: 92 | return QKt_n @ V 93 | 94 | 95 | class MultiHeadAttention(Attention): 96 | """Dot product attention with multiple heads. 97 | 98 | Linearly project the keys, values, and queries and applies dot product 99 | attention to the result. This process is repeated as many times as there 100 | are heads, and the results are concatenated together. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | in_key: int, 106 | in_value: int, 107 | in_query: int, 108 | n_head: int, 109 | head_qk: int, 110 | head_v: int 111 | ) -> None: 112 | """Initialize multi head attention. 113 | 114 | Parameters 115 | ---------- 116 | in_key: 117 | Dimension of input keys. 118 | in_value: 119 | Dimension of input values. 120 | in_query: 121 | Dimension of input queries. 122 | n_head: 123 | Number of heads to use. 124 | head_qk: 125 | Dimension every projected head for queries and keys. They share the 126 | Same dimension as the affinity is computed through dot product. 127 | head_v: 128 | Dimension every projected head for values. 129 | 130 | """ 131 | super().__init__( 132 | affinity=aff.DotProduct(), normalization=norm.NoNorm() 133 | ) 134 | self.lin_k = nn.Linear(in_key, head_qk * n_head) 135 | self.lin_v = nn.Linear(in_value, head_v * n_head) 136 | self.lin_q = nn.Linear(in_query, head_qk * n_head) 137 | self._n_head = n_head 138 | 139 | def _view_heads(self, X: torch.Tensor) -> torch.Tensor: 140 | """Reshape output of Linear by number of heads.""" 141 | if X.dim() == 2: 142 | out_dim = X.size(1) 143 | return X.view(-1, self._n_head, out_dim // self._n_head) 144 | else: 145 | raise RuntimeError( 146 | f"Only dimension 2 supported, recieved: {X.dim()}" 147 | ) 148 | 149 | def forward( 150 | self, 151 | K: torch.Tensor, 152 | V: torch.Tensor, 153 | Q: torch.Tensor, 154 | m: Optional[Union[torch.Tensor, MaskedTensor]] = None 155 | ) -> torch.Tensor: 156 | """Compute attention. 157 | 158 | Parameters 159 | ---------- 160 | K: 161 | Attention keys. First dimension is key index, other are feature 162 | values. 163 | V: 164 | Attention values. First dimension is the value index. There 165 | should be as many attention values as their are keys. 166 | Q: 167 | Queries to make on attention keys. 168 | m: 169 | A matrix of dimension number of queries per number of keys. 170 | Passed to the affinity function. Can be used to make a mask 171 | or to pass additional queries data (e.g. edge information for 172 | a graph). 173 | 174 | Returns 175 | ------- 176 | attention: 177 | First dimension is align with queries indexes. Second dimension is 178 | the number of heads times the output dimension of one value head 179 | (`head_v`). 180 | 181 | """ 182 | K_proj = self._view_heads(self.lin_k(K)) 183 | V_proj = self._view_heads(self.lin_v(V)) 184 | Q_proj = self._view_heads(self.lin_q(Q)) 185 | 186 | V_out = [] 187 | for k in range(self._n_head): 188 | V_out.append(super().forward( 189 | K=K_proj[:, k], V=V_proj[:, k], Q=Q_proj[:, k], m=m 190 | )) 191 | 192 | return torch.cat(V_out, dim=1) 193 | -------------------------------------------------------------------------------- /sgcn/nn/normalization.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """Normalization classes.""" 4 | 5 | from typing import Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from sgcn.masked.tensor import MaskedTensor 11 | 12 | 13 | class Normalization(nn.Module): 14 | """Normalization base class. 15 | 16 | A normalization class implements a function `forward` with one parameter: 17 | the unnormalized attention weights for every query over every values. 18 | """ 19 | 20 | pass 21 | 22 | 23 | class NoNorm(Normalization): 24 | """No normalization class.""" 25 | 26 | def forward( 27 | self, QKt: Union[torch.Tensor, MaskedTensor] 28 | ) -> Union[torch.Tensor, MaskedTensor]: 29 | """Identity function.""" 30 | return QKt 31 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | """Pytest objects.""" 4 | 5 | import pytest 6 | import torch 7 | 8 | 9 | @pytest.fixture(params=["cpu", "cuda"]) 10 | def device(request): 11 | """Device to run test on.""" 12 | _device = torch.device(request.param) 13 | if _device.type == "cuda" and not torch.cuda.is_available(): 14 | pytest.skip() 15 | return _device 16 | -------------------------------------------------------------------------------- /tests/masked/test_funtions.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from torch.autograd import gradcheck 7 | 8 | import sgcn.masked.functions as F 9 | 10 | 11 | def _allclose(A, B): 12 | return np.allclose(A.detach().cpu(), B.detach().cpu()) 13 | 14 | 15 | @pytest.fixture 16 | def sparse_matrix_data(device): 17 | idx = torch.tensor( 18 | [[1, 1, 9, 4], [2, 3, 0, 4]], dtype=torch.int64, device=device) 19 | values = torch.rand(4, device=device) 20 | return idx, values, (13, 7) 21 | 22 | 23 | def test_matmul_forward(sparse_matrix_data, device): 24 | A = torch.sparse_coo_tensor(*sparse_matrix_data) 25 | B = torch.rand(7, 5, device=device) 26 | AB_tested = F.matmul(*sparse_matrix_data, B) 27 | AB_expected = A @ B 28 | 29 | assert _allclose(AB_tested, AB_expected) 30 | 31 | 32 | def test_matmul_grad(sparse_matrix_data, device): 33 | idx, values, size = sparse_matrix_data 34 | B = torch.rand(7, 5, device=device) 35 | 36 | # gradcheck requires double precision 37 | values = values.double().requires_grad_() 38 | B = B.double().requires_grad_() 39 | 40 | assert gradcheck(F.matmul, (idx, values, size, B)) 41 | 42 | 43 | def test_matmulmasked_forward(sparse_matrix_data, device): 44 | A = torch.rand(13, 5, device=device) 45 | B = torch.rand(5, 7, device=device) 46 | indices, _, _ = sparse_matrix_data 47 | 48 | AB_values_tested = F.matmulmasked(A, B, indices) 49 | AB_alues_expected = (A @ B)[tuple(indices)] 50 | assert _allclose(AB_values_tested, AB_alues_expected) 51 | 52 | 53 | def test_matmulmasked_grad(sparse_matrix_data, device): 54 | # gradcheck requires double precision 55 | A = torch.rand(13, 5, device=device).double().requires_grad_() 56 | B = torch.rand(5, 7, device=device).double().requires_grad_() 57 | indices, _, _ = sparse_matrix_data 58 | 59 | assert gradcheck(F.matmulmasked, (A, B, indices)) 60 | -------------------------------------------------------------------------------- /tests/masked/test_tensor.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from mock import MagicMock 7 | 8 | from sgcn.masked.tensor import MaskedTensor 9 | 10 | 11 | def _allclose(A, B): 12 | return np.allclose(A.detach().cpu(), B.detach().cpu()) 13 | 14 | 15 | @pytest.fixture 16 | def data(device): 17 | idx = torch.tensor( 18 | [[1, 1, 9, 4], [2, 3, 0, 4]], dtype=torch.int64, device=device) 19 | values = torch.rand(4, 3, 9, device=device) 20 | return idx, values, (13, 7, 3, 9) 21 | 22 | 23 | @pytest.fixture 24 | def maskedtensor(data): 25 | return MaskedTensor(*data) 26 | 27 | 28 | def test_initialization(data): 29 | idx, val, shape = data 30 | 31 | # shape 32 | assert MaskedTensor(idx, val).shape == (10, 5, 3, 9) 33 | assert MaskedTensor(idx, val, shape).shape == shape 34 | assert isinstance(MaskedTensor(idx, val).shape, torch.Size) 35 | # dtype 36 | assert MaskedTensor(idx, val).dtype == val.dtype 37 | assert MaskedTensor(idx, val, dtype=torch.int).dtype == torch.int 38 | # device 39 | assert isinstance( 40 | MaskedTensor(idx, val, device="cpu").device, torch.device) 41 | # These is a test when fixture device is cuda. 42 | assert MaskedTensor(idx, val, device="cpu").device.type == "cpu" 43 | assert MaskedTensor(idx.cpu(), val).device == val.device 44 | # indices 45 | assert MaskedTensor(idx.float(), val).indices.dtype == idx.dtype 46 | with pytest.raises(ValueError): 47 | MaskedTensor(idx.t(), val) 48 | # values 49 | with pytest.raises(ValueError): 50 | MaskedTensor(idx, val[1:]) 51 | 52 | 53 | def test_to_sparse(data): 54 | m = MaskedTensor(*data) 55 | s = torch.sparse_coo_tensor(*data) 56 | 57 | assert torch.is_tensor(m.to_sparse()) 58 | assert m.to_sparse().is_sparse 59 | assert _allclose(s.to_dense(), m.to_sparse().to_dense()) 60 | 61 | 62 | def test_from_sparse(data): 63 | s = torch.sparse_coo_tensor(*data) 64 | m = MaskedTensor.from_sparse(s) 65 | 66 | assert isinstance(m, MaskedTensor) 67 | # Should suceed is test_to_sparse succeed 68 | assert _allclose(s.to_dense(), m.to_sparse().to_dense()) 69 | 70 | 71 | def test_size(maskedtensor): 72 | assert maskedtensor.size() == maskedtensor.shape 73 | assert maskedtensor.size(3) == maskedtensor.shape[3] 74 | 75 | 76 | def test_dims(maskedtensor): 77 | assert maskedtensor.sparse_dims == 2 78 | assert maskedtensor.dense_dims == 2 79 | assert maskedtensor.dims == 4 80 | 81 | 82 | def test_with_values(maskedtensor): 83 | v = torch.rand(4, 13).to(maskedtensor.device) 84 | m = maskedtensor.with_values(v) 85 | 86 | assert isinstance(m, MaskedTensor) 87 | assert (m.values == v).all().item() 88 | assert m.shape == (13, 7, 13) 89 | assert m.dtype == maskedtensor.dtype 90 | assert m.device == maskedtensor.device 91 | assert m.indices is maskedtensor.indices 92 | 93 | 94 | def test_apply(maskedtensor): 95 | v = torch.rand(4, 13).to(maskedtensor.device) 96 | func = MagicMock() 97 | func.return_value = v 98 | m = maskedtensor.apply(func) 99 | 100 | assert isinstance(m, MaskedTensor) 101 | func.assert_called_once_with(maskedtensor.values) 102 | assert (m.values == v).all().item() 103 | assert m.indices is maskedtensor.indices 104 | assert (m.shape[:m.sparse_dims] 105 | == maskedtensor.shape[:maskedtensor.sparse_dims]) 106 | 107 | 108 | def test_sum(maskedtensor): 109 | assert (maskedtensor.sum() == maskedtensor.values.sum()).item() 110 | # sum on dense dim 111 | m = maskedtensor.sum(2) 112 | assert (m.values == maskedtensor.values.sum(1)).all().item() 113 | assert m.shape == (13, 7, 9) 114 | # keepdim 115 | assert maskedtensor.sum(2, True).shape == (13, 7, 1, 9) 116 | # not implemented yet 117 | with pytest.raises(NotImplementedError): 118 | maskedtensor.sum(1) 119 | 120 | 121 | def test_mm(maskedtensor, device): 122 | # mm only works for sparse matrices (2 sparse dims) 123 | maskedtensor = maskedtensor.with_values(maskedtensor.values[:, 0, 0]) 124 | other = torch.rand(7, 5, device=device) 125 | 126 | result = maskedtensor.mm(other) 127 | expected_result = maskedtensor.to_sparse() @ other 128 | assert torch.is_tensor(result) 129 | assert not result.is_sparse 130 | assert _allclose(result, expected_result) 131 | 132 | 133 | def test_mv(maskedtensor, device): 134 | # mv only works for sparse matrices (2 sparse dims) 135 | maskedtensor = maskedtensor.with_values(maskedtensor.values[:, 0, 0]) 136 | other = torch.rand(7, device=device) 137 | 138 | result = maskedtensor.mv(other) 139 | expected_result = (maskedtensor.to_sparse().to_dense() @ other) 140 | assert torch.is_tensor(result) 141 | assert not result.is_sparse 142 | assert _allclose(result, expected_result) 143 | 144 | 145 | def test_mask_mm(maskedtensor, device): 146 | A = torch.rand(13, 5, device=device) 147 | B = torch.rand(5, 7, device=device) 148 | dense_mask = maskedtensor.with_values(torch.ones(4, device=device)) \ 149 | .to_sparse().to_dense() 150 | 151 | result = maskedtensor.mask_mm(A, B) 152 | expected_result = (A @ B) * dense_mask 153 | assert isinstance(result, MaskedTensor) 154 | assert result.shape == (13, 7) 155 | assert _allclose(result.to_sparse().to_dense(), expected_result) 156 | 157 | 158 | def test_transpose(maskedtensor): 159 | # sparse dims 160 | t = maskedtensor.transpose(0, 1) 161 | assert t.shape == (7, 13, 3, 9) 162 | assert (t.indices[0] == maskedtensor.indices[1]).all().item() 163 | assert (t.indices[1] == maskedtensor.indices[0]).all().item() 164 | 165 | # dense dims 166 | t = maskedtensor.transpose(2, 3) 167 | assert t.shape == (13, 7, 9, 3) 168 | assert (t.values == maskedtensor.values.transpose(1, 2)).all().item() 169 | 170 | # raises 171 | with pytest.raises(RuntimeError): 172 | maskedtensor.transpose(0, 2) 173 | 174 | 175 | def test_t(maskedtensor): 176 | with pytest.raises(RuntimeError): 177 | maskedtensor.t() 178 | 179 | # Make 2 dimensional 180 | maskedtensor = maskedtensor.with_values(maskedtensor.values[:, 0, 0]) 181 | assert maskedtensor.t().shape == (7, 13) 182 | assert (maskedtensor.t().indices 183 | == maskedtensor.transpose(0, 1).indices).all().item() 184 | -------------------------------------------------------------------------------- /tests/nn/test_affinity.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from sgcn.masked.tensor import MaskedTensor 8 | import sgcn.nn.affinity as aff 9 | 10 | 11 | def _allclose(A, B): 12 | return np.allclose(A.detach().cpu(), B.detach().cpu()) 13 | 14 | 15 | @pytest.fixture(params=["dense", "masked"]) 16 | def data(device, request): 17 | K = torch.rand(7, 3, device=device, requires_grad=True) 18 | V = torch.rand(7, 4, device=device, requires_grad=True) 19 | Q = torch.rand(4, 3, device=device, requires_grad=True) 20 | md = (torch.rand(4, 7, device=device, requires_grad=True) > .6).float() 21 | 22 | if request.param == "dense": 23 | return K, V, Q, md 24 | else: 25 | idx = md.nonzero() 26 | mm = MaskedTensor(idx.t(), torch.ones(len(idx), device=device), (4, 7)) 27 | return K, V, Q, mm 28 | 29 | 30 | def test_dotproduct_no_mask(data): 31 | K, _, Q, _ = data 32 | func = aff.DotProduct(scaled=False) 33 | coefs = func(Q, K) 34 | assert torch.is_tensor(coefs) 35 | assert _allclose(coefs, Q @ K.t()) 36 | 37 | 38 | def test_dotproduct(data): 39 | K, _, Q, m = data 40 | func = aff.DotProduct(scaled=False) 41 | coefs = func(Q, K, m) 42 | assert isinstance(coefs, m.__class__) 43 | 44 | # check if mask is respected 45 | if isinstance(m, MaskedTensor): 46 | dense_mask = m.to_sparse().to_dense() 47 | dense_coefs = coefs.to_sparse().to_dense() 48 | else: 49 | dense_mask = m 50 | dense_coefs = coefs 51 | 52 | checks = (dense_coefs * (1 - dense_mask)) > 0 53 | assert not checks.any() 54 | -------------------------------------------------------------------------------- /tests/nn/test_attention.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import mock 4 | import pytest 5 | import numpy as np 6 | import torch 7 | 8 | import sgcn 9 | from sgcn.masked.tensor import MaskedTensor 10 | 11 | 12 | def _allclose(A, B): 13 | return np.allclose(A.detach().cpu(), B.detach().cpu(), 1e-4, 1e-7) 14 | 15 | 16 | @pytest.fixture(params=["dense", "masked"]) 17 | def data(device, request): 18 | K = torch.rand(7, 3, device=device, requires_grad=True) 19 | V = torch.rand(7, 4, device=device, requires_grad=True) 20 | Q = torch.rand(4, 3, device=device, requires_grad=True) 21 | md = (torch.rand(4, 7, device=device, requires_grad=True) > .6).float() 22 | 23 | if request.param == "dense": 24 | return K, V, Q, md 25 | else: 26 | idx = md.nonzero() 27 | mm = MaskedTensor(idx.t(), torch.ones(len(idx), device=device), (4, 7)) 28 | return K, V, Q, mm 29 | 30 | 31 | def test_attention(data): 32 | K, V, Q, m = data 33 | # Mock module forward 34 | aff = type( 35 | "Affinity", 36 | (sgcn.nn.Affinity, ), 37 | {"forward": mock.MagicMock(return_value=m)} 38 | )() 39 | norm = type( 40 | "Normalization", 41 | (sgcn.nn.Normalization, ), 42 | {"forward": mock.MagicMock(return_value=m)} 43 | )() 44 | 45 | attention = sgcn.nn.Attention(affinity=aff, normalization=norm) 46 | attention(K, V, Q, m=m) 47 | 48 | aff.forward.assert_called_once_with(Q, K, m) 49 | norm.forward.assert_called_once_with(m) 50 | 51 | 52 | def test_multi_head_attention(data, device): 53 | K, V, Q, m = data 54 | att = sgcn.nn.MultiHeadAttention( 55 | in_key=3, in_value=4, in_query=3, 56 | n_head=13, head_qk=9, head_v=11 57 | ).to(device) 58 | out = att(K, V, Q, m=m) 59 | 60 | assert out.shape == (4, 13*11) 61 | if isinstance(m, MaskedTensor): 62 | out_d = att(K, V, Q, m=m.to_sparse().to_dense()) 63 | assert _allclose(out, out_d) 64 | 65 | 66 | def test_multi_head_attention_backward(data, device): 67 | K, V, Q, m = data 68 | att = sgcn.nn.MultiHeadAttention( 69 | in_key=3, in_value=4, in_query=3, 70 | n_head=13, head_qk=9, head_v=11 71 | ).to(device) 72 | out = att(K, V, Q, m=m) 73 | 74 | out.sum().backward() 75 | --------------------------------------------------------------------------------