├── .github └── workflows │ ├── build-wheels.yml │ └── test.yml ├── .gitignore ├── .mailmap ├── LICENSE ├── README.rst ├── docs ├── api_docs │ └── index.html ├── attention.md ├── builders.md ├── custom_attention_layer.md ├── events.md ├── feature_maps.md ├── index.md ├── masking.md ├── recurrent_transformers.md ├── tips_and_tricks.md └── transformers.md ├── fast_transformers ├── __init__.py ├── aggregate │ ├── __init__.py │ ├── aggregate_cpu.cpp │ ├── aggregate_cuda.cu │ └── clustered_aggregate_cuda.cu ├── attention │ ├── __init__.py │ ├── aft_attention.py │ ├── attention_layer.py │ ├── causal_linear_attention.py │ ├── clustered_attention.py │ ├── conditional_full_attention.py │ ├── exact_topk_attention.py │ ├── full_attention.py │ ├── improved_clustered_attention.py │ ├── improved_clustered_causal_attention.py │ ├── linear_attention.py │ ├── local_attention.py │ └── reformer_attention.py ├── attention_registry │ ├── __init__.py │ ├── registry.py │ └── spec.py ├── builders │ ├── __init__.py │ ├── attention_builders.py │ ├── base.py │ └── transformer_builders.py ├── causal_product │ ├── __init__.py │ ├── causal_product_cpu.cpp │ └── causal_product_cuda.cu ├── clustering │ ├── __init__.py │ └── hamming │ │ ├── __init__.py │ │ ├── cluster_cpu.cpp │ │ └── cluster_cuda.cu ├── events │ ├── __init__.py │ ├── event.py │ ├── event_dispatcher.py │ └── filters.py ├── feature_maps │ ├── __init__.py │ ├── base.py │ └── fourier_features.py ├── hashing │ ├── __init__.py │ ├── hash_cpu.cpp │ └── hash_cuda.cu ├── local_product │ ├── __init__.py │ ├── local_product_cpu.cpp │ └── local_product_cuda.cu ├── masking.py ├── recurrent │ ├── __init__.py │ ├── _utils.py │ ├── attention │ │ ├── __init__.py │ │ ├── cross_attention │ │ │ ├── __init__.py │ │ │ ├── attention_layer.py │ │ │ ├── full_attention.py │ │ │ └── linear_attention.py │ │ └── self_attention │ │ │ ├── __init__.py │ │ │ ├── attention_layer.py │ │ │ ├── full_attention.py │ │ │ └── linear_attention.py │ └── transformers.py ├── sparse_product │ ├── __init__.py │ ├── clustered_sparse_product_cpu.cpp │ ├── clustered_sparse_product_cuda.cu │ ├── sparse_product_cpu.cpp │ └── sparse_product_cuda.cu ├── transformers.py ├── utils.py └── weight_mapper.py ├── mkdocs.yml ├── setup.py ├── tests ├── __init__.py ├── aggregate │ ├── __init__.py │ ├── test_aggregate_cpu.py │ ├── test_aggregate_gpu.py │ ├── test_clustered_aggregate_cpu.py │ ├── test_clustered_aggregate_gpu.py │ ├── test_clustered_broadcast_cpu.py │ └── test_clustered_broadcast_gpu.py ├── attention │ ├── test_aft_attention.py │ ├── test_attention_layer.py │ ├── test_causal_linear_attention.py │ ├── test_clustered_transformer.py │ ├── test_clustered_transformer_gpu.py │ ├── test_full_attention.py │ ├── test_improved_clustered_transformer_gpu.py │ ├── test_linear_attention.py │ └── test_local_attention.py ├── causal_product │ ├── __init__.py │ ├── test_causal_product.py │ ├── test_causal_product_cpu.py │ └── test_causal_product_gpu.py ├── clustering │ ├── __init__.py │ └── hamming │ │ ├── __init__.py │ │ ├── test_cluster_cpu.py │ │ ├── test_cluster_gpu.py │ │ ├── test_python_api_gpu.py │ │ └── time_python_api_gpu.py ├── events │ ├── __init__.py │ ├── test_event_dispatcher.py │ ├── test_event_filters.py │ └── test_events.py ├── feature_maps │ ├── __init__.py │ └── test_fourier_features.py ├── hashing │ ├── __init__.py │ ├── test_hash_cpu.py │ └── test_hash_gpu.py ├── local_product │ ├── test_local_product_cpu.py │ └── test_local_product_cuda.py ├── recurrent │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── cross_attention │ │ │ ├── __init__.py │ │ │ ├── test_attention_layer.py │ │ │ ├── test_full_attention.py │ │ │ └── test_linear_attention.py │ │ └── self_attention │ │ │ ├── __init__.py │ │ │ ├── test_attention_layer.py │ │ │ ├── test_full_attention.py │ │ │ └── test_linear_attention.py │ ├── test_transformer_decoder.py │ └── test_transformer_encoder.py ├── sparse_product │ ├── __init__.py │ ├── test_clustered_sparse_product_backward_cpu.py │ ├── test_clustered_sparse_product_backward_cpu_v2.py │ ├── test_clustered_sparse_product_backward_gpu.py │ ├── test_clustered_sparse_product_cpu.py │ ├── test_clustered_sparse_product_cpu_v2.py │ ├── test_clustered_sparse_product_gpu.py │ ├── test_clustered_sparse_weighted_average_cpu.py │ ├── test_clustered_sparse_weighted_average_cpu_v2.py │ ├── test_clustered_sparse_weighted_average_gpu.py │ ├── test_sparse_product_backward_cpu.py │ ├── test_sparse_product_backward_gpu.py │ ├── test_sparse_product_cpu.py │ ├── test_sparse_product_gpu.py │ ├── test_sparse_weighted_average_cpu.py │ └── test_sparse_weighted_average_gpu.py ├── test_builders.py ├── test_masking.py ├── test_transformer_decoder.py ├── test_transformer_encoder.py └── test_weight_mapper.py └── tools.py /.github/workflows/build-wheels.yml: -------------------------------------------------------------------------------- 1 | # Build wheels for easier installation 2 | name: build-wheels 3 | 4 | on: 5 | push: 6 | tags: ['v*'] 7 | branches: [build-wheels] 8 | 9 | # Build a wheel 10 | jobs: 11 | build: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: [3.6, 3.7, 3.8] 16 | pytorch-version: [1.7, 1.8, 1.9] 17 | cuda-version: [10.2, 11.1] 18 | exclude: 19 | - pytorch-version: 1.7 20 | cuda-version: 11.1 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v2 24 | - run: echo "FAST_TRANSFORMERS_VERSION_SUFFIX=+pytorch${{ matrix.pytorch-version }}+cu${{ matrix.cuda-version }}" >> $GITHUB_ENV 25 | - run: | 26 | if [ "${{ matrix.cuda-version }}" == "10.2" ]; then 27 | sudo apt install -y gcc-8 g++-8 28 | echo "CC=gcc-8" >> $GITHUB_ENV 29 | echo "CXX=g++-8" >> $GITHUB_ENV 30 | sudo rm /usr/bin/gcc /usr/bin/g++ 31 | sudo ln -s /usr/bin/gcc-8 /usr/bin/gcc 32 | sudo ln -s /usr/bin/g++-8 /usr/bin/g++ 33 | wget --quiet https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run 34 | sudo sh cuda_10.2.89_440.33.01_linux.run --silent --toolkit 35 | echo "/usr/local/cuda-10.2/bin" >> $GITHUB_PATH 36 | echo "TORCH_CUDA_ARCH_LIST=6.0;6.1;6.2;7.0;7.2;7.5" >> $GITHUB_ENV 37 | elif [ "${{ matrix.cuda-version }}" == "11.1" ]; then 38 | wget --quiet https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run 39 | sudo sh cuda_11.1.1_455.32.00_linux.run --silent --toolkit 40 | echo "/usr/local/cuda-11.1/bin" >> $GITHUB_PATH 41 | echo "TORCH_CUDA_ARCH_LIST=6.0;6.1;6.2;7.0;7.2;7.5;8.0" >> $GITHUB_ENV 42 | else 43 | exit 1 44 | fi 45 | - run: | 46 | mkdir miniconda 47 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda/miniconda.sh 48 | bash miniconda/miniconda.sh -b -u -p $(pwd)/miniconda 49 | rm miniconda/miniconda.sh 50 | - run: echo "$(pwd)/miniconda/bin" >> $GITHUB_PATH 51 | - run: conda install -y python=${{ matrix.python-version }} 52 | - run: conda install -y pytorch=${{ matrix.pytorch-version }} cudatoolkit=${{ matrix.cuda-version }} -c pytorch -c nvidia 53 | - run: python setup.py build_ext --inplace 54 | - run: python setup.py bdist_wheel 55 | - uses: actions/upload-artifact@v2 56 | with: 57 | name: dist-wheel 58 | path: dist/*.whl 59 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # Test at least the CPU part since there are no GPU runners 2 | name: tests 3 | 4 | # Test only pushes on master or pull requests on master 5 | on: 6 | push: 7 | branches: [master, test-workflow] 8 | pull_request: 9 | branches: [master] 10 | 11 | # Build and run the tests 12 | jobs: 13 | test: 14 | strategy: 15 | matrix: 16 | python-version: [3.6, 3.7, 3.8] 17 | pytorch-version: [1.6, 1.7, 1.8, 1.9] 18 | runs-on: ubuntu-latest 19 | steps: 20 | - run: sudo apt install -y nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc 21 | - uses: actions/checkout@v2 22 | - run: | 23 | mkdir miniconda 24 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda/miniconda.sh 25 | bash miniconda/miniconda.sh -b -u -p $(pwd)/miniconda 26 | rm miniconda/miniconda.sh 27 | - run: echo "$(pwd)/miniconda/bin" >> $GITHUB_PATH 28 | - run: conda install -y python=${{ matrix.python-version }} 29 | - run: conda install -y -c pytorch pytorch=${{ matrix.pytorch-version }} 30 | - run: python setup.py build_ext --inplace 31 | - run: pip install -e . 32 | - run: python -m unittest discover -s $GITHUB_WORKSPACE/tests -v 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build 3 | site 4 | dist 5 | *.egg-info 6 | *.so 7 | *.swp 8 | *.pyd 9 | -------------------------------------------------------------------------------- /.mailmap: -------------------------------------------------------------------------------- 1 | Angelos Katharopoulos 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 4 | Written by Angelos Katharopoulos , 5 | Apoorv Vyas 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of 8 | this software and associated documentation files (the "Software"), to deal in 9 | the Software without restriction, including without limitation the rights to 10 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 11 | of the Software, and to permit persons to whom the Software is furnished to do 12 | so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /docs/api_docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Redirecting to the API docs ... 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/custom_attention_layer.md: -------------------------------------------------------------------------------- 1 | Creating a custom attention layer 2 | ================================= 3 | 4 | In this page, we will go through the process of creating a custom attention 5 | module and integrating it with the library. We will implement a quadratic 6 | kernel attention instead of softmax attention. 7 | 8 | New Attention 9 | ------------- 10 | 11 | Our attention layer will follow closely the implementation of 12 | [FullAttention][1]. Let's start with the skeleton of our module. 13 | 14 | ```python 15 | class QuadraticAttention(Module): 16 | def __init__(self, quadratic_temp=1.0, eps=1e-6): 17 | super(QuadraticAttention, self).__init__() 18 | self.eps = eps 19 | self.quadratic_temp = quadratic_temp 20 | 21 | def forward(self, queries, keys, values, attn_mask, query_lengths, 22 | key_lengths): 23 | # implement the logic of the layer here 24 | ``` 25 | 26 | The queries, keys and values are already projected and split into multiple 27 | heads by the [AttentionLayer][2]. This means that we need only implement the 28 | attention part. 29 | 30 | ```python 31 | class QuadraticAttention(Module): 32 | def __init__(self, quadratic_temp=1.0, eps=1e-6): 33 | super(QuadraticAttention, self).__init__() 34 | self.eps = eps 35 | self.quadratic_temp = quadratic_temp 36 | 37 | def forward(self, queries, keys, values, attn_mask, query_lengths, 38 | key_lengths): 39 | # compute the unnormalized attention 40 | QK = torch.einsum("nlhe,nshe->nhls", queries, keys) # compute the dot products 41 | QK = torch.square(self.quadratic_temp * QK) # implement our custom attention twist 42 | QK = QK * attn_mask.float_matrix # use the attention mask as a multiplicative mask 43 | QK = QK * key_lengths.float_matrix[:, None, None] # also a multiplicative mask 44 | 45 | # normalize and compute the average 46 | A = QK / (QK.sum(dim=-1, keepdim=True) + self.eps) 47 | V = torch.einsum("nhls,nshd->nlhd", A, values) 48 | 49 | return V.contiguous() 50 | ``` 51 | 52 | Integrate with the Builder 53 | -------------------------- 54 | 55 | To add it as an option to the `TransformerEncoderBuilder` or the 56 | `TransformerDecoderBuilder` we have to register our new attention in the 57 | appropriate [attention registry](builders.md#attention-registry). The available 58 | registries are 59 | 60 | * AttentionRegistry 61 | * RecurrentAttentionRegistry 62 | * RecurrentCrossAttentionRegistry 63 | 64 | Similar to [FullAttention][1] we will use `AttentionRegistry` because our 65 | implementation is not recurrent. The following snippet integrates our quadratic 66 | attention with the builders. 67 | 68 | ```python 69 | from fast_transformers.attention_registry import AttentionRegistry, \ 70 | Optional, Float # we also need these to add our new 71 | # parameter 'quadratic_temp' 72 | 73 | AttentionRegistry.register( 74 | "square", QuadraticAttention, # attention_type, class pair 75 | [ 76 | ("quadratic_temp", Optional(Float, 1.0)) # an optional parameter named 77 | # 'quadratic_temp' of type 78 | # float and with default 79 | # value 1.0 80 | ] 81 | ) 82 | ``` 83 | 84 | Afterwards we can use the builder to create transformers with our new 85 | attention layer. 86 | 87 | ```python 88 | quadratic_bert = TransformerEncoderBuilder.from_kwargs( 89 | attention_type="square", # here we select our custom attention layer 90 | n_layers=12, 91 | n_heads=12, 92 | query_dimensions=64, 93 | value_dimensions=64, 94 | feed_forward_dimensions=3072, 95 | activation="gelu", 96 | quadratic_temp=5.0 # set the temperature for our quadratic layer 97 | ) 98 | ``` 99 | 100 | 101 | [1]: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/full_attention.py 102 | [2]: attention.md 103 | [3]: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/builders/attention_builder.py 104 | [4]: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/builders/transformer_encoder_builder.py 105 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | Fast Transformers 2 | ================= 3 | 4 | Transformers are very succsessfull models that achieve state of the art 5 | performance in many natural language tasks. However, it is very difficult to 6 | scale them to long sequences due to the quadratic scaling of self-attention. 7 | 8 | This library was developed for our research on fast attention for transformers. 9 | You can find a list of our papers [below](#research) as well as related papers 10 | and papers that we have implemented. 11 | 12 | Quick-start 13 | ----------- 14 | 15 | The main interface of the library for using the implemented fast transformers 16 | is the [builder interface](api/fast_transformers/builders/). This allows for 17 | experimenting with different attention implentations with minimal code changes. 18 | For instance building a BERT-like transformer encoder is as simple as the 19 | following code: 20 | 21 | ```python 22 | import torch 23 | from fast_transformers.builders import TransformerEncoderBuilder 24 | 25 | # Build a transformer encoder 26 | bert = TransformerEncoderBuilder.from_kwargs( 27 | n_layers=12, 28 | n_heads=12, 29 | query_dimensions=64, 30 | value_dimensions=64, 31 | feed_forward_dimensions=3072, 32 | attention_type="full", # change this to use another 33 | # attention implementation 34 | activation="gelu" 35 | ).get() 36 | 37 | y = bert(torch.rand( 38 | 10, # batch_size 39 | 512, # sequence length 40 | 64*12 # features 41 | )) 42 | ``` 43 | 44 | Installation 45 | ------------ 46 | 47 | The fast transformers library has the following dependencies: 48 | 49 | * PyTorch 50 | * C++ toolchain 51 | * CUDA toolchain (if you want to compile for GPUs) 52 | 53 | For most machines installation should be as simple as: 54 | 55 | ```bash 56 | pip install --user pytorch-fast-transformers 57 | ``` 58 | 59 | Research 60 | -------- 61 | 62 | ### Ours 63 | 64 | To read about the theory behind some attention implementations in this library 65 | we encourage you to follow our research. 66 | 67 | * Transformers are RNNs: Fast Autoregressive Transformers with 68 | Linear Attention ([arxiv](https://arxiv.org/abs/2006.16236), 69 | [video](https://youtu.be/KBWh7XCUAi8)) 70 | * Fast Transformers with Clustered Attention 71 | ([arxiv](https://arxiv.org/abs/2007.04825), 72 | [blog](https://clustered-transformers.github.io/blog/)) 73 | 74 | If you found our research helpful or influential please consider citing 75 | 76 | ``` 77 | @inproceedings{katharopoulos_et_al_2020, 78 | author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.}, 79 | title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention}, 80 | booktitle = {Proceedings of the International Conference on Machine Learning (ICML)}, 81 | year = {2020} 82 | } 83 | 84 | @article{vyas_et_al_2020, 85 | author={Vyas, A. and Katharopoulos, A. and Fleuret, F.}, 86 | title={Fast Transformers with Clustered Attention}, 87 | booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)}, 88 | year={2020} 89 | } 90 | ``` 91 | 92 | ### By others 93 | 94 | * Efficient Attention: Attention with Linear Complexities ([arxiv](https://arxiv.org/abs/1812.01243)) 95 | * Linformer: Self-Attention with Linear Complexity ([arxiv](https://arxiv.org/abs/2006.04768)) 96 | * Reformer: The Efficient Transformer ([arxiv](https://arxiv.org/abs/2001.04451)) 97 | 98 | Support, License and Copyright 99 | ------------------------------ 100 | 101 | This software is distributed with the **MIT** license which pretty much means that 102 | you can use it however you want and for whatever reason you want. All the 103 | information regarding support, copyright and the license can be found in the 104 | [LICENSE](https://github.com/idiap/fast-transformers/blob/master/LICENSE) file 105 | in the repository. 106 | -------------------------------------------------------------------------------- /docs/masking.md: -------------------------------------------------------------------------------- 1 | Masking 2 | ======= 3 | 4 | In this library, both for convenience and efficiency, we define a [BaseMask][1] 5 | interface that all masks should implement. The BaseMask interface allows 6 | accessing a mask in the following ways: 7 | 8 | 1. a bool tensor where True signifies what is kept 9 | 2. a float tensor where minus infinity signifies what is to be masked 10 | 2. a float tensor where zero signifies what is to be masked 11 | 3. a length tensor where everything after a certain length is to be masked 12 | 13 | This interface allows us to use the same mask definition with various attention 14 | implementations without compromising in performance or requiring code changes. 15 | For instance, softmax masks are usually implemented with additive masks that 16 | contain -inf and linear attention masks are efficiently implemented with 17 | multiplicative masks that contain zeros. 18 | 19 | BaseMask 20 | -------- 21 | 22 | Our [API docs][1] are quite thorough in explaining the BaseMask interface. 23 | 24 | Implementations 25 | --------------- 26 | 27 | We provide three implementations of the BaseMask interface *FullMask*, 28 | *LengthMask* and *TriangularCausalMask*. 29 | 30 | ### FullMask 31 | 32 | ``` 33 | fast_transformers.masking.FullMask(mask=None, N=None, M=None, device='cpu') 34 | ``` 35 | 36 | The FullMask is a simple wrapper over a pytorch boolean tensor. The arguments 37 | can be given both by keyword arguments and positional arguments. To imitate 38 | function overloading, the constructor checks the type of the first argument and 39 | if it is a tensor it treats it as the mask. otherwise it assumes that it was 40 | the N argument. 41 | 42 | **Arguments** 43 | 44 | * **mask**: The mask as a PyTorch tensor. 45 | * **N**: The rows of the all True mask to be created if the mask argument is 46 | not provided. 47 | * **M**: The columns of the all True mask to be created if the mask argument 48 | is not provided. If N is given M defaults to N. 49 | * **device**: The device to create the mask in (defaults to cpu) 50 | 51 | ### LengthMask 52 | 53 | ``` 54 | fast_transformers.masking.LengthMask(lengths, max_len=None, device=None) 55 | ``` 56 | 57 | The LengthMask is designed to be used for conveying different lengths of 58 | sequences. It can be accessed as an array of integers which may be beneficial 59 | for some attention implementations. 60 | 61 | **Arguments** 62 | 63 | * **lengths**: The lengths as a PyTorch long tensor 64 | * **max\_len**: The maximum length for the mask (defaults to lengths.max()) 65 | * **device**: The device to be used for creating the masks (defaults to 66 | lengths.device) 67 | 68 | ### TriangularCausalMask 69 | 70 | ``` 71 | fast_transformers.masking.TriangularCausalMask(N, device="cpu") 72 | ``` 73 | 74 | Represents a square matrix with everything masked above the main diagonal. It 75 | is meant to be used for training autoregressive transformers. 76 | 77 | **Arguments** 78 | 79 | * **N**: The size of the matrix 80 | * **device**: The device to create the mask in (defaults to cpu) 81 | 82 | 83 | [1]: /api_docs/fast_transformers/masking.html#fast_transformers.masking.BaseMask 84 | -------------------------------------------------------------------------------- /docs/recurrent_transformers.md: -------------------------------------------------------------------------------- 1 | Recurrent Transformers 2 | ====================== 3 | 4 | The transformer layers implemented in the [fast_transformers.transformers][1] 5 | module are processing the entire sequence simultaneously. On the other hand, 6 | this module implements transfomers as recurrent networks. Namely as networks 7 | that process the sequence one element at a time while updating some state. 8 | 9 | The TransformerEncoder and TransformerEncoderLayer give way to 10 | [RecurrentTransformerEncoder][2] and [RecurrentTransformerEncoderLayer][3] and 11 | for the decoders [RecurrentTransformerDecoder][7] and 12 | [RecurrentTransformerDecoderLayer][8] respectively. 13 | 14 | Forward method 15 | -------------- 16 | 17 | **RecurrentTransformerEncoder** or **RecurrentTransformerEncoderLayer** 18 | 19 | ``` 20 | forward(x, state=None) 21 | ``` 22 | 23 | **Arguments** 24 | 25 | * **x**: The input features of shape (N, E) where N is the batch size and E is 26 | `d_model` passed in the constructor. Note that x corresponds to a specific 27 | element in the sequence and not the entire sequence. 28 | * **state**: The state is a python object that varies depending on the 29 | attention implementation 30 | 31 | 32 | **RecurrentTransformerDecoder** or **RecurrentTransformerDecoderLayer** 33 | 34 | ``` 35 | forward(x, memory, memory_length_mask=None, state=None) 36 | ``` 37 | 38 | * **x**: The input features of shape (N, E) where N is the batch size and E is 39 | `d_model` passed in the constructor. Note that x corresponds to a specific 40 | element in the sequence and not the entire sequence. 41 | * **memory**: A sequence of features (N, S, E) that the input will attend 42 | to. S is the sequence length and E is the same as for x. 43 | * **memory\_length\_mask**: An implementation of a BaseMask that encodes 44 | how many elements each memory sequence in the batch consists of. 45 | * **state**: The state is a python object that varies depending on the 46 | attention implementation 47 | 48 |
49 |

Note

50 |

The masks are different in the recurrent implementations than in their 51 | batch counterparts. Namely, recurrent encoders and decoders enforce a 52 | triangular causal mask on self attention. In addition, recurrent decoders 53 | enforce a full mask on cross attention.

