├── LICENSE ├── luna_transformer ├── __init__.py ├── feed_forward.py ├── embedding.py ├── mask.py ├── encoder.py ├── model.py └── attention.py ├── setup.py ├── .gitignore └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Soohwan Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /luna_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .model import LunaTransformerEncoder 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from setuptools import setup, find_packages 24 | 25 | setup( 26 | name='luna-transformer', 27 | packages=find_packages(), 28 | version='latest', 29 | description='Luna: Linear Unified Nested Attention', 30 | author='Soohwan Kim', 31 | author_email='sh951011@gmail.com', 32 | url='https://github.com/sooftware/luna-transformer', 33 | install_requires=[ 34 | 'torch>=1.4.0', 35 | 'numpy', 36 | ], 37 | python_requires='>=3.7', 38 | ) 39 | -------------------------------------------------------------------------------- /luna_transformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | 27 | class PositionwiseFeedForwardNetwork(nn.Module): 28 | """ 29 | Position-wise Feedforward Networks proposed in "Attention Is All You Need". 30 | Fully connected feed-forward network, which is applied to each position separately and identically. 31 | This consists of two linear transformations with a ReLU activation in between. 32 | Another way of describing this is as two convolutions with kernel size 1. 33 | """ 34 | def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout_p: float = 0.3) -> None: 35 | super(PositionwiseFeedForwardNetwork, self).__init__() 36 | self.feed_forward = nn.Sequential( 37 | nn.Linear(d_model, d_ff), 38 | nn.Dropout(dropout_p), 39 | nn.ReLU(), 40 | nn.Linear(d_ff, d_model), 41 | nn.Dropout(dropout_p), 42 | ) 43 | 44 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 45 | return self.feed_forward(inputs) 46 | -------------------------------------------------------------------------------- /luna_transformer/embedding.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import torch 25 | import torch.nn as nn 26 | from torch import Tensor 27 | 28 | 29 | class PositionalEncoding(nn.Module): 30 | """ 31 | Positional Encoding proposed in "Attention Is All You Need". 32 | Since transformer contains no recurrence and no convolution, in order for the model to make 33 | use of the order of the sequence, we must add some positional information. 34 | 35 | "Attention Is All You Need" use sine and cosine functions of different frequencies: 36 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) 37 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) 38 | """ 39 | def __init__(self, d_model: int = 80, max_length: int = 5000) -> None: 40 | super(PositionalEncoding, self).__init__() 41 | pe = torch.zeros(max_length, d_model, requires_grad=False) 42 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1).float() 43 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 44 | pe[:, 0::2] = torch.sin(position * div_term) 45 | pe[:, 1::2] = torch.cos(position * div_term) 46 | pe = pe.unsqueeze(0) 47 | self.register_buffer('pe', pe) 48 | 49 | def forward(self, length: int) -> Tensor: 50 | return self.pe[:, :length] 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | *.DS_Store 29 | .DS_Store 30 | .idea/* 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ -------------------------------------------------------------------------------- /luna_transformer/mask.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | from torch import Tensor 25 | 26 | 27 | def get_attn_pad_mask(inputs, input_lengths, expand_length): 28 | """ mask position is set to 1 """ 29 | 30 | def get_transformer_non_pad_mask(inputs: Tensor, input_lengths: Tensor) -> Tensor: 31 | """ Padding position is set to 0, either use input_lengths or pad_id """ 32 | batch_size = inputs.size(0) 33 | 34 | if len(inputs.size()) == 2: 35 | non_pad_mask = inputs.new_ones(inputs.size()) # B x T 36 | elif len(inputs.size()) == 3: 37 | non_pad_mask = inputs.new_ones(inputs.size()[:-1]) # B x T 38 | else: 39 | raise ValueError(f"Unsupported input shape {inputs.size()}") 40 | 41 | for i in range(batch_size): 42 | non_pad_mask[i, input_lengths[i]:] = 0 43 | 44 | return non_pad_mask 45 | 46 | non_pad_mask = get_transformer_non_pad_mask(inputs, input_lengths) 47 | pad_mask = non_pad_mask.lt(1) 48 | attn_pad_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1) 49 | return attn_pad_mask 50 | 51 | 52 | def get_attn_subsequent_mask(seq): 53 | assert seq.dim() == 2 54 | attn_shape = [seq.size(0), seq.size(1), seq.size(1)] 55 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) 56 | 57 | if seq.is_cuda: 58 | subsequent_mask = subsequent_mask.cuda() 59 | 60 | return subsequent_mask 61 | -------------------------------------------------------------------------------- /luna_transformer/encoder.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from luna_transformer.attention import LinearUnifiedNestedAttention 27 | from luna_transformer.feed_forward import PositionwiseFeedForwardNetwork 28 | 29 | 30 | class LunaTransformerEncoderLayer(nn.Module): 31 | def __init__( 32 | self, 33 | d_model: int = 512, 34 | num_attention_heads: int = 8, 35 | d_ff: int = 2048, 36 | dropout_p: float = 0.3, 37 | ) -> None: 38 | super(LunaTransformerEncoderLayer, self).__init__() 39 | self.luna_attention = LinearUnifiedNestedAttention(d_model, num_attention_heads) 40 | self.feed_forward = PositionwiseFeedForwardNetwork(d_model, d_ff, dropout_p) 41 | self.packed_context_layer_norm = nn.LayerNorm(d_model) 42 | self.unpacked_context_layer_norm = nn.LayerNorm(d_model) 43 | self.unpacked_context_layer_norm = nn.LayerNorm(d_model) 44 | self.feed_forward_layer_norm = nn.LayerNorm(d_model) 45 | 46 | def forward( 47 | self, 48 | inputs: torch.FloatTensor, 49 | p: torch.FloatTensor, 50 | attention_padding_mask: torch.FloatTensor = None, 51 | ): 52 | unpacked_context, packed_context = self.luna_attention( 53 | query=inputs, 54 | key=inputs, 55 | value=inputs, 56 | p=p, 57 | attention_padding_mask=attention_padding_mask, 58 | ) 59 | 60 | packed_context = self.packed_context_layer_norm(packed_context + p) 61 | unpacked_context = self.unpacked_context_layer_norm(unpacked_context + inputs) 62 | 63 | outputs = self.feed_forward(unpacked_context) 64 | outputs = self.feed_forward_layer_norm(outputs + unpacked_context) 65 | 66 | return outputs, packed_context 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 | 5 |

6 | 7 | **Unofficial PyTorch implementation of [Luna: Linear Unified Nested Attention](https://arxiv.org/abs/2106.01540.pdf)** 8 | 9 | 10 |
11 | 12 | *** 13 | 14 | 15 | The quadratic computational and memory complexities of the Transformer’s attention mechanism have limited its scalability for modeling long sequences. In 16 | this paper, we propose Luna, a linear unified nested attention mechanism that 17 | approximates softmax attention with two nested linear attention functions, yielding 18 | only linear (as opposed to quadratic) time and space complexity. As compared to 19 | a more traditional attention mechanism, Luna introduces an additional sequence 20 | with a fixed length as input and an additional corresponding output, which allows 21 | Luna to perform attention operation linearly, while also storing adequate contextual 22 | information. We perform extensive evaluations on three benchmarks of sequence 23 | modeling tasks: long-context sequence modeling, neural machine translation and 24 | masked language modeling for large-scale pretraining. Competitive or even better 25 | experimental results demonstrate both the effectiveness and efficiency of Luna 26 | compared to a variety of strong baseline methods including the full-rank attention 27 | and other efficient sparse and dense attention methods 28 | 29 | ![image](https://user-images.githubusercontent.com/42150335/127543497-0b4a5513-4ac6-48c7-9595-d38c880ad8ed.png) 30 | 31 | ## Installation 32 | This project recommends Python 3.7 or higher. 33 | We recommend creating a new virtual environment for this project (using virtual env or conda). 34 | 35 | ### Prerequisites 36 | * Numpy: `pip install numpy` (Refer [here](https://github.com/numpy/numpy) for problem installing Numpy). 37 | * Pytorch: Refer to [PyTorch website](http://pytorch.org/) to install the version w.r.t. your environment. 38 | 39 | ### Install from source 40 | Currently we only support installation from source code using setuptools. Checkout the source code and run the 41 | following commands: 42 | 43 | ``` 44 | pip install -e . 45 | ``` 46 | 47 | ## Usage 48 | 49 | ```python 50 | import torch 51 | from luna_transformer import LunaTransformerEncoder 52 | 53 | DUMMY_INPUTS = torch.LongTensor([ 54 | [2, 3, 3, 3, 3, 3, 2, 2, 0], 55 | [2, 3, 3, 3, 3, 3, 2, 3, 2], 56 | [2, 3, 3, 3, 3, 3, 2, 2, 0], 57 | ]) 58 | DUMMY_INPUT_LENGTHS = torch.LongTensor([9, 8, 7]) 59 | 60 | model = LunaTransformerEncoder(vocab_size=4, d_model=512, num_layers=6, 61 | num_attention_heads=8, project_embedding_length=32, 62 | dropout_p=0.1, max_length=1024) 63 | ouputs = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS) 64 | ``` 65 | 66 | ## Troubleshoots and Contributing 67 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/conformer/issues) on github or 68 | contacts sh951011@gmail.com please. 69 | 70 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues. 71 | 72 | ## Code Style 73 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation. 74 | 75 | ## Author 76 | 77 | * Soohwan Kim [@sooftware](https://github.com/sooftware) 78 | * Contacts: sh951011@gmail.com 79 | -------------------------------------------------------------------------------- /luna_transformer/model.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import torch 25 | import torch.nn as nn 26 | 27 | from luna_transformer.embedding import PositionalEncoding 28 | from luna_transformer.encoder import LunaTransformerEncoderLayer 29 | from luna_transformer.mask import get_attn_pad_mask 30 | 31 | 32 | class LunaTransformerEncoder(nn.Module): 33 | """ 34 | Transformer encoder architecture applied Linear Unified Nested Attention (Luna). 35 | Luna was proposed in the paper "Luna: Linear Unified Nested Attention" (https://arxiv.org/abs/2106.01540.pdf) 36 | """ 37 | def __init__( 38 | self, 39 | vocab_size: int, 40 | d_model: int, 41 | num_layers: int = 6, 42 | num_attention_heads: int = 8, 43 | d_ff: int = 2048, 44 | dropout_p: float = 0.1, 45 | project_embedding_length: int = 32, 46 | max_length: int = 1024, 47 | ): 48 | super(LunaTransformerEncoder, self).__init__() 49 | self.d_model = d_model 50 | self.projected_embedding_length = project_embedding_length 51 | 52 | self.projected_embeddings = nn.Parameter(torch.Tensor(project_embedding_length, self.d_model)) 53 | self.projected_positions = PositionalEncoding(self.d_model, project_embedding_length) 54 | nn.init.normal_(self.projected_embeddings, mean=0.0, std=self.d_model ** -0.5) 55 | 56 | self.input_embedding = nn.Embedding(vocab_size, d_model) 57 | self.dropout = nn.Dropout(p=dropout_p) 58 | self.input_positions = PositionalEncoding(d_model, max_length) 59 | 60 | self.input_norm = nn.LayerNorm(d_model) 61 | self.embed_scale = math.sqrt(self.d_model) 62 | self.layers = nn.ModuleList([ 63 | LunaTransformerEncoderLayer( 64 | d_model=d_model, 65 | num_attention_heads=num_attention_heads, 66 | d_ff=d_ff, 67 | dropout_p=dropout_p, 68 | ) for _ in range(num_layers) 69 | ]) 70 | 71 | def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): 72 | batch_size, seq_length = inputs.size() 73 | 74 | attention_padding_mask = get_attn_pad_mask(inputs, input_lengths, self.projected_embedding_length) 75 | 76 | embedded = self.input_embedding(inputs) 77 | 78 | embedded *= self.embed_scale 79 | projected_embedded = self.projected_embeddings * self.embed_scale 80 | 81 | embedded += self.input_positions(embedded.size(1)) 82 | projected_embedded += self.projected_positions(self.projected_embedding_length).squeeze(0) 83 | 84 | seq_length, dim = projected_embedded.size() 85 | projected_embedded = projected_embedded.unsqueeze(0).expand(batch_size, seq_length, dim) 86 | 87 | outputs = self.dropout(embedded) 88 | p = self.dropout(projected_embedded) 89 | 90 | for layer in self.layers: 91 | outputs, p = layer(outputs, p, attention_padding_mask) 92 | 93 | return outputs 94 | -------------------------------------------------------------------------------- /luna_transformer/attention.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 Soohwan Kim 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | import numpy as np 27 | from typing import Optional, Tuple 28 | 29 | 30 | class DotProductAttention(nn.Module): 31 | r""" 32 | Scaled Dot-Product Attention proposed in "Attention Is All You Need" 33 | Compute the dot products of the query with all keys, divide each by sqrt(dim), 34 | and apply a softmax function to obtain the weights on the values 35 | 36 | Args: dim, mask 37 | dim (int): dimension of attention 38 | mask (torch.Tensor): tensor containing indices to be masked 39 | 40 | Inputs: query, key, value, mask 41 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoders. 42 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoders. 43 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. 44 | - **mask** (-): tensor containing indices to be masked 45 | 46 | Returns: context, attn 47 | - **context**: tensor containing the context vector from attention mechanism. 48 | - **attn**: tensor containing the attention (alignment) from the encoders outputs. 49 | """ 50 | def __init__(self, dim: int, scale: bool = True) -> None: 51 | super(DotProductAttention, self).__init__() 52 | if scale: 53 | self.sqrt_dim = np.sqrt(dim) 54 | else: 55 | self.sqrt_dim = 1 56 | 57 | def forward( 58 | self, 59 | query: torch.FloatTensor, 60 | key: torch.FloatTensor, 61 | value: torch.FloatTensor, 62 | mask: Optional[torch.FloatTensor] = None, 63 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 64 | score = torch.matmul(query, key.transpose(2, 3)) / self.sqrt_dim 65 | 66 | if mask is not None: 67 | score.masked_fill_(mask, -1e4) 68 | 69 | attn = F.softmax(score, -1) 70 | 71 | if len(query.size()) == 3: 72 | context = torch.bmm(attn, value) 73 | else: 74 | context = torch.matmul(attn, value) 75 | 76 | return context, attn 77 | 78 | 79 | class MultiHeadAttention(nn.Module): 80 | r""" 81 | Multi-Head Attention proposed in "Attention Is All You Need" 82 | Instead of performing a single attention function with d_model-dimensional keys, values, and queries, 83 | project the queries, keys and values h times with different, learned linear projections to d_head dimensions. 84 | These are concatenated and once again projected, resulting in the final values. 85 | Multi-head attention allows the model to jointly attend to information from different representation 86 | subspaces at different positions. 87 | 88 | MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o 89 | where head_i = Attention(Q · W_q, K · W_k, V · W_v) 90 | 91 | Args: 92 | dim (int): The dimension of model (default: 512) 93 | num_attention_heads (int): The number of attention heads. (default: 8) 94 | 95 | Inputs: query, key, value, mask 96 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoders. 97 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoders. 98 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. 99 | - **mask** (-): tensor containing indices to be masked 100 | 101 | Returns: output, attn 102 | - **output** (batch, output_len, dimensions): tensor containing the attended output features. 103 | - **attn** (batch * num_attention_heads, v_len): tensor containing the attention (alignment) from the encoders outputs. 104 | """ 105 | def __init__(self, dim: int = 512, num_attention_heads: int = 8) -> None: 106 | super(MultiHeadAttention, self).__init__() 107 | 108 | assert dim % num_attention_heads == 0, "hidden_dim % num_attention_heads should be zero." 109 | 110 | self.d_head = int(dim / num_attention_heads) 111 | self.num_attention_heads = num_attention_heads 112 | self.query_proj = nn.Linear(dim, self.d_head * num_attention_heads) 113 | self.key_proj = nn.Linear(dim, self.d_head * num_attention_heads) 114 | self.value_proj = nn.Linear(dim, self.d_head * num_attention_heads) 115 | self.scaled_dot_attn = DotProductAttention(dim, scale=True) 116 | 117 | def forward( 118 | self, 119 | query: torch.FloatTensor, 120 | key: torch.FloatTensor, 121 | value: torch.FloatTensor, 122 | mask: Optional[torch.FloatTensor] = None, 123 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 124 | batch_size = value.size(0) 125 | 126 | query = self.query_proj(query).view(batch_size, -1, self.num_attention_heads, self.d_head).transpose(1, 2) 127 | key = self.key_proj(key).view(batch_size, -1, self.num_attention_heads, self.d_head).transpose(1, 2) 128 | value = self.value_proj(value).view(batch_size, -1, self.num_attention_heads, self.d_head).transpose(1, 2) 129 | 130 | if mask is not None: 131 | mask = mask.unsqueeze(1).repeat(1, self.num_attention_heads, 1, 1) 132 | 133 | context, attn = self.scaled_dot_attn(query, key, value, mask) 134 | 135 | context = context.transpose(1, 2).reshape(batch_size, -1, self.num_attention_heads * self.d_head) 136 | 137 | return context, attn 138 | 139 | 140 | class LinearUnifiedNestedAttention(nn.Module): 141 | def __init__(self, dim, num_attention_heads: int = 8) -> None: 142 | super(LinearUnifiedNestedAttention, self).__init__() 143 | self.pack_attention = MultiHeadAttention(dim, num_attention_heads) 144 | self.unpack_attention = MultiHeadAttention(dim, num_attention_heads) 145 | 146 | def forward( 147 | self, 148 | query: torch.FloatTensor, 149 | key: torch.FloatTensor, 150 | value: torch.FloatTensor, 151 | p: torch.FloatTensor, 152 | attention_padding_mask: torch.BoolTensor = None, 153 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 154 | packed_context, _ = self.pack_attention(p, key, value, attention_padding_mask) 155 | unpacked_context, _ = self.unpack_attention(query, packed_context, packed_context) 156 | return unpacked_context, packed_context 157 | --------------------------------------------------------------------------------