├── .gitignore ├── LICENSE ├── README.md ├── performer ├── __init__.py ├── jax_kernel.py └── kernel.py ├── requirements.txt └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | **.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Henry Mao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Performer-Pytorch 2 | [WIP] 3 | 4 | Pytorch implementation of Performer self attention module from the paper ["Rethinking Attention with Performers"](https://arxiv.org/abs/2009.14794). 5 | 6 | This implementation is based on available code in https://github.com/google-research/google-research/tree/master/performer/fast_self_attention 7 | 8 | ## Test 9 | Run tests to compare against Google's implementation in JAX. 10 | ``` 11 | python test.py 12 | ``` 13 | Requires JAX to be installed. -------------------------------------------------------------------------------- /performer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calclavia/Performer-Pytorch/d729f91fbfdc9fcab10eed82cf66d439390a1301/performer/__init__.py -------------------------------------------------------------------------------- /performer/jax_kernel.py: -------------------------------------------------------------------------------- 1 | 2 | import abc 3 | from collections.abc import Iterable # pylint: disable=g-importing-member 4 | import functools 5 | from absl import logging 6 | import jax 7 | from jax import lax 8 | from jax import random 9 | import jax.numpy as jnp 10 | 11 | import numpy as onp 12 | 13 | 14 | def nonnegative_softmax_kernel_feature_creator(data, 15 | projection_matrix, 16 | attention_dims_t, 17 | batch_dims_t, 18 | precision, 19 | is_query, 20 | normalize_data=True, 21 | eps=0.0001): 22 | """Constructs nonnegative kernel features for fast softmax attention. 23 | Args: 24 | data: input for which features are computes 25 | projection_matrix: random matrix used to compute features 26 | attention_dims_t: tuple of attention dimensions 27 | batch_dims_t: tuple of batch dimensions 28 | precision: precision parameter 29 | is_query: predicate indicating whether input data corresponds to queries or 30 | keys 31 | normalize_data: predicate indicating whether data should be normalized, 32 | eps: numerical stabilizer. 33 | Returns: 34 | Random features for fast softmax attention. 35 | """ 36 | del attention_dims_t 37 | if normalize_data: 38 | # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where 39 | # w_norm = w * data_normalizer for w in {q,k}. 40 | data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) 41 | else: 42 | data_normalizer = 1.0 43 | ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) 44 | data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape 45 | data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix 46 | 47 | # print('data', data.shape) 48 | # print('data_thick_random_matrix', data_thick_random_matrix.shape) 49 | 50 | data_dash = lax.dot_general( 51 | data_normalizer * data, 52 | data_thick_random_matrix, 53 | (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), 54 | (batch_dims_t, batch_dims_t)), 55 | precision=precision) 56 | 57 | # print('data_dash', data_dash.shape) 58 | 59 | diag_data = jnp.square(data) 60 | diag_data = jnp.sum(diag_data, axis=data.ndim - 1) 61 | diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer 62 | # print('diag_data', diag_data.shape) 63 | diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) 64 | # print('diag_data', diag_data.shape) 65 | 66 | 67 | if is_query: 68 | last_dims_t = (len(data_dash.shape) - 1,) 69 | data_dash = ratio * ( 70 | jnp.exp(data_dash - diag_data - 71 | jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps) 72 | else: 73 | data_dash = ratio * ( 74 | jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps) 75 | 76 | return data_dash 77 | -------------------------------------------------------------------------------- /performer/kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | def nonnegative_softmax_kernel_feature_creator( 6 | data: torch.Tensor, 7 | projection_matrix: torch.Tensor, 8 | batch_dims_t, 9 | is_query: bool, 10 | normalize_data: bool=True, 11 | eps: float=0.0001): 12 | """ 13 | Constructs nonnegative kernel features for fast softmax attention. 14 | Args: 15 | data: input for which features are computes 16 | projection_matrix: random matrix used to compute features 17 | batch_dims_t: tuple of batch dimensions 18 | is_query: predicate indicating whether input data corresponds to queries or 19 | keys 20 | normalize_data: predicate indicating whether data should be normalized, 21 | eps: numerical stabilizer. 22 | Returns: 23 | Random features for fast softmax attention. 24 | """ 25 | if normalize_data: 26 | # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where 27 | # w_norm = w * data_normalizer for w in {q,k}. 28 | data_normalizer = 1.0 / (math.sqrt(math.sqrt(data.shape[-1]))) 29 | else: 30 | data_normalizer = 1.0 31 | 32 | ratio = 1.0 / math.sqrt(projection_matrix.shape[0]) 33 | data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape 34 | data_thick_random_matrix = torch.zeros(data_mod_shape) + projection_matrix 35 | 36 | 37 | data_dash = torch.matmul( 38 | data_normalizer * data, 39 | data_thick_random_matrix.transpose(-1, -2) 40 | ) 41 | 42 | diag_data = torch.square(data) 43 | diag_data = torch.sum(diag_data, dim=data.ndim - 1) 44 | diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer 45 | diag_data = diag_data.unsqueeze(-1) 46 | 47 | if is_query: 48 | data_dash = ratio * ( 49 | torch.exp(data_dash - diag_data - 50 | torch.max(data_dash, dim=-1, keepdim=True)[0]) + eps) 51 | else: 52 | data_dash = ratio * ( 53 | torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps) 54 | 55 | return data_dash 56 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from performer.jax_kernel import nonnegative_softmax_kernel_feature_creator as jax_kernel 2 | from performer.kernel import nonnegative_softmax_kernel_feature_creator as kernel 3 | 4 | import unittest 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class TestPerformer(unittest.TestCase): 11 | def test_kernel(self): 12 | for is_query in [False, True]: 13 | # Batch, seq len, d 14 | data = np.random.rand(2, 4, 8) 15 | # Batch, r, d 16 | projection_matrix = np.random.rand(3, 8) 17 | 18 | jax_output = jax_kernel( 19 | data, 20 | projection_matrix, 21 | attention_dims_t=None, 22 | batch_dims_t=[0], 23 | precision=None, 24 | is_query=is_query 25 | ) 26 | print('jax_output', jax_output.shape) 27 | print() 28 | 29 | output = kernel( 30 | torch.from_numpy(data), 31 | torch.from_numpy(projection_matrix), 32 | batch_dims_t=[0], 33 | is_query=is_query 34 | ) 35 | print('output', output.shape) 36 | assert jax_output.shape == output.shape, (jax_output.shape, output.shape) 37 | assert np.allclose(jax_output, output) 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | --------------------------------------------------------------------------------