54 |
55 | 56 | Available Attentions 57 | -------------------- 58 | 59 | Not all attention formulations can be written in an autoregressive fashion as a 60 | recurrent model. In particular, since the sequence is passed to the transformer 61 | element by element we have the same result as passing a causal mask to normal 62 | transformers. The current list for recurrent attention implementations is: 63 | 64 | * [LinearAttention][4] 65 | * [FullAttention][5] 66 | 67 | Example 68 | ------- 69 | 70 | The following example builds a random recurrent transformer encoder and applies 71 | its output as input 100 times. 72 | 73 | ```python 74 | # for simplicity ignore all the classification 75 | # layers and the embedding layers 76 | 77 | from fast_transformers.builders import RecurrentEncoderBuilder 78 | 79 | model = RecurrentEncoderBuilder.from_kwargs( 80 | attention_type="linear", 81 | n_layers=8, 82 | n_heads=12, 83 | feed_forward_dimensions=1536, 84 | query_dimensions=32, 85 | value_dimensions=32 86 | ).get() 87 | 88 | x0 = torch.rand( 89 | 10, # batch size 90 | 12*32 # feature size 91 | ) 92 | state = None 93 | 94 | x = x0 95 | for i in range(100): 96 | x, state = model(x, state=state) 97 | ``` 98 | 99 | 100 | [1]: /api_docs/fast_transformers/transformers.html 101 | [2]: /api_docs/fast_transformers/recurrent/transformers.html#fast_transformers.recurrent.transformers.RecurrentTransformerEncoder 102 | [3]: /api_docs/fast_transformers/recurrent/transformers.html#fast_transformers.recurrent.transformers.RecurrentTransformerEncoderLayer 103 | [4]: /api_docs/fast_transformers/recurrent/attention/self_attention/linear_attention.html 104 | [5]: /api_docs/fast_transformers/recurrent/attention/self_attention/full_attention.html 105 | [6]: /api_docs/fast_transformers/builders/transformer_builders.html 106 | [7]: /api_docs/fast_transformers/recurrent/transformers.html#fast_transformers.recurrent.transformers.RecurrentTransformerDecoder 107 | [8]: /api_docs/fast_transformers/recurrent/transformers.html#fast_transformers.recurrent.transformers.RecurrentTransformerDecoderLayer 108 | -------------------------------------------------------------------------------- /docs/tips_and_tricks.md: -------------------------------------------------------------------------------- 1 | Tips & Tricks 2 | ============= 3 | 4 | In this module we will provide examples of common usecases when using the fast 5 | transformers library. We will be adding more examples as more utilities are 6 | implemented. 7 | 8 | Mirrored networks 9 | --------------- 10 | 11 | We call mirrored networks, networks that _share the parameter instances_ but have 12 | different module implementations. The most common use case is to have mirrored 13 | batch and recurrent versions of the same transformer model in order to train 14 | with the batch version and evaluate using the recurrent version. 15 | 16 | We provide the utility `make_mirror(src_module, dst_module)` to automatically 17 | set the source module parameters to the destination module. 18 | 19 | ```python 20 | from fast_transformer.builders import TransformerEncoderBuilder, \ 21 | RecurrentEncoderBuilder 22 | from fast_transfomer.utils import make_mirror 23 | 24 | params = dict(...) 25 | transformer = TransformerEncoderBuilder.from_dictionary(params).get() 26 | recurrent_transformer = RecurrentEncoderBuilder.from_dictionary(params).get() 27 | make_mirror(transformer, recurrent_transformer) 28 | 29 | # Now training transformer also changes the parameters of recurrent transformer 30 | # and vice-versa. 31 | ``` 32 | 33 | Checkpointing 34 | --------------- 35 | 36 | [Checkpointing](https://pytorch.org/docs/stable/checkpoint.html) is important 37 | when training large neural networks to allow for more layers to fit in a single 38 | GPU. The default PyTorch method of checkpointing, only accepts tensors as 39 | arguments which unfortunately excludes our self-attention and transformer 40 | modules that expect `BaseMask` objects for masking. 41 | 42 | !!! tip "Under development" 43 | We are developing wrappers around the default checkpointing mechanisms that 44 | will allow users to checkpoint modules of their choosing or even checkpoint 45 | every transformer block in a transformer encoder or decoder. 46 | 47 | Check back for details or check our [github repository issue #21][1]. 48 | 49 | [1]: https://github.com/idiap/fast-transformers/issues/21 50 | -------------------------------------------------------------------------------- /fast_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Provide a library with fast transformer implementations.""" 8 | 9 | __author__ = "Angelos Katharopoulos, Apoorv Vyas" 10 | __copyright__ = "Copyright (c) 2020 Idiap Research Institute" 11 | __license__ = "MIT" 12 | __maintainer__ = "Angelos Katharopoulos, Apoorv Vyas" 13 | __email__ = "angelos.katharopoulos@idiap.ch, avyas@idiap.ch" 14 | __url__ = "https://github.com/idiap/fast-transformers" 15 | __version__ = "0.4.0" 16 | -------------------------------------------------------------------------------- /fast_transformers/aggregate/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import torch 9 | 10 | from .aggregate_cpu import aggregate as aggregate_cpu, \ 11 | broadcast as broadcast_cpu 12 | try: 13 | from .aggregate_cuda import aggregate as aggregate_gpu, \ 14 | broadcast as broadcast_gpu 15 | from .clustered_aggregate_cuda import \ 16 | clustered_broadcast as clustered_broadcast_gpu, \ 17 | clustered_aggregate as clustered_aggregate_gpu 18 | 19 | except ImportError: 20 | pass 21 | 22 | 23 | def aggregate(X, G, F, Y=None): 24 | device = X.device 25 | if Y is None: 26 | Y = torch.zeros( 27 | F.shape + (X.shape[-1],), 28 | device=device, 29 | dtype=X.dtype 30 | ) 31 | else: 32 | Y.zero_() 33 | 34 | if device.type == "cpu": 35 | aggregate_cpu(X, G, F, Y) 36 | else: 37 | aggregate_gpu(X, G, F, Y) 38 | 39 | return Y 40 | 41 | 42 | def broadcast(Y, G, F, X=None): 43 | device = Y.device 44 | if X is None: 45 | X = torch.zeros( 46 | G.shape + (Y.shape[-1],), 47 | device=device, 48 | dtype=Y.dtype 49 | ) 50 | 51 | if device.type == "cpu": 52 | broadcast_cpu(Y, G, F, X) 53 | else: 54 | broadcast_gpu(Y, G, F, X) 55 | 56 | return X 57 | 58 | 59 | # Divide the cluster into groups of equal size 60 | # as constrained by the shared memory 61 | def set_group(C, E): 62 | C_per_block = int(192 * 64 / (E+1)) 63 | G_min = (C + C_per_block - 1) // C_per_block 64 | for G in range(G_min, C+1): 65 | if C % G == 0: 66 | return G 67 | 68 | 69 | def clustered_broadcast(Y, groups, counts, factors, X=None): 70 | device = Y.device 71 | if X is None: 72 | X = torch.zeros( 73 | groups.shape + (Y.shape[-1],), 74 | device=device, 75 | dtype=Y.dtype 76 | ) 77 | if device.type == "cpu": 78 | broadcast_cpu(Y, groups, factors, X) 79 | else: 80 | N, H, C, E = Y.shape 81 | _, _, L, _ = X.shape 82 | 83 | # Following are some booking keeping parameters to facilitate the 84 | # broadcast kernel that takes advantage of clustering 85 | # More information can be found in the cuda file 86 | with torch.no_grad(): 87 | threads = 256 88 | G = set_group(C, E) 89 | group_counts = counts.view(N, H, G, -1).sum(-1) 90 | block_counts = (group_counts + threads - 1) // threads 91 | total_blocks = block_counts.sum().item() 92 | indx_maps = torch.ones( 93 | (total_blocks, 5), 94 | device=X.device, 95 | dtype=torch.int32 96 | ) 97 | 98 | clustered_broadcast_gpu( 99 | Y, 100 | groups, 101 | factors, 102 | X, 103 | block_counts.int(), 104 | group_counts.int(), 105 | threads, 106 | G, 107 | total_blocks, 108 | indx_maps 109 | ) 110 | return X 111 | 112 | 113 | def clustered_aggregate(X, G, F, lengths, Y=None): 114 | device = X.device 115 | if Y is None: 116 | Y = torch.zeros( 117 | F.shape + (X.shape[-1],), 118 | device=device, 119 | dtype=X.dtype 120 | ) 121 | else: 122 | Y.zero_() 123 | 124 | if device.type == "cpu": 125 | aggregate_cpu(X, G, F, Y) 126 | else: 127 | clustered_aggregate_gpu(X, G, F, lengths, Y) 128 | return Y 129 | -------------------------------------------------------------------------------- /fast_transformers/aggregate/aggregate_cpu.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | // Written by Angelos Katharopoulos , 4 | // Apoorv Vyas 5 | // 6 | 7 | #include 8 | 9 | 10 | /** 11 | * Aggregate the passed vectors X based on the group indices in G multiplied 12 | * by the factors F. 13 | */ 14 | void aggregate( 15 | const torch::Tensor X, 16 | const torch::Tensor G, 17 | const torch::Tensor F, 18 | torch::Tensor Y 19 | ) { 20 | int N = X.size(0); 21 | int H = X.size(1); 22 | int L = X.size(2); 23 | int E = X.size(3); 24 | 25 | int C = Y.size(2); 26 | const float *x = X.data_ptr(); 27 | const int32_t *g = G.data_ptr(); 28 | const float *f = F.data_ptr(); 29 | float *y = Y.data_ptr(); 30 | 31 | // Aggregate all the Xs to the destination 32 | #pragma omp parallel for 33 | for (int n=0; n= C)) { 38 | continue; 39 | } 40 | const float *src = x + n*H*L*E + h*L*E + l*E; 41 | float f_nhk = *(f + n*H*C + h*C + k); 42 | float *dst = y + n*H*C*E + h*C*E + k*E; 43 | 44 | for (int e=0; e(); 73 | const int32_t *g = G.data_ptr(); 74 | const float *f = F.data_ptr(); 75 | float *x = X.data_ptr(); 76 | 77 | // Broadcast all the Ys back into Xs 78 | // For now the parallelization is over L. 79 | // TODO: Check if parallelization over n is faster 80 | #pragma omp parallel for 81 | for (int l=0; l= C)) { 86 | continue; 87 | } 88 | const float *src = y + n*H*C*E + h*C*E + k*E; 89 | float f_nhk = *(f + n*H*C + h*C + k); 90 | float *dst = x + n*H*L*E + h*L*E + l*E; 91 | 92 | for (int e=0; e, 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implementations of different types of attention mechanisms.""" 8 | 9 | 10 | from .attention_layer import AttentionLayer 11 | from .full_attention import FullAttention 12 | from .linear_attention import LinearAttention 13 | from .causal_linear_attention import CausalLinearAttention 14 | from .clustered_attention import ClusteredAttention 15 | from .improved_clustered_attention import ImprovedClusteredAttention 16 | from .reformer_attention import ReformerAttention 17 | from .conditional_full_attention import ConditionalFullAttention 18 | from .exact_topk_attention import ExactTopKAttention 19 | from .improved_clustered_causal_attention import ImprovedClusteredCausalAttention 20 | from .local_attention import LocalAttention 21 | from .aft_attention import AFTFullAttention, AFTSimpleAttention 22 | -------------------------------------------------------------------------------- /fast_transformers/attention/causal_linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement causally masked linear attention.""" 8 | 9 | import torch 10 | from torch.nn import Module 11 | 12 | from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \ 13 | EventDispatcherInstance 14 | from ..events import EventDispatcher 15 | from ..causal_product import causal_dot_product 16 | from ..feature_maps import elu_feature_map 17 | 18 | 19 | def causal_linear(Q, K, V): 20 | Q = Q.permute(0,2,1,3).contiguous() 21 | K = K.permute(0,2,1,3).contiguous() 22 | V = V.permute(0,2,1,3).contiguous() 23 | V_new = causal_dot_product(Q, K, V) 24 | return V_new.permute(0,2,1,3).contiguous() 25 | 26 | 27 | class CausalLinearAttention(Module): 28 | """Implement causally masked attention using dot product of feature maps in 29 | O(N D^2) complexity. 30 | 31 | See fast_transformers.attention.linear_attention.LinearAttention for the 32 | general concept of replacing the softmax with feature maps. In addition to 33 | that, we also make use of the fact that causal masking is a triangular mask 34 | which allows us to apply the masking and still compute the attention in O(N 35 | D^2) complexity. 36 | 37 | Arguments 38 | --------- 39 | feature_map: callable, a callable that applies the feature map to the 40 | last dimension of a tensor (default: elu(x)+1) 41 | eps: float, a small number to ensure the numerical stability of the 42 | denominator (default: 1e-6) 43 | event_dispatcher: str or EventDispatcher instance to be used by this 44 | module for dispatching events (default: the default 45 | global dispatcher) 46 | """ 47 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6, 48 | event_dispatcher=""): 49 | super(CausalLinearAttention, self).__init__() 50 | self.feature_map = ( 51 | feature_map(query_dimensions) if feature_map else 52 | elu_feature_map(query_dimensions) 53 | ) 54 | self.eps = eps 55 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 56 | 57 | def _make_sizes_compatible(self, Q, K): 58 | """Either slice or pad K in case that the sizes do not match between Q 59 | and K.""" 60 | N, L, H, E = Q.shape 61 | _, S, _, _ = K.shape 62 | if L == S: 63 | return Q, K 64 | 65 | if L < S: 66 | return Q, K[:, :L, :, :] 67 | 68 | if L > S: 69 | return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1) 70 | 71 | def forward(self, queries, keys, values, attn_mask, query_lengths, 72 | key_lengths): 73 | # Apply the feature map to the queries and keys 74 | self.feature_map.new_feature_map(queries.device) 75 | Q = self.feature_map.forward_queries(queries) 76 | K = self.feature_map.forward_keys(keys) 77 | 78 | # Apply the key padding mask and make sure the attn_mask is a 79 | # lower triangular causal mask 80 | if not attn_mask.lower_triangular: 81 | raise RuntimeError(("CausalLinearAttention only supports full " 82 | "lower triangular masks")) 83 | K = K * key_lengths.float_matrix[:, :, None, None] 84 | 85 | # Ensure that Q and K have compatible sizes for the following 86 | # computations, namely L == S 87 | Q, K = self._make_sizes_compatible(Q, K) 88 | 89 | # TODO: Shall we divide the Q and K with a relatively large number to 90 | # avoid numerical instabilities in computing the denominator? 91 | # We used to divide each with the max norm of all q and k but 92 | # that seems relatively costly for a simple normalization. 93 | 94 | # Compute the normalizers 95 | Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps) 96 | 97 | # Compute the unnormalized result 98 | V = causal_linear( 99 | Q, 100 | K, 101 | values 102 | ) 103 | 104 | return V * Z[:, :, :, None] 105 | 106 | 107 | # Register the attention implementation so that it becomes available in our 108 | # builders 109 | AttentionRegistry.register( 110 | "causal-linear", CausalLinearAttention, 111 | [ 112 | ("query_dimensions", Int), 113 | ("feature_map", Optional(Callable)), 114 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 115 | ] 116 | ) 117 | -------------------------------------------------------------------------------- /fast_transformers/attention/conditional_full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement a self attention that delegates to full attention or another 8 | attention depending on the input sequence length.""" 9 | 10 | import torch 11 | from torch.nn import Module 12 | 13 | from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ 14 | EventDispatcherInstance 15 | from ..events import EventDispatcher 16 | from .full_attention import FullAttention 17 | 18 | 19 | class ConditionalFullAttention(Module): 20 | """"Delegate to full attention if the input sequence is short. 21 | 22 | Arguments 23 | --------- 24 | other_attention: Use the passed attention module if the sequence is 25 | longer than 'length_limit'. 26 | length_limit: An integer denoting the maximum sequence length to 27 | consider. 28 | softmax_temp: See fast_transformers.attention.full_attention. 29 | attention_dropout: See fast_transformers.attention.full_attention. 30 | event_dispatcher: str or EventDispatcher instance to be used by this 31 | module for dispatching events (default: the default 32 | global dispatcher) 33 | """ 34 | def __init__(self, other_attention, length_limit=512, softmax_temp=None, 35 | attention_dropout=0.1, event_dispatcher=""): 36 | super(ConditionalFullAttention, self).__init__() 37 | self.full_attention = FullAttention(softmax_temp, attention_dropout) 38 | self.other_attention = other_attention 39 | self.length_limit = length_limit 40 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 41 | 42 | def forward(self, queries, keys, values, attn_mask, query_lengths, 43 | key_lengths): 44 | # Extract some shapes to compare with the length limit 45 | L = queries.shape[1] 46 | S = values.shape[1] 47 | 48 | if L > self.length_limit or S > self.length_limit: 49 | return self.other_attention(queries, keys, values, attn_mask, 50 | query_lengths, key_lengths) 51 | else: 52 | return self.full_attention(queries, keys, values, attn_mask, 53 | query_lengths, key_lengths) 54 | 55 | 56 | # Register the attention implementation so that it becomes available in our 57 | # builders 58 | AttentionRegistry.register( 59 | "conditional-full", ConditionalFullAttention, 60 | [ 61 | ("length_limit", Optional(Int, 512)), 62 | ("softmax_temp", Optional(Float)), 63 | ("attention_dropout", Optional(Float, 0.1)), 64 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 65 | ] 66 | ) 67 | -------------------------------------------------------------------------------- /fast_transformers/attention/exact_topk_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement the oracle top-k attention. The top-k keys are exact ones. 8 | MultiHeadAttention module. Note that this module is to be used in conjuction 9 | with the AttentionLayer in order to work.""" 10 | 11 | from math import sqrt 12 | 13 | import torch 14 | from torch.nn import Dropout, Module 15 | 16 | from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ 17 | EventDispatcherInstance 18 | from ..events import EventDispatcher 19 | 20 | 21 | class ExactTopKAttention(Module): 22 | """Implement the oracle top-k softmax attention. 23 | 24 | Arguments 25 | --------- 26 | top-k: The top k keys to attend to (default: 32) 27 | softmax_temp: The temperature to use for the softmax attention. 28 | (default: 1/sqrt(d_keys) where d_keys is computed at 29 | runtime) 30 | attention_dropout: The dropout rate to apply to the attention 31 | (default: 0.1) 32 | event_dispatcher: str or EventDispatcher instance to be used by this 33 | module for dispatching events (default: the default 34 | global dispatcher) 35 | """ 36 | def __init__(self, topk=32, softmax_temp=None, attention_dropout=0.1, 37 | event_dispatcher=""): 38 | super(ExactTopKAttention, self).__init__() 39 | self.topk = topk 40 | self.softmax_temp = softmax_temp 41 | self.dropout = Dropout(attention_dropout) 42 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 43 | 44 | def forward(self, queries, keys, values, attn_mask, query_lengths, 45 | key_lengths): 46 | # Extract some shapes and compute the temperature 47 | N, L, H, E = queries.shape 48 | _, S, _, D = values.shape 49 | softmax_temp = self.softmax_temp or 1./sqrt(E) 50 | 51 | # Compute the unnormalized attention and apply the masks 52 | QK = torch.einsum("nlhe,nshe->nhls", queries, keys) 53 | topk = min(self.topk, S) 54 | 55 | if not attn_mask.all_ones: 56 | QK = QK + attn_mask.additive_matrix 57 | QK = QK + key_lengths.additive_matrix[:, None, None] 58 | 59 | topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1) 60 | mask = QK.new_ones(QK.shape) * float("-inf") 61 | mask[ 62 | torch.arange(N, device=QK.device).view(N, 1, 1, 1), 63 | torch.arange(H, device=QK.device).view(1, H, 1, 1), 64 | torch.arange(L, device=QK.device).view(1, 1, L, 1), 65 | topk_idx, 66 | ] = 0. 67 | 68 | QK = QK + mask 69 | 70 | # Compute the attention and the weighted average 71 | A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) 72 | V = torch.einsum("nhls,nshd->nlhd", A, values) 73 | 74 | # Make sure that what we return is contiguous 75 | return V.contiguous() 76 | 77 | 78 | # Register the attention implementation so that it becomes available in our 79 | # builders 80 | AttentionRegistry.register( 81 | "exact-topk", ExactTopKAttention, 82 | [ 83 | ("topk", Optional(Int, 32)), 84 | ("softmax_temp", Optional(Float)), 85 | ("attention_dropout", Optional(Float, 0.1)), 86 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 87 | ] 88 | ) 89 | -------------------------------------------------------------------------------- /fast_transformers/attention/full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement the full attention similar to the one implemented by PyTorch's 8 | MultiHeadAttention module. Note that this module is to be used in conjuction 9 | with the `fast_transformers.attention.attention_layer.AttentionLayer` in order 10 | to work.""" 11 | 12 | from math import sqrt 13 | 14 | import torch 15 | from torch.nn import Dropout, Module 16 | 17 | from ..attention_registry import AttentionRegistry, Optional, Float, \ 18 | EventDispatcherInstance 19 | from ..events import EventDispatcher, AttentionEvent 20 | 21 | 22 | class FullAttention(Module): 23 | """Implement the scaled dot product attention with softmax. 24 | 25 | Arguments 26 | --------- 27 | softmax_temp: The temperature to use for the softmax attention. 28 | (default: 1/sqrt(d_keys) where d_keys is computed at 29 | runtime) 30 | attention_dropout: The dropout rate to apply to the attention 31 | (default: 0.1) 32 | event_dispatcher: str or EventDispatcher instance to be used by this 33 | module for dispatching events (default: the default 34 | global dispatcher) 35 | """ 36 | def __init__(self, softmax_temp=None, attention_dropout=0.1, 37 | event_dispatcher=""): 38 | super(FullAttention, self).__init__() 39 | self.softmax_temp = softmax_temp 40 | self.dropout = Dropout(attention_dropout) 41 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 42 | 43 | def forward(self, queries, keys, values, attn_mask, query_lengths, 44 | key_lengths): 45 | """Implements the multihead softmax attention. 46 | 47 | Arguments 48 | --------- 49 | queries: (N, L, H, E) The tensor containing the queries 50 | keys: (N, S, H, E) The tensor containing the keys 51 | values: (N, S, H, D) The tensor containing the values 52 | attn_mask: An implementation of BaseMask that encodes where each 53 | query can attend to 54 | query_lengths: An implementation of BaseMask that encodes how 55 | many queries each sequence in the batch consists of 56 | key_lengths: An implementation of BaseMask that encodes how 57 | many queries each sequence in the batch consists of 58 | """ 59 | # Extract some shapes and compute the temperature 60 | N, L, H, E = queries.shape 61 | _, S, _, D = values.shape 62 | softmax_temp = self.softmax_temp or 1./sqrt(E) 63 | 64 | # Scale the queries instead of applying the softmax temperature to the 65 | # dot products 66 | queries = queries * softmax_temp 67 | 68 | # Compute the unnormalized attention and apply the masks 69 | QK = torch.einsum("nlhe,nshe->nhls", queries, keys) 70 | if not attn_mask.all_ones: 71 | QK = QK + attn_mask.additive_matrix 72 | if not key_lengths.all_ones: 73 | QK = QK + key_lengths.additive_matrix[:, None, None] 74 | 75 | # Compute the attention and the weighted average 76 | A = self.dropout(torch.softmax(QK, dim=-1)) 77 | V = torch.einsum("nhls,nshd->nlhd", A, values) 78 | 79 | # Let the world know of the attention matrix 80 | self.event_dispatcher.dispatch(AttentionEvent(self, A)) 81 | 82 | # Make sure that what we return is contiguous 83 | return V.contiguous() 84 | 85 | 86 | # Register the attention implementation so that it becomes available in our 87 | # builders 88 | AttentionRegistry.register( 89 | "full", FullAttention, 90 | [ 91 | ("softmax_temp", Optional(Float)), 92 | ("attention_dropout", Optional(Float, 0.1)), 93 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 94 | ] 95 | ) 96 | -------------------------------------------------------------------------------- /fast_transformers/attention/linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement unmasked linear attention.""" 8 | 9 | import torch 10 | from torch.nn import Module 11 | 12 | from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \ 13 | EventDispatcherInstance 14 | from ..events import EventDispatcher 15 | from ..feature_maps import elu_feature_map 16 | 17 | 18 | class LinearAttention(Module): 19 | """Implement unmasked attention using dot product of feature maps in 20 | O(N D^2) complexity. 21 | 22 | Given the queries, keys and values as Q, K, V instead of computing 23 | 24 | V' = softmax(Q.mm(K.t()), dim=-1).mm(V), 25 | 26 | we make use of a feature map function Φ(.) and perform the following 27 | computation 28 | 29 | V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V). 30 | 31 | The above can be computed in O(N D^2) complexity where D is the 32 | dimensionality of Q, K and V and N is the sequence length. Depending on the 33 | feature map, however, the complexity of the attention might be limited. 34 | 35 | Arguments 36 | --------- 37 | feature_map: callable, a callable that applies the feature map to the 38 | last dimension of a tensor (default: elu(x)+1) 39 | eps: float, a small number to ensure the numerical stability of the 40 | denominator (default: 1e-6) 41 | event_dispatcher: str or EventDispatcher instance to be used by this 42 | module for dispatching events (default: the default 43 | global dispatcher) 44 | """ 45 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6, 46 | event_dispatcher=""): 47 | super(LinearAttention, self).__init__() 48 | self.feature_map = ( 49 | feature_map(query_dimensions) if feature_map else 50 | elu_feature_map(query_dimensions) 51 | ) 52 | self.eps = eps 53 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 54 | 55 | def forward(self, queries, keys, values, attn_mask, query_lengths, 56 | key_lengths): 57 | # Apply the feature map to the queries and keys 58 | self.feature_map.new_feature_map(queries.device) 59 | Q = self.feature_map.forward_queries(queries) 60 | K = self.feature_map.forward_keys(keys) 61 | 62 | # Apply the key padding mask and make sure that the attn_mask is 63 | # all_ones 64 | if not attn_mask.all_ones: 65 | raise RuntimeError(("LinearAttention does not support arbitrary " 66 | "attention masks")) 67 | K = K * key_lengths.float_matrix[:, :, None, None] 68 | 69 | # Compute the KV matrix, namely the dot product of keys and values so 70 | # that we never explicitly compute the attention matrix and thus 71 | # decrease the complexity 72 | KV = torch.einsum("nshd,nshm->nhmd", K, values) 73 | 74 | # Compute the normalizer 75 | Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) 76 | 77 | # Finally compute and return the new values 78 | V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) 79 | 80 | return V.contiguous() 81 | 82 | 83 | # Register the attention implementation so that it becomes available in our 84 | # builders 85 | AttentionRegistry.register( 86 | "linear", LinearAttention, 87 | [ 88 | ("query_dimensions", Int), 89 | ("feature_map", Optional(Callable)), 90 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 91 | ] 92 | ) 93 | -------------------------------------------------------------------------------- /fast_transformers/attention/local_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Implement local context attention.""" 7 | 8 | from math import sqrt 9 | 10 | import torch 11 | from torch.nn import Module, Dropout 12 | from torch.nn import functional as F 13 | 14 | from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ 15 | EventDispatcherInstance 16 | from ..events import EventDispatcher 17 | from ..local_product import local_dot_product, local_weighted_average 18 | 19 | 20 | class LocalAttention(Module): 21 | """Implement fast local attention where a query can only attend to 22 | neighboring keys. 23 | 24 | In this attention module the query Q_i can only attend to a key K_j if 25 | |i-j| < local_context/2. 26 | 27 | Arguments 28 | --------- 29 | local_context: The neighborhood to consider for local attention. 30 | softmax_temp: The temperature to use for the softmax attention. 31 | (default: 1/sqrt(d_keys) where d_keys is computed at 32 | runtime) 33 | attention_dropout: The dropout rate to apply to the attention 34 | (default: 0.1) 35 | event_dispatcher: str or EventDispatcher instance to be used by this 36 | module for dispatching events (default: the default 37 | global dispatcher) 38 | """ 39 | def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1, 40 | event_dispatcher=""): 41 | super(LocalAttention, self).__init__() 42 | self.local_context = local_context 43 | self.softmax_temp = softmax_temp 44 | self.dropout = Dropout(attention_dropout) 45 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 46 | 47 | def forward(self, queries, keys, values, attn_mask, query_lengths, 48 | key_lengths): 49 | """Implements the local attention. 50 | 51 | The attn_mask can be anything but the only values that will be 52 | considered will be the ones in the neighborhood of each query. 53 | 54 | Arguments 55 | --------- 56 | queries: (N, L, H, E) The tensor containing the queries 57 | keys: (N, S, H, E) The tensor containing the keys 58 | values: (N, S, H, D) The tensor containing the values 59 | attn_mask: An implementation of BaseMask that encodes where each 60 | query can attend to 61 | query_lengths: An implementation of BaseMask that encodes how 62 | many queries each sequence in the batch consists of 63 | key_lengths: An implementation of BaseMask that encodes how 64 | many queries each sequence in the batch consists of 65 | """ 66 | # Extract some shapes and compute the temperature 67 | N, L, H, E = queries.shape 68 | _, S, _, D = values.shape 69 | context = self.local_context 70 | softmax_temp = self.softmax_temp or 1./sqrt(E) 71 | 72 | # Permute the dimensions to NHLE instead of NLHE 73 | queries = queries.permute(0, 2, 1, 3).contiguous() 74 | keys = keys.permute(0, 2, 1, 3).contiguous() 75 | values = values.permute(0, 2, 1, 3).contiguous() 76 | 77 | QK = local_dot_product( 78 | queries, 79 | keys, 80 | attn_mask.additive_matrix_finite, 81 | key_lengths.lengths, 82 | self.local_context 83 | ) 84 | A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) 85 | 86 | V_new = local_weighted_average(A, values) 87 | 88 | return V_new.permute(0, 2, 1, 3).contiguous() 89 | 90 | 91 | # Register the attention implementation so that it becomes available in our 92 | # builders 93 | AttentionRegistry.register( 94 | "local", LocalAttention, 95 | [ 96 | ("local_context", Int), 97 | ("softmax_temp", Optional(Float)), 98 | ("attention_dropout", Optional(Float, 0.1)), 99 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 100 | ] 101 | ) 102 | -------------------------------------------------------------------------------- /fast_transformers/attention_registry/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Allow for the dynamic registration of new attention implementations. 7 | 8 | This module provides a Registry implementation that other modules can use to 9 | register attention implementations for the builders. 10 | """ 11 | 12 | from .registry import \ 13 | AttentionRegistry, \ 14 | RecurrentAttentionRegistry, \ 15 | RecurrentCrossAttentionRegistry 16 | from .spec import Spec, Choice, Optional, Int, Float, Bool, Callable, \ 17 | EventDispatcherInstance 18 | -------------------------------------------------------------------------------- /fast_transformers/attention_registry/registry.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | 7 | class Registry(object): 8 | """Hold the available attention implementations and their required 9 | parameters.""" 10 | def __init__(self): 11 | self._classes = {} 12 | self._class_params = {} 13 | self._parameters = {} 14 | 15 | def register(self, key, class_object, parameter_tuples): 16 | # register the class if the key is new 17 | if key in self._classes: 18 | raise ValueError("{} is already registered".format(key)) 19 | self._classes[key] = class_object 20 | 21 | # register the parameters 22 | for parameter, spec in parameter_tuples: 23 | if ( 24 | parameter in self._parameters and 25 | self._parameters[parameter] != spec 26 | ): 27 | raise ValueError(("{} is already registered with " 28 | "spec {!r} instead of {!r}").format( 29 | parameter, 30 | self._parameters[parameter], 31 | spec 32 | )) 33 | self._parameters[parameter] = spec 34 | 35 | # note which parameters are needed by this class 36 | self._class_params[key] = [p for p, s in parameter_tuples] 37 | 38 | def __contains__(self, key): 39 | return key in self._classes 40 | 41 | def __getitem__(self, key): 42 | return self._classes[key], self._class_params[key] 43 | 44 | @property 45 | def keys(self): 46 | return list(self._classes.keys()) 47 | 48 | def contains_parameter(self, key): 49 | return key in self._parameters 50 | 51 | def validate_parameter(self, key, value): 52 | try: 53 | return self._parameters[key].get(value) 54 | except Exception as e: 55 | raise ValueError(("Invalid value {!r} for " 56 | "parameter {!r}").format(value, key)) from e 57 | 58 | 59 | AttentionRegistry = Registry() 60 | RecurrentAttentionRegistry = Registry() 61 | RecurrentCrossAttentionRegistry = Registry() 62 | -------------------------------------------------------------------------------- /fast_transformers/attention_registry/spec.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Spec instances allow to describe and check the type and value of 7 | parameters.""" 8 | 9 | from ..events import EventDispatcher 10 | 11 | 12 | class Spec(object): 13 | """Describe and validate a parameter type. 14 | 15 | Arguments 16 | --------- 17 | predicate: A callable that checks if the value is acceptable and 18 | returns its canonical value or raises ValueError. 19 | name: A name to create a human readable description of the Spec 20 | """ 21 | def __init__(self, predicate, name="CustomSpec"): 22 | self._predicate = predicate 23 | self._name = name 24 | 25 | def __repr__(self): 26 | return self._name 27 | 28 | def check(self, x): 29 | try: 30 | self._predicate(x) 31 | return True 32 | except ValueError: 33 | return False 34 | 35 | def get(self, x): 36 | return self._predicate(x) 37 | 38 | def __eq__(self, y): 39 | return self is y 40 | 41 | 42 | class Choice(Spec): 43 | """A parameter type for a set of options. 44 | 45 | Arguments 46 | --------- 47 | choices: A set or list of possible values for this parameter 48 | """ 49 | def __init__(self, choices): 50 | self._choices = choices 51 | 52 | def get(self, x): 53 | if x in self._choices: 54 | return x 55 | raise ValueError("{!r} is not in {!r}".format(x, self._choices)) 56 | 57 | def __repr__(self): 58 | return "Choice({!r})".format(self._choices) 59 | 60 | def __eq__(self, x): 61 | if isinstance(x, Choice): 62 | return self._choices == x._choices 63 | return False 64 | 65 | 66 | class _Callable(Spec): 67 | def __init__(self): 68 | super(_Callable, self).__init__(None, "Callable") 69 | 70 | def get(self, x): 71 | if callable(x): 72 | return x 73 | raise ValueError("{!r} is not a callable".format(x)) 74 | 75 | 76 | class _EventDispatcherInstance(Spec): 77 | def __init__(self): 78 | super(_EventDispatcherInstance, self).__init__( 79 | _EventDispatcherInstance._get_event_dispatcher, 80 | "EventDispatcherInstance" 81 | ) 82 | 83 | @staticmethod 84 | def _get_event_dispatcher(x): 85 | if isinstance(x, str): 86 | return x 87 | if isinstance(x, EventDispatcher): 88 | return x 89 | raise ValueError("{!r} is not an event dispatcher".format(x)) 90 | 91 | 92 | class Optional(Spec): 93 | """Represent an optional parameter that can either have a value or it can 94 | be None. 95 | 96 | Arguments 97 | --------- 98 | spec: The spec for the value if it is not None 99 | default: The returned value in case it is None 100 | """ 101 | def __init__(self, spec, default=None): 102 | self._other_spec = spec 103 | self._default = default 104 | 105 | def __repr__(self): 106 | return "Optional[{!r}, {!r}]".format(self._other_spec, self._default) 107 | 108 | def get(self, x): 109 | if x is None: 110 | return self._default 111 | return self._other_spec.get(x) 112 | 113 | def __eq__(self, x): 114 | if isinstance(x, Optional): 115 | return ( 116 | self._other_spec == x._other_spec and 117 | self._default == x._default 118 | ) 119 | return False 120 | 121 | 122 | Int = Spec(int, "Int") 123 | Float = Spec(float, "Float") 124 | Bool = Spec(bool, "Bool") 125 | Callable = _Callable() 126 | EventDispatcherInstance = _EventDispatcherInstance() 127 | -------------------------------------------------------------------------------- /fast_transformers/builders/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """This module implements builders that simplify building complex transformer 8 | architectures with different attention mechanisms. 9 | 10 | The main idea is to facilitate the construction of various attention layers and 11 | transformer encoder layers and simplify their assembly into one transformer 12 | module. It also allows for flexibility in the scripts as many builder 13 | parameters can correspond 1-1 with command line arguments. 14 | 15 | Example usage: 16 | 17 | builder = TransformerEncoderBuilder() 18 | builder.n_layers = 12 19 | builder.n_heads = 8 20 | builder.feed_forward_dimensions = 1024 21 | builder.query_dimensions = 64 22 | builder.value_dimensions = 64 23 | builder.dropout = 0.1 24 | builder.attention_dropout = 0.1 25 | builder.attention_type = "linear" 26 | transformer = builder.get() 27 | """ 28 | 29 | __all__ = [ 30 | "AttentionBuilder", 31 | "RecurrentAttentionBuilder", 32 | "RecurrentCrossAttentionBuilder" 33 | ] 34 | 35 | # Import the attention implementations so that they register themselves with 36 | # the builder. Attention implementations external to the library should be 37 | # imported before using the builders. 38 | # 39 | # TODO: Should this behaviour change? Namely, should all attention 40 | # implementations be imported in order to be useable? This also allows 41 | # using the library even partially built, for instance. 42 | from ..attention import \ 43 | FullAttention, \ 44 | LinearAttention, CausalLinearAttention, \ 45 | ClusteredAttention, ImprovedClusteredAttention, \ 46 | ReformerAttention, \ 47 | ExactTopKAttention, ImprovedClusteredCausalAttention, \ 48 | ConditionalFullAttention 49 | del FullAttention, \ 50 | LinearAttention, CausalLinearAttention, \ 51 | ClusteredAttention, ImprovedClusteredAttention, \ 52 | ReformerAttention, \ 53 | ExactTopKAttention, ImprovedClusteredCausalAttention, \ 54 | ConditionalFullAttention 55 | 56 | 57 | from .attention_builders import \ 58 | AttentionBuilder, \ 59 | RecurrentAttentionBuilder, \ 60 | RecurrentCrossAttentionBuilder 61 | 62 | from .transformer_builders import \ 63 | TransformerEncoderBuilder, \ 64 | RecurrentEncoderBuilder, \ 65 | TransformerDecoderBuilder, \ 66 | RecurrentDecoderBuilder 67 | -------------------------------------------------------------------------------- /fast_transformers/builders/base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Provide a class for the others to inherit some useful functionality.""" 8 | 9 | 10 | class BaseBuilder(object): 11 | @classmethod 12 | def from_kwargs(cls, **kwargs): 13 | """Construct a builder and set all the keyword arguments as parameters. 14 | 15 | The keyword argument strict is passed to 16 | BaseBuilder.from_dictionary separately. 17 | 18 | See BaseBuilder.from_dictionary(). 19 | """ 20 | strict = kwargs.pop("strict", True) 21 | return cls.from_dictionary(kwargs, strict=strict) 22 | 23 | @classmethod 24 | def from_namespace(cls, args, strict=False): 25 | """Construct a builder from an argparse Namespace. 26 | 27 | To be used for building transformers from command line arguments. 28 | 29 | See BaseBuilder.from_dictionary(). 30 | """ 31 | return cls.from_dictionary(vars(args), strict=strict) 32 | 33 | @classmethod 34 | def from_dictionary(cls, dictionary, strict=True): 35 | """Construct a builder and set all the parameters in the dictionary. 36 | 37 | Given a dictionary 38 | 39 | d = {"foo": "bar"} 40 | 41 | then 42 | 43 | builder = TransformerEncoderBuilder.from_dictionary(d) 44 | 45 | is equivalent to 46 | 47 | builder = TransformerEncoderBuilder() 48 | builder.foo = "bar" 49 | 50 | Arguments 51 | --------- 52 | dictionary: A dictionary of parameters to set to the builder. 53 | strict: bool, If a key is not a parameter and strict is set to True 54 | then a ValueError is raised, otherwise that dictionary key 55 | is ignored (default: True) 56 | """ 57 | builder = cls() 58 | for k, v in dictionary.items(): 59 | try: 60 | setattr(builder, k, v) 61 | except AttributeError: 62 | if strict: 63 | raise ValueError(("The builder has no " 64 | "parameter {!r}").format(k)) 65 | else: 66 | continue 67 | return builder 68 | -------------------------------------------------------------------------------- /fast_transformers/causal_product/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import torch 8 | 9 | from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \ 10 | causal_dot_backward as causal_dot_backward_cpu 11 | 12 | try: 13 | from .causal_product_cuda import \ 14 | causal_dot_product as causal_dot_product_cuda, \ 15 | causal_dot_backward as causal_dot_backward_cuda 16 | except ImportError: 17 | causal_dot_product_cuda = causal_dot_backward_cuda = None 18 | 19 | 20 | class CausalDotProduct(torch.autograd.Function): 21 | """Compute the weighted sum of values but attending only to previous 22 | values.""" 23 | dot = { 24 | "cpu": causal_dot_product_cpu, 25 | "cuda": causal_dot_product_cuda 26 | } 27 | dot_backward = { 28 | "cpu": causal_dot_backward_cpu, 29 | "cuda": causal_dot_backward_cuda 30 | } 31 | 32 | @staticmethod 33 | def forward(ctx, Q, K, V): 34 | # Save the inputs for the gradient computation 35 | ctx.save_for_backward(Q, K, V) 36 | 37 | # Create the output tensor 38 | device = Q.device 39 | N, H, L, _ = Q.shape 40 | _, _, _, M = V.shape 41 | product = torch.zeros((N, H, L, M), device=device) 42 | 43 | # Actually perform the dot product 44 | CausalDotProduct.dot[device.type]( 45 | Q.data, 46 | K.data, 47 | V.data, 48 | product 49 | ) 50 | 51 | return product 52 | 53 | @staticmethod 54 | def backward(ctx, grad_out): 55 | # Extract the saved tensors 56 | Q, K, V = ctx.saved_tensors 57 | 58 | # Allocate memory for the gradients 59 | grad_Q = torch.zeros_like(Q) 60 | grad_K = torch.zeros_like(K) 61 | grad_V = torch.zeros_like(V) 62 | 63 | # Actually compute the gradients 64 | CausalDotProduct.dot_backward[Q.device.type]( 65 | Q.data, 66 | K.data, 67 | V.data, 68 | grad_out, 69 | grad_Q, 70 | grad_K, 71 | grad_V 72 | ) 73 | 74 | return grad_Q, grad_K, grad_V 75 | 76 | 77 | # Alias the autograd functions to python style snake case naming 78 | causal_dot_product = CausalDotProduct.apply 79 | -------------------------------------------------------------------------------- /fast_transformers/clustering/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/fast-transformers/2ad36b97e64cb93862937bd21fcc9568d989561f/fast_transformers/clustering/__init__.py -------------------------------------------------------------------------------- /fast_transformers/clustering/hamming/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import numpy as np 9 | 10 | import torch 11 | 12 | from .cluster_cpu import cluster as cluster_cpu 13 | try: 14 | from .cluster_cuda import cluster as cluster_gpu 15 | except ImportError: 16 | pass 17 | 18 | 19 | def cluster( 20 | hashes, 21 | lengths, 22 | groups=None, 23 | counts=None, 24 | centroids=None, 25 | distances=None, 26 | bitcounts=None, 27 | clusters=30, 28 | iterations=10, 29 | bits=32 30 | ): 31 | """Cluster hashes using a few iterations of K-Means with hamming distance. 32 | 33 | All the tensors default initialized to None are optional buffers to avoid 34 | memory allocations. distances and bitcounts are only used by the CUDA 35 | version of this call. clusters will be ignored if centroids is provided. 36 | 37 | Arguments 38 | --------- 39 | hashes: A long tensor of shape (N, H, L) containing a hashcode for each 40 | query. 41 | lengths: An int tensor of shape (N,) containing the sequence length for 42 | each sequence in hashes. 43 | groups: An int tensor buffer of shape (N, H, L) contaning the cluster 44 | in which the corresponding hash belongs to. 45 | counts: An int tensor buffer of shape (N, H, K) containing the number 46 | of elements in each cluster. 47 | centroids: A long tensor buffer of shape (N, H, K) containing the 48 | centroid for each cluster. 49 | distances: An int tensor of shape (N, H, L) containing the distance to 50 | the closest centroid for each hash. 51 | bitcounts: An int tensor of shape (N, H, K, bits) containing the number 52 | of elements that have 1 for a given bit. 53 | clusters: The number of clusters to use for each sequence. It is 54 | ignored if centroids is not None. 55 | iterations: How many k-means iterations to perform. 56 | bits: How many of the least-significant bits in hashes to consider. 57 | 58 | Returns 59 | ------- 60 | groups and counts as defined above. 61 | """ 62 | device = hashes.device 63 | N, H, L = hashes.shape 64 | 65 | # Unfortunately cpu and gpu have different APIs so the entire call must be 66 | # surrounded by an if-then-else 67 | if device.type == "cpu": 68 | if groups is None: 69 | groups = torch.empty((N, H, L), dtype=torch.int32) 70 | if centroids is None: 71 | centroids = torch.empty((N, H, clusters), dtype=torch.int64) 72 | centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)] 73 | K = centroids.shape[2] 74 | if counts is None: 75 | counts = torch.empty((N, H, K), dtype=torch.int32) 76 | 77 | cluster_cpu( 78 | hashes, lengths, 79 | centroids, groups, counts, 80 | iterations, bits 81 | ) 82 | 83 | return groups, counts 84 | 85 | else: 86 | if groups is None: 87 | groups = torch.empty((N, H, L), dtype=torch.int32, device=device) 88 | if centroids is None: 89 | centroids = torch.empty((N, H, clusters), dtype=torch.int64, 90 | device=device) 91 | centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)] 92 | K = centroids.numel() // N // H 93 | #K = clusters 94 | if counts is None: 95 | counts = torch.empty((N, H, K), dtype=torch.int32, device=device) 96 | if distances is None: 97 | distances = torch.empty((N, H, L), dtype=torch.int32, 98 | device=device) 99 | if bitcounts is None: 100 | bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32, 101 | device=device) 102 | groups = groups.view(N, H, L) 103 | counts = counts.view(N, H, K) 104 | centroids = centroids.view(N, H, K) 105 | distances = distances.view(N, H, L) 106 | bitcounts = bitcounts.view(N, H, K, -1) 107 | 108 | cluster_gpu( 109 | hashes, lengths, 110 | centroids, distances, bitcounts, groups, counts, 111 | iterations, bits 112 | ) 113 | 114 | return groups, counts 115 | 116 | -------------------------------------------------------------------------------- /fast_transformers/events/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """This module implements a basic event system that allows the transformer 7 | internal components to make available any tensor with minimal overhead.""" 8 | 9 | from .event import Event, AttentionEvent, QKVEvent, IntermediateOutput 10 | from .event_dispatcher import EventDispatcher 11 | -------------------------------------------------------------------------------- /fast_transformers/events/event.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | 7 | class Event(object): 8 | """The Event is the base class for all events that are dispatched from any 9 | transformer module. 10 | 11 | This class defines only the basic attributes of an event without any 12 | payload. 13 | 14 | Arguments 15 | --------- 16 | source: torch.nn.Module instance that dispatched this event 17 | """ 18 | def __init__(self, source): 19 | self.source = source 20 | 21 | 22 | class AttentionEvent(Event): 23 | """An event containing an attention matrix. 24 | 25 | Arguments 26 | --------- 27 | source: torch.nn.Module instance that dispatched this event 28 | attention_matrix: torch.tensor of the multihead attention matrix 29 | computed in the corresponding attention layer 30 | """ 31 | def __init__(self, source, attention_matrix): 32 | super(AttentionEvent, self).__init__(source) 33 | self.attention_matrix = attention_matrix 34 | 35 | 36 | class QKVEvent(Event): 37 | """An event containing the queries, keys and values projected in their 38 | multiple heads. 39 | 40 | Arguments 41 | --------- 42 | source: torch.nn.Module instance that dispatched this event 43 | queries: torch.tensor containing the queries in shape NLHE 44 | keys: torch.tensor containing the keys in shape NSHE 45 | values: torch.tensor containing the values in shape NSHD 46 | """ 47 | def __init__(self, source, queries, keys, values): 48 | super(QKVEvent, self).__init__(source) 49 | self.queries = queries 50 | self.keys = keys 51 | self.values = values 52 | 53 | 54 | class IntermediateOutput(Event): 55 | """Used by the TransformerEncoder and the TransformerDecoder to provide the 56 | intermediate outputs to interested callers. 57 | 58 | Arguments 59 | --------- 60 | source: torch.nn.Module instance that dispatched this event 61 | x: torch.tensor containing the intermediate features in shape NLD 62 | """ 63 | def __init__(self, source, x): 64 | super().__init__(source) 65 | self.x = x 66 | -------------------------------------------------------------------------------- /fast_transformers/events/event_dispatcher.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from collections import OrderedDict 7 | 8 | from .event import Event 9 | from .filters import event_class 10 | 11 | 12 | class EventDispatcher(object): 13 | """An EventDispatcher is a simple way to implement an observer pattern for 14 | loose coupling of components. In our case it is used so that the internals 15 | of large neural networks can communicate with the outside world in an 16 | agnostic and efficient way. 17 | 18 | Example usage 19 | ------------- 20 | 21 | from fast_transformers.events import EventDispatcher, AttentionEvent 22 | from fast_transformers.events.filters import \ 23 | layer_name_contains 24 | 25 | def attention_event_handler(event): 26 | print(event.attention_matrix) 27 | 28 | ed = EventDispatcher() 29 | ed.listen(AttentionEvent, attention_event_handler) 30 | ed.listen( 31 | AttentionEvent & layer_name_contains("layers.12"), 32 | attention_event_handler 33 | ) 34 | """ 35 | _dispatchers = {} 36 | 37 | def __init__(self): 38 | self._listeners = OrderedDict() 39 | 40 | def listen(self, event_filter, event_handler): 41 | """Add an event handler for the events that pass the event filter. 42 | 43 | Arguments 44 | --------- 45 | event_filter: callable or Event class to define for which events 46 | this handler will be called 47 | event_handler: callable that accepts an instance of Event 48 | """ 49 | if isinstance(event_filter, type) and issubclass(event_filter, Event): 50 | event_filter = event_class(event_filter) 51 | 52 | self._listeners[event_handler] = event_filter 53 | 54 | def remove(self, event_handler): 55 | """Remove the event_handler from the listeners so that no more events 56 | are dispatched to this handler.""" 57 | self._listeners.pop(event_handler, None) 58 | 59 | def clear(self): 60 | """Remove all listeners from the event dispatcher.""" 61 | self._listeners.clear() 62 | 63 | def dispatch(self, event): 64 | """Dispatch an event to the listeners. 65 | 66 | Arguments 67 | --------- 68 | event: Event instance 69 | """ 70 | for event_handler, event_filter in self._listeners.items(): 71 | if event_filter(event): 72 | event_handler(event) 73 | 74 | @classmethod 75 | def get(cls, key=""): 76 | """Factory method for creating global event dispatchers for loosely 77 | coupling parts of a larger codebase. 78 | 79 | Since global objects are a complete antipattern, we suggest that this 80 | is only used to set a default value for an event dispatcher passed as 81 | an argument. 82 | 83 | Argument 84 | -------- 85 | key: A key to uniquely identify a dispatcher or an instance of a 86 | dispatcher to be returned as is 87 | """ 88 | if isinstance(key, cls): 89 | return key 90 | if key not in cls._dispatchers: 91 | cls._dispatchers[key] = cls() 92 | return cls._dispatchers[key] 93 | -------------------------------------------------------------------------------- /fast_transformers/feature_maps/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Implementations of feature maps to be used with linear attention and causal 7 | linear attention.""" 8 | 9 | 10 | from .base import elu_feature_map, ActivationFunctionFeatureMap 11 | from .fourier_features import RandomFourierFeatures, Favor, \ 12 | SmoothedRandomFourierFeatures, GeneralizedRandomFeatures 13 | -------------------------------------------------------------------------------- /fast_transformers/feature_maps/base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Create the feature map interface and some commonly used feature maps. 7 | 8 | All attention implementations that expect a feature map shall receive a factory 9 | function that returns a feature map instance when called with the query 10 | dimensions. 11 | """ 12 | 13 | from functools import partial 14 | 15 | import torch 16 | from torch.nn import Module 17 | 18 | 19 | class FeatureMap(Module): 20 | """Define the FeatureMap interface.""" 21 | def __init__(self, query_dims): 22 | super().__init__() 23 | self.query_dims = query_dims 24 | 25 | def new_feature_map(self, device): 26 | """Create a new instance of this feature map. In particular, if it is a 27 | random feature map sample new parameters.""" 28 | raise NotImplementedError() 29 | 30 | def forward_queries(self, x): 31 | """Encode the queries `x` using this feature map.""" 32 | return self(x) 33 | 34 | def forward_keys(self, x): 35 | """Encode the keys `x` using this feature map.""" 36 | return self(x) 37 | 38 | def forward(self, x): 39 | """Encode x using this feature map. For symmetric feature maps it 40 | suffices to define this function, but for asymmetric feature maps one 41 | needs to define the `forward_queries` and `forward_keys` functions.""" 42 | raise NotImplementedError() 43 | 44 | @classmethod 45 | def factory(cls, *args, **kwargs): 46 | """Return a function that when called with the query dimensions returns 47 | an instance of this feature map. 48 | 49 | It is inherited by the subclasses so it is available in all feature 50 | maps. 51 | """ 52 | def inner(query_dims): 53 | return cls(query_dims, *args, **kwargs) 54 | return inner 55 | 56 | 57 | class ActivationFunctionFeatureMap(FeatureMap): 58 | """Define a feature map that is simply an element-wise activation 59 | function.""" 60 | def __init__(self, query_dims, activation_function): 61 | super().__init__(query_dims) 62 | self.activation_function = activation_function 63 | 64 | def new_feature_map(self, device): 65 | return 66 | 67 | def forward(self, x): 68 | return self.activation_function(x) 69 | 70 | 71 | elu_feature_map = ActivationFunctionFeatureMap.factory( 72 | lambda x: torch.nn.functional.elu(x) + 1 73 | ) 74 | -------------------------------------------------------------------------------- /fast_transformers/hashing/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import torch 9 | 10 | from .hash_cpu import compute_hashes as compute_hashes_cpu 11 | try: 12 | from .hash_cuda import compute_hashes as compute_hashes_cuda 13 | except ImportError: 14 | pass 15 | 16 | 17 | def compute_hashes(X, A, H=None): 18 | device = X.device 19 | if H is None: 20 | H = torch.zeros(len(X), dtype=torch.int64, device=device) 21 | else: 22 | H.zero_() 23 | if A.shape[1] != X.shape[1] + 1: 24 | raise ValueError("The hash requires a bias") 25 | 26 | if device.type == "cpu": 27 | compute_hashes_cpu(X, A, H) 28 | else: 29 | compute_hashes_cuda(X, A, H) 30 | 31 | return H 32 | -------------------------------------------------------------------------------- /fast_transformers/hashing/hash_cpu.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | // Written by Angelos Katharopoulos , 4 | // Apoorv Vyas 5 | // 6 | 7 | #include 8 | 9 | #include 10 | 11 | 12 | /** 13 | * Hash the vectors in X with the hyperplanes A and store the result in H. 14 | * The positive side of the plane gets a 1 the negative a 0. 15 | */ 16 | void compute_hashes(torch::Tensor X, torch::Tensor A, torch::Tensor H) { 17 | float *x = X.data_ptr(); 18 | float *a = A.data_ptr(); 19 | int64_t *h = H.data_ptr(); 20 | int N = X.size(0); 21 | int B = A.size(0); 22 | int D = X.size(1); 23 | assert(((void)"Bias expected for the parameters", D+1 == A.size(1))); 24 | #pragma omp parallel for 25 | for (int n=0; n (*aij))) << i; 37 | } 38 | h[n] = hash; 39 | } 40 | } 41 | 42 | 43 | /** 44 | * Hash the vectors given the projections on the B planes. 45 | * The positive side of the plane gets a 1 the negative a 0. 46 | */ 47 | void compute_hashes_from_projections(torch::Tensor P, torch::Tensor H) { 48 | float *p = P.data_ptr(); 49 | int64_t *h = H.data_ptr(); 50 | int N = P.size(0); 51 | int B = P.size(1); 52 | #pragma omp parallel for 53 | for (int n=0; n 0)) << i; 60 | } 61 | h[n] = hash; 62 | } 63 | } 64 | 65 | 66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 67 | m.def("compute_hashes", 68 | &compute_hashes, 69 | "Hash the vectors X using SIMPLE-LSH."); 70 | m.def("compute_hashes_from_projections", 71 | &compute_hashes_from_projections, 72 | "Hash the vectors X given the computed projections."); 73 | } 74 | -------------------------------------------------------------------------------- /fast_transformers/local_product/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import torch 7 | 8 | from .local_product_cpu import local_dot_product as local_dot_product_cpu, \ 9 | local_dot_backward as local_dot_backward_cpu, \ 10 | local_weighted_average as local_weighted_average_cpu, \ 11 | local_weighted_average_backward as local_weighted_average_backward_cpu 12 | 13 | try: 14 | from .local_product_cuda import \ 15 | local_dot_product as local_dot_product_cuda, \ 16 | local_dot_backward as local_dot_backward_cuda, \ 17 | local_weighted_average as local_weighted_average_cuda, \ 18 | local_weighted_average_backward as local_weighted_average_backward_cuda 19 | except ImportError: 20 | local_dot_product_cuda = None 21 | local_dot_backward_cuda = None 22 | local_weighted_average_cuda = None 23 | local_weighted_average_backward_cuda = None 24 | 25 | 26 | class LocalDotProduct(torch.autograd.Function): 27 | """Compute the dot product of the queries and keys but only consider a 28 | local neighborhood of each query.""" 29 | dot = { 30 | "cpu": local_dot_product_cpu, 31 | "cuda": local_dot_product_cuda 32 | } 33 | dot_backward = { 34 | "cpu": local_dot_backward_cpu, 35 | "cuda": local_dot_backward_cuda 36 | } 37 | 38 | @staticmethod 39 | def forward(ctx, queries, keys, attn_mask, key_lengths, local_context): 40 | # Save the inputs for the gradient computation 41 | ctx.save_for_backward(queries, keys, key_lengths) 42 | ctx.local_context = local_context 43 | 44 | return LocalDotProduct.dot[queries.device.type]( 45 | queries, 46 | keys, 47 | attn_mask, 48 | key_lengths, 49 | local_context 50 | ) 51 | 52 | @staticmethod 53 | def backward(ctx, grad_input): 54 | queries, keys, key_lengths = ctx.saved_tensors 55 | local_context = ctx.local_context 56 | 57 | grads = LocalDotProduct.dot_backward[queries.device.type]( 58 | queries, 59 | keys, 60 | key_lengths, 61 | grad_input, 62 | local_context 63 | ) 64 | 65 | # plus 3 None for masks and local_context 66 | return grads + (None, None, None) 67 | 68 | 69 | class LocalWeightedAverage(torch.autograd.Function): 70 | """Compute the weighted average of the values with the local attention.""" 71 | avg = { 72 | "cpu": local_weighted_average_cpu, 73 | "cuda": local_weighted_average_cuda 74 | } 75 | avg_backward = { 76 | "cpu": local_weighted_average_backward_cpu, 77 | "cuda": local_weighted_average_backward_cuda 78 | } 79 | 80 | @staticmethod 81 | def forward(ctx, A, V): 82 | # Save the inputs for the gradient computation 83 | ctx.save_for_backward(A, V) 84 | 85 | return LocalWeightedAverage.avg[A.device.type](A, V) 86 | 87 | @staticmethod 88 | def backward(ctx, grad_input): 89 | A, V = ctx.saved_tensors 90 | return LocalWeightedAverage.avg_backward[A.device.type]( 91 | A, V, grad_input 92 | ) 93 | 94 | 95 | # Alias the autograd functions to python style snake case naming 96 | local_dot_product = LocalDotProduct.apply 97 | local_weighted_average = LocalWeightedAverage.apply 98 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implementations of transformers as recurrent functions.""" 8 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import warnings 7 | 8 | 9 | def check_state(state=None, memory=None): 10 | if memory is not None: 11 | warnings.warn(("'memory' is deprecated for recurrent transformers " 12 | " and will be removed in the future, use 'state' " 13 | "instead"), DeprecationWarning) 14 | if state is None: 15 | state = memory 16 | return state 17 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implementations of different types of autoregressive attention 8 | mechanisms for self attention and cross attention.""" 9 | 10 | from .self_attention.attention_layer import RecurrentAttentionLayer 11 | from .self_attention.full_attention import RecurrentFullAttention 12 | from .self_attention.linear_attention import RecurrentLinearAttention 13 | 14 | from .cross_attention.attention_layer import RecurrentCrossAttentionLayer 15 | from .cross_attention.full_attention import RecurrentCrossFullAttention 16 | from .cross_attention.linear_attention import RecurrentCrossLinearAttention 17 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/cross_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Autoregressive implementations for cross attention as a recurrent module. 7 | 8 | The attention implementations in this module expect one input for query and a 9 | sequence of inputs for keys and values. The sequence for the keys and values is 10 | fixed for all queries. 11 | 12 | Example 13 | -------- 14 | 15 | import torch 16 | 17 | from fast_transformers.recurrent.attention import \ 18 | RecurrentCrossAttentionLayer, RecurrentCrossFullAttention 19 | 20 | att = RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), 16, 4) 21 | state = None 22 | x = torch.rand(8, 16) 23 | memory = torch.rand(8, 64, 16) 24 | for i in range(10): 25 | x, state = att(x, memory, memory, state=state) 26 | """ 27 | 28 | from .attention_layer import RecurrentCrossAttentionLayer 29 | from .full_attention import RecurrentCrossFullAttention 30 | from .linear_attention import RecurrentCrossLinearAttention 31 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/cross_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Similar to the corresponding module in fast_transformers.attention, this 7 | module performs all the query, key, value projections and output projections 8 | leaving the implementation of the attention to the inner attention module. 9 | 10 | The crucial difference with respect to the self attention recurrent module 11 | (fast_transformers.recurrent.attention.RecurrentAttentionLayer) is that it 12 | doesn't recompute the projections for the keys and values if the state is not 13 | None. 14 | """ 15 | 16 | from torch.nn import Linear, Module 17 | 18 | from ....events import EventDispatcher 19 | 20 | 21 | class RecurrentCrossAttentionLayer(Module): 22 | """See fast_transformers.attention.attention_layer.AttentionLayer . 23 | 24 | The differences with the aforementioned module as well as the 25 | RecurrentAttentionLayer are that this module projects the query every time 26 | and the keys and values only the first time they are provided. 27 | 28 | Arguments 29 | --------- 30 | attention: Specific inner attention implementation that just computes a 31 | weighted average of values given a similarity of queries and 32 | keys. 33 | d_model: The input feature dimensionality 34 | n_heads: The number of heads for the multi head attention 35 | d_keys: The dimensionality of the keys/queries 36 | (default: d_model/n_heads) 37 | d_values: The dimensionality of the values (default: d_model/n_heads) 38 | event_dispatcher: str or EventDispatcher instance to be used by this 39 | module for dispatching events (default: the default 40 | global dispatcher) 41 | """ 42 | def __init__(self, attention, d_model, n_heads, d_keys=None, 43 | d_values=None, d_model_keys=None, event_dispatcher=""): 44 | super(RecurrentCrossAttentionLayer, self).__init__() 45 | 46 | # Fill d_keys and d_values 47 | d_keys = d_keys or (d_model//n_heads) 48 | d_values = d_values or (d_model//n_heads) 49 | d_model_keys = d_model_keys or d_model 50 | 51 | self.inner_attention = attention 52 | self.query_projection = Linear(d_model, d_keys * n_heads) 53 | self.key_projection = Linear(d_model_keys, d_keys * n_heads) 54 | self.value_projection = Linear(d_model_keys, d_values * n_heads) 55 | self.out_projection = Linear(d_values * n_heads, d_model) 56 | self.n_heads = n_heads 57 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 58 | 59 | def forward(self, query, keys, values, key_lengths, state=None): 60 | """Attend to the keys and values based on the passed in query. 61 | 62 | In the argument description we make use of the following sizes 63 | 64 | - N: the batch size 65 | - S: the sequence length of the keys and values 66 | - D: The input feature dimensionality passed in the constructor as 67 | 'd_model' 68 | 69 | Argument 70 | -------- 71 | query: (N, D) The tensor containing the queries 72 | keys: (N, S, D) The tensor containing the keys 73 | values: (N, S, D) The tensor containing the values 74 | key_lengths: A fast_transformers.masking.BaseMask implementation 75 | that defines the length of each key/value sequence 76 | state: The state varies depending on the inner attention 77 | implementation, but if it is not None then the keys and 78 | values are ignored 79 | """ 80 | #Extract some shapes 81 | N, _ = query.shape 82 | H = self.n_heads 83 | 84 | # Project the query 85 | query = self.query_projection(query).view(N, H, -1) 86 | 87 | # Project the keys and values if there is no state 88 | if state is None: 89 | _, S, _ = keys.shape 90 | keys = self.key_projection(keys).view(N, S, H, -1) 91 | values = self.value_projection(values).view(N, S, H, -1) 92 | else: 93 | keys = None 94 | values = None 95 | 96 | new_value, state = self.inner_attention( 97 | query, 98 | keys, 99 | values, 100 | key_lengths, 101 | state=state 102 | ) 103 | new_value = new_value.view(N, -1) 104 | 105 | # Project the output and return 106 | return self.out_projection(new_value), state 107 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/cross_attention/full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Implement the typical softmax attention as a recurrent cross attention 7 | module to speed up autoregressive decoding.""" 8 | 9 | from math import sqrt 10 | 11 | import torch 12 | from torch.nn import Dropout, Module 13 | 14 | from ....attention_registry import RecurrentCrossAttentionRegistry, Optional, \ 15 | Float, EventDispatcherInstance 16 | from ....events import EventDispatcher, AttentionEvent 17 | 18 | 19 | class RecurrentCrossFullAttention(Module): 20 | """Implement autoregressive softmax cross attention as a recurrent 21 | module. 22 | 23 | Arguments 24 | --------- 25 | softmax_temp: The temperature to use for the softmax attention. 26 | (default: 1/sqrt(d_keys) where d_keys is computed at 27 | runtime) 28 | attention_dropout: The dropout rate to apply to the attention 29 | (default: 0.1) 30 | event_dispatcher: str or EventDispatcher instance to be used by this 31 | module for dispatching events (default: the default 32 | global dispatcher) 33 | """ 34 | 35 | def __init__(self, softmax_temp=None, attention_dropout=0.1, 36 | event_dispatcher=""): 37 | super(RecurrentCrossFullAttention, self).__init__() 38 | self.softmax_temp = softmax_temp 39 | self.dropout = Dropout(attention_dropout) 40 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 41 | 42 | def forward(self, query, keys, values, key_lengths, state=None): 43 | # Extract some shapes and compute the temperature 44 | N, H, E = query.shape 45 | softmax_temp = self.softmax_temp or 1. / sqrt(E) 46 | 47 | # Extract the keys and values either from the arguments or the state 48 | if state is not None: 49 | keys, values = state 50 | 51 | # Compute the unnormalized attention and apply the key length mask 52 | QK = torch.einsum("nhe,nshe->nsh", query, keys) 53 | QK = QK + key_lengths.additive_matrix[:, :, None] 54 | 55 | # Compute the attention and the weighted average 56 | A = self.dropout(torch.softmax(softmax_temp * QK, dim=1)) 57 | V = torch.einsum("nsh,nshd->nhd", A, values) 58 | 59 | # Let the world know of the attention matrix 60 | self.event_dispatcher.dispatch(AttentionEvent(self, A)) 61 | 62 | # Make sure that we return a contiguous value 63 | return V.contiguous(), [keys, values] 64 | 65 | 66 | # Register the attention implementation so that it becomes available in our 67 | # builders 68 | RecurrentCrossAttentionRegistry.register( 69 | "full", RecurrentCrossFullAttention, 70 | [ 71 | ("softmax_temp", Optional(Float)), 72 | ("attention_dropout", Optional(Float, 0.1)), 73 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 74 | ] 75 | ) 76 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/cross_attention/linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Implement unmasked linear attention as a recurrent cross attention module to 7 | speed up autoregressive decoding.""" 8 | 9 | import torch 10 | from torch.nn import Module 11 | 12 | from ....attention_registry import RecurrentCrossAttentionRegistry, Optional, Int, \ 13 | Callable, EventDispatcherInstance 14 | from ....events import EventDispatcher 15 | from ....feature_maps import elu_feature_map 16 | 17 | 18 | class RecurrentCrossLinearAttention(Module): 19 | """Implement autoregressive linear cross attention as a recurrent 20 | module. 21 | 22 | See fast_transformers.attention.linear_attention.LinearAttention . 23 | 24 | Arguments 25 | --------- 26 | feature_map: callable, a callable that applies the feature map to the 27 | last dimension of a tensor (default: elu(x)+1) 28 | eps: float, a small number to ensure the numerical stability of the 29 | denominator (default: 1e-6) 30 | event_dispatcher: str or EventDispatcher instance to be used by this 31 | module for dispatching events (default: the default 32 | global dispatcher) 33 | """ 34 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6, 35 | event_dispatcher=""): 36 | super(RecurrentCrossLinearAttention, self).__init__() 37 | self.feature_map = ( 38 | feature_map(query_dimensions) if feature_map else 39 | elu_feature_map(query_dimensions) 40 | ) 41 | self.eps = eps 42 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 43 | 44 | def forward(self, query, keys, values, key_lengths, state=None): 45 | # If this is a new sequence re initialize the feature map 46 | if state is None: 47 | self.feature_map.new_feature_map(query.device) 48 | 49 | # Compute the feature representation of the query 50 | Q = self.feature_map.forward_queries(query) 51 | 52 | # If the state is not given compute the key-value matrix and the 53 | # normalizers, namely compute whatever is needed in order to attend to 54 | # keys and values with a given query. 55 | if state is None: 56 | K = self.feature_map.forward_keys(keys) 57 | K = K * key_lengths.float_matrix[:, :, None, None] 58 | S = torch.einsum("nshd,nshm->nhmd", K, values) 59 | Z = K.sum(dim=1) 60 | else: 61 | S, Z = state 62 | 63 | # Given S and Z now we can efficiently compute the new value 64 | QZ = 1/(torch.einsum("nhd,nhd->nh", Q, Z)+self.eps) 65 | V = torch.einsum("nhd,nhmd,nh->nhm", Q, S, QZ) 66 | 67 | return V.contiguous(), [S, Z] 68 | 69 | 70 | # Register the attention implementation so that it becomes available in our 71 | # builders 72 | RecurrentCrossAttentionRegistry.register( 73 | "linear", RecurrentCrossLinearAttention, 74 | [ 75 | ("query_dimensions", Int), 76 | ("feature_map", Optional(Callable)), 77 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 78 | ] 79 | ) 80 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/self_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Autoregressive implementations for self attention as a recurrent module. 7 | 8 | The attention implementations in this module expect one input for query, one 9 | for key and one for value and attend to all the keys and values seen so far. No 10 | masking is necessary as an implicit lower triangular attention mask is assumed 11 | in all cases. 12 | 13 | Example 14 | ------- 15 | 16 | import torch 17 | 18 | from fast_transformers.recurrent.attention import \ 19 | RecurrentAttentionLayer, RecurrentFullAttention 20 | 21 | att = RecurrentAttentionLayer(RecurrentFullAttention(), 16, 4) 22 | state = None 23 | x = torch.rand(8, 16) 24 | for i in range(10): 25 | x, state = att(x, x, x, state=state) 26 | """ 27 | 28 | from .attention_layer import RecurrentAttentionLayer 29 | from .full_attention import RecurrentFullAttention 30 | from .linear_attention import RecurrentLinearAttention 31 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/self_attention/attention_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Similar to the corresponding module in fast_transformers.attention, this 8 | module performs all the query, key, value projections and output projections 9 | leaving the implementation of the attention to the inner attention module.""" 10 | 11 | from torch.nn import Linear, Module 12 | 13 | from ....events import EventDispatcher 14 | from ..._utils import check_state 15 | 16 | 17 | class RecurrentAttentionLayer(Module): 18 | """See fast_transformers.attention.attention_layer.AttentionLayer. 19 | 20 | The only difference with the corresponding module is that this projects 21 | only one input and then calls the inner attention with the provided 22 | previous state. 23 | 24 | Arguments 25 | --------- 26 | attention: Specific inner attention implementation that just computes a 27 | weighted average of values given a similarity of queries and 28 | keys. 29 | d_model: The input feature dimensionality 30 | n_heads: The number of heads for the multi head attention 31 | d_keys: The dimensionality of the keys/queries 32 | (default: d_model/n_heads) 33 | d_values: The dimensionality of the values (default: d_model/n_heads) 34 | event_dispatcher: str or EventDispatcher instance to be used by this 35 | module for dispatching events (default: the default 36 | global dispatcher) 37 | """ 38 | def __init__(self, attention, d_model, n_heads, d_keys=None, 39 | d_values=None, d_model_keys=None, event_dispatcher=""): 40 | super(RecurrentAttentionLayer, self).__init__() 41 | 42 | # Fill d_keys and d_values 43 | d_keys = d_keys or (d_model//n_heads) 44 | d_values = d_values or (d_model//n_heads) 45 | d_model_keys = d_model_keys or d_model 46 | 47 | self.inner_attention = attention 48 | self.query_projection = Linear(d_model, d_keys * n_heads) 49 | self.key_projection = Linear(d_model_keys, d_keys * n_heads) 50 | self.value_projection = Linear(d_model_keys, d_values * n_heads) 51 | self.out_projection = Linear(d_values * n_heads, d_model) 52 | self.n_heads = n_heads 53 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 54 | 55 | def forward(self, query, key, value, state=None, memory=None): 56 | """Apply attention to the passed in query/key/value after projecting 57 | them to multiple heads. 58 | 59 | In the argument description we make use of the following sizes 60 | 61 | - N: the batch size 62 | - D: The input feature dimensionality passed in the constructor as 63 | 'd_model' 64 | 65 | Arguments 66 | --------- 67 | query: (N, D) The tensor containing the queries 68 | key: (N, D) The tensor containing the keys 69 | value: (N, D) The tensor containing the values 70 | state: The state varies depending on the inner attention implementation 71 | memory: **Deprecated** and replaced by state 72 | 73 | Returns 74 | ------- 75 | The new value for each query as a tensor of shape (N, D). 76 | """ 77 | # Normalize the state/memory 78 | state = check_state(state, memory) 79 | 80 | # Project the queries/keys/values 81 | query = self.query_projection(query) 82 | key = self.key_projection(key) 83 | value = self.value_projection(value) 84 | 85 | # Reshape them into many heads and compute the attention 86 | N, D = query.shape 87 | H = self.n_heads 88 | new_value, state = self.inner_attention( 89 | query.view(N, H, -1), 90 | key.view(N, H, -1), 91 | value.view(N, H, -1), 92 | state 93 | ) 94 | new_value = new_value.view(N, -1) 95 | 96 | # Project the output and return 97 | return self.out_projection(new_value), state 98 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/self_attention/full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement the typical softmax attention as a recurrent module to speed up 8 | autoregressive inference. See fast_transformers.attention.full_attention .""" 9 | 10 | from math import sqrt 11 | 12 | import torch 13 | from torch.nn import Dropout, Module 14 | 15 | from ....attention_registry import RecurrentAttentionRegistry, Optional, \ 16 | Float, EventDispatcherInstance 17 | from ....events import EventDispatcher, AttentionEvent 18 | from ..._utils import check_state 19 | 20 | 21 | class RecurrentFullAttention(Module): 22 | """Implement the full softmax attention as a recurrent module. 23 | 24 | Arguments 25 | --------- 26 | softmax_temp: The temperature to use for the softmax attention. 27 | (default: 1/sqrt(d_keys) where d_keys is computed at 28 | runtime) 29 | attention_dropout: The dropout rate to apply to the attention 30 | (default: 0.1) 31 | event_dispatcher: str or EventDispatcher instance to be used by this 32 | module for dispatching events (default: the default 33 | global dispatcher) 34 | """ 35 | def __init__(self, softmax_temp=None, attention_dropout=0.1, 36 | event_dispatcher=""): 37 | super(RecurrentFullAttention, self).__init__() 38 | self.softmax_temp = softmax_temp 39 | self.dropout = Dropout(attention_dropout) 40 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 41 | 42 | def forward(self, query, key, value, state=None, memory=None): 43 | # Normalize state/memory 44 | state = check_state(state, memory) 45 | 46 | # Extract some shapes and compute the temperature 47 | N, H, E = query.shape 48 | _, _, D = value.shape 49 | softmax_temp = self.softmax_temp or 1./sqrt(E) 50 | 51 | # Aggregate the list of keys and values 52 | if state is not None: 53 | keys, values = state 54 | keys = torch.cat([keys, key[:, :, None]], dim=2) 55 | values = torch.cat([values, value[:, :, None]], dim=2) 56 | else: 57 | keys = key[:, :, None] 58 | values = value[:, :, None] 59 | 60 | # Compute the unnormalized attention 61 | QK = torch.einsum("nhe,nhse->nhs", query, keys) 62 | 63 | # Compute the attention and the weighted average 64 | A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) 65 | V = torch.einsum("nhs,nhsd->nhd", A, values).contiguous() 66 | 67 | # Let the world know of the attention matrix 68 | self.event_dispatcher.dispatch(AttentionEvent(self, A)) 69 | 70 | # Make sure that what we return is contiguous 71 | return V, [keys, values] 72 | 73 | 74 | # Register the attention implementation so that it becomes available in our 75 | # builders 76 | RecurrentAttentionRegistry.register( 77 | "full", RecurrentFullAttention, 78 | [ 79 | ("softmax_temp", Optional(Float)), 80 | ("attention_dropout", Optional(Float, 0.1)), 81 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 82 | ] 83 | ) 84 | -------------------------------------------------------------------------------- /fast_transformers/recurrent/attention/self_attention/linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | """Implement the causally masked linear attention as a recurrent model.""" 8 | 9 | import torch 10 | from torch.nn import Module 11 | 12 | from ....attention_registry import RecurrentAttentionRegistry, Optional, Int, \ 13 | Callable, EventDispatcherInstance 14 | from ....events import EventDispatcher 15 | from ....feature_maps import elu_feature_map 16 | from ..._utils import check_state 17 | 18 | 19 | class RecurrentLinearAttention(Module): 20 | """Implement fast_transformers.attention.causal_linear_attention as a 21 | fixed-dimensional state recurrent model. 22 | 23 | See fast_transformers.attention.linear_attention and 24 | fast_transformers.attention.causal_linear_attention for the general concept 25 | of replacing the softmax with feature maps. 26 | 27 | Arguments 28 | --------- 29 | feature_map: callable, a callable that applies the feature map to the 30 | last dimension of a tensor (default: elu(x)+1) 31 | eps: float, a small number to ensure the numerical stability of the 32 | denominator (default: 1e-6) 33 | event_dispatcher: str or EventDispatcher instance to be used by this 34 | module for dispatching events (default: the default 35 | global dispatcher) 36 | """ 37 | def __init__(self, query_dimensions, feature_map=None, eps=1e-6, 38 | event_dispatcher=""): 39 | super(RecurrentLinearAttention, self).__init__() 40 | self.feature_map = ( 41 | feature_map(query_dimensions) if feature_map else 42 | elu_feature_map(query_dimensions) 43 | ) 44 | self.eps = eps 45 | self.event_dispatcher = EventDispatcher.get(event_dispatcher) 46 | 47 | def forward(self, query, key, value, state=None, memory=None): 48 | # Normalize state/memory 49 | state = check_state(state, memory) 50 | 51 | # If this is a new sequence reinitialize the feature map 52 | if state is None: 53 | self.feature_map.new_feature_map(query.device) 54 | 55 | # Apply the feature map to the query and key 56 | Q = self.feature_map.forward_queries(query) 57 | K = self.feature_map.forward_keys(key) 58 | 59 | # Extract some shapes 60 | N, H, D = Q.shape 61 | _, _, M = value.shape 62 | 63 | # Extract the memory or initialize it 64 | if state is None: 65 | Si = query.new_zeros((N, H, D, M)) 66 | Zi = query.new_zeros((N, H, D)) 67 | else: 68 | Si, Zi = state 69 | 70 | # Ensure the batch size did not change 71 | if len(Si) != N: 72 | raise ValueError("The batch size changed during iteration") 73 | 74 | # Update the internal state 75 | # 76 | # NOTE: The if clause is added due to GitHub PR #10. Simply using the 77 | # following two lines does not perform the operation in place which 78 | # means it is slower for inference. 79 | if K.grad_fn is not None or value.grad_fn is not None: 80 | Zi = Zi + K 81 | Si = Si + torch.einsum("nhd,nhm->nhdm", K, value) 82 | else: 83 | Zi += K 84 | Si += torch.einsum("nhd,nhm->nhdm", K, value) 85 | 86 | # Compute the output 87 | Z = 1. / (torch.einsum("nhd,nhd->nh", Q, Zi) + self.eps) 88 | V = torch.einsum("nhd,nhdm,nh->nhm", Q, Si, Z) 89 | 90 | return V, [Si, Zi] 91 | 92 | 93 | # Register the attention implementation so that it becomes available in our 94 | # builders 95 | RecurrentAttentionRegistry.register( 96 | "linear", RecurrentLinearAttention, 97 | [ 98 | ("query_dimensions", Int), 99 | ("feature_map", Optional(Callable)), 100 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 101 | ] 102 | ) 103 | RecurrentAttentionRegistry.register( 104 | "causal-linear", RecurrentLinearAttention, 105 | [ 106 | ("query_dimensions", Int), 107 | ("feature_map", Optional(Callable)), 108 | ("event_dispatcher", Optional(EventDispatcherInstance, "")) 109 | ] 110 | ) 111 | -------------------------------------------------------------------------------- /fast_transformers/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | """Boilerplate code for dealing with fast_transformers modules.""" 7 | 8 | 9 | def make_mirror(src_module, dst_module): 10 | """Sets the parameters of src_module to dst_module so that they share the 11 | same parameters. 12 | 13 | Most noteable usecase is to make a recurrent transformer mirror of a batch 14 | transformer for fast inference. 15 | 16 | Arguments 17 | --------- 18 | src_module: Module to take the parameters from 19 | dst_module: Module to set the parameters to 20 | 21 | Returns 22 | ------- 23 | None, it changes dst_module in place 24 | """ 25 | def setattr_recursive(mod, key, value): 26 | key, *next_key = key.split(".", maxsplit=1) 27 | if not next_key: 28 | setattr(mod, key, value) 29 | else: 30 | setattr_recursive(getattr(mod, key), next_key[0], value) 31 | 32 | for name, param in src_module.named_parameters(): 33 | setattr_recursive(dst_module, name, param) 34 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Fast Transformers for PyTorch 2 | repo_url: https://github.com/idiap/fast-transformers 3 | theme: readthedocs 4 | extra_javascript: 5 | - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML 6 | extra_css: 7 | - css/extra.css 8 | markdown_extensions: 9 | - mdx_math 10 | - tables 11 | - admonition 12 | nav: 13 | - Home: index.md 14 | - Transformers: transformers.md 15 | - Masking: masking.md 16 | - Attention: attention.md 17 | - Feature Maps: feature_maps.md 18 | - Builders: builders.md 19 | - Custom Attention Layer: custom_attention_layer.md 20 | - Recurrent Transformers: recurrent_transformers.md 21 | - Events: events.md 22 | - Tips and Tricks: tips_and_tricks.md 23 | - API Docs: /api_docs/fast_transformers/ 24 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/aggregate/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/aggregate/test_clustered_aggregate_cpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | import os 10 | import numpy as np 11 | import time 12 | 13 | import torch 14 | 15 | try: 16 | from fast_transformers.aggregate import clustered_aggregate, \ 17 | clustered_broadcast 18 | except ImportError: 19 | pass 20 | 21 | 22 | class TestAggregateCPU(unittest.TestCase): 23 | 24 | def test_aggregate(self): 25 | N = 2 26 | H = 4 27 | L = 80 28 | E = 2 29 | C = 4 30 | 31 | for i in range(30): 32 | C = np.random.randint(5, 10) 33 | L = np.random.randint(1, 30) * C 34 | E = np.random.randint(10, 128) 35 | if os.getenv("VERBOSE_TESTS", ""): 36 | print(("Testing: N H L E C: " 37 | "{} {} {} {} {}").format(N, H, L, E, C)) 38 | 39 | x = torch.rand((N, H, L, E)).cpu() 40 | g = (torch.arange(L) % C).view(1, 1, L).repeat(N, H, 1).int().cpu() 41 | f = torch.ones(N, H, C).cpu() * (C / L) 42 | counts = torch.ones_like(f, dtype=torch.int32) * (L // C) 43 | y = torch.zeros(N, H, C, E).cpu() 44 | lengths = torch.full((N,), L, dtype=torch.int32).to(x.device) 45 | 46 | sorted_g, sorted_gi = torch.sort(g.view(N*H, -1), dim=-1) 47 | sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) 48 | 49 | q_offset = torch.arange(N*H, device=x.device).unsqueeze(-1) * L 50 | q_flat = (sorted_gi + q_offset).reshape(-1) 51 | 52 | # sorted queries, keys, values 53 | s_x = x.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E) 54 | y = clustered_aggregate( 55 | s_x, sorted_g.view(N, H, -1), f, lengths, y 56 | ) 57 | for i in range(C): 58 | self.assertLess( 59 | torch.abs( 60 | x[:, :, i::C, :].mean(2) - y[:, :, i, :] 61 | ).max().item(), 62 | 1e-6 63 | ) 64 | 65 | def test_aggregate_masked(self): 66 | N = 10 67 | H = 3 68 | L = 40 69 | E = 32 70 | C = 4 71 | 72 | for i in range(30): 73 | C = np.random.randint(5, 10) 74 | L = np.random.randint(2, 30) * C 75 | E = np.random.randint(10, 128) 76 | if os.getenv("VERBOSE_TESTS", ""): 77 | print(("Testing: N H L E C: " 78 | "{} {} {} {} {}").format(N, H, L, E, C)) 79 | 80 | x = torch.rand((N, H, L, E)).cpu() 81 | g = (torch.arange(L) % C).view(1, 1, L).repeat(N, H, 1).int().cpu() 82 | g[:, :, -C:] = C + 1 83 | c = (L // C) - 1 84 | 85 | lengths = torch.full((N,), L-C, dtype=torch.int32).to(x.device) 86 | f = torch.ones(N, H, C).cpu() / float(c) 87 | counts = torch.ones_like(f, dtype=torch.int32) * c 88 | y = torch.zeros(N, H, C, E).cpu() 89 | 90 | sorted_g, sorted_gi = torch.sort(g.view(N*H, -1), dim=-1) 91 | sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) 92 | 93 | q_offset = torch.arange(N*H, device=x.device).unsqueeze(-1) * L 94 | q_flat = (sorted_gi + q_offset).reshape(-1) 95 | 96 | # sorted queries, keys, values 97 | s_x = x.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E) 98 | y = clustered_aggregate( 99 | s_x, sorted_g.view(N, H, -1), f, lengths, y 100 | ) 101 | 102 | for i in range(C): 103 | x_m = x[:, :, i::C, :][:, :, :-1, :].mean(2) 104 | self.assertLess( 105 | torch.abs( 106 | x_m - y[:, :, i, :] 107 | ).max().item(), 108 | 1e-6 109 | ) 110 | 111 | 112 | if __name__ == "__main__": 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /tests/attention/test_aft_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import os 7 | import time 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.masking import FullMask, LengthMask 13 | from fast_transformers.attention.aft_attention import AFTFullAttention, \ 14 | AFTSimpleAttention 15 | 16 | 17 | class TestAFTAttention(unittest.TestCase): 18 | def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=32, device="cpu"): 19 | return ( 20 | torch.rand(N, L, H, E).to(device), 21 | torch.rand(N, S, H, E).to(device), 22 | torch.rand(N, S, H, D).to(device), 23 | FullMask(L, S, device=device), 24 | FullMask(N, L, device=device), 25 | FullMask(N, S, device=device) 26 | ) 27 | 28 | def test_forward(self): 29 | att = AFTFullAttention() 30 | q, k, v, m1, m2, m3 = self._get_inputs() 31 | v = att(q, k, v, m1, m2, m3) 32 | self.assertTrue(v.is_contiguous()) 33 | 34 | att = AFTSimpleAttention() 35 | q, k, v, m1, m2, m3 = self._get_inputs() 36 | v = att(q, k, v, m1, m2, m3) 37 | self.assertTrue(v.is_contiguous()) 38 | 39 | def test_masking(self): 40 | q, k, v, m1, m2, m3 = self._get_inputs() 41 | m1 = FullMask(torch.rand(5, 8) > 0.5) 42 | 43 | att = AFTFullAttention() 44 | v = att(q, k, v, m1, m2, m3) 45 | 46 | att = AFTSimpleAttention() 47 | with self.assertRaises(ValueError): 48 | v = att(q, k, v, m1, m2, m3) 49 | 50 | q, k, v, m1, m2, m3 = self._get_inputs(L=8, S=8) 51 | m1 = FullMask(torch.tril(torch.ones(8, 8, dtype=torch.bool))) 52 | v = att(q, k, v, m1, m2, m3) 53 | 54 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 55 | def test_benchmark_cpu(self): 56 | q, k, v, m1, m2, m3 = self._get_inputs(L=256, S=256, E=64, D=64) 57 | att_full = AFTFullAttention() 58 | att_simple = AFTSimpleAttention() 59 | 60 | for name, att in zip(["full", "simple"], [att_full, att_simple]): 61 | # warmup the cache 62 | for i in range(10): 63 | v_new = att(q, k, v, m1, m2, m3) 64 | 65 | # measure 66 | start = time.time() 67 | for i in range(10): 68 | v_new = att(q, k, v, m1, m2, m3) 69 | end = time.time() 70 | print("AFT", name, "CPU Time taken:", (end-start)*1000, "(ms)") 71 | 72 | @unittest.skipUnless(torch.cuda.is_available(), "no CUDA capable device") 73 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 74 | def test_benchmark_gpu(self): 75 | q, k, v, m1, m2, m3 = self._get_inputs(L=256, S=256, E=64, D=64, 76 | device="cuda") 77 | att_full = AFTFullAttention().cuda() 78 | att_simple = AFTSimpleAttention() 79 | 80 | for name, att in zip(["full", "simple"], [att_full, att_simple]): 81 | # warmup the caches 82 | for i in range(10): 83 | v_new = att(q, k, v, m1, m2, m3) 84 | 85 | # measure 86 | start = torch.cuda.Event(enable_timing=True) 87 | end = torch.cuda.Event(enable_timing=True) 88 | start.record() 89 | for i in range(10): 90 | v_new = att(q, k, v, m1, m2, m3) 91 | end.record() 92 | torch.cuda.synchronize() 93 | print("AFT", name, "GPU time taken:", start.elapsed_time(end), "(ms)") 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /tests/attention/test_attention_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.attention.attention_layer import AttentionLayer 13 | 14 | 15 | class TestAttentionLayer(unittest.TestCase): 16 | def _assert_sizes_attention(self, qshape, kshape, vshape): 17 | def inner(q, k, v, m1, m2, m3): 18 | self.assertEqual(q.shape, qshape) 19 | self.assertEqual(k.shape, kshape) 20 | self.assertEqual(v.shape, vshape) 21 | N, L, H, E = q.shape 22 | _, S, _, D = v.shape 23 | return v.new_zeros((N, L, H, D)) 24 | return inner 25 | 26 | def test_forward(self): 27 | att = AttentionLayer( 28 | self._assert_sizes_attention( 29 | (10, 5, 4, 25), 30 | (10, 8, 4, 25), 31 | (10, 8, 4, 25) 32 | ), 33 | 100, 34 | 4 35 | ) 36 | v = att( 37 | torch.rand(10, 5, 100), 38 | torch.rand(10, 8, 100), 39 | torch.rand(10, 8, 100), 40 | None, None, None 41 | ) 42 | self.assertEqual(v.shape, (10, 5, 100)) 43 | 44 | att = AttentionLayer( 45 | self._assert_sizes_attention( 46 | (10, 5, 4, 32), 47 | (10, 8, 4, 32), 48 | (10, 8, 4, 64) 49 | ), 50 | 100, 51 | 4, 52 | d_keys=32, 53 | d_values=64 54 | ) 55 | v = att( 56 | torch.rand(10, 5, 100), 57 | torch.rand(10, 8, 100), 58 | torch.rand(10, 8, 100), 59 | None, None, None 60 | ) 61 | self.assertEqual(v.shape, (10, 5, 100)) 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /tests/attention/test_causal_linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | 7 | import unittest 8 | 9 | import torch 10 | 11 | 12 | from fast_transformers.masking import TriangularCausalMask, FullMask 13 | from fast_transformers.attention import CausalLinearAttention 14 | 15 | 16 | class TestCausalLinearAttention(unittest.TestCase): 17 | def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"): 18 | return ( 19 | torch.rand(N, L, H, E).to(device), 20 | torch.rand(N, S, H, E).to(device), 21 | torch.rand(N, S, H, D).to(device), 22 | TriangularCausalMask(L, device=device), 23 | FullMask(N, L, device=device), 24 | FullMask(N, S, device=device) 25 | ) 26 | 27 | def test_forward(self): 28 | att = CausalLinearAttention(32) 29 | q, k, v, m1, m2, m3 = self._get_inputs(L=5, S=5) 30 | v = att(q, k, v, m1, m2, m3) 31 | self.assertTrue(v.is_contiguous()) 32 | 33 | q, k, v, m1, m2, m3 = self._get_inputs(L=5, S=10) 34 | v = att(q, k, v, m1, m2, m3) 35 | self.assertTrue(v.is_contiguous()) 36 | 37 | q, k, v, m1, m2, m3 = self._get_inputs(L=10, S=5) 38 | v = att(q, k, v, m1, m2, m3) 39 | self.assertTrue(v.is_contiguous()) 40 | 41 | 42 | if __name__ == "__main__": 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /tests/attention/test_clustered_transformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.attention import AttentionLayer, ClusteredAttention 13 | from fast_transformers.masking import FullMask 14 | from fast_transformers.transformers import TransformerEncoderLayer, TransformerEncoder 15 | 16 | 17 | class TestTransformerEncoder(unittest.TestCase): 18 | def test_full_attention_forward(self): 19 | d_model = 128 20 | n_heads = 4 21 | transformer = TransformerEncoder([ 22 | TransformerEncoderLayer( 23 | AttentionLayer( 24 | ClusteredAttention( 25 | clusters = 10 26 | ), 27 | d_model, 28 | n_heads 29 | ), 30 | d_model, 31 | n_heads 32 | ) 33 | for i in range(6) 34 | ]) 35 | 36 | x = transformer(torch.rand(100, 20, d_model)) 37 | self.assertEqual(x.shape, (100, 20, d_model)) 38 | 39 | 40 | if __name__ == "__main__": 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /tests/attention/test_clustered_transformer_gpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.attention import AttentionLayer, ClusteredAttention 13 | from fast_transformers.masking import FullMask 14 | from fast_transformers.transformers import TransformerEncoderLayer, TransformerEncoder 15 | 16 | 17 | class TestTransformerEncoder(unittest.TestCase): 18 | @classmethod 19 | def setUpClass(cls): 20 | if not torch.cuda.is_available(): 21 | raise unittest.SkipTest("No CUDA capable device detected") 22 | 23 | def test_full_attention_forward(self): 24 | d_model = 128 25 | n_heads = 4 26 | transformer = TransformerEncoder([ 27 | TransformerEncoderLayer( 28 | AttentionLayer( 29 | ClusteredAttention( 30 | clusters = 10 31 | ), 32 | d_model, 33 | n_heads 34 | ), 35 | d_model, 36 | n_heads 37 | ) 38 | for i in range(6) 39 | ]) 40 | 41 | transformer = transformer.to("cuda") 42 | x = torch.rand(100, 20, d_model).to("cuda") 43 | y = transformer(x) 44 | self.assertEqual(y.shape, (100, 20, d_model)) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/attention/test_full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import os 9 | import time 10 | import unittest 11 | 12 | import torch 13 | 14 | from fast_transformers.masking import FullMask 15 | from fast_transformers.attention.full_attention import FullAttention 16 | 17 | 18 | class TestFullAttention(unittest.TestCase): 19 | def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"): 20 | return ( 21 | torch.rand(N, L, H, E).to(device), 22 | torch.rand(N, S, H, E).to(device), 23 | torch.rand(N, S, H, D).to(device), 24 | FullMask(L, S, device=device), 25 | FullMask(N, L, device=device), 26 | FullMask(N, S, device=device) 27 | ) 28 | 29 | def test_forward(self): 30 | att = FullAttention(softmax_temp=1) 31 | q, k, v, m1, m2, m3 = self._get_inputs() 32 | v = att(q, k, v, m1, m2, m3) 33 | self.assertTrue(v.is_contiguous()) 34 | 35 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 36 | def test_benchmark_cpu(self): 37 | q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64) 38 | att = FullAttention() 39 | 40 | # warmup the cache 41 | for i in range(10): 42 | v_new = att(q, k, v, m1, m2, m3) 43 | 44 | # measure 45 | start = time.time() 46 | for i in range(10): 47 | v_new = att(q, k, v, m1, m2, m3) 48 | end = time.time() 49 | print("CPU Time taken:", (end-start)*1000, "(ms)") 50 | 51 | @unittest.skipUnless(torch.cuda.is_available(), "no CUDA capable device") 52 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 53 | def test_benchmark_gpu(self): 54 | q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64, 55 | device="cuda") 56 | att = FullAttention() 57 | 58 | # warmup the caches 59 | for i in range(10): 60 | v_new = att(q, k, v, m1, m2, m3) 61 | 62 | # measure 63 | start = torch.cuda.Event(enable_timing=True) 64 | end = torch.cuda.Event(enable_timing=True) 65 | start.record() 66 | for i in range(10): 67 | v_new = att(q, k, v, m1, m2, m3) 68 | end.record() 69 | torch.cuda.synchronize() 70 | print("GPU time taken:", start.elapsed_time(end), "(ms)") 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /tests/attention/test_linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import os 9 | import time 10 | import unittest 11 | 12 | import torch 13 | 14 | from fast_transformers.masking import FullMask, LengthMask 15 | from fast_transformers.attention.linear_attention import LinearAttention 16 | 17 | 18 | class TestLinearAttention(unittest.TestCase): 19 | def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"): 20 | return ( 21 | torch.rand(N, L, H, E).to(device), 22 | torch.rand(N, S, H, E).to(device), 23 | torch.rand(N, S, H, D).to(device), 24 | FullMask(L, S, device=device), 25 | FullMask(N, L, device=device), 26 | FullMask(N, S, device=device) 27 | ) 28 | 29 | def test_forward(self): 30 | att = LinearAttention(32) 31 | q, k, v, m1, m2, m3 = self._get_inputs() 32 | v = att(q, k, v, m1, m2, m3) 33 | self.assertTrue(v.is_contiguous()) 34 | 35 | def test_masking(self): 36 | att = LinearAttention(32) 37 | q, k, v, m1, m2, m3 = self._get_inputs() 38 | 39 | # Make sure that we raise an error if m1 is not all ones 40 | with self.assertRaises(RuntimeError): 41 | att(q, k, v, FullMask(torch.rand(*m1.shape) > 0.5), m2, m3) 42 | 43 | # Make sure that the key lengths is paid attention to 44 | q, k, v, m1, m2, m3 = self._get_inputs(S=10, D=1) 45 | m3 = LengthMask(torch.tensor(list(range(10)))+1) 46 | for i in range(9): 47 | v[i, i+1:] = 1e9 48 | v_new = att(q, k, v, m1, m2, m3) 49 | self.assertLess(v_new.max().item(), 1) 50 | 51 | 52 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 53 | def test_benchmark_cpu(self): 54 | q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64) 55 | att = LinearAttention(64) 56 | 57 | # warmup the cache 58 | for i in range(10): 59 | v_new = att(q, k, v, m1, m2, m3) 60 | 61 | # measure 62 | start = time.time() 63 | for i in range(10): 64 | v_new = att(q, k, v, m1, m2, m3) 65 | end = time.time() 66 | print("CPU time taken:", (end-start)*1000, "(ms)") 67 | 68 | @unittest.skipUnless(torch.cuda.is_available(), "no CUDA capable device") 69 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 70 | def test_benchmark_gpu(self): 71 | q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64, 72 | device="cuda") 73 | att = LinearAttention(64) 74 | 75 | # warmup the caches 76 | for i in range(10): 77 | v_new = att(q, k, v, m1, m2, m3) 78 | 79 | # measure 80 | start = torch.cuda.Event(enable_timing=True) 81 | end = torch.cuda.Event(enable_timing=True) 82 | start.record() 83 | for i in range(10): 84 | v_new = att(q, k, v, m1, m2, m3) 85 | end.record() 86 | torch.cuda.synchronize() 87 | print("GPU time taken:", start.elapsed_time(end), "(ms)") 88 | 89 | 90 | if __name__ == "__main__": 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /tests/attention/test_local_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # 5 | 6 | 7 | import os 8 | import time 9 | import unittest 10 | 11 | import torch 12 | 13 | from fast_transformers.masking import FullMask, LengthMask 14 | from fast_transformers.attention.full_attention import FullAttention 15 | from fast_transformers.attention.local_attention import LocalAttention 16 | 17 | 18 | class TestLocalAttention(unittest.TestCase): 19 | def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"): 20 | return ( 21 | torch.rand(N, L, H, E).to(device), 22 | torch.rand(N, S, H, E).to(device), 23 | torch.rand(N, S, H, D).to(device), 24 | FullMask(L, S, device=device), 25 | FullMask(N, L, device=device), 26 | FullMask(N, S, device=device) 27 | ) 28 | 29 | def test_forward(self): 30 | att = LocalAttention(3, softmax_temp=1) 31 | q, k, v, m1, m2, m3 = self._get_inputs() 32 | v = att(q, k, v, m1, m2, m3) 33 | self.assertTrue(v.is_contiguous()) 34 | 35 | def test_masked(self): 36 | att = LocalAttention(16, softmax_temp=1) 37 | q, k, v, m1, m2, m3 = self._get_inputs(N=3, L=64, S=64, D=32) 38 | m2 = m3 = LengthMask(torch.tensor([8, 16, 64], dtype=torch.long)) 39 | v_hat = att(q, k, v, m1, m2, m3) 40 | self.assertFalse(torch.any(torch.isnan(v_hat))) 41 | 42 | def test_compare_with_full(self): 43 | local_att = LocalAttention(17, softmax_temp=1).eval() 44 | full_att = FullAttention(softmax_temp=1).eval() 45 | 46 | q, k, v, m1, m2, m3 = self._get_inputs(N=10, L=128, S=128, D=32) 47 | m = FullMask( 48 | torch.abs(torch.arange(128)[:, None] - torch.arange(128)[None]) < 9 49 | ) 50 | v_full = full_att(q, k, v, m, m2, m3) 51 | v_local = local_att(q, k, v, m1, m2, m3) 52 | 53 | self.assertTrue(torch.allclose(v_full, v_local, atol=1e-5, rtol=1e-5)) 54 | 55 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 56 | def test_benchmark_cpu(self): 57 | q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64) 58 | att = LocalAttention(128) 59 | 60 | # warmup the cache 61 | for i in range(10): 62 | v_new = att(q, k, v, m1, m2, m3) 63 | 64 | # measure 65 | start = time.time() 66 | for i in range(10): 67 | v_new = att(q, k, v, m1, m2, m3) 68 | end = time.time() 69 | print("CPU Time taken:", (end-start)*1000, "(ms)") 70 | 71 | @unittest.skipUnless(torch.cuda.is_available(), "no CUDA capable device") 72 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 73 | def test_benchmark_gpu(self): 74 | q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64, 75 | device="cuda") 76 | att = LocalAttention(128) 77 | 78 | # warmup the caches 79 | for i in range(10): 80 | v_new = att(q, k, v, m1, m2, m3) 81 | 82 | # measure 83 | start = torch.cuda.Event(enable_timing=True) 84 | end = torch.cuda.Event(enable_timing=True) 85 | start.record() 86 | for i in range(10): 87 | v_new = att(q, k, v, m1, m2, m3) 88 | end.record() 89 | torch.cuda.synchronize() 90 | print("GPU time taken:", start.elapsed_time(end), "(ms)") 91 | 92 | 93 | if __name__ == "__main__": 94 | unittest.main() 95 | 96 | -------------------------------------------------------------------------------- /tests/causal_product/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/causal_product/test_causal_product.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import unittest 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from fast_transformers.causal_product import causal_dot_product 13 | 14 | 15 | class TestCausalProduct(unittest.TestCase): 16 | def _zero_grad(self, *tensors): 17 | for t in tensors: 18 | if t.grad is not None: 19 | t.grad[...] = 0 20 | 21 | def _test_api(self, device): 22 | for t in range(10): 23 | N = 2 24 | H = 4 25 | L = 100 26 | S = 100 27 | E = np.random.randint(10, 256) 28 | M = np.random.randint(10, 256) 29 | Q = torch.rand(N, H, L, E).to(device).requires_grad_(True) 30 | K = torch.rand(N, H, S, E).to(device).requires_grad_(True) 31 | V = torch.randn(N, H, S, M).to(device).requires_grad_(True) 32 | 33 | self._zero_grad(Q, K, V) 34 | QK = torch.einsum("nhle,nhse->nhls", Q, K) 35 | mask = torch.tril(torch.ones(L, S))[None, None].to(device) 36 | QK = QK * mask 37 | QK = QK / (QK.sum(-1, keepdim=True) + 1e-6) 38 | V_new = torch.einsum("nhls,nhsm->nhlm", QK, V) 39 | V_new.sum().backward() 40 | grad = [torch.clone(x.grad) for x in [Q, K, V]] 41 | 42 | self._zero_grad(Q, K, V) 43 | V_new_hat = causal_dot_product(Q, K, V) 44 | Z = torch.einsum( 45 | "nhle,nhle->nhl", 46 | Q, 47 | torch.cumsum(K, dim=-2) + 1e-6 48 | ).unsqueeze(-1) 49 | 50 | V_new_hat = V_new_hat / Z 51 | V_new_hat.sum().backward() 52 | grad_hat = [torch.clone(x.grad) for x in [Q, K, V]] 53 | 54 | self.assertLess( 55 | torch.abs(V_new - V_new_hat).max(), 56 | 5e-4 57 | ) 58 | for g1, g2 in zip(grad, grad_hat): 59 | self.assertLess( 60 | torch.abs(g1-g2).max(), 61 | 5e-4 62 | ) 63 | 64 | def test_api_cpu(self): 65 | self._test_api("cpu") 66 | 67 | @unittest.skipUnless(torch.cuda.is_available(), "No CUDA capable device") 68 | def test_api_cuda(self): 69 | self._test_api("cuda") 70 | 71 | 72 | if __name__ == "__main__": 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /tests/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | -------------------------------------------------------------------------------- /tests/clustering/hamming/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/clustering/hamming/time_python_api_gpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import numpy as np 8 | import torch 9 | import time 10 | 11 | from fast_transformers.hashing import compute_hashes 12 | from fast_transformers.clustering.hamming import cluster 13 | 14 | def simple_lsh(X, A): 15 | B = (torch.einsum("nj,ij->ni", [X, A]) > 0).long() 16 | bits = 2**torch.arange(A.shape[0]) 17 | return torch.einsum("ni,i->n", [B, bits]) 18 | 19 | 20 | def generate_hash(n_points, d, b, h): 21 | torch.manual_seed(0) 22 | x = torch.rand(n_points, d).cuda() 23 | a = torch.randn(b, d + 1).cuda() 24 | compute_hashes(x, a, h) 25 | return h 26 | 27 | 28 | def time_clustering(L, N, H, E, 29 | n_batches, n_attentions, 30 | k, n_buckets, n_iterations, verbose): 31 | n_points = L * N * H 32 | hashes = torch.zeros(n_points, dtype=torch.int64).cuda() 33 | hashes = generate_hash(n_points, E, n_buckets, hashes).view(N, H, L) 34 | 35 | groups = torch.zeros((N, H, L), dtype=torch.int32).cuda() 36 | counts = torch.zeros((N, H, k), dtype=torch.int32).cuda() 37 | centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda() 38 | distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() 39 | cluster_bit_counts = torch.zeros((N, H, k, n_buckets), 40 | dtype=torch.int32).cuda() 41 | sequence_lengths = torch.ones((N,), dtype=torch.int32).cuda() * L 42 | 43 | start = time.time() 44 | for batch_idx in range(int(n_batches)): 45 | for attention_idx in range(int(n_attentions)): 46 | #hashes = generate_hash(n_points, E, n_buckets, hashes).view(L, N, H) 47 | cluster( 48 | hashes, sequence_lengths, 49 | groups=groups, counts=counts, centroids=centroids, 50 | distances=distances, bitcounts=cluster_bit_counts, 51 | iterations=n_iterations, 52 | bits=n_buckets 53 | ) 54 | end = time.time() 55 | duration = end - start 56 | print("Time Elapsed: {}".format(duration)) 57 | 58 | 59 | if __name__ == "__main__": 60 | L = 1000 61 | N = 12 62 | H = 8 63 | E = 32 64 | 65 | n_batches = 50000/N 66 | n_attentions = 3 67 | 68 | k = 30 69 | n_buckets = 31 70 | n_iterations = 10 71 | verbose = 0 72 | 73 | time_clustering(L, N, H, E, 74 | n_batches, n_attentions, 75 | k, n_buckets, n_iterations, verbose) 76 | -------------------------------------------------------------------------------- /tests/events/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | -------------------------------------------------------------------------------- /tests/events/test_event_dispatcher.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | from fast_transformers.events import Event, EventDispatcher 9 | 10 | 11 | class MockEvent(Event): 12 | def __init__(self, source, payload): 13 | super(MockEvent, self).__init__(source) 14 | self.payload = payload 15 | 16 | 17 | class TestEventDispatcher(unittest.TestCase): 18 | def test_simple_listen_dispatch(self): 19 | d = {"x": 0} 20 | def listener1(event): 21 | d["x"] += 1 22 | 23 | def listener2(event): 24 | d["x"] += 1 25 | 26 | ed = EventDispatcher() 27 | ed.listen(Event, listener1) 28 | ed.listen(Event, listener2) 29 | ed.dispatch(Event(None)) 30 | self.assertEqual(d["x"], 2) 31 | ed.remove(listener1) 32 | ed.dispatch(Event(None)) 33 | self.assertEqual(d["x"], 3) 34 | ed.remove(listener2) 35 | 36 | def set_payload(event): 37 | d.update(event.payload) 38 | ed.listen(MockEvent, set_payload) 39 | ed.dispatch(Event(None)) 40 | self.assertTrue("y" not in d) 41 | ed.dispatch(MockEvent(None, {"y": 0})) 42 | self.assertEqual(d["y"], 0) 43 | self.assertEqual(d["x"], 3) 44 | 45 | def test_factory_method(self): 46 | ed1 = EventDispatcher.get() 47 | ed2 = EventDispatcher.get() 48 | self.assertTrue(ed1 is ed2) 49 | ed1 = EventDispatcher.get("foo") 50 | ed2 = EventDispatcher.get("bar") 51 | self.assertTrue(ed1 is not ed2) 52 | ed1 = EventDispatcher.get("foo") 53 | ed2 = EventDispatcher.get("foo") 54 | self.assertTrue(ed1 is ed2) 55 | ed1 = EventDispatcher() 56 | ed2 = EventDispatcher.get(ed1) 57 | self.assertTrue(ed1 is ed2) 58 | 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/events/test_event_filters.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | import torch 9 | 10 | from fast_transformers.events import Event 11 | from fast_transformers.events.filters import event_class, from_layer, \ 12 | layer_name_contains 13 | 14 | 15 | class MockEvent(Event): 16 | def __init__(self, source, payload): 17 | super(MockEvent, self).__init__(source) 18 | self.payload = payload 19 | 20 | 21 | class TestEventFilters(unittest.TestCase): 22 | def test_simple_filters(self): 23 | mock_event = event_class(MockEvent) 24 | self.assertTrue(mock_event(MockEvent(None, None))) 25 | self.assertFalse(mock_event(Event(None))) 26 | 27 | source = object() 28 | fl = from_layer(source) 29 | self.assertTrue(fl(Event(source))) 30 | self.assertFalse(fl(Event(None))) 31 | self.assertFalse(fl(Event(object()))) 32 | 33 | net = torch.nn.Sequential( 34 | torch.nn.Linear(2, 10), 35 | torch.nn.ReLU(), 36 | torch.nn.Sequential( 37 | torch.nn.Linear(10, 10), 38 | torch.nn.Linear(10, 10), 39 | torch.nn.Linear(10, 10) 40 | ), 41 | torch.nn.ReLU(), 42 | torch.nn.Linear(10, 1), 43 | torch.nn.Sigmoid() 44 | ) 45 | lnc = layer_name_contains(net, "2.1") 46 | self.assertTrue(lnc(Event(net[2][1]))) 47 | 48 | def test_filter_composition(self): 49 | net = torch.nn.Sequential( 50 | torch.nn.Linear(2, 10), 51 | torch.nn.ReLU(), 52 | torch.nn.Sequential( 53 | torch.nn.Linear(10, 10), 54 | torch.nn.Linear(10, 10), 55 | torch.nn.Linear(10, 10) 56 | ), 57 | torch.nn.ReLU(), 58 | torch.nn.Linear(10, 1), 59 | torch.nn.Sigmoid() 60 | ) 61 | 62 | event_filter = MockEvent & from_layer(net[2][1]) 63 | self.assertFalse(event_filter(Event(net[2][1]))) 64 | self.assertFalse(event_filter(MockEvent(net[2], None))) 65 | self.assertTrue(event_filter(MockEvent(net[2][1], None))) 66 | 67 | # should raise error because ev.payload is accessed before event is 68 | # made sure to be a MockEvent 69 | event_filter = (lambda ev: ev.payload == 0) & event_class(MockEvent) 70 | with self.assertRaises(AttributeError): 71 | event_filter(Event(None)) 72 | event_filter = event_class(MockEvent) & (lambda ev: ev.payload == 0) 73 | self.assertFalse(event_filter(Event(None))) 74 | self.assertFalse(event_filter(MockEvent(None, 1))) 75 | self.assertTrue(event_filter(MockEvent(None, 0))) 76 | 77 | 78 | if __name__ == "__main__": 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /tests/events/test_events.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | import torch 9 | 10 | from fast_transformers.events import EventDispatcher, QKVEvent, \ 11 | AttentionEvent, IntermediateOutput 12 | from fast_transformers.events.filters import layer_name_contains 13 | from fast_transformers.builders import TransformerEncoderBuilder 14 | 15 | 16 | class TestEvents(unittest.TestCase): 17 | def test_qkv(self): 18 | d = {} 19 | def store_qkv(event): 20 | d["q"] = event.queries 21 | d["k"] = event.keys 22 | d["v"] = event.values 23 | # default transformer is 4 layers 4 heads 24 | transformer = TransformerEncoderBuilder().get() 25 | x = transformer(torch.rand(1, 100, 64*4)) 26 | self.assertEqual(len(d), 0) 27 | 28 | EventDispatcher.get().listen(QKVEvent, store_qkv) 29 | x = transformer(torch.rand(1, 100, 64*4)) 30 | self.assertEqual(len(d), 3) 31 | d.clear() 32 | 33 | EventDispatcher.get().remove(store_qkv) 34 | x = transformer(torch.rand(1, 100, 64*4)) 35 | self.assertEqual(len(d), 0) 36 | d.clear() 37 | 38 | EventDispatcher.get().listen( 39 | QKVEvent & layer_name_contains(transformer, "layers.2.attention"), 40 | store_qkv 41 | ) 42 | x = transformer(torch.rand(1, 100, 64*4)) 43 | self.assertEqual(len(d), 3) 44 | d.clear() 45 | 46 | EventDispatcher.get().listen( 47 | QKVEvent & layer_name_contains(transformer, "layers.22.attention"), 48 | store_qkv 49 | ) 50 | x = transformer(torch.rand(1, 100, 64*4)) 51 | self.assertEqual(len(d), 0) 52 | d.clear() 53 | EventDispatcher.get().clear() 54 | 55 | def test_attention_matrix(self): 56 | A = [] 57 | def store_attention(event): 58 | A.append(event.attention_matrix) 59 | # default transformer is 4 layers 4 heads 60 | transformer = TransformerEncoderBuilder().get() 61 | x = transformer(torch.rand(1, 100, 64*4)) 62 | self.assertEqual(len(A), 0) 63 | 64 | EventDispatcher.get().listen(AttentionEvent, store_attention) 65 | x = transformer(torch.rand(1, 100, 64*4)) 66 | self.assertEqual(len(A), 4) 67 | 68 | def test_intermediate_output(self): 69 | intermediates = [] 70 | def store_values(event): 71 | intermediates.append(event.x) 72 | 73 | transformer = TransformerEncoderBuilder().get() 74 | x = transformer(torch.rand(1, 100, 64*4)) 75 | 76 | EventDispatcher.get().listen(IntermediateOutput, store_values) 77 | transformer(x) 78 | self.assertEqual(len(intermediates), 4) 79 | 80 | 81 | if __name__ == "__main__": 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /tests/feature_maps/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | -------------------------------------------------------------------------------- /tests/feature_maps/test_fourier_features.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import unittest 7 | 8 | import torch 9 | 10 | from fast_transformers.attention import AttentionLayer, LinearAttention 11 | from fast_transformers.feature_maps.fourier_features import \ 12 | RandomFourierFeatures, SmoothedRandomFourierFeatures, Favor, \ 13 | GeneralizedRandomFeatures 14 | from fast_transformers.masking import FullMask 15 | 16 | 17 | class TestFourierFeatures(unittest.TestCase): 18 | def test_omega(self): 19 | f = RandomFourierFeatures(32, n_dims=64, orthogonal=True) 20 | f.new_feature_map("cpu") 21 | self.assertLess( 22 | torch.abs( 23 | f.omega.t().matmul(f.omega)[torch.eye(32) == 0] 24 | ).max().item(), 25 | 1e-4 26 | ) 27 | 28 | def test_rff(self): 29 | for ortho in [False, True]: 30 | f = RandomFourierFeatures(32, n_dims=32*1000, softmax_temp=1, 31 | orthogonal=ortho) 32 | f.new_feature_map("cpu") 33 | 34 | x = torch.randn(100, 32) * 0.15 35 | y = torch.randn(100, 32) * 0.15 36 | phi_x = f(x) 37 | phi_y = f(y) 38 | 39 | rbf_xy = torch.exp(-((x[:, None] - y[None, :])**2).sum(-1)/2) 40 | rbf_xy_hat = phi_x.matmul(phi_y.t()) 41 | 42 | self.assertLess( 43 | ((rbf_xy - rbf_xy_hat)**2).mean().item(), 44 | 1e-4 45 | ) 46 | 47 | f = SmoothedRandomFourierFeatures(32, n_dims=32*1000, 48 | softmax_temp=1, orthogonal=ortho, 49 | smoothing=1.0) 50 | f.new_feature_map("cpu") 51 | phi_x = f(x) 52 | phi_y = f(y) 53 | rbf_xy = torch.exp(-((x[:, None] - y[None, :])**2).sum(-1)/2) + 1 54 | rbf_xy_hat = phi_x.matmul(phi_y.t()) 55 | 56 | self.assertLess( 57 | ((rbf_xy - rbf_xy_hat)**2).mean().item(), 58 | 1e-4 59 | ) 60 | 61 | def test_prf(self): 62 | for ortho in [False, True]: 63 | f = Favor(32, n_dims=32*1000, softmax_temp=1, orthogonal=ortho) 64 | 65 | f.new_feature_map("cpu") 66 | 67 | x = torch.randn(100, 32) * 0.15 68 | y = torch.randn(100, 32) * 0.15 69 | phi_x = f(x) 70 | phi_y = f(y) 71 | 72 | sm_xy = torch.exp(x.mm(y.t())) 73 | sm_xy_hat = phi_x.mm(phi_y.t()) 74 | 75 | self.assertLess( 76 | ((sm_xy - sm_xy_hat)**2).mean().item(), 77 | 1e-3 78 | ) 79 | 80 | def test_grf(self): 81 | f = GeneralizedRandomFeatures(32, n_dims=128) 82 | f.new_feature_map("cpu") 83 | x = torch.randn(100, 32) 84 | phi_x = f(x) 85 | self.assertEqual((100, 128), phi_x.shape) 86 | 87 | def test_feature_map_sharing(self): 88 | x = torch.rand(3, 100, 4*32) 89 | f = Favor.factory(n_dims=64) 90 | att = AttentionLayer( 91 | LinearAttention(32, f), 92 | 4*32, 93 | 4 94 | ) 95 | 96 | attn_mask = FullMask(100) 97 | lengths = FullMask(3, 100) 98 | y = att(x, x, x, attn_mask, lengths, lengths) 99 | y = att(y, y, y, attn_mask, lengths, lengths) 100 | y.sum().backward() 101 | 102 | def test_feature_map_redraw(self): 103 | x = torch.rand(3, 100, 32) 104 | f = Favor(32, n_dims=64) 105 | 106 | f.new_feature_map("cpu") 107 | fx1 = f(x) 108 | f.new_feature_map("cpu") 109 | fx2 = f(x) 110 | self.assertFalse(torch.allclose(fx1, fx2)) 111 | 112 | f = Favor(32, n_dims=64, redraw=2) 113 | f.new_feature_map("cpu") 114 | fx1 = f(x) 115 | f.new_feature_map("cpu") 116 | fx2 = f(x) 117 | f.new_feature_map("cpu") 118 | fx3 = f(x) 119 | self.assertTrue(torch.allclose(fx1, fx2)) 120 | self.assertFalse(torch.allclose(fx2, fx3)) 121 | 122 | f = Favor(32, n_dims=64, redraw=1, deterministic_eval=True) 123 | f.new_feature_map("cpu") 124 | fx1 = f(x) 125 | f.new_feature_map("cpu") 126 | fx2 = f(x) 127 | f.eval() 128 | f.new_feature_map("cpu") 129 | fx3 = f(x) 130 | self.assertFalse(torch.allclose(fx1, fx2)) 131 | self.assertTrue(torch.allclose(fx2, fx3)) 132 | 133 | 134 | if __name__ == "__main__": 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /tests/hashing/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/hashing/test_hash_cpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | import os 10 | import time 11 | 12 | import torch 13 | 14 | from fast_transformers.hashing import hash_cpu 15 | 16 | 17 | def simple_lsh(X, A): 18 | B = (torch.einsum("nj,ij->ni", [X, A[:, :-1]]) > A[None, :, -1]).long() 19 | bits = 2**torch.arange(A.shape[0]) 20 | return torch.einsum("ni,i->n", [B, bits]) 21 | 22 | 23 | class TestHashCPU(unittest.TestCase): 24 | def test_hash(self): 25 | for bits in range(10, 63): 26 | x = torch.rand(100, 32) 27 | a = torch.randn(bits, 33) 28 | a[:,-1] = 0.0 29 | h1 = simple_lsh(x, a) 30 | h2 = torch.zeros_like(h1) 31 | h3 = torch.zeros_like(h1) 32 | hash_cpu.compute_hashes(x, a, h2) 33 | self.assertTrue(torch.all(h1==h2)) 34 | B = torch.einsum("nj,ij->ni", [x, a[:, :-1]]) 35 | hash_cpu.compute_hashes_from_projections(B, h3) 36 | self.assertTrue(torch.all(h1==h3)) 37 | 38 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 39 | def test_benchmark_hash(self): 40 | N = 12 41 | L = 1000 42 | H = 8 43 | E = 32 44 | B = 63 45 | x = torch.rand(N*L*H, E) 46 | a = torch.randn(B, E+1) 47 | a[:,-1] = 0. 48 | h1 = simple_lsh(x, a) 49 | h2 = torch.zeros_like(h1) 50 | h3 = torch.zeros_like(h1) 51 | 52 | # Count simple pytorch 53 | for i in range(50): 54 | simple_lsh(x, a) 55 | t = time.time() 56 | for i in range(50): 57 | simple_lsh(x, a) 58 | d1 = time.time()-t 59 | 60 | # Count simple C++ pytorch 61 | for i in range(50): 62 | hash_cpu.compute_hashes(x, a, h2) 63 | t = time.time() 64 | for i in range(50): 65 | hash_cpu.compute_hashes(x, a, h2) 66 | d2 = time.time()-t 67 | 68 | # Count simple C++ pytorch version 2 69 | for i in range(50): 70 | P = torch.einsum("nj,ij->ni", [x, a[:, :-1]]) 71 | hash_cpu.compute_hashes_from_projections(P, h3) 72 | t = time.time() 73 | for i in range(50): 74 | P = torch.einsum("nj,ij->ni", [x, a[:, :-1]]) 75 | hash_cpu.compute_hashes_from_projections(P, h3) 76 | d3 = time.time()-t 77 | 78 | print(d1, d2, d3, d1/d2, d2/d3, d1/d3) 79 | 80 | 81 | if __name__ == "__main__": 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /tests/hashing/test_hash_gpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | import os 10 | import time 11 | 12 | import torch 13 | 14 | try: 15 | from fast_transformers.hashing import hash_cuda 16 | except ImportError: 17 | pass 18 | 19 | def simple_lsh(X, A): 20 | X = X.cpu() 21 | A = A.cpu() 22 | B = (torch.einsum("nj,ij->ni", [X, A[:, :-1]]) > A[None, :, -1]).long() 23 | bits = 2**torch.arange(A.shape[0]) 24 | return torch.einsum("ni,i->n", [B, bits]).cuda() 25 | 26 | 27 | class TestHashGPU(unittest.TestCase): 28 | @classmethod 29 | def setUpClass(cls): 30 | if not torch.cuda.is_available(): 31 | raise unittest.SkipTest("No CUDA capable device detected") 32 | 33 | def test_hash(self): 34 | for bits in range(10, 63): 35 | x = torch.rand(100, 32).to("cuda") 36 | a = torch.randn(bits, 33).to("cuda") 37 | a[:,-1] = 0.0 38 | h1 = simple_lsh(x, a) 39 | h2 = torch.zeros_like(h1) 40 | h3 = torch.zeros_like(h1) 41 | hash_cuda.compute_hashes(x, a, h2) 42 | self.assertTrue(torch.all(h1==h2)) 43 | 44 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 45 | def test_benchmark_hash(self): 46 | N = 12 47 | L = 1000 48 | H = 8 49 | E = 32 50 | B = 63 51 | x = torch.rand(N*L*H, 32).to("cuda") 52 | a = torch.randn(B, 33).to("cuda") 53 | h1 = simple_lsh(x, a) 54 | h2 = torch.zeros_like(h1) 55 | h3 = torch.zeros_like(h1) 56 | 57 | # Count simple pytorch 58 | for i in range(50): 59 | simple_lsh(x, a) 60 | 61 | s = torch.cuda.Event(enable_timing=True) 62 | e = torch.cuda.Event(enable_timing=True) 63 | s.record() 64 | simple_lsh(x, a) 65 | e.record() 66 | torch.cuda.synchronize() 67 | t_simple = s.elapsed_time(e) 68 | 69 | # Count simple C++ pytorch 70 | for i in range(50): 71 | hash_cuda.compute_hashes(x, a, h2) 72 | 73 | s = torch.cuda.Event(enable_timing=True) 74 | e = torch.cuda.Event(enable_timing=True) 75 | s.record() 76 | hash_cuda.compute_hashes(x, a, h2) 77 | e.record() 78 | torch.cuda.synchronize() 79 | t_cuda = s.elapsed_time(e) 80 | 81 | print(t_simple, t_cuda, t_simple/t_cuda) 82 | 83 | 84 | 85 | if __name__ == "__main__": 86 | unittest.main() 87 | 88 | -------------------------------------------------------------------------------- /tests/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/recurrent/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/recurrent/attention/cross_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/fast-transformers/2ad36b97e64cb93862937bd21fcc9568d989561f/tests/recurrent/attention/cross_attention/__init__.py -------------------------------------------------------------------------------- /tests/recurrent/attention/cross_attention/test_attention_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | 7 | import unittest 8 | 9 | import torch 10 | 11 | from fast_transformers.recurrent.attention import RecurrentCrossAttentionLayer 12 | 13 | 14 | class TestRecurrentCrossAttentionLayer(unittest.TestCase): 15 | def _assert_sizes_attention(self, qshape, kshape, vshape): 16 | def inner(q, k, v, kl, state=None): 17 | if state is not None: 18 | k, v = state 19 | self.assertEqual(q.shape, qshape) 20 | self.assertEqual(k.shape, kshape) 21 | self.assertEqual(v.shape, vshape) 22 | N, H, E = q.shape 23 | _, _, _, D = v.shape 24 | return v.new_zeros((N, H, D)), [k, v] 25 | return inner 26 | 27 | def test_forward(self): 28 | att = RecurrentCrossAttentionLayer( 29 | self._assert_sizes_attention( 30 | (10, 4, 25), 31 | (10, 42, 4, 25), 32 | (10, 42, 4, 25) 33 | ), 34 | 100, 35 | 4 36 | ) 37 | 38 | v, s = att( 39 | torch.rand(10, 100), 40 | torch.rand(10, 42, 100), 41 | torch.rand(10, 42, 100), 42 | None, 43 | state=None 44 | ) 45 | self.assertEqual(v.shape, (10, 100)) 46 | self.assertEqual(s[0].shape, (10, 42, 4, 25)) 47 | self.assertEqual(s[1].shape, (10, 42, 4, 25)) 48 | 49 | v, s = att( 50 | torch.rand(10, 100), 51 | None, 52 | None, 53 | None, 54 | state=[torch.rand(10, 42, 4, 25), torch.rand(10, 42, 4, 25)] 55 | ) 56 | self.assertEqual(v.shape, (10, 100)) 57 | self.assertEqual(s[0].shape, (10, 42, 4, 25)) 58 | self.assertEqual(s[1].shape, (10, 42, 4, 25)) 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/recurrent/attention/cross_attention/test_full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | 7 | import os 8 | import time 9 | import unittest 10 | 11 | import torch 12 | 13 | from fast_transformers.attention import FullAttention 14 | from fast_transformers.masking import FullMask, LengthMask 15 | from fast_transformers.recurrent.attention import RecurrentCrossFullAttention 16 | 17 | 18 | class TestRecurrentCrossFullAttention(unittest.TestCase): 19 | def test_correctness(self): 20 | # Prepare the inputs 21 | N = 10 22 | H = 4 23 | E = 25 24 | M = 64 25 | L = 42 26 | S = 100 27 | q = torch.rand(N, L, H, E) 28 | k = torch.rand(N, S, H, E) 29 | v = torch.rand(N, S, H, M) 30 | m1 = FullMask(L, S) 31 | m2 = LengthMask(torch.full((N,), L, dtype=torch.int64)) 32 | m3 = LengthMask(torch.full((N,), S, dtype=torch.int64)) 33 | 34 | # Get the outputs from the attention in batch mode 35 | att = FullAttention() 36 | att.eval() 37 | v_out1 = att(q, k, v, m1, m2, m3) 38 | 39 | # Get the output from the attention in recurrent mode 40 | att = RecurrentCrossFullAttention() 41 | att.eval() 42 | v_out2_unstacked = [] 43 | state = None 44 | for i in range(L): 45 | vi, state = att(q[:, i], k, v, m3, state=state) 46 | v_out2_unstacked.append(vi) 47 | v_out2 = torch.stack(v_out2_unstacked, dim=1) 48 | 49 | # Check that they match 50 | self.assertLess(torch.abs(v_out1 - v_out2).max(), 1e-6) 51 | 52 | 53 | if __name__ == "__main__": 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /tests/recurrent/attention/cross_attention/test_linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | 7 | import os 8 | import time 9 | import unittest 10 | 11 | import torch 12 | 13 | from fast_transformers.attention import LinearAttention 14 | from fast_transformers.masking import FullMask, LengthMask 15 | from fast_transformers.recurrent.attention import RecurrentCrossLinearAttention 16 | 17 | 18 | class TestRecurrentCrossLinearAttention(unittest.TestCase): 19 | def test_correctness(self): 20 | # Prepare the inputs 21 | N = 10 22 | H = 4 23 | E = 25 24 | M = 64 25 | L = 42 26 | S = 100 27 | q = torch.rand(N, L, H, E) 28 | k = torch.rand(N, S, H, E) 29 | v = torch.rand(N, S, H, M) 30 | m1 = FullMask(L, S) 31 | m2 = LengthMask(torch.full((N,), L, dtype=torch.int64)) 32 | m3 = LengthMask(torch.full((N,), S, dtype=torch.int64)) 33 | 34 | # Get the outputs from the attention in batch mode 35 | att = LinearAttention(E) 36 | att.eval() 37 | v_out1 = att(q, k, v, m1, m2, m3) 38 | 39 | # Get the output from the attention in recurrent mode 40 | att = RecurrentCrossLinearAttention(E) 41 | att.eval() 42 | v_out2_unstacked = [] 43 | state = None 44 | for i in range(L): 45 | vi, state = att(q[:, i], k, v, m3, state=state) 46 | v_out2_unstacked.append(vi) 47 | v_out2 = torch.stack(v_out2_unstacked, dim=1) 48 | 49 | # Check that they match 50 | self.assertLess(torch.abs(v_out1 - v_out2).max(), 1e-6) 51 | 52 | 53 | if __name__ == "__main__": 54 | unittest.main() 55 | 56 | -------------------------------------------------------------------------------- /tests/recurrent/attention/self_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idiap/fast-transformers/2ad36b97e64cb93862937bd21fcc9568d989561f/tests/recurrent/attention/self_attention/__init__.py -------------------------------------------------------------------------------- /tests/recurrent/attention/self_attention/test_attention_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.recurrent.attention import RecurrentAttentionLayer 13 | 14 | 15 | class TestRecurrentAttentionLayer(unittest.TestCase): 16 | def _assert_sizes_attention(self, qshape, kshape, vshape): 17 | def inner(q, k, v, m): 18 | self.assertEqual(q.shape, qshape) 19 | self.assertEqual(k.shape, kshape) 20 | self.assertEqual(v.shape, vshape) 21 | N, H, E = q.shape 22 | _, _, D = v.shape 23 | return v.new_zeros((N, H, D)), m 24 | return inner 25 | 26 | def test_forward(self): 27 | att = RecurrentAttentionLayer( 28 | self._assert_sizes_attention( 29 | (10, 4, 25), 30 | (10, 4, 25), 31 | (10, 4, 25) 32 | ), 33 | 100, 34 | 4 35 | ) 36 | v, m = att( 37 | torch.rand(10, 100), 38 | torch.rand(10, 100), 39 | torch.rand(10, 100), 40 | "test memory" 41 | ) 42 | self.assertEqual(v.shape, (10, 100)) 43 | self.assertEqual(m, "test memory") 44 | 45 | att = RecurrentAttentionLayer( 46 | self._assert_sizes_attention( 47 | (10, 4, 32), 48 | (10, 4, 32), 49 | (10, 4, 64) 50 | ), 51 | 100, 52 | 4, 53 | d_keys=32, 54 | d_values=64 55 | ) 56 | v, m = att( 57 | torch.rand(10, 100), 58 | torch.rand(10, 100), 59 | torch.rand(10, 100), 60 | "test memory" 61 | ) 62 | self.assertEqual(v.shape, (10, 100)) 63 | self.assertEqual(m, "test memory") 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/recurrent/attention/self_attention/test_full_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import os 9 | import time 10 | import unittest 11 | 12 | import torch 13 | 14 | from fast_transformers.attention import FullAttention 15 | from fast_transformers.masking import TriangularCausalMask, LengthMask 16 | from fast_transformers.recurrent.attention import RecurrentFullAttention 17 | 18 | 19 | class TestRecurrentFullAttention(unittest.TestCase): 20 | def test_forward(self): 21 | # Prepare the inputs 22 | N = 10 23 | H = 4 24 | E = 25 25 | M = 64 26 | L = 100 27 | q = torch.rand(N, H, E) 28 | k = torch.rand(N, H, E) 29 | v = torch.rand(N, H, M) 30 | memory = [ 31 | torch.rand(N, H, L, E), 32 | torch.rand(N, H, L, M) 33 | ] 34 | 35 | # Test the attention module 36 | att = RecurrentFullAttention(softmax_temp=1) 37 | v_new, mem_new = att(q, k, v) 38 | self.assertEqual(v_new.shape, (N, H, M)) 39 | self.assertEqual(len(mem_new), 2) 40 | self.assertEqual(mem_new[0].shape, (N, H, 1, E)) 41 | self.assertEqual(mem_new[1].shape, (N, H, 1, M)) 42 | v_new, mem_new = att(q, k, v, mem_new) 43 | self.assertEqual(v_new.shape, (N, H, M)) 44 | self.assertEqual(len(mem_new), 2) 45 | self.assertEqual(mem_new[0].shape, (N, H, 2, E)) 46 | self.assertEqual(mem_new[1].shape, (N, H, 2, M)) 47 | 48 | v_new, mem_new = att(q, k, v, memory) 49 | self.assertEqual(v_new.shape, (N, H, M)) 50 | self.assertEqual(len(mem_new), 2) 51 | self.assertEqual(mem_new[0].shape, (N, H, L+1, E)) 52 | self.assertEqual(mem_new[1].shape, (N, H, L+1, M)) 53 | 54 | def test_correctness(self): 55 | # Prepare the inputs 56 | N = 10 57 | H = 4 58 | E = 25 59 | M = 64 60 | L = 100 61 | q = torch.rand(N, L, H, E) 62 | k = torch.rand(N, L, H, E) 63 | v = torch.rand(N, L, H, M) 64 | m1 = TriangularCausalMask(L) 65 | m2 = LengthMask(torch.full((N,), L, dtype=torch.int64)) 66 | m3 = LengthMask(torch.full((N,), L, dtype=torch.int64)) 67 | att = FullAttention() 68 | rec_att = RecurrentFullAttention() 69 | att.eval() 70 | rec_att.eval() 71 | 72 | v1 = att(q, k, v, m1, m2, m3) 73 | v2 = [] 74 | memory = None 75 | for i in range(L): 76 | v2i, memory = rec_att(q[:, i], k[:, i], v[:, i], memory) 77 | v2.append(v2i) 78 | v2 = torch.stack(v2, dim=1) 79 | self.assertLess(torch.abs(v1-v2).max(), 1e-5) 80 | 81 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 82 | def test_benchmark_cpu(self): 83 | # Prepare the inputs 84 | N = 10 85 | H = 12 86 | E = 25 87 | M = 64 88 | L = 100 89 | q = torch.rand(N, H, E) 90 | k = torch.rand(N, H, E) 91 | v = torch.rand(N, H, M) 92 | memory = None 93 | att = RecurrentFullAttention(softmax_temp=1) 94 | 95 | start = time.time() 96 | for i in range(100): 97 | v, memory = att(q, k, v, memory) 98 | end = time.time() 99 | print("CPU Time taken:", (end-start)*1000, "(ms)") 100 | 101 | @unittest.skipUnless(torch.cuda.is_available(), "no CUDA capable device") 102 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 103 | def test_benchmark_gpu(self): 104 | # Prepare the inputs 105 | N = 10 106 | H = 12 107 | E = 25 108 | M = 64 109 | L = 100 110 | q = torch.rand(N, H, E).cuda() 111 | k = torch.rand(N, H, E).cuda() 112 | v = torch.rand(N, H, M).cuda() 113 | memory = None 114 | att = RecurrentFullAttention(softmax_temp=1) 115 | 116 | start = torch.cuda.Event(enable_timing=True) 117 | end = torch.cuda.Event(enable_timing=True) 118 | start.record() 119 | for i in range(100): 120 | v, memory = att(q, k, v, memory) 121 | end.record() 122 | torch.cuda.synchronize() 123 | print("GPU time taken:", start.elapsed_time(end), "(ms)") 124 | 125 | 126 | if __name__ == "__main__": 127 | unittest.main() 128 | -------------------------------------------------------------------------------- /tests/recurrent/attention/self_attention/test_linear_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import os 9 | import time 10 | import unittest 11 | 12 | import torch 13 | 14 | from fast_transformers.attention import CausalLinearAttention 15 | from fast_transformers.masking import TriangularCausalMask, LengthMask 16 | from fast_transformers.recurrent.attention import RecurrentLinearAttention 17 | 18 | 19 | class TestRecurrentLinearAttention(unittest.TestCase): 20 | def test_forward(self): 21 | # Prepare the inputs 22 | N = 10 23 | H = 4 24 | E = 25 25 | M = 64 26 | L = 100 27 | q = torch.rand(N, H, E) 28 | k = torch.rand(N, H, E) 29 | v = torch.rand(N, H, M) 30 | memory = [ 31 | torch.rand(N, H, E, M), 32 | torch.rand(N, H, E) 33 | ] 34 | 35 | # Test the attention module 36 | att = RecurrentLinearAttention(E) 37 | v_new, mem_new = att(q, k, v) 38 | self.assertEqual(v_new.shape, (N, H, M)) 39 | self.assertEqual(len(mem_new), 2) 40 | self.assertEqual(mem_new[0].shape, (N, H, E, M)) 41 | self.assertEqual(mem_new[1].shape, (N, H, E)) 42 | v_new, mem_new = att(q, k, v, mem_new) 43 | self.assertEqual(v_new.shape, (N, H, M)) 44 | self.assertEqual(len(mem_new), 2) 45 | self.assertEqual(mem_new[0].shape, (N, H, E, M)) 46 | self.assertEqual(mem_new[1].shape, (N, H, E)) 47 | 48 | v_new, mem_new = att(q, k, v, memory) 49 | self.assertEqual(v_new.shape, (N, H, M)) 50 | self.assertEqual(len(mem_new), 2) 51 | self.assertEqual(mem_new[0].shape, (N, H, E, M)) 52 | self.assertEqual(mem_new[1].shape, (N, H, E)) 53 | 54 | def test_correctness(self): 55 | # Prepare the inputs 56 | N = 10 57 | H = 4 58 | E = 25 59 | M = 64 60 | L = 100 61 | q = torch.rand(N, L, H, E) 62 | k = torch.rand(N, L, H, E) 63 | v = torch.rand(N, L, H, M) 64 | m1 = TriangularCausalMask(L) 65 | m2 = LengthMask(torch.full((N,), L, dtype=torch.long)) 66 | m3 = LengthMask(torch.full((N,), L, dtype=torch.long)) 67 | att = CausalLinearAttention(E) 68 | rec_att = RecurrentLinearAttention(E) 69 | att.eval() 70 | rec_att.eval() 71 | 72 | v1 = att(q, k, v, m1, m2, m3) 73 | v2 = [] 74 | memory = None 75 | for i in range(L): 76 | v2i, memory = rec_att(q[:, i], k[:, i], v[:, i], memory) 77 | v2.append(v2i) 78 | v2 = torch.stack(v2, dim=1) 79 | self.assertLess(torch.abs(v1-v2).max(), 1e-5) 80 | 81 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 82 | def test_benchmark_cpu(self): 83 | # Prepare the inputs 84 | N = 10 85 | H = 12 86 | E = 25 87 | M = 64 88 | L = 100 89 | q = torch.rand(N, H, E) 90 | k = torch.rand(N, H, E) 91 | v = torch.rand(N, H, M) 92 | memory = None 93 | att = RecurrentLinearAttention(E) 94 | 95 | start = time.time() 96 | for i in range(100): 97 | v, memory = att(q, k, v, memory) 98 | end = time.time() 99 | print("CPU Time taken:", (end-start)*1000, "(ms)") 100 | 101 | @unittest.skipUnless(torch.cuda.is_available(), "no CUDA capable device") 102 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 103 | def test_benchmark_gpu(self): 104 | # Prepare the inputs 105 | N = 10 106 | H = 12 107 | E = 25 108 | M = 64 109 | L = 100 110 | q = torch.rand(N, H, E).cuda() 111 | k = torch.rand(N, H, E).cuda() 112 | v = torch.rand(N, H, M).cuda() 113 | memory = None 114 | att = RecurrentLinearAttention(E) 115 | 116 | start = torch.cuda.Event(enable_timing=True) 117 | end = torch.cuda.Event(enable_timing=True) 118 | start.record() 119 | for i in range(100): 120 | v, memory = att(q, k, v, memory) 121 | end.record() 122 | torch.cuda.synchronize() 123 | print("GPU time taken:", start.elapsed_time(end), "(ms)") 124 | 125 | 126 | if __name__ == "__main__": 127 | unittest.main() 128 | 129 | -------------------------------------------------------------------------------- /tests/recurrent/test_transformer_decoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | from functools import partial 7 | import unittest 8 | 9 | import torch 10 | 11 | from fast_transformers.attention import AttentionLayer, FullAttention, \ 12 | LinearAttention, CausalLinearAttention 13 | from fast_transformers.masking import TriangularCausalMask, FullMask, \ 14 | LengthMask 15 | from fast_transformers.recurrent.attention import \ 16 | RecurrentAttentionLayer, RecurrentCrossAttentionLayer, \ 17 | RecurrentFullAttention, RecurrentCrossFullAttention, \ 18 | RecurrentLinearAttention, RecurrentCrossLinearAttention 19 | from fast_transformers.recurrent.transformers import \ 20 | RecurrentTransformerDecoderLayer, RecurrentTransformerDecoder 21 | from fast_transformers.transformers import TransformerDecoderLayer, \ 22 | TransformerDecoder 23 | 24 | 25 | class TestRecurrentTransformerDecoder(unittest.TestCase): 26 | def test_compare_with_batch(self): 27 | N = 10 28 | L = 42 29 | S = 100 30 | D = 1024 31 | E = D // 4 32 | x = torch.rand(N, L, D) 33 | m = torch.rand(N, S, D) 34 | 35 | tests = [ 36 | ("full", FullAttention, FullAttention, 37 | RecurrentFullAttention, RecurrentCrossFullAttention), 38 | ("linear", partial(CausalLinearAttention, E), 39 | partial(LinearAttention, E), partial(RecurrentLinearAttention, E), 40 | partial(RecurrentCrossLinearAttention, E)) 41 | ] 42 | 43 | for name, a1, a2, a3, a4 in tests: 44 | dec = TransformerDecoder([ 45 | TransformerDecoderLayer( 46 | AttentionLayer(a1(), D, 4), 47 | AttentionLayer(a2(), D, 4), 48 | D 49 | ) 50 | for i in range(4) 51 | ]) 52 | rdec = RecurrentTransformerDecoder([ 53 | RecurrentTransformerDecoderLayer( 54 | RecurrentAttentionLayer(a3(), D, 4), 55 | RecurrentCrossAttentionLayer(a4(), D, 4), 56 | D 57 | ) 58 | for i in range(4) 59 | ]) 60 | dec.eval() 61 | rdec.eval() 62 | rdec.load_state_dict(dec.state_dict()) 63 | 64 | x_mask = TriangularCausalMask(L) 65 | x_length = LengthMask(torch.full((N,), L, dtype=torch.int64)) 66 | m_mask = FullMask(L, S) 67 | m_length = LengthMask(torch.full((N,), S, dtype=torch.int64)) 68 | 69 | y1 = dec(x, m, x_mask=x_mask, x_length_mask=x_length, 70 | memory_mask=m_mask, memory_length_mask=m_length) 71 | state = None 72 | y2 = [] 73 | for i in range(L): 74 | y2i, state = rdec(x[:, i], m, memory_length_mask=m_length, 75 | state=state) 76 | y2.append(y2i) 77 | y2 = torch.stack(y2, dim=1) 78 | 79 | self.assertLess(torch.abs(y1-y2).max(), 1e-5) 80 | 81 | def test_mask_creation(self): 82 | N = 10 83 | L = 42 84 | S = 100 85 | D = 1024 86 | x = torch.rand(N, D) 87 | m = torch.rand(N, S, D) 88 | 89 | rdec = RecurrentTransformerDecoder([ 90 | RecurrentTransformerDecoderLayer( 91 | RecurrentAttentionLayer(RecurrentFullAttention(), D, 4), 92 | RecurrentCrossAttentionLayer( 93 | RecurrentCrossFullAttention(), D, 4 94 | ), 95 | D 96 | ) 97 | for i in range(4) 98 | ]) 99 | rdec(x, m) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /tests/recurrent/test_transformer_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.recurrent.attention import RecurrentAttentionLayer, \ 13 | RecurrentFullAttention, RecurrentLinearAttention 14 | from fast_transformers.recurrent.transformers import \ 15 | RecurrentTransformerEncoderLayer, RecurrentTransformerEncoder 16 | 17 | 18 | class TestRecurrentTransformerEncoder(unittest.TestCase): 19 | def test_full_attention_forward(self): 20 | d_model = 128 21 | n_heads = 4 22 | transformer = RecurrentTransformerEncoder([ 23 | RecurrentTransformerEncoderLayer( 24 | RecurrentAttentionLayer( 25 | RecurrentFullAttention(), 26 | d_model, 27 | n_heads 28 | ), 29 | d_model, 30 | n_heads 31 | ) 32 | for i in range(6) 33 | ]) 34 | 35 | xs = [] 36 | memory = None 37 | for i in range(7): 38 | x, memory = transformer(torch.rand(10, d_model), state=memory) 39 | xs.append(x) 40 | for i in range(7): 41 | self.assertEqual(xs[i].shape, (10, d_model)) 42 | self.assertEqual(len(memory), 6) 43 | for i in range(6): 44 | self.assertEqual(len(memory[i]), 2) 45 | self.assertEqual(memory[i][0].shape, (10, n_heads, 7, 32)) 46 | self.assertEqual(memory[i][1].shape, (10, n_heads, 7, 32)) 47 | 48 | def test_linear_attention_forward(self): 49 | d_model = 128 50 | n_heads = 4 51 | d_head = d_model // n_heads 52 | transformer = RecurrentTransformerEncoder([ 53 | RecurrentTransformerEncoderLayer( 54 | RecurrentAttentionLayer( 55 | RecurrentLinearAttention(d_head), 56 | d_model, 57 | n_heads 58 | ), 59 | d_model, 60 | n_heads 61 | ) 62 | for i in range(6) 63 | ]) 64 | 65 | xs = [] 66 | memory = None 67 | for i in range(7): 68 | x, memory = transformer(torch.rand(10, d_model), state=memory) 69 | xs.append(x) 70 | for i in range(7): 71 | self.assertEqual(xs[i].shape, (10, d_model)) 72 | self.assertEqual(len(memory), 6) 73 | for i in range(6): 74 | self.assertEqual(len(memory[i]), 2) 75 | self.assertEqual(memory[i][0].shape, (10, n_heads, 32, 32)) 76 | self.assertEqual(memory[i][1].shape, (10, n_heads, 32)) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/sparse_product/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | -------------------------------------------------------------------------------- /tests/sparse_product/test_sparse_product_backward_cpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import os 8 | import time 9 | from os import getenv 10 | import unittest 11 | 12 | import torch 13 | 14 | from fast_transformers.sparse_product import sparse_dot_product 15 | 16 | 17 | class TestSparseProductBackward(unittest.TestCase): 18 | @property 19 | def device(self): 20 | return "cpu" 21 | 22 | def _zero_grad(self, Q, K): 23 | for x in [Q, K]: 24 | if x.grad is not None: 25 | x.grad[...] = 0 26 | 27 | def test_simple_grad(self): 28 | N = 2 29 | H = 4 30 | L = 100 31 | S = 100 32 | E = 32 33 | k = 10 34 | Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True) 35 | K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) 36 | topk = torch.round( 37 | torch.cumsum(torch.rand(N, H, L, k)*10, dim=-1) 38 | ).long().to(self.device) 39 | 40 | self._zero_grad(Q, K) 41 | QK_full = torch.einsum("nhle,nhse->nhls", Q, K) 42 | QK_selected = QK_full[ 43 | torch.arange(N).view(N, 1, 1, 1).to(self.device), 44 | torch.arange(H).view(1, H, 1, 1).to(self.device), 45 | torch.arange(L).view(1, 1, L, 1).to(self.device), 46 | topk 47 | ] 48 | QK_selected.sum().backward() 49 | grad = [torch.clone(Q.grad), torch.clone(K.grad)] 50 | 51 | self._zero_grad(Q, K) 52 | QK_selected_hat = sparse_dot_product(Q, K, topk) 53 | QK_selected_hat.sum().backward() 54 | grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)] 55 | 56 | self.assertLess( 57 | torch.abs(QK_selected - QK_selected_hat).max(), 58 | 1e-4 59 | ) 60 | for g1, g2 in zip(grad, grad_hat): 61 | self.assertLess( 62 | torch.abs(g1 - g2).max(), 63 | 1e-4 64 | ) 65 | 66 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 67 | def test_benchmark_forward(self): 68 | N = 12 69 | H = 8 70 | L = 1024 71 | S = 1024 72 | E = 32 73 | k = 32 74 | Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True) 75 | K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) 76 | topk = torch.round( 77 | torch.cumsum(torch.rand(N, H, L, k)*(S//k), dim=-1) 78 | ).long().to(self.device) 79 | n_runs = 10 80 | s = time.time() 81 | for i in range(n_runs): 82 | QK = torch.einsum("nhle,nhse->nhls", Q, K) 83 | QK.sum() 84 | e = time.time() 85 | t_full = (e - s) / n_runs 86 | 87 | s = time.time() 88 | for i in range(n_runs): 89 | QK = sparse_dot_product(Q, K, topk) 90 | QK.sum() 91 | e = time.time() 92 | t_sparse = (e - s) / n_runs 93 | print("Benchmark Forward: T_Full: {}, T_Sparse: {}".format(t_full, t_sparse)) 94 | 95 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 96 | def test_benchmark_forward_backward(self): 97 | N = 12 98 | H = 8 99 | L = 1024 100 | S = 1024 101 | E = 32 102 | k = 32 103 | Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True) 104 | K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) 105 | topk = torch.round( 106 | torch.cumsum(torch.rand(N, H, L, k)*(S//k), dim=-1) 107 | ).long().to(self.device) 108 | n_runs = 10 109 | self._zero_grad(Q, K) 110 | s = time.time() 111 | for i in range(n_runs): 112 | QK = torch.einsum("nhle,nhse->nhls", Q, K) 113 | QK.sum().backward() 114 | e = time.time() 115 | t_full = (e - s) / n_runs 116 | 117 | self._zero_grad(Q, K) 118 | s = time.time() 119 | for i in range(n_runs): 120 | QK = sparse_dot_product(Q, K, topk) 121 | QK.sum().backward() 122 | e = time.time() 123 | t_sparse = (e - s) / n_runs 124 | print("Benchmark Forward-Backward: T_Full: {}, T_Sparse: {}".format(t_full, t_sparse)) 125 | 126 | 127 | if __name__ == "__main__": 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /tests/sparse_product/test_sparse_product_cpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import os 8 | import time 9 | import unittest 10 | 11 | import torch 12 | 13 | from fast_transformers.sparse_product import sparse_dot_product 14 | 15 | 16 | class TestSparseProductCPU(unittest.TestCase): 17 | def test_simple_product(self): 18 | X = torch.randn(10, 4, 100, 32) 19 | Y = torch.randn(10, 4, 100, 32) 20 | topk = (torch.cumsum(torch.rand(10, 4, 100, 10)*10, dim=-1)).long() 21 | 22 | products = sparse_dot_product( 23 | X, 24 | Y, 25 | topk, 26 | ) 27 | 28 | all_products = torch.einsum("nhle,nhse->nhls", X, Y) 29 | self.assertLess( 30 | torch.max(torch.abs( 31 | products - 32 | all_products[ 33 | torch.arange(10).view(10, 1, 1, 1), 34 | torch.arange(4).view(1, 4, 1, 1), 35 | torch.arange(100).view(1, 1, 100, 1), 36 | topk 37 | ] 38 | )), 39 | 1e-4 40 | ) 41 | 42 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 43 | def test_small_benchmark(self): 44 | N = 12 45 | H = 8 46 | L = 1000 47 | S = 1000 48 | E = 32 49 | k = 32 50 | X = torch.randn(N, H, L, E) 51 | Y = torch.randn(N, H, S, E) 52 | topk = (torch.cumsum(torch.rand(N, H, L, k)*40, dim=-1)).long() 53 | 54 | n_runs = 10 55 | s = time.time() 56 | for run in range(n_runs): 57 | products = sparse_dot_product( 58 | X, 59 | Y, 60 | topk, 61 | ) 62 | e = time.time() 63 | t_s = (e - s) / n_runs 64 | 65 | s = time.time() 66 | for run in range(n_runs): 67 | torch.einsum("nhle,nhse->nhls", X, Y) 68 | e = time.time() 69 | t_f = (e - s) / n_runs 70 | print("Sparse: {}, Full: {}, F/S: {}".format(t_s, t_f, t_f/t_s)) 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /tests/sparse_product/test_sparse_product_gpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import os 8 | import time 9 | import unittest 10 | 11 | import torch 12 | 13 | from fast_transformers.sparse_product import sparse_dot_product 14 | 15 | 16 | class TestSparseProductCUDA(unittest.TestCase): 17 | @classmethod 18 | def setUpClass(cls): 19 | if not torch.cuda.is_available(): 20 | raise unittest.SkipTest("No CUDA capable device detected") 21 | 22 | def test_single_query(self): 23 | X = torch.randn(1, 1, 1, 32).cuda() 24 | Y = torch.randn(1, 1, 100, 32).cuda() 25 | lengths = torch.full((1,), 1, dtype=torch.int32).cuda() 26 | topk = (torch.cumsum(torch.rand(1, 1, 1, 10)*10, dim=-1)).long().cuda() 27 | 28 | products = sparse_dot_product( 29 | X, 30 | Y, 31 | topk, 32 | ) 33 | all_products = torch.einsum("nhle,nhse->nhls", X, Y) 34 | 35 | self.assertLess( 36 | torch.max(torch.abs( 37 | products.squeeze() - 38 | all_products[0, 0, 0, topk[0, 0, 0]] 39 | )), 40 | 1e-4 41 | ) 42 | 43 | def test_simple_product(self): 44 | X = torch.randn(10, 4, 100, 32).cuda() 45 | Y = torch.randn(10, 4, 100, 32).cuda() 46 | lengths = torch.full((10,), 100, dtype=torch.int32).cuda() 47 | topk = (torch.cumsum(torch.rand(10, 4, 100, 10)*10, dim=-1)).long().cuda() 48 | 49 | A = torch.randn(10, 4, 100, 100).to(X.device).requires_grad_(False) 50 | topk_v, topk = torch.topk(A, 10, dim=-1) 51 | topk = topk.contiguous() 52 | 53 | products = sparse_dot_product( 54 | X, 55 | Y, 56 | topk, 57 | ) 58 | all_products = torch.einsum("nhle,nhse->nhls", X, Y) 59 | 60 | self.assertLess( 61 | torch.max(torch.abs( 62 | products - 63 | all_products[ 64 | torch.arange(10).view(10, 1, 1, 1), 65 | torch.arange(4).view(1, 4, 1, 1), 66 | torch.arange(100).view(1, 1, 100, 1), 67 | topk 68 | ] 69 | )), 70 | 1e-4 71 | ) 72 | 73 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS", ""), "no benchmarks") 74 | def test_small_benchmark(self): 75 | N = 12 76 | H = 8 77 | L = 1000 78 | S = 1000 79 | E = 32 80 | k = 32 81 | X = torch.randn(N, H, L, E).cuda() 82 | Y = torch.randn(N, H, S, E).cuda() 83 | 84 | A = torch.randn(N, H, L, S).to(X.device).requires_grad_(False) 85 | topk_v, topk = torch.topk(A, k, dim=-1) 86 | topk = topk.contiguous() 87 | 88 | for i in range(1000): 89 | products = sparse_dot_product( 90 | X, 91 | Y, 92 | topk, 93 | ) 94 | torch.cuda.synchronize() 95 | s = torch.cuda.Event(enable_timing=True) 96 | e = torch.cuda.Event(enable_timing=True) 97 | s.record() 98 | products = sparse_dot_product( 99 | X, 100 | Y, 101 | topk, 102 | ) 103 | e.record() 104 | torch.cuda.synchronize() 105 | t_s = s.elapsed_time(e) 106 | for i in range(1000): 107 | torch.einsum("nhle,nhse->nhls", X, Y) 108 | s = torch.cuda.Event(enable_timing=True) 109 | e = torch.cuda.Event(enable_timing=True) 110 | s.record() 111 | torch.einsum("nhle,nhse->nhls", X, Y) 112 | e.record() 113 | torch.cuda.synchronize() 114 | t_f = s.elapsed_time(e) 115 | print("Sparse: {}, Full: {}, F/S: {}".format(t_s, t_f, t_f/t_s)) 116 | 117 | 118 | if __name__ == "__main__": 119 | unittest.main() 120 | -------------------------------------------------------------------------------- /tests/test_masking.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import unittest 8 | 9 | import torch 10 | 11 | from fast_transformers.masking import FullMask, LengthMask, TriangularCausalMask 12 | 13 | 14 | class TestMasking(unittest.TestCase): 15 | def test_full_mask(self): 16 | m = FullMask(N=10) 17 | self.assertEqual(m.shape, (10, 10)) 18 | self.assertTrue(torch.all(m.bool_matrix)) 19 | self.assertTrue(torch.all(m.float_matrix == 1)) 20 | self.assertTrue(torch.all(m.additive_matrix == 0)) 21 | 22 | with self.assertRaises(ValueError): 23 | m = FullMask(torch.rand(10)) 24 | 25 | m = FullMask(torch.rand(10, 5) > 0.5) 26 | self.assertEqual(m.shape, (10, 5)) 27 | 28 | def test_lengths(self): 29 | m = LengthMask(torch.tensor([1, 2, 3])) 30 | self.assertEqual(m.shape, (3, 3)) 31 | self.assertTrue(torch.all( 32 | m.float_matrix.sum(axis=1) == torch.tensor([1, 2, 3.]) 33 | )) 34 | self.assertTrue(torch.all( 35 | m.lengths == torch.tensor([1, 2, 3]) 36 | )) 37 | for i, n in enumerate(m.lengths): 38 | self.assertTrue(torch.all(torch.isinf(m.additive_matrix[i, n:]))) 39 | 40 | def test_max_lengths(self): 41 | m = LengthMask(torch.tensor([1, 2, 3]), max_len=10) 42 | self.assertEqual(m.shape, (3, 10)) 43 | self.assertTrue(torch.all( 44 | m.float_matrix.sum(axis=1) == torch.tensor([1, 2, 3.]) 45 | )) 46 | self.assertTrue(torch.all( 47 | m.lengths == torch.tensor([1, 2, 3]) 48 | )) 49 | for i, n in enumerate(m.lengths): 50 | self.assertTrue(torch.all(torch.isinf(m.additive_matrix[i, n:]))) 51 | 52 | def test_casting_to_lengths(self): 53 | m = FullMask(torch.tensor([ 54 | [1, 0, 0], 55 | [1, 1, 0], 56 | [1, 1, 1] 57 | ]) > 0) 58 | self.assertEqual(m.shape, (3, 3)) 59 | self.assertTrue(torch.all(m.lengths == torch.tensor([1, 2, 3]))) 60 | 61 | m = FullMask(torch.tensor([ 62 | [1, 0, 1], 63 | [1, 1, 0], 64 | [1, 1, 1] 65 | ]) > 0) 66 | with self.assertRaises(ValueError): 67 | m.lengths 68 | 69 | def test_full_mask_constructor_arguments(self): 70 | m = FullMask(torch.rand(10, 10) > 0.5) 71 | self.assertEqual(m.shape, (10, 10)) 72 | self.assertFalse(m.all_ones) 73 | 74 | m = FullMask(10) 75 | self.assertEqual(m.shape, (10, 10)) 76 | self.assertTrue(m.all_ones) 77 | 78 | m = FullMask(10, 5) 79 | self.assertEqual(m.shape, (10, 5)) 80 | self.assertTrue(m.all_ones) 81 | 82 | def test_lower_triangular(self): 83 | m = TriangularCausalMask(3) 84 | self.assertTrue(m.lower_triangular) 85 | self.assertTrue(torch.all(m.bool_matrix == (torch.tensor([ 86 | [1, 0, 0], 87 | [1, 1, 0], 88 | [1, 1, 1] 89 | ]) > 0))) 90 | 91 | m = FullMask(torch.tensor([ 92 | [1, 0, 0], 93 | [1, 1, 0], 94 | [1, 1, 1] 95 | ]) > 0) 96 | self.assertTrue(m.lower_triangular) 97 | 98 | m = FullMask(torch.tensor([ 99 | [1, 0, 1], 100 | [1, 1, 0], 101 | [1, 1, 1] 102 | ]) > 0) 103 | self.assertFalse(m.lower_triangular) 104 | 105 | m = LengthMask(torch.tensor([1, 1, 3])) 106 | self.assertFalse(m.lower_triangular) 107 | m = LengthMask(torch.tensor([1, 2, 3])) 108 | self.assertTrue(m.lower_triangular) 109 | m = LengthMask(torch.tensor([1, 2, 3]), max_len=4) 110 | self.assertTrue(m.lower_triangular) 111 | 112 | 113 | if __name__ == "__main__": 114 | unittest.main() 115 | -------------------------------------------------------------------------------- /tests/test_transformer_decoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos 4 | # 5 | 6 | import os 7 | import time 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.attention import AttentionLayer, FullAttention 13 | from fast_transformers.builders import RecurrentDecoderBuilder 14 | from fast_transformers.masking import FullMask, LengthMask 15 | from fast_transformers.transformers import TransformerDecoderLayer, \ 16 | TransformerDecoder 17 | 18 | 19 | class TestTransformerDecoder(unittest.TestCase): 20 | def test_full_attention_forward(self): 21 | d_model = 128 22 | n_heads = 4 23 | transformer = TransformerDecoder([ 24 | TransformerDecoderLayer( 25 | AttentionLayer(FullAttention(), d_model, n_heads), # self 26 | AttentionLayer(FullAttention(), d_model, n_heads), # cross 27 | d_model 28 | ) 29 | for i in range(6) 30 | ]) 31 | x = torch.rand(10, 7, d_model) 32 | mem = torch.rand(10, 12, d_model) 33 | y = transformer(x, mem) 34 | self.assertEqual(y.shape, (10, 7, d_model)) 35 | 36 | @unittest.skipUnless(os.getenv("BENCHMARK_TESTS"), "no benchmarks") 37 | def test_decoder_inference_benchmark(self): 38 | builder = RecurrentDecoderBuilder.from_kwargs( 39 | n_layers=4, 40 | n_heads=8, 41 | query_dimensions=64, 42 | value_dimensions=64 43 | ) 44 | t1 = builder.get() 45 | builder.self_attention_type = "linear" 46 | builder.cross_attention_type = "linear" 47 | t2 = builder.get() 48 | 49 | B = 128 50 | L = 100 51 | S = 100 52 | D = 512 53 | memory = torch.rand(B, S, D) 54 | memory_lengths = LengthMask(torch.full((B,), S, dtype=torch.int64)) 55 | 56 | x = torch.rand(B, D) 57 | state = None 58 | start = time.time() 59 | with torch.no_grad(): 60 | for i in range(L): 61 | x, state = t1(x, memory, memory_lengths, state=state) 62 | end = time.time() 63 | print("Softmax attention took", round(end-start, 2), "s") 64 | 65 | x = torch.rand(B, D) 66 | state = None 67 | start = time.time() 68 | with torch.no_grad(): 69 | for i in range(L): 70 | x, state = t2(x, memory, memory_lengths, state=state) 71 | end = time.time() 72 | print("Linear attention took", round(end-start, 2), "s") 73 | 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/test_transformer_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fast_transformers.attention import AttentionLayer, FullAttention, \ 13 | ClusteredAttention, ImprovedClusteredAttention, ReformerAttention 14 | from fast_transformers.masking import FullMask 15 | from fast_transformers.transformers import TransformerEncoderLayer, TransformerEncoder 16 | 17 | 18 | class TestTransformerEncoder(unittest.TestCase): 19 | def test_full_attention_forward(self): 20 | d_model = 128 21 | n_heads = 4 22 | transformer = TransformerEncoder([ 23 | TransformerEncoderLayer( 24 | AttentionLayer(FullAttention(), d_model, n_heads), 25 | d_model, 26 | n_heads 27 | ) 28 | for i in range(6) 29 | ]) 30 | x = transformer(torch.rand(10, 7, d_model)) 31 | self.assertEqual(x.shape, (10, 7, d_model)) 32 | 33 | def test_clustered_attention_forward(self): 34 | d_model = 128 35 | n_heads = 4 36 | transformer = TransformerEncoder([ 37 | TransformerEncoderLayer( 38 | AttentionLayer( 39 | ClusteredAttention( 40 | clusters = 10 41 | ), 42 | d_model, 43 | n_heads 44 | ), 45 | d_model, 46 | n_heads 47 | ) 48 | for i in range(6) 49 | ]) 50 | x = transformer(torch.rand(100, 20, d_model)) 51 | self.assertEqual(x.shape, (100, 20, d_model)) 52 | 53 | def test_improved_clustered_attention_forward(self): 54 | d_model = 128 55 | n_heads = 4 56 | transformer = TransformerEncoder([ 57 | TransformerEncoderLayer( 58 | AttentionLayer( 59 | ImprovedClusteredAttention( 60 | clusters=10, 61 | topk=5 62 | ), 63 | d_model, 64 | n_heads 65 | ), 66 | d_model, 67 | n_heads 68 | ) 69 | for i in range(6) 70 | ]) 71 | x = torch.rand(100, 20, d_model) 72 | y = transformer(x) 73 | self.assertEqual(y.shape, (100, 20, d_model)) 74 | 75 | def test_improved_clustered_attention_forward(self): 76 | d_model = 128 77 | n_heads = 4 78 | transformer = TransformerEncoder([ 79 | TransformerEncoderLayer( 80 | AttentionLayer( 81 | ReformerAttention( 82 | chunk_size=32, 83 | rounds=4, 84 | bits=8, 85 | masked=False, 86 | ), 87 | d_model, 88 | n_heads 89 | ), 90 | d_model, 91 | n_heads 92 | ) 93 | for i in range(6) 94 | ]) 95 | x = torch.rand(12, 128, d_model) 96 | y = transformer(x) 97 | self.assertEqual(y.shape, (12, 128, d_model)) 98 | 99 | 100 | if __name__ == "__main__": 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /tests/test_weight_mapper.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | 8 | import unittest 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from fast_transformers.builders import TransformerEncoderBuilder 14 | from fast_transformers.weight_mapper import PytorchMapper, \ 15 | HugginfaceBertEncoderMapper, LongformerMapper 16 | 17 | try: 18 | from transformers import BertConfig, BertModel 19 | except ImportError: 20 | BertConfig = BertModel = None 21 | 22 | try: 23 | from longformer.longformer import LongformerConfig, Longformer 24 | except ImportError: 25 | LongformerConfig = Longformer = None 26 | 27 | 28 | class TestWeightMapper(unittest.TestCase): 29 | def test_mapping(self): 30 | t1 = nn.TransformerEncoder( 31 | nn.TransformerEncoderLayer( 32 | 128, 4, dim_feedforward=256 33 | ), 34 | 4 35 | ) 36 | t2 = TransformerEncoderBuilder.from_kwargs( 37 | n_layers=4, 38 | n_heads=4, 39 | query_dimensions=128//4, 40 | value_dimensions=128//4, 41 | feed_forward_dimensions=256, 42 | attention_type="full", 43 | final_normalization=False 44 | ).get() 45 | t1.eval() 46 | t2.eval() 47 | 48 | with self.assertRaises(RuntimeError): 49 | t2.load_state_dict(t1.state_dict()) 50 | 51 | t2.load_state_dict(PytorchMapper().map(t1.state_dict())) 52 | x = torch.rand(3, 10, 128) 53 | o1 = t2(x) 54 | o2 = t1(x.permute(1, 0, 2)).permute(1, 0, 2) 55 | self.assertLess(torch.abs(o1 - o2).max().item(), 1e-5) 56 | 57 | @unittest.skipUnless(BertConfig, "Hugginface is not installed") 58 | def test_huggin_bert(self): 59 | bert = BertModel(BertConfig()) 60 | encoder = TransformerEncoderBuilder.from_kwargs( 61 | n_layers=12, 62 | n_heads=12, 63 | query_dimensions=64, 64 | value_dimensions=64, 65 | feed_forward_dimensions=3072, 66 | attention_type="full", 67 | final_normalization=False, 68 | activation="gelu" 69 | ).get() 70 | bert.eval() 71 | encoder.eval() 72 | 73 | # Before the weight copy they should be different 74 | x = torch.rand(3, 10, 768) 75 | o1 = bert.encoder(x, head_mask=[None]*12)[0] 76 | o2 = encoder(x) 77 | self.assertGreater(torch.abs(o1-o2).max().item(), 1) 78 | 79 | # And after the copy they should be exactly the same 80 | encoder.load_state_dict( 81 | HugginfaceBertEncoderMapper().map(bert.encoder.state_dict()) 82 | ) 83 | o1 = bert.encoder(x, head_mask=[None]*12)[0] 84 | o2 = encoder(x) 85 | self.assertLess(torch.abs(o1-o2).max().item(), 1e-4) 86 | 87 | @unittest.skipUnless(Longformer, "Longformer is not installed") 88 | def test_longformer(self): 89 | config = LongformerConfig() 90 | config.attention_mode = "n2" 91 | config.attention_window = [256]*12 92 | config.attention_dilation = [1]*12 93 | longformer = Longformer(config) 94 | encoder = TransformerEncoderBuilder.from_kwargs( 95 | n_layers=12, 96 | n_heads=12, 97 | query_dimensions=64, 98 | value_dimensions=64, 99 | feed_forward_dimensions=3072, 100 | attention_type="full", 101 | final_normalization=False, 102 | activation="gelu" 103 | ).get() 104 | longformer.eval() 105 | encoder.eval() 106 | 107 | # Before the weight copy they should be different 108 | x = torch.rand(3, 10, 768) 109 | o1 = longformer.encoder(x, head_mask=[None]*12)[0] 110 | o2 = encoder(x) 111 | self.assertGreater(torch.abs(o1-o2).max().item(), 1) 112 | 113 | # And after the copy they should be exactly the same 114 | encoder.load_state_dict( 115 | LongformerMapper().map(longformer.encoder.state_dict()) 116 | ) 117 | o1 = longformer.encoder(x, head_mask=[None]*12)[0] 118 | o2 = encoder(x) 119 | self.assertLess(torch.abs(o1-o2).max().item(), 1e-4) 120 | 121 | 122 | if __name__ == "__main__": 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 4 | # Written by Angelos Katharopoulos 5 | # 6 | 7 | """A script to contain some package administration tools and automations such 8 | as building the documentation. 9 | 10 | Maybe this script should be a shell-script, maybe it shouldn't. :-) 11 | """ 12 | 13 | import argparse 14 | from functools import partial 15 | from http.server import HTTPServer, SimpleHTTPRequestHandler 16 | import os 17 | from shutil import rmtree 18 | from subprocess import call 19 | import sys 20 | import time 21 | 22 | from watchdog.events import FileSystemEventHandler 23 | from watchdog.observers import Observer 24 | 25 | 26 | def throttled(once_every): 27 | last_time = [0] 28 | def decorator(f): 29 | def decorated(*args, **kwargs): 30 | if time.time() - last_time[0] > once_every: 31 | last_time[0] = time.time() 32 | return f(*args, **kwargs) 33 | return decorated 34 | return decorator 35 | 36 | 37 | @throttled(3) 38 | def build_docs(args): 39 | # Remove the directory 40 | rmtree(args.output_dir) 41 | call(["mkdocs", "build", "-d", args.output_dir]) 42 | call(["pdoc", "--html", "-o", os.path.join(args.output_dir, "api_docs"), 43 | "fast_transformers"]) 44 | 45 | 46 | def serve_docs(args): 47 | class BuildDocsEventHandler(FileSystemEventHandler): 48 | def on_any_event(self, event): 49 | if os.path.splitext(event.src_path)[1] in [".md", ".py"]: 50 | build_docs(args) 51 | 52 | build_docs(args) 53 | this_dir = os.path.dirname(os.path.realpath(__file__)) 54 | observer = Observer() 55 | observer.schedule(BuildDocsEventHandler(), this_dir, recursive=True) 56 | observer.start() 57 | try: 58 | handler = partial(SimpleHTTPRequestHandler, directory=args.output_dir) 59 | httpd = HTTPServer(args.bind, handler) 60 | httpd.serve_forever() 61 | except KeyboardInterrupt: 62 | observer.stop() 63 | observer.join() 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser( 68 | description="Build the documentation site" 69 | ) 70 | subparsers = parser.add_subparsers(dest="command") 71 | 72 | # Documentation command 73 | docs = subparsers.add_parser( 74 | "build_docs", 75 | help="Build the documentation site" 76 | ) 77 | docs.add_argument( 78 | "--output_dir", "-o", 79 | default="site", 80 | help="Choose the output directory to store the html (default: site)" 81 | ) 82 | 83 | # Serve the documentation (for writing the docs) 84 | serve = subparsers.add_parser( 85 | "serve_docs", 86 | help="Serve the documentation site for development purposes" 87 | ) 88 | serve.add_argument( 89 | "--bind", "-b", 90 | type=lambda x: (x.split(":")[0], int(x.split(":")[1])), 91 | default=("", 8000), 92 | help="The address and port to bind the server to (default: :8000)" 93 | ) 94 | serve.add_argument( 95 | "--output_dir", "-o", 96 | default="site", 97 | help="Choose the output directory to store the html (default: site)" 98 | ) 99 | 100 | # Parse the arguments 101 | args = parser.parse_args() 102 | if args.command is None: 103 | parser.print_help() 104 | sys.exit(1) 105 | 106 | # Dispatch the command 107 | dict( 108 | build_docs=build_docs, 109 | serve_docs=serve_docs 110 | )[args.command](args) 111 | --------------------------------------------------------------------------------