├── requirements.txt ├── src └── perceiver_io │ ├── __init__.py │ ├── perceiver.py │ ├── positional_encoding.py │ ├── decoders.py │ ├── encoder.py │ └── attention.py ├── setup.py ├── LICENSE ├── README.md └── examples └── language_modelling.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch -------------------------------------------------------------------------------- /src/perceiver_io/__init__.py: -------------------------------------------------------------------------------- 1 | from perceiver_io.encoder import PerceiverEncoder 2 | from perceiver_io.decoders import (ClassificationDecoder, 3 | PerceiverDecoder, 4 | ProjectionDecoder) 5 | from perceiver_io.perceiver import PerceiverIO 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup( 5 | name='perceiver-io-pytorch', 6 | version='0.1.5', 7 | packages=['perceiver_io'], 8 | package_dir={'': 'src'}, 9 | url='https://github.com/esceptico/perceiver-io', 10 | license='MIT', 11 | author='Timur Ganiev', 12 | author_email='ganiev.tmr@gmail.com', 13 | description='Unofficial Perceiver IO implementation', 14 | install_requires=['torch'] 15 | ) 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Timur Ganiev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/perceiver_io/perceiver.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from perceiver_io.decoders import BasePerceiverDecoder 7 | from perceiver_io.encoder import PerceiverEncoder 8 | 9 | 10 | class PerceiverIO(nn.Module): 11 | """Perceiver IO encoder-decoder architecture.""" 12 | def __init__( 13 | self, 14 | encoder: PerceiverEncoder, 15 | decoder: BasePerceiverDecoder 16 | ): 17 | """Constructor. 18 | 19 | Args: 20 | encoder: Instance of Perceiver IO encoder. 21 | decoder: Instance of Perceiver IO decoder. 22 | """ 23 | super().__init__() 24 | self.encoder = encoder 25 | self.decoder = decoder 26 | 27 | def forward( 28 | self, 29 | inputs: torch.Tensor, 30 | query: Optional[torch.Tensor] = None, 31 | input_mask: Optional[torch.Tensor] = None, 32 | query_mask: Optional[torch.Tensor] = None, 33 | ): 34 | """ 35 | Args: 36 | inputs: Input tensor. 37 | query: Decoder query tensor. Can be a trainable or hand-made. 38 | Defaults to None. 39 | input_mask: Input mask tensor. Mask values selected in [0, 1]. 40 | Defaults to None. 41 | query_mask: Decoder query mask tensor. Mask values selected in 42 | [0, 1]. Defaults to None. 43 | 44 | Returns: 45 | Output tensor. 46 | """ 47 | latents = self.encoder(inputs, kv_mask=input_mask) 48 | outputs = self.decoder( 49 | query=query, 50 | latents=latents, 51 | q_mask=query_mask 52 | ) 53 | return outputs 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Perceiver IO 2 | Unofficial implementation of 3 | [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) 4 | 5 | 6 | # Installation 7 | **From PyPI** 8 | ```shell 9 | pip install -U perceiver-io-pytorch 10 | ``` 11 | 12 | 13 | # Usage 14 | 15 | ```python 16 | import torch 17 | 18 | from perceiver_io.decoders import PerceiverDecoder 19 | from perceiver_io.encoder import PerceiverEncoder 20 | from perceiver_io import PerceiverIO 21 | 22 | num_latents = 128 23 | latent_dim = 256 24 | input_dim = 64 25 | 26 | decoder_query_dim = 4 27 | 28 | encoder = PerceiverEncoder( 29 | num_latents=num_latents, 30 | latent_dim=latent_dim, 31 | input_dim=input_dim, 32 | num_self_attn_per_block=8, 33 | num_blocks=1 34 | ) 35 | decoder = PerceiverDecoder( 36 | latent_dim=latent_dim, 37 | query_dim=decoder_query_dim 38 | ) 39 | perceiver = PerceiverIO(encoder, decoder) 40 | 41 | inputs = torch.randn(2, 16, input_dim) 42 | output_query = torch.randn(2, 3, decoder_query_dim) 43 | 44 | perceiver(inputs, output_query) # shape = (2, 3, 4) 45 | 46 | ``` 47 | 48 | # List of implemented decoders 49 | * ProjectionDecoder 50 | * ClassificationDecoder 51 | * PerceiverDecoder 52 | 53 | # Example architectures: 54 | * [Perceiver for LM](examples/language_modelling.py) 55 | 56 | # Citation 57 | ```bibtex 58 | @misc{jaegle2021perceiver, 59 | title = {Perceiver IO: A General Architecture for Structured Inputs & Outputs}, 60 | author = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira}, 61 | year = {2021}, 62 | eprint = {2107.14795}, 63 | archivePrefix = {arXiv}, 64 | primaryClass = {cs.LG} 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /src/perceiver_io/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Sequence 3 | 4 | import torch 5 | 6 | 7 | def fourier_encoding( 8 | dims: Sequence[int], 9 | num_bands: int, 10 | resolutions: Sequence[int], 11 | concatenate_positions: bool = True 12 | ) -> torch.Tensor: 13 | """Generate Fourier positional encodings. 14 | 15 | Args: 16 | dims: Sequence of dimensions. 17 | num_bands: Number of frequency bands. 18 | resolutions: Sequence of resolutions for each dimension. 19 | concatenate_positions: Indicates whether to concatenate positions to 20 | the encodings. Defaults to True. 21 | 22 | Returns: 23 | Tensor of shape (dims[0], ..., dims[d], num_bands * D) 24 | where D is number of dimensions. 25 | """ 26 | # make sure that number of resolutions is equals to number of dimensions 27 | assert len(resolutions) == len(dims) 28 | 29 | # generate a position indices grid of shape (dims[0], ..., dims[d], D) 30 | ranges = [torch.linspace(-1, 1, dim) for dim in dims] 31 | grid = torch.meshgrid(*ranges) 32 | grid = torch.stack(grid, dim=-1) 33 | 34 | # frequency bands for each resolution of shape (len(resolutions), num_bands) 35 | freq_bands = torch.stack([ 36 | torch.linspace(1, res / 2, steps=num_bands) 37 | for res in resolutions 38 | ], dim=0) 39 | 40 | # frequency features of shape (dims[1], ..., dims[d], D, num_bands) 41 | features = grid[..., None] * freq_bands[None, ...] 42 | sin = torch.sin(features * math.pi) 43 | cos = torch.cos(features * math.pi) 44 | features = torch.cat([sin, cos], dim=-1) 45 | 46 | # reshape the encodings as a tensor of shape 47 | # (dims[0], dims[1], ..., dims[d], num_bands * D) 48 | features = features.view(*grid.shape[:-1], -1) 49 | 50 | if concatenate_positions: 51 | features = torch.cat([features, grid], dim=-1) 52 | return features 53 | -------------------------------------------------------------------------------- /src/perceiver_io/decoders.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from perceiver_io.attention import CrossAttention 8 | 9 | 10 | class BasePerceiverDecoder(nn.Module, metaclass=ABCMeta): 11 | """Abstract decoder class.""" 12 | @abstractmethod 13 | def forward( 14 | self, 15 | *, 16 | query: torch.Tensor, 17 | latents: torch.Tensor, 18 | q_mask: Optional[torch.Tensor] = None 19 | ): 20 | return NotImplementedError 21 | 22 | 23 | class ProjectionDecoder(BasePerceiverDecoder): 24 | """Projection decoder without using a cross-attention layer.""" 25 | def __init__(self, latent_dim: int, num_classes: int): 26 | super().__init__() 27 | self.projection = nn.Linear(latent_dim, num_classes) 28 | 29 | def forward( 30 | self, 31 | *, 32 | query: torch.Tensor, 33 | latents: torch.Tensor, 34 | q_mask: Optional[torch.Tensor] = None 35 | ): 36 | latents = latents.mean(dim=1) 37 | logits = self.projection(latents) 38 | return logits 39 | 40 | 41 | class PerceiverDecoder(BasePerceiverDecoder): 42 | """Basic cross-attention decoder.""" 43 | def __init__( 44 | self, 45 | latent_dim: int, 46 | query_dim: int, 47 | widening_factor: int = 1, 48 | num_heads: int = 1, 49 | qk_out_dim: Optional[int] = None, 50 | v_out_dim: Optional[int] = None, 51 | projection_dim: Optional[int] = None, 52 | use_query_residual: bool = False 53 | ): 54 | super().__init__() 55 | self.cross_attention = CrossAttention( 56 | kv_dim=latent_dim, 57 | q_dim=query_dim, 58 | widening_factor=widening_factor, 59 | num_heads=num_heads, 60 | qk_out_dim=qk_out_dim, 61 | v_out_dim=v_out_dim, 62 | use_query_residual=use_query_residual 63 | ) 64 | if projection_dim is not None: 65 | self.projection = nn.Linear(query_dim, projection_dim) 66 | else: 67 | self.projection = nn.Identity() 68 | 69 | def forward( 70 | self, 71 | *, 72 | query: torch.Tensor, 73 | latents: torch.Tensor, 74 | q_mask: Optional[torch.Tensor] = None 75 | ): 76 | if q_mask is not None: 77 | q_mask = q_mask[:, None, None, :].transpose(-2, -1) 78 | outputs = self.cross_attention( 79 | inputs_kv=latents, 80 | inputs_q=query, 81 | attention_mask=q_mask 82 | ) 83 | return self.projection(outputs) 84 | 85 | 86 | class ClassificationDecoder(BasePerceiverDecoder): 87 | """Classification decoder. Based on PerceiverDecoder.""" 88 | def __init__( 89 | self, 90 | num_classes: int, 91 | latent_dim: int, 92 | widening_factor: int = 1, 93 | num_heads: int = 1, 94 | head_dim: Optional[int] = None 95 | ): 96 | super().__init__() 97 | self.task_ids = nn.Parameter(torch.randn(1, num_classes)) 98 | self.decoder = PerceiverDecoder( 99 | latent_dim=latent_dim, 100 | query_dim=num_classes, 101 | widening_factor=widening_factor, 102 | num_heads=num_heads, 103 | head_dim=head_dim, 104 | projection_dim=None, 105 | use_query_residual=False 106 | ) 107 | 108 | def forward( 109 | self, 110 | *, 111 | query: torch.Tensor, 112 | latents: torch.Tensor, 113 | q_mask: Optional[torch.Tensor] = None 114 | ): 115 | batch_size = latents.size(0) 116 | logits = self.decoder.forward( 117 | query=self.task_ids.repeat(batch_size, 1, 1), 118 | latents=latents, 119 | q_mask=q_mask 120 | ) 121 | return logits.squeeze(1) 122 | 123 | -------------------------------------------------------------------------------- /examples/language_modelling.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from perceiver_io import PerceiverEncoder, PerceiverDecoder, PerceiverIO 7 | 8 | 9 | class PerceiverLM(nn.Module): 10 | """Encoder-decoder based language model.""" 11 | def __init__( 12 | self, 13 | vocab_size: int, 14 | max_seq_len: int, 15 | embedding_dim: int, 16 | num_latents: int = 256, 17 | latent_dim: int = 512, 18 | num_self_attn_heads=8, 19 | self_attn_head_dim=None, 20 | cross_attn_head_dim=None, 21 | self_attn_widening_factor=1, 22 | cross_attn_widening_factor=1, 23 | num_blocks=1, 24 | num_self_attn_per_block=12, 25 | dropout: float = 0.0 26 | ): 27 | """Constructor. 28 | 29 | Args: 30 | vocab_size: Size of vocabulary. 31 | max_seq_len: Maximum length of token sequence. 32 | embedding_dim: Dimension of token embedding. 33 | num_latents: Number of latent vectors. Defaults to 256. 34 | latent_dim: Dimension of latent vector. Defaults to 512. 35 | num_self_attn_heads: Number of self-attention heads. Defaults to 8. 36 | self_attn_head_dim: Size of self-attention head. If None,this 37 | value will be calculated as latent_dim / num_self_attn_heads. 38 | Defaults to None. 39 | cross_attn_head_dim: Size of cross-attention head. If None,this 40 | value will be equal latent_dims. Defaults to None. 41 | self_attn_widening_factor: Widening factor in self-attention 42 | feed-forward layer. Defaults to 1. 43 | cross_attn_widening_factor: Widening factor in cross-attention 44 | feed-forward layer. Defaults to 1. 45 | num_blocks: Number of transformer blocks. Defaults to 1. 46 | num_self_attn_per_block: Number of self-attention modules per 47 | transformer block. Defaults to 12. 48 | dropout: Dropout probability. Defaults to 0. 49 | """ 50 | super().__init__() 51 | self.token_embedding = nn.Embedding(vocab_size, embedding_dim) 52 | self.position_embedding = nn.Embedding(max_seq_len, embedding_dim) 53 | encoder = PerceiverEncoder( 54 | num_latents=num_latents, 55 | latent_dim=latent_dim, 56 | input_dim=embedding_dim, 57 | num_self_attn_per_block=num_self_attn_per_block, 58 | num_blocks=num_blocks, 59 | cross_attn_head_dim=cross_attn_head_dim, 60 | self_attn_head_dim=self_attn_head_dim, 61 | num_self_attn_heads=num_self_attn_heads, 62 | cross_attn_widening_factor=cross_attn_widening_factor, 63 | self_attn_widening_factor=self_attn_widening_factor, 64 | dropout=dropout, 65 | ) 66 | decoder = PerceiverDecoder( 67 | latent_dim=latent_dim, 68 | query_dim=embedding_dim, 69 | widening_factor=cross_attn_widening_factor, 70 | projection_dim=vocab_size 71 | ) 72 | self.perceiver = PerceiverIO(encoder, decoder) 73 | 74 | def forward( 75 | self, 76 | inputs: torch.Tensor, 77 | mask: Optional[torch.Tensor] = None 78 | ): 79 | """ 80 | Args: 81 | inputs: Tensor of token ids. 82 | mask: Token mask. Mask values selected in [0, 1]. Defaults to None. 83 | 84 | Returns: 85 | Tensor of shape (batch_size, seq_len, vocab_size). 86 | """ 87 | seq_len = inputs.size(1) 88 | token_embeddings = self.token_embedding(inputs) 89 | positions_ids = torch.arange(seq_len, device=inputs.device).view(1, -1) 90 | position_embeddings = self.position_embedding(positions_ids) 91 | embeddings = token_embeddings + position_embeddings 92 | 93 | outputs = self.perceiver( 94 | inputs=embeddings, 95 | query=position_embeddings, 96 | input_mask=mask, 97 | query_mask=mask 98 | ) 99 | return outputs 100 | -------------------------------------------------------------------------------- /src/perceiver_io/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from perceiver_io.attention import CrossAttention, SelfAttention 7 | 8 | 9 | class PerceiverEncoder(nn.Module): 10 | """Perceiver encoder module. Consists of two components: cross-attention 11 | module that maps an input tensor and a trainable latent tensor to a latent 12 | tensor and a stacked Transformer blocks with shared weights. 13 | """ 14 | def __init__( 15 | self, 16 | num_latents: int, 17 | latent_dim: int, 18 | input_dim: int, 19 | num_self_attn_per_block: int = 2, 20 | num_blocks: int = 4, 21 | qk_out_dim: Optional[int] = None, 22 | v_out_dim: Optional[int] = None, 23 | num_cross_attn_heads: int = 1, 24 | num_self_attn_heads: int = 8, 25 | cross_attn_widening_factor: int = 1, 26 | self_attn_widening_factor: int = 1, 27 | use_query_residual: bool = True, 28 | dropout: float = 0.0, 29 | cross_attention_dropout: float = 0.0, 30 | self_attention_dropout: float = 0.0 31 | ): 32 | """Constructor. 33 | 34 | Args: 35 | num_latents: Number of latent vectors. 36 | latent_dim: Dimension of latent vector. 37 | input_dim: Dimension of input tensor. 38 | num_self_attn_per_block: Number of self-attention modules per 39 | transformer block. Defaults to 2. 40 | num_blocks: Number of transformer blocks. Defaults to 4. 41 | qk_out_dim: Size of Query and Key matrices last dimension. 42 | Defaults to None. 43 | v_out_dim: Size of Value matrix last dimension. 44 | Defaults to None. 45 | num_cross_attn_heads: Number of cross-attention heads. 46 | Defaults to 1. 47 | num_self_attn_heads: Number of self-attention heads. 48 | Defaults to 8. 49 | cross_attn_widening_factor: Widening factor in cross-attention 50 | feed-forward layer. Defaults to 1. 51 | self_attn_widening_factor: Widening factor in self-attention 52 | feed-forward layer. Defaults to 1. 53 | use_query_residual: Indicates whether to use query residual in 54 | cross-attention. Defaults to True. 55 | dropout: Feed-forward dropout probability. Defaults to 0. 56 | cross_attention_dropout: Cross-attention scores dropout probability. 57 | Defaults to 0. 58 | self_attention_dropout: Self-attention scores dropout probability. 59 | Defaults to 0. 60 | """ 61 | super().__init__() 62 | self.num_blocks = num_blocks 63 | 64 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 65 | self.cross_attn = CrossAttention( 66 | kv_dim=input_dim, 67 | q_dim=latent_dim, 68 | widening_factor=cross_attn_widening_factor, 69 | num_heads=num_cross_attn_heads, 70 | qk_out_dim=qk_out_dim, 71 | v_out_dim=v_out_dim, 72 | use_query_residual=use_query_residual, 73 | dropout=dropout, 74 | attention_dropout=cross_attention_dropout 75 | ) 76 | self.self_attention_block = nn.ModuleList([ 77 | SelfAttention( 78 | hidden_dim=latent_dim, 79 | widening_factor=self_attn_widening_factor, 80 | num_heads=num_self_attn_heads, 81 | qk_out_dim=qk_out_dim, 82 | v_out_dim=v_out_dim, 83 | dropout=dropout, 84 | attention_dropout=self_attention_dropout 85 | ) for _ in range(num_self_attn_per_block) 86 | ]) 87 | 88 | def forward(self, x: torch.Tensor, kv_mask: Optional[torch.Tensor] = None): 89 | """ 90 | Args: 91 | x: Input tensor of shape (B, M, C). 92 | kv_mask: Input mask tensor of shape (B, M). Mask values selected 93 | in [0, 1]. Defaults to None. 94 | 95 | Returns: 96 | Latent tensor. 97 | """ 98 | batch_size = x.size(0) 99 | if kv_mask is not None: 100 | kv_mask = kv_mask[:, None, None, :] 101 | 102 | latents = self.cross_attn( 103 | inputs_kv=x, 104 | inputs_q=self.latents.repeat(batch_size, 1, 1), 105 | attention_mask=kv_mask 106 | ) 107 | for _ in range(self.num_blocks): 108 | for self_attn_layer in self.self_attention_block: 109 | latents = self_attn_layer(latents) 110 | return latents 111 | -------------------------------------------------------------------------------- /src/perceiver_io/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MultiHeadAttention(nn.Module): 8 | """Multi-head attention""" 9 | def __init__( 10 | self, 11 | kv_dim: int, 12 | q_dim: int, 13 | *, 14 | qk_out_dim: Optional[int] = None, 15 | v_out_dim: Optional[int] = None, 16 | output_dim: Optional[int] = None, 17 | num_heads: int = 1, 18 | dropout: float = 0.0 19 | ): 20 | """Constructor. 21 | 22 | Args: 23 | kv_dim: Size of input key and value vectors. 24 | q_dim: Size of input query vector. 25 | qk_out_dim: Size of Query and Key matrices last dimension. 26 | If None, it will be equal to q_dim. Defaults to None. 27 | v_out_dim: Size of Value matrix last dimension. 28 | If None, it will be equal to qk_out_dim. Defaults to None. 29 | output_dim: Size of output after the QKV attention. 30 | If none, it will be equal to v_out_dim. Defaults to None. 31 | num_heads: Number of heads. Defaults to 1. 32 | dropout: Dropout probability. Defaults to 0.0. 33 | """ 34 | super().__init__() 35 | 36 | if qk_out_dim is None: 37 | qk_out_dim = q_dim 38 | if v_out_dim is None: 39 | v_out_dim = qk_out_dim 40 | if output_dim is None: 41 | output_dim = v_out_dim 42 | 43 | self.num_heads = num_heads 44 | self.qk_head_dim = qk_out_dim // num_heads 45 | self.v_head_dim = v_out_dim // num_heads 46 | 47 | self.k = nn.Linear(kv_dim, qk_out_dim) 48 | self.q = nn.Linear(q_dim, qk_out_dim) 49 | self.v = nn.Linear(kv_dim, v_out_dim) 50 | self.projection = nn.Linear(v_out_dim, output_dim) 51 | self.dropout = nn.Dropout(dropout) 52 | self.scale = self.qk_head_dim ** -0.5 53 | 54 | def transform_for_scores(self, x: torch.Tensor, head_dim: int): 55 | # (..., seq_len, dim) -> (..., n_heads, seq_len, head_dim) 56 | *dims, seq, hid = x.size() 57 | x = x.view(*dims, seq, self.num_heads, head_dim) 58 | return x.transpose(-3, -2) 59 | 60 | def forward( 61 | self, 62 | inputs_kv: torch.Tensor, 63 | inputs_q: torch.Tensor, 64 | attention_mask: Optional[torch.Tensor] = None 65 | ): 66 | """ 67 | Args: 68 | inputs_kv: Key/Value embeddings of shape (B, ..., M, C). 69 | inputs_q: Query embeddings of shape (B, ..., N, D) 70 | attention_mask: Tensor of shape (B, ..., N, M). 71 | 72 | Returns: 73 | Tensor of shape (B, ..., N, D) 74 | """ 75 | keys, queries, values = self.k(inputs_kv), self.q(inputs_q), self.v(inputs_kv) 76 | keys = self.transform_for_scores(keys, self.qk_head_dim) 77 | queries = self.transform_for_scores(queries, self.qk_head_dim) 78 | values = self.transform_for_scores(values, self.v_head_dim) 79 | attention = (queries @ keys.transpose(-2, -1) * self.scale) 80 | if attention_mask is not None: 81 | min_value = torch.finfo(attention.dtype).min 82 | extended_mask = (1 - attention_mask) * min_value 83 | attention = attention + extended_mask 84 | attention = attention.softmax(dim=-1) 85 | attention = self.dropout(attention) 86 | if attention_mask is not None: 87 | attention = attention.masked_fill(1 - attention_mask, value=0) 88 | weighted = attention @ values 89 | # (..., n_heads, seq_len, head_dim) -> (..., seq_len, hid) 90 | *dims, n_heads, seq, hid = weighted.size() 91 | weighted = weighted.transpose(-3, -2) 92 | weighted = weighted.reshape(*dims, seq, n_heads * hid) 93 | return self.projection(weighted) 94 | 95 | 96 | class FeedForward(nn.Module): 97 | """Transformer Feed-Forward network.""" 98 | def __init__( 99 | self, 100 | dim: int, 101 | widening_factor: int = 4, 102 | dropout: float = 0.0 103 | ): 104 | """Constructor. 105 | 106 | Args: 107 | dim: Dimension of input tensor. 108 | widening_factor: Widening factor. Defaults to 4. 109 | dropout: Dropout probability. Defaults to 0. 110 | """ 111 | super().__init__() 112 | self.mlp = nn.Sequential( 113 | nn.Linear(dim, dim * widening_factor), 114 | nn.GELU(), 115 | nn.Linear(dim * widening_factor, dim), 116 | nn.Dropout(dropout) 117 | ) 118 | 119 | def forward(self, x: torch.Tensor): 120 | return self.mlp(x) 121 | 122 | 123 | class SelfAttention(nn.Module): 124 | """Self-attention module.""" 125 | def __init__( 126 | self, 127 | *, 128 | hidden_dim: int, 129 | qk_out_dim: Optional[int] = None, 130 | v_out_dim: Optional[int] = None, 131 | widening_factor: int = 4, 132 | num_heads: int = 1, 133 | dropout: float = 0.0, 134 | attention_dropout: float = 0.0 135 | ): 136 | """Constructor. 137 | 138 | Args: 139 | hidden_dim: Dimension of input tensor. 140 | qk_out_dim: Size of Query and Key matrices last dimension. 141 | Defaults to None. 142 | v_out_dim: Size of Value matrix last dimension. 143 | Defaults to None. 144 | widening_factor: Feed-forward network widening factor. 145 | Defaults to 4. 146 | num_heads: Number of attention heads. Defaults to 1. 147 | dropout: Dropout probability. Defaults to 0. 148 | attention_dropout: Attention scores probability. Defaults to 0. 149 | """ 150 | super().__init__() 151 | self.layer_norm = nn.LayerNorm(hidden_dim) 152 | self.qkv_layer_norm = nn.LayerNorm(hidden_dim) 153 | self.attention = MultiHeadAttention( 154 | kv_dim=hidden_dim, 155 | q_dim=hidden_dim, 156 | qk_out_dim=qk_out_dim, 157 | v_out_dim=v_out_dim, 158 | output_dim=hidden_dim, 159 | num_heads=num_heads, 160 | dropout=attention_dropout 161 | ) 162 | self.dropout = nn.Dropout(dropout) 163 | self.mlp = FeedForward(hidden_dim, widening_factor, dropout) 164 | 165 | def forward( 166 | self, 167 | x: torch.Tensor, 168 | attention_mask: Optional[torch.Tensor] = None 169 | ): 170 | """ 171 | Args: 172 | x: Input tensor of shape (B, ..., M, C). 173 | attention_mask: Input mask tensor of shape (B, ..., M, M). 174 | Mask values selected in [0, 1]. Defaults to None. 175 | """ 176 | x_norm = self.layer_norm(x) 177 | attention = self.attention( 178 | inputs_kv=x_norm, 179 | inputs_q=x_norm, 180 | attention_mask=attention_mask 181 | ) 182 | attention = self.dropout(attention) 183 | x = x + attention 184 | x = x + self.mlp(self.qkv_layer_norm(x)) 185 | return x 186 | 187 | 188 | class CrossAttention(nn.Module): 189 | """Cross-attention module.""" 190 | def __init__( 191 | self, 192 | *, 193 | kv_dim: int, 194 | q_dim: int, 195 | qk_out_dim: Optional[int] = None, 196 | v_out_dim: Optional[int] = None, 197 | widening_factor: int = 1, 198 | num_heads: int = 1, 199 | use_query_residual: bool = True, 200 | dropout: float = 0.0, 201 | attention_dropout: float = 0.0 202 | ): 203 | """Constructor. 204 | 205 | Args: 206 | kv_dim: Dimension of key/value input tensor. 207 | q_dim: Dimension of query input tensor. 208 | qk_out_dim: Size of Query and Key matrices last dimension. 209 | Defaults to None. 210 | v_out_dim: Size of Value matrix last dimension. 211 | Defaults to None. 212 | widening_factor: Feed-forward network widening factor. 213 | Defaults to 4. 214 | num_heads: Number of attention heads. Defaults to 1. 215 | use_query_residual: Indicates whether to use query residual in 216 | cross-attention. Defaults to True. 217 | dropout: Dropout probability. Defaults to 0. 218 | attention_dropout: Attention scores probability. Defaults to 0. 219 | """ 220 | super().__init__() 221 | self.use_query_residual = use_query_residual 222 | self.kv_layer_norm = nn.LayerNorm(kv_dim) 223 | self.q_layer_norm = nn.LayerNorm(q_dim) 224 | self.qkv_layer_norm = nn.LayerNorm(q_dim) 225 | self.attention = MultiHeadAttention( 226 | kv_dim=kv_dim, 227 | q_dim=q_dim, 228 | qk_out_dim=qk_out_dim, 229 | v_out_dim=v_out_dim, 230 | output_dim=q_dim, 231 | num_heads=num_heads, 232 | dropout=attention_dropout 233 | ) 234 | self.dropout = nn.Dropout(dropout) 235 | self.mlp = FeedForward(q_dim, widening_factor, dropout) 236 | 237 | def forward( 238 | self, 239 | inputs_kv: torch.Tensor, 240 | inputs_q: torch.Tensor, 241 | attention_mask: Optional[torch.Tensor] = None 242 | ): 243 | """ 244 | Args: 245 | inputs_kv: Key/Value embeddings of shape (B, ..., M, C). 246 | inputs_q: Query embeddings of shape (B, ..., N, D) 247 | attention_mask: Tensor of shape (B, ..., N, M). Mask values selected 248 | in [0, 1]. Defaults to None. 249 | """ 250 | attention = self.attention( 251 | inputs_kv=self.kv_layer_norm(inputs_kv), 252 | inputs_q=self.q_layer_norm(inputs_q), 253 | attention_mask=attention_mask 254 | ) 255 | attention = self.dropout(attention) 256 | if self.use_query_residual: 257 | x = inputs_q + attention 258 | else: 259 | x = attention 260 | x = x + self.mlp(self.qkv_layer_norm(x)) 261 | return x 262 | --------------------------------------------------------------------------------