├── figures ├── .keep ├── perceiver.png ├── perceiver-io.png ├── benchmark_speedup.png └── benchmark_memory_usage_reduction.png ├── flash_perceiver ├── __init__.py ├── utils │ ├── training.py │ ├── __init__.py │ └── encodings.py ├── adapters.py └── perceiver.py ├── scratch.py ├── tests ├── __init__.py ├── test_mha.py ├── test_perceiver_io.py ├── test_perceiver.py ├── test_adapters.py └── test_perceiver_base.py ├── pyproject.toml ├── create_plots.py ├── .gitignore ├── examples └── cifar_classification.py ├── run_benchmarks.py └── README.md /figures/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flash_perceiver/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceiver import Perceiver, PerceiverIO 2 | -------------------------------------------------------------------------------- /scratch.py: -------------------------------------------------------------------------------- 1 | from flash_attn.modules.mha import CrossAttention 2 | 3 | 4 | -------------------------------------------------------------------------------- /figures/perceiver.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/FlashPerceiver/HEAD/figures/perceiver.png -------------------------------------------------------------------------------- /figures/perceiver-io.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/FlashPerceiver/HEAD/figures/perceiver-io.png -------------------------------------------------------------------------------- /figures/benchmark_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/FlashPerceiver/HEAD/figures/benchmark_speedup.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | torch.set_default_device('cuda') 5 | torch.set_default_dtype(torch.float16) 6 | -------------------------------------------------------------------------------- /figures/benchmark_memory_usage_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/FlashPerceiver/HEAD/figures/benchmark_memory_usage_reduction.png -------------------------------------------------------------------------------- /flash_perceiver/utils/training.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | 7 | class CosineWithWarmupLR(LambdaLR): 8 | def __init__(self, optimizer: Optimizer, training_steps: int, warmup_steps: int = 0, num_cycles: float = 0.5): 9 | def lr_lambda(current_step): 10 | if current_step < warmup_steps: 11 | return float(current_step) / float(max(1, warmup_steps)) 12 | progress = float(current_step - warmup_steps) / float(max(1, training_steps - warmup_steps)) 13 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 14 | 15 | super().__init__(optimizer, lr_lambda, -1) 16 | -------------------------------------------------------------------------------- /tests/test_mha.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from flash_perceiver.perceiver import PatchedMHA 5 | 6 | 7 | @pytest.mark.parametrize('kv_dim', [None, 128]) 8 | @pytest.mark.parametrize('with_fa', [False, True]) 9 | def test_mha_fa_change(kv_dim, with_fa): 10 | embed_dim = 64 11 | 12 | mha = PatchedMHA( 13 | embed_dim=embed_dim, 14 | kv_dim=kv_dim, 15 | cross_attn=kv_dim is not None, 16 | use_flash_attn=with_fa 17 | ) 18 | 19 | x = torch.randn(32, 128, embed_dim) 20 | 21 | x_kv = None 22 | 23 | if kv_dim is not None: 24 | x_kv = torch.randn(32, 256, kv_dim) 25 | 26 | out_before = mha(x, x_kv=x_kv) 27 | 28 | mha.set_flash_attn(not with_fa) 29 | 30 | out_after = mha(x, x_kv=x_kv) 31 | 32 | assert out_before.shape == (32, 128, embed_dim) 33 | 34 | # Is this high tolerance reasonable? 35 | assert torch.allclose(out_before, out_after, atol=5e-4) 36 | -------------------------------------------------------------------------------- /tests/test_perceiver_io.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from flash_perceiver import PerceiverIO 5 | 6 | 7 | @pytest.mark.parametrize('query_dim', [32, 64]) 8 | @pytest.mark.parametrize('num_queries', [8, 16]) 9 | @pytest.mark.parametrize('proj_dim', [None, 128]) 10 | @pytest.mark.parametrize('num_zero_tokens', [None, 0, 32]) 11 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 12 | @pytest.mark.parametrize('latent_drop', [0.0, 0.5]) 13 | def test_perceiver_io( 14 | query_dim, 15 | num_queries, 16 | proj_dim, 17 | num_zero_tokens, 18 | use_flash_attn, 19 | latent_drop 20 | ): 21 | model = PerceiverIO( 22 | input_dim=128, 23 | depth=4, 24 | query_dim=query_dim, 25 | proj_dim=proj_dim, 26 | num_zero_tokens=num_zero_tokens, 27 | use_flash_attn=use_flash_attn, 28 | latent_drop=latent_drop 29 | ) 30 | 31 | x = torch.randn(32, 64, 128) 32 | queries = torch.randn(num_queries, query_dim) 33 | 34 | out = model(x, queries=queries) 35 | 36 | if proj_dim is None: 37 | assert out.shape == (32, num_queries, query_dim) 38 | else: 39 | assert out.shape == (32, num_queries, proj_dim) -------------------------------------------------------------------------------- /tests/test_perceiver.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from flash_perceiver import Perceiver 5 | 6 | 7 | @pytest.mark.parametrize('output_dim', [None, 128]) 8 | @pytest.mark.parametrize('output_mode', ['average', 'concat', 'first']) 9 | @pytest.mark.parametrize('num_zero_tokens', [None, 0, 32]) 10 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 11 | @pytest.mark.parametrize('latent_drop', [0.0, 0.5]) 12 | def test_output_projection( 13 | output_dim, 14 | output_mode, 15 | num_zero_tokens, 16 | use_flash_attn, 17 | latent_drop 18 | ): 19 | build_fn = lambda: Perceiver( 20 | input_dim=128, 21 | depth=4, 22 | output_dim=output_dim, 23 | output_mode=output_mode, 24 | num_zero_tokens=num_zero_tokens, 25 | use_flash_attn=use_flash_attn, 26 | latent_drop=latent_drop 27 | ) 28 | 29 | if output_mode == 'concat' and latent_drop > 0: 30 | with pytest.raises(ValueError): 31 | model = build_fn() 32 | return 33 | 34 | model = build_fn() 35 | 36 | x = torch.randn(32, 64, 128) 37 | 38 | out = model(x) 39 | 40 | if output_dim is None: 41 | assert out.shape == (32, int(model.num_latents * (1 - latent_drop)), model.latent_dim) 42 | else: 43 | assert out.shape == (32, output_dim) 44 | -------------------------------------------------------------------------------- /flash_perceiver/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from einops import repeat 2 | import torch 3 | import torch.nn as nn 4 | 5 | from functools import wraps 6 | 7 | 8 | def cache_fn(f): 9 | cache = dict() 10 | @wraps(f) 11 | def cached_fn(*args, _cache=True, key=None, **kwargs): 12 | if not _cache: 13 | return f(*args, **kwargs) 14 | nonlocal cache 15 | if key in cache: 16 | return cache[key] 17 | result = f(*args, **kwargs) 18 | cache[key] = result 19 | return result 20 | return cached_fn 21 | 22 | 23 | def identity(x, *args, **kwargs): 24 | return x 25 | 26 | 27 | def random_mask(x): 28 | bs, n = x.shape[:2] 29 | 30 | seq_lens = torch.randint(1, n + 1, (bs,), device=x.device) 31 | mask = torch.arange(n, device=x.device)[None, :] < seq_lens[:, None] 32 | 33 | return mask 34 | 35 | 36 | def numel(m: nn.Module, only_trainable: bool = True): 37 | return sum(p.numel() for p in m.parameters() if not only_trainable or p.requires_grad) 38 | 39 | 40 | def meshgrid(*size: int, batch_size: int | None = None): 41 | tensors = [torch.linspace(-1, 1, n) for n in size] 42 | grid = torch.stack( 43 | torch.meshgrid(tensors, indexing='ij'), -1 44 | ) 45 | 46 | if batch_size is not None: 47 | return repeat(grid, '... -> b ...', b=batch_size) 48 | 49 | return grid 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "flash-perceiver" 3 | version = "0.2.0" 4 | description = "Fast and memory efficient PyTorch implementation of the Perceiver with FlashAttention." 5 | authors = ["Kristian Klemon "] 6 | readme = "README.md" 7 | packages = [{include = "flash_perceiver"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.9" 11 | einops = "^0.7.0" 12 | # FlashAttention has recently dropped PEP 5017 support as it led to issue with 13 | # declaring torch as dependency. 14 | # Until this is resolved, we can't declare flash-attn as dependency and the user 15 | # needs to install it manually. 16 | # See https://github.com/Dao-AILab/flash-attention/pull/193 17 | # flash-attn = "^2.2.5" 18 | 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | tqdm = "^4.65.0" 22 | pandas = "^2.0.3" 23 | seaborn = "^0.12.2" 24 | jupyter = "^1.0.0" 25 | perceiver-pytorch = "^0.8.7" 26 | pytest = "^7.4.0" 27 | pytorch-lamb = {git = "https://github.com/cybertronai/pytorch-lamb.git"} 28 | pytest-readme = "^1.0.1" 29 | torch = {version = "^2.0.1", source = "pytorch-gpu-src"} 30 | torchvision = {version = "^0.15.2", source = "pytorch-gpu-src"} 31 | 32 | [[tool.poetry.source]] 33 | name = "pytorch-gpu-src" 34 | url = "https://download.pytorch.org/whl/cu118" 35 | priority = "explicit" 36 | 37 | [build-system] 38 | requires = ["poetry-core"] 39 | build-backend = "poetry.core.masonry.api" 40 | -------------------------------------------------------------------------------- /create_plots.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | from pathlib import Path 7 | 8 | sns.set_theme() 9 | 10 | hue_model_order = [ 11 | 'perceiver-pytorch', 12 | 'flash-perceiver', 13 | ] 14 | 15 | 16 | def calc_relative_improvement(s, col): 17 | ref_value = s[s['implementation'] == 'perceiver-pytorch'][col].iloc[0] 18 | rel_improvement = ref_value / s[col] 19 | return rel_improvement 20 | 21 | 22 | def create_plot(df, y_col): 23 | fig, ax = plt.subplots() 24 | 25 | g = sns.barplot( 26 | df, 27 | x='input sequence length', 28 | y=y_col, 29 | hue='implementation', 30 | hue_order=hue_model_order, 31 | width=0.5, 32 | ax=ax 33 | ) 34 | g.set_ylim(0, g.get_ylim()[1] * 1.2) 35 | g.bar_label(g.containers[1]) 36 | 37 | return fig 38 | 39 | def main(args): 40 | df = pd.read_csv(args.results_file) 41 | df = df.rename(columns={ 42 | 'model': 'implementation', 43 | 'input_size': 'input sequence length' 44 | }).sort_values('input sequence length') 45 | 46 | for res_col, col in [ 47 | ['speedup', 'time_per_it'], 48 | ['memory usage reduction', 'peak_memory'] 49 | ]: 50 | df[res_col] = ( 51 | df 52 | .groupby(['input sequence length']) 53 | .apply(calc_relative_improvement, col=col) 54 | .reset_index(drop=True).values 55 | ) 56 | 57 | for col in df.columns: 58 | if df[col].dtype == 'float64': 59 | df[col] = df[col].round(2) 60 | 61 | out_dir = Path(args.output_dir) 62 | 63 | for y_col in ['speedup', 'memory usage reduction']: 64 | savename = y_col.replace(' ', '_') 65 | 66 | fig = create_plot(df, y_col) 67 | fig.savefig(out_dir / f'benchmark_{savename}.png', bbox_inches='tight') 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--results_file', type=str, default='benchmark_results.csv') 73 | parser.add_argument('--output_dir', type=str, default='figures') 74 | main(parser.parse_args()) -------------------------------------------------------------------------------- /tests/test_adapters.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from flash_perceiver.adapters import ImageAdapter 5 | from flash_perceiver.utils.encodings import NeRFPositionalEncoding 6 | 7 | 8 | @pytest.fixture 9 | def image(): 10 | return torch.randn(32, 3, 128, 128) 11 | 12 | 13 | def test_image_adapter(image): 14 | _, _, w, h = image.shape 15 | 16 | adapter = ImageAdapter(num_channels=3, embed_dim=512) 17 | out = adapter(image) 18 | 19 | assert out.shape == (32, w * h, 512) 20 | 21 | 22 | def test_image_adapter_patching(image): 23 | _, _, w, h = image.shape 24 | 25 | adapter = ImageAdapter( 26 | num_channels=3, 27 | embed_dim=512, 28 | patch_size=(16, 16) 29 | ) 30 | out = adapter(image) 31 | 32 | assert out.shape == (32, w * h / (16 ** 2), 512) 33 | 34 | 35 | def test_image_adapter_pos_encoding(image): 36 | _, _, w, h = image.shape 37 | 38 | pos_encoding = NeRFPositionalEncoding(2) 39 | 40 | adapter = ImageAdapter( 41 | num_channels=3, 42 | embed_dim=512, 43 | pos_encoding=pos_encoding 44 | ) 45 | out = adapter(image) 46 | 47 | assert out.shape == (32, w * h, 512) 48 | 49 | 50 | def test_image_adapter_pos_encoding_with_patching(image): 51 | _, _, w, h = image.shape 52 | 53 | pos_encoding = NeRFPositionalEncoding(2) 54 | 55 | adapter = ImageAdapter( 56 | num_channels=3, 57 | embed_dim=512, 58 | pos_encoding=pos_encoding, 59 | patch_size=(16, 16) 60 | ) 61 | out = adapter(image) 62 | 63 | assert out.shape == (32, w * h / (16 ** 2), 512) 64 | 65 | def test_image_adapter_pos_encoding_with_patching(image): 66 | _, _, w, h = image.shape 67 | 68 | pos_encoding = NeRFPositionalEncoding(2) 69 | 70 | adapter = ImageAdapter( 71 | num_channels=3, 72 | embed_dim=512, 73 | pos_encoding=pos_encoding, 74 | patch_size=(16, 16) 75 | ) 76 | out = adapter(image) 77 | 78 | assert out.shape == (32, w * h / (16 ** 2), 512) 79 | 80 | 81 | @pytest.mark.parametrize('channel_first', [False, True]) 82 | def test_image_adapter_channel_first(image, channel_first): 83 | _, _, w, h = image.shape 84 | 85 | if not channel_first: 86 | image = image.permute(0, 2, 3, 1) 87 | 88 | pos_encoding = NeRFPositionalEncoding(2) 89 | 90 | adapter = ImageAdapter( 91 | num_channels=3, 92 | embed_dim=512, 93 | pos_encoding=pos_encoding, 94 | patch_size=(16, 16), 95 | channel_first=channel_first 96 | ) 97 | out = adapter(image) 98 | 99 | assert out.shape == (32, w * h / (16 ** 2), 512) 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | data/ -------------------------------------------------------------------------------- /flash_perceiver/adapters.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, repeat 2 | from einops.layers.torch import Rearrange 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from flash_perceiver import utils 8 | from flash_perceiver.utils.encodings import BasePositionalEncoding 9 | 10 | 11 | class ImageAdapter(nn.Module): 12 | """ 13 | Adapter for images as input to the Perceiver. 14 | 15 | Can optionally patch the image and add positional encodings. 16 | 17 | Args: 18 | embed_dim: Dimensionality of the final output, the Perceiver input respectively. 19 | num_channels: Number of channels of the input image. 20 | pos_encoding: Positional encoding module to use. 21 | Defaults to `None`, i.e. no positional encoding is applied. 22 | patch_size: Size of the patches to extract from the image. 23 | Can be a single integer or a tuple representing `(width, height)` of the patches. 24 | Default to `None`, i.e. no patching. 25 | channel_first: Whether the input image has the channels first or last. 26 | Defaults to `True`, i.e. channels first. 27 | """ 28 | def __init__( 29 | self, 30 | embed_dim: int, 31 | num_channels: int = 3, 32 | pos_encoding: BasePositionalEncoding | None = None, 33 | patch_size: int | tuple[int, int ] | None = None, 34 | channel_first: bool = True 35 | ): 36 | super().__init__() 37 | 38 | self.embed_dim = embed_dim 39 | self.num_channels = num_channels 40 | self.patch_size = patch_size 41 | self.pos_encoding = pos_encoding 42 | self.channel_first = channel_first 43 | 44 | if patch_size is not None: 45 | if isinstance(patch_size, int): 46 | patch_size = (patch_size, patch_size) 47 | 48 | self.patchify = Rearrange( 49 | 'b (h p1) (w p2) c -> b h w (p1 p2 c)', 50 | p1=self.patch_size[0], 51 | p2=self.patch_size[1] 52 | ) 53 | self.patch_dim = np.prod(patch_size) * num_channels 54 | else: 55 | self.patchify = None 56 | self.patch_dim = num_channels 57 | 58 | if pos_encoding is not None: 59 | self.patch_dim += pos_encoding.out_dim 60 | 61 | self.proj = nn.Sequential( 62 | nn.LayerNorm(self.patch_dim), 63 | nn.Linear(self.patch_dim, embed_dim), 64 | nn.LayerNorm(embed_dim), 65 | ) 66 | 67 | self.pos_grid = None 68 | 69 | def get_pos_grid(self, x): 70 | b, h, w, _ = x.shape 71 | 72 | if self.pos_grid is None or self.pos_grid.shape[0] < w or self.pos_grid.shape[1] < h: 73 | self.pos_grid = utils.meshgrid(h, w).to(x.device) 74 | 75 | pos_grid = self.pos_grid[:h, :w] 76 | pos_grid = repeat(pos_grid, 'h w c -> b h w c', b=b) 77 | 78 | return pos_grid 79 | 80 | def forward(self, x): 81 | assert x.ndim == 4, \ 82 | f'Expected input to have four dimensions but found {x.ndim} dimensions instead' 83 | 84 | if self.channel_first: 85 | x = rearrange(x, 'b c h w -> b h w c') 86 | 87 | assert x.shape[-1] == self.num_channels, \ 88 | f'Expected input to have {self.num_channels} channels but found {x.shape[-1]} channels instead' 89 | 90 | if self.patchify is not None: 91 | x = self.patchify(x) 92 | 93 | if self.pos_encoding is not None: 94 | pos_grid = self.get_pos_grid(x) 95 | x = torch.cat([x, self.pos_encoding(pos_grid)], dim=-1) 96 | 97 | x = self.proj(x) 98 | x = rearrange(x, 'b h w c -> b (h w) c') 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /examples/cifar_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | 6 | from torch.utils.data import DataLoader, random_split 7 | from torchvision.datasets import CIFAR10 8 | 9 | from pytorch_lamb import Lamb 10 | from tqdm import tqdm 11 | from flash_perceiver import utils, Perceiver 12 | from flash_perceiver.adapters import ImageAdapter 13 | from flash_perceiver.utils.encodings import NeRFPositionalEncoding 14 | from flash_perceiver.utils.training import CosineWithWarmupLR 15 | 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 18 | parser.add_argument('--lr', default=5e-4, type=float, help='learning rate') 19 | parser.add_argument('--device', default='cuda') 20 | parser.add_argument('--data_root', default='./data') 21 | parser.add_argument('--epochs', type=int, default=100) 22 | parser.add_argument('--batch_size', type=int, default=128) 23 | args = parser.parse_args() 24 | 25 | print('==> Preparing data..') 26 | 27 | transform_train = transforms.Compose([ 28 | transforms.RandomCrop(32, padding=4), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 32 | ]) 33 | 34 | transform_test = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 37 | ]) 38 | 39 | rng = torch.Generator(device=args.device) 40 | 41 | train_set = CIFAR10(args.data_root, train=True, download=True,transform=transform_train) 42 | test_set = CIFAR10(args.data_root, train=False, download=True,transform=transform_test) 43 | 44 | valid_set, train_set = random_split(train_set, [5000, 45000]) 45 | 46 | train_batches = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) 47 | valid_batches = DataLoader(valid_set, batch_size=args.batch_size, shuffle=False) 48 | test_batches = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) 49 | 50 | print('==> Building model..') 51 | 52 | model = nn.Sequential( 53 | ImageAdapter( 54 | embed_dim=64, 55 | pos_encoding=NeRFPositionalEncoding(2), 56 | ), 57 | Perceiver( 58 | input_dim=64, 59 | depth=1, 60 | output_dim=10, 61 | num_latents=128, 62 | latent_dim=256, 63 | cross_attn_dropout=0.2, 64 | latent_attn_dropout=0.2, 65 | self_per_cross_attn=4 66 | ) 67 | ).to(args.device) 68 | 69 | criterion = nn.CrossEntropyLoss() 70 | optimizer = Lamb( 71 | model.parameters(), 72 | lr=args.lr, 73 | weight_decay=1e-4 74 | ) 75 | scheduler = CosineWithWarmupLR( 76 | optimizer, 77 | training_steps=args.epochs * len(train_batches), 78 | warmup_steps=1000 79 | ) 80 | 81 | print(f'Number of parameters: {utils.numel(model):,}') 82 | 83 | def train(dataset, log_prefix): 84 | model.train() 85 | 86 | with tqdm(dataset) as pbar: 87 | for inputs, targets in pbar: 88 | inputs, targets = inputs.to(args.device), targets.to(args.device) 89 | 90 | optimizer.zero_grad() 91 | 92 | with torch.autocast(args.device): 93 | outputs = model(inputs) 94 | loss = criterion(outputs, targets) 95 | 96 | loss.backward() 97 | 98 | optimizer.step() 99 | scheduler.step() 100 | 101 | acc = outputs.argmax(-1).eq(targets).float().mean().item() 102 | lr = scheduler.get_last_lr()[0] 103 | 104 | pbar.set_description( 105 | f'{log_prefix} | loss: {loss.item():.3f}, acc: {100.0 * acc:.3f}, lr: {lr:.3e}' 106 | ) 107 | 108 | 109 | @torch.no_grad() 110 | def evaluate(dataset, log_prefix='VALID'): 111 | model.eval() 112 | 113 | loss = 0 114 | correct = 0 115 | total = 0 116 | 117 | with torch.autocast(args.device): 118 | for inputs, targets in dataset: 119 | inputs, targets = inputs.to(args.device), targets.to(args.device) 120 | 121 | outputs = model(inputs) 122 | 123 | loss += criterion(outputs, targets).item() 124 | total += targets.size(0) 125 | correct += outputs.argmax(-1).eq(targets).sum().item() 126 | 127 | print(f'{log_prefix} | loss: {loss / total:.3f}, acc: {100.0 * correct / total:.3f}') 128 | 129 | 130 | for epoch in range(args.epochs): 131 | train(train_batches, f'TRAIN | EPOCH {epoch}') 132 | evaluate(valid_batches, f'VALID | EPOCH {epoch}') 133 | 134 | evaluate(test_batches, f'TEST ') -------------------------------------------------------------------------------- /flash_perceiver/utils/encodings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from abc import abstractmethod, ABC 5 | from typing import Optional 6 | from math import pi 7 | from einops import rearrange 8 | 9 | 10 | T = torch.Tensor 11 | 12 | 13 | class BasePositionalEncoding(nn.Module, ABC): 14 | """Base class for positional encoders. 15 | 16 | An implementation set values for in_dim and out_dim. 17 | 18 | Attributes: 19 | in_dim: Expected input dimensionality of the encoder. 20 | out_dim: Output dimensionality of the encoder. 21 | """ 22 | in_dim: int 23 | out_dim: int 24 | 25 | @abstractmethod 26 | def forward(self, x: T) -> T: 27 | """foo""" 28 | raise NotImplementedError 29 | 30 | 31 | class FourierPositionalEncoding(BasePositionalEncoding): 32 | """Projects an input by the given projection matrix before applying a sinus function. 33 | The input will be concatenated along the last axis. 34 | 35 | Args: 36 | proj_matrix: Projection matrix of shape ``(m, n)``. 37 | is_trainable: Whether the projection should be stored as trainable parameter. Default: ``False`` 38 | 39 | Raises: 40 | ValueError: Raised if the given projection matrix does not have two dimensions. 41 | """ 42 | def __init__(self, proj_matrix: T, is_trainable: bool = False): 43 | super().__init__() 44 | 45 | if proj_matrix.ndim != 2: 46 | raise ValueError(f'Expected projection matrix to have two dimensions but found {proj_matrix.ndim}') 47 | 48 | self.is_trainable = is_trainable 49 | 50 | if is_trainable: 51 | self.register_parameter('proj_matrix', nn.Parameter(proj_matrix)) 52 | else: 53 | self.register_buffer('proj_matrix', proj_matrix) 54 | 55 | self.in_dim, self.out_dim = self.proj_matrix.shape 56 | 57 | def forward(self, x: T) -> T: 58 | channels = x.shape[-1] 59 | 60 | assert channels == self.in_dim, \ 61 | f'Expected input to have {self.in_dim} channels but found {channels} channels instead)' 62 | 63 | x = torch.einsum('... i, i j -> ... j', x, self.proj_matrix) 64 | x = 2 * pi * x 65 | 66 | return torch.sin(x) 67 | 68 | 69 | class IdentityPositionalEncoding(BasePositionalEncoding): 70 | """Positional encoder that returns the identity of the input.""" 71 | 72 | def __init__(self, in_dim: int): 73 | super().__init__() 74 | self.in_dim = in_dim 75 | self.out_dim = in_dim 76 | 77 | def forward(self, x: T) -> T: 78 | return x 79 | 80 | 81 | class GaussianFourierFeatureTransform(FourierPositionalEncoding): 82 | """Implements the positional encoder proposed in (Tancik et al., 2020). 83 | 84 | Args: 85 | in_dim: Dimensionality of inputs. 86 | mapping_size: Dimensionality to map inputs to. Default: ``32`` 87 | sigma: SD of the gaussian projection matrix. Default: ``1.0`` 88 | is_trainable: Whether the projection should be stored as trainable parameter. Default: ``False`` 89 | seed: Optional seed for the random number generator. 90 | 91 | Attributes: 92 | in_dim: Expected input dimensionality. 93 | out_dim: Output dimensionality (mapping_size * 2). 94 | """ 95 | def __init__( 96 | self, 97 | in_dim: int, 98 | mapping_size: int = 32, 99 | sigma: float = 1.0, 100 | is_trainable: bool = False, 101 | seed: Optional[int] = None 102 | ): 103 | super().__init__(self.get_proj_matrix(in_dim, mapping_size, sigma, seed=seed), is_trainable=is_trainable) 104 | self.mapping_size = mapping_size 105 | self.sigma = sigma 106 | self.seed = seed 107 | 108 | @classmethod 109 | def get_proj_matrix(cls, in_dim, mapping_size, sigma, seed=None): 110 | generator = None 111 | if seed is not None: 112 | generator = torch.Generator().manual_seed(seed) 113 | return torch.normal(mean=0, std=sigma, size=(in_dim, mapping_size), generator=generator) 114 | 115 | @classmethod 116 | def from_proj_matrix(cls, projection_matrix): 117 | in_dim, mapping_size = projection_matrix.shape 118 | feature_transform = cls(in_dim, mapping_size) 119 | feature_transform.projection_matrix.data = projection_matrix 120 | return feature_transform 121 | 122 | 123 | class NeRFPositionalEncoding(FourierPositionalEncoding): 124 | """Implements the NeRF positional encoding from (Mildenhall et al., 2020). 125 | 126 | Args: 127 | in_dim: Dimensionality of inputs. 128 | num_frequency_bands: Number of frequency bands where the i-th band has frequency :math:`f_{i} = 2^{i}`. 129 | Default: ``10`` 130 | 131 | Attributes: 132 | in_dim: Expected input dimensionality. 133 | out_dim: Output dimensionality (in_dim * n * 2). 134 | """ 135 | def __init__(self, in_dim: int, num_frequency_bands: int = 10): 136 | super().__init__((2.0 ** torch.arange(num_frequency_bands))[None, :]) 137 | self.num_frequency_bands = num_frequency_bands 138 | self.out_dim = num_frequency_bands * 2 * in_dim 139 | 140 | def forward(self, x: T) -> T: 141 | x = rearrange(x, '... -> ... 1') * self.proj_matrix 142 | x = pi * x 143 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 144 | x = rearrange(x, '... i j -> ... (i j)') 145 | return x 146 | 147 | 148 | def get_encoder(name: str, in_dim: int, **kwargs): 149 | encoders = { 150 | 'identity': IdentityPositionalEncoding, 151 | 'gaussian_fourier_features': GaussianFourierFeatureTransform, 152 | 'nerf': NeRFPositionalEncoding 153 | } 154 | 155 | if name not in encoders: 156 | raise ValueError(f'Unknown encoder {name}. Must be one of {list(encoders)}.') 157 | 158 | return encoders[name](in_dim, **kwargs) 159 | -------------------------------------------------------------------------------- /run_benchmarks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | from functools import partial 4 | import gc 5 | import logging 6 | from timeit import default_timer 7 | import torch 8 | import pandas as pd 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | from itertools import islice, product 12 | from tqdm import tqdm 13 | from torch.utils.data import IterableDataset, DataLoader 14 | from torch.cuda import OutOfMemoryError 15 | from perceiver_pytorch import Perceiver as LucidrainsPerceiver 16 | from flash_perceiver import Perceiver 17 | 18 | 19 | sns.set_theme() 20 | 21 | logging.basicConfig(level=logging.INFO, 22 | format='[%(levelname)s] - %(asctime)s - %(message)s', 23 | datefmt='%Y-%m-%d %H:%M:%S') 24 | logger = logging.getLogger() 25 | 26 | 27 | def build_lucidrains_perceiver(config, **kwargs): 28 | return LucidrainsPerceiver( 29 | input_channels=config['input_dim'], 30 | input_axis=1, 31 | num_freq_bands=None, 32 | max_freq=None, 33 | num_latents=config['num_latents'], 34 | latent_dim=config['latent_dim'], 35 | depth=config['depth'], 36 | final_classifier_head=False, 37 | fourier_encode_data=False, 38 | **kwargs 39 | ) 40 | 41 | 42 | def build_flash_perceiver(config, **kwargs): 43 | return Perceiver( 44 | input_dim=config['input_dim'], 45 | depth=config['depth'], 46 | num_latents=config['num_latents'], 47 | latent_dim=config['latent_dim'], 48 | **kwargs 49 | ) 50 | 51 | 52 | num_batches = 100 53 | 54 | default_config = { 55 | 'batch_size': 256, 56 | 'input_dim': 128, 57 | 'input_size': 512, 58 | 'depth': 8, 59 | 'latent_dim': 256, 60 | 'num_latents': 256 61 | } 62 | 63 | benchmark_configs = [ 64 | # Use different batch size to prevent OOM errors 65 | { 66 | 'input_size': [128, 256, 512], 67 | 'batch_size': 256, 68 | }, 69 | { 70 | 'input_size': [1024, 2048, 4096], 71 | 'batch_size': 128, 72 | }, 73 | { 74 | 'input_size': [8192], 75 | 'batch_size': 64, 76 | }, 77 | { 78 | 'input_size': [16384], 79 | 'batch_size': 48, 80 | }, 81 | # { 82 | # 'input_size': [2048, 4096, 8196, 16392], 83 | # 'batch_size': 128, 84 | # }, 85 | 86 | # {'depth': [6, 12, 24]}, 87 | # {'num_latents': [32, 64, 128, 256, 512]}, 88 | # {'masking_rate': [0.0, 0.2, 0.4, 0.6, 0.8]}, 89 | ] 90 | 91 | models = { 92 | 'perceiver-pytorch': build_lucidrains_perceiver, 93 | 'flash-perceiver': build_flash_perceiver, 94 | } 95 | 96 | 97 | class DummyDataset(IterableDataset): 98 | def __init__(self, dim, seq_len, batch_size, mask_rate=None): 99 | self.dim = dim 100 | self.seq_len = seq_len 101 | self.batch_size = batch_size 102 | self.mask_rate = mask_rate 103 | 104 | def __iter__(self): 105 | while True: 106 | yield torch.randn(self.batch_size, self.seq_len, self.dim) 107 | 108 | 109 | def default_list(o): 110 | if isinstance(o, list): 111 | return o 112 | elif isinstance(o, tuple): 113 | return list(o) 114 | else: 115 | return [o] 116 | 117 | 118 | def create_configs(configs): 119 | for config in configs: 120 | config = {k: default_list(v) for k, v in config.items()} 121 | config_tuples = [[(k, v) for v in vs] for k, vs in config.items()] 122 | yield from map(dict, product(*config_tuples)) 123 | 124 | 125 | @contextmanager 126 | def elapsed_timer(): 127 | start = default_timer() 128 | elapser = lambda: default_timer() - start 129 | 130 | yield lambda: elapser() 131 | 132 | end = default_timer() 133 | elapser = lambda: end - start 134 | 135 | 136 | def benchmark_single(model_factory, config, pbar=True, handle_oom=False): 137 | orig_config = config 138 | 139 | while True: 140 | try: 141 | config = {**default_config, **orig_config} 142 | 143 | model = model_factory(config) 144 | dataset = DummyDataset(config['input_dim'], config['input_size'], config['batch_size']) 145 | 146 | data_loader = DataLoader(dataset, batch_size=None) 147 | batches = list(islice(data_loader, num_batches)) 148 | 149 | def run_epoch(_batches): 150 | for batch in _batches: 151 | out = model(batch) 152 | out.mean().backward() 153 | 154 | torch.cuda.synchronize() 155 | 156 | with torch.autocast('cuda'): 157 | # Do some warmup first 158 | run_epoch(batches[:1]) 159 | 160 | if pbar: 161 | batches = tqdm(batches[1:]) 162 | 163 | with elapsed_timer() as elapser: 164 | run_epoch(batches) 165 | return elapser() 166 | 167 | except OutOfMemoryError: 168 | if not handle_oom: 169 | raise 170 | 171 | logger.info('OOM, retrying; reducing batch size from ' 172 | f'{config["batch_size"]} to {config["batch_size"] // 2}') 173 | orig_config["batch_size"] //= 2 174 | 175 | 176 | def reset_all(): 177 | gc.collect() 178 | 179 | torch.cuda.empty_cache() 180 | torch.cuda.reset_peak_memory_stats() 181 | torch.cuda.synchronize() 182 | 183 | 184 | def main(args): 185 | torch.set_default_device(args.device) 186 | torch.set_default_dtype(torch.float16) 187 | 188 | if args.quiet: 189 | logger.removeHandler(logger.handlers[0]) 190 | 191 | results = [] 192 | 193 | for model_name, model_factory in models.items(): 194 | logger.info(f'Benchmarking {model_name}') 195 | 196 | for config in create_configs(benchmark_configs): 197 | logger.info(config) 198 | 199 | reset_all() 200 | 201 | run_time = benchmark_single(model_factory, config, pbar=not args.quiet) 202 | 203 | mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) 204 | 205 | samples_sum = config['batch_size'] * num_batches 206 | 207 | results.append({ 208 | **config, 209 | 'model': model_name, 210 | 'run_time': run_time, 211 | 'peak_memory': mem, 212 | 'it_per_sec': samples_sum / run_time, 213 | 'time_per_it': run_time / samples_sum, 214 | }) 215 | 216 | reset_all() 217 | 218 | df = pd.DataFrame(results) 219 | df.to_csv(args.output_path, index=False) 220 | 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument('--output_path', type=str, default='benchmark_results.csv') 225 | parser.add_argument('--quiet', '-q', action='store_true') 226 | parser.add_argument('--device', default='cuda') 227 | 228 | main(parser.parse_args()) 229 | -------------------------------------------------------------------------------- /tests/test_perceiver_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from flash_perceiver import utils 5 | from flash_perceiver.perceiver import PerceiverBase 6 | 7 | 8 | @pytest.mark.parametrize('input_dim', [32, 64]) 9 | @pytest.mark.parametrize('depth', [1, 4]) 10 | @pytest.mark.parametrize('num_latents', [32]) 11 | @pytest.mark.parametrize('latent_dim', [128]) 12 | @pytest.mark.parametrize('self_per_cross_attn', [1, 2]) 13 | @pytest.mark.parametrize('input_length', [128, 256]) 14 | @pytest.mark.parametrize('batch_size', [32]) 15 | @pytest.mark.parametrize('mask', [False, True]) 16 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 17 | @pytest.mark.parametrize('num_zero_tokens', [None, 0, 32]) 18 | @pytest.mark.parametrize('latent_drop', [0.0, 0.5]) 19 | def test_model( 20 | input_dim, 21 | depth, 22 | num_latents, 23 | latent_dim, 24 | self_per_cross_attn, 25 | input_length, 26 | batch_size, 27 | mask, 28 | use_flash_attn, 29 | num_zero_tokens, 30 | latent_drop 31 | ): 32 | model = PerceiverBase( 33 | input_dim=input_dim, 34 | depth=depth, 35 | num_latents=num_latents, 36 | latent_dim=latent_dim, 37 | self_per_cross_attn=self_per_cross_attn, 38 | use_flash_attn=use_flash_attn, 39 | num_zero_tokens=num_zero_tokens, 40 | latent_drop=latent_drop 41 | ) 42 | model.train() 43 | 44 | x = torch.randn(batch_size, input_length, input_dim) 45 | 46 | if mask: 47 | mask = utils.random_mask(x) 48 | else: 49 | mask = None 50 | 51 | out = model(x, mask=mask) 52 | 53 | assert out.shape == (batch_size, int(num_latents * (1 - latent_drop)), latent_dim) 54 | 55 | 56 | @pytest.mark.parametrize('input_dim', [64]) 57 | @pytest.mark.parametrize('depth', [4]) 58 | @pytest.mark.parametrize('cross_heads', [None, 1, 4]) 59 | @pytest.mark.parametrize('cross_head_dim', [None, 32]) 60 | @pytest.mark.parametrize('latent_heads', [None, 1, 4]) 61 | @pytest.mark.parametrize('latent_head_dim', [None, 32]) 62 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 63 | def test_attn_heads( 64 | input_dim, 65 | depth, 66 | cross_heads, 67 | cross_head_dim, 68 | latent_heads, 69 | latent_head_dim, 70 | use_flash_attn 71 | ): 72 | build_model = lambda: PerceiverBase( 73 | input_dim=input_dim, 74 | depth=depth, 75 | cross_heads=cross_heads, 76 | cross_head_dim=cross_head_dim, 77 | latent_heads=latent_heads, 78 | latent_head_dim=latent_head_dim, 79 | use_flash_attn=use_flash_attn 80 | ) 81 | 82 | if ( 83 | (cross_heads is None and cross_head_dim is None) or 84 | (latent_heads is None and latent_head_dim is None) 85 | ): 86 | with pytest.raises(AssertionError): 87 | build_model() 88 | else: 89 | build_model() 90 | 91 | 92 | 93 | @pytest.mark.parametrize('input_dim', [64]) 94 | @pytest.mark.parametrize('depth', [4]) 95 | @pytest.mark.parametrize('cross_attn_dropout', [0.0, 0.2]) 96 | @pytest.mark.parametrize('latent_attn_dropout', [0.0, 0.2]) 97 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 98 | def test_dropout( 99 | input_dim, 100 | depth, 101 | cross_attn_dropout, 102 | latent_attn_dropout, 103 | use_flash_attn 104 | ): 105 | model = PerceiverBase( 106 | input_dim=input_dim, 107 | depth=depth, 108 | cross_attn_dropout=cross_attn_dropout, 109 | latent_attn_dropout=latent_attn_dropout, 110 | use_flash_attn=use_flash_attn 111 | ) 112 | 113 | input = torch.randn(32, 128, input_dim) 114 | 115 | pass_a = model(input) 116 | pass_b = model(input) 117 | 118 | if cross_attn_dropout > 0.0 or latent_attn_dropout > 0.0: 119 | assert not torch.allclose(pass_a, pass_b) 120 | else: 121 | assert torch.allclose(pass_a, pass_b) 122 | 123 | 124 | @pytest.mark.parametrize('num_latents', [None, 128]) 125 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 126 | def test_num_latents(num_latents, use_flash_attn): 127 | model = PerceiverBase( 128 | input_dim=64, 129 | depth=4, 130 | num_latents=num_latents, 131 | latent_dim=128, 132 | use_flash_attn=use_flash_attn 133 | ) 134 | 135 | data = torch.randn(32, 128, 64) 136 | 137 | if num_latents is None: 138 | with pytest.raises(AssertionError): 139 | model(data) 140 | else: 141 | model(data) 142 | 143 | 144 | @pytest.mark.parametrize('latent_dim', [64, 128]) 145 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 146 | def test_custom_latents(latent_dim, use_flash_attn): 147 | model = PerceiverBase( 148 | input_dim=64, 149 | depth=4, 150 | num_latents=128, 151 | latent_dim=64, 152 | use_flash_attn=use_flash_attn 153 | ) 154 | 155 | latents = torch.randn(128, latent_dim) 156 | data = torch.randn(32, 128, 64) 157 | 158 | if latent_dim != model.latent_dim: 159 | with pytest.raises(AssertionError): 160 | model(data, latents=latents) 161 | else: 162 | out = model(data, latents=latents) 163 | assert out.shape == (32, 128, latent_dim) 164 | 165 | 166 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 167 | def test_weight_tying(use_flash_attn): 168 | build_model = lambda weight_tie_layers: PerceiverBase( 169 | input_dim=64, 170 | depth=4, 171 | weight_tie_layers=weight_tie_layers, 172 | use_flash_attn=use_flash_attn 173 | ) 174 | 175 | weight_tied_model = build_model(True) 176 | not_weight_tied_model = build_model(False) 177 | 178 | assert utils.numel(not_weight_tied_model) > utils.numel(weight_tied_model) 179 | 180 | 181 | @pytest.mark.parametrize('input_dim', [ 182 | 64, 183 | [64, 64], 184 | [64, 128, 256] 185 | ]) 186 | @pytest.mark.parametrize('use_mask', [False, True]) 187 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 188 | def test_multi_input(input_dim, use_mask, use_flash_attn): 189 | mask = None 190 | 191 | if isinstance(input_dim, int): 192 | data = torch.randn(32, 128, input_dim) 193 | depth = 1 194 | 195 | if use_mask: 196 | mask = utils.random_mask(data) 197 | else: 198 | data = [torch.randn(32, 128, dim) for dim in input_dim] 199 | depth = len(input_dim) 200 | 201 | if use_mask: 202 | mask = [utils.random_mask(x) for x in data] 203 | 204 | model = PerceiverBase( 205 | input_dim=input_dim, 206 | depth=depth, 207 | use_flash_attn=use_flash_attn 208 | ) 209 | 210 | out = model(data, mask=mask) 211 | 212 | assert out.shape == (32, model.num_latents, model.latent_dim) 213 | 214 | 215 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 216 | @pytest.mark.parametrize('use_mask', [False, True]) 217 | def test_flash_attn(use_flash_attn, use_mask): 218 | model = PerceiverBase( 219 | input_dim=128, 220 | depth=4, 221 | use_flash_attn=use_flash_attn, 222 | ) 223 | 224 | x = torch.randn(32, 64, 128) 225 | 226 | if use_mask: 227 | mask = utils.random_mask(x) 228 | else: 229 | mask = None 230 | 231 | out = model(x, mask=mask) 232 | 233 | assert out.shape == (32, model.num_latents, model.latent_dim) 234 | 235 | 236 | @pytest.mark.parametrize('with_fa', [False, True]) 237 | @pytest.mark.parametrize('use_mask', [False, True]) 238 | def test_setting_fa(with_fa, use_mask): 239 | model = PerceiverBase( 240 | input_dim=128, 241 | depth=4, 242 | use_flash_attn=with_fa, 243 | ) 244 | 245 | x = torch.randn(32, 64, 128) 246 | 247 | if use_mask: 248 | mask = utils.random_mask(x) 249 | else: 250 | mask = None 251 | 252 | out_before = model(x, mask=mask) 253 | 254 | model.set_flash_attn(not with_fa) 255 | 256 | out_after = model(x, mask=mask) 257 | 258 | assert out_before.shape == (32, model.num_latents, model.latent_dim) 259 | 260 | # Is this high tolerance reasonable? 261 | assert torch.allclose(out_before, out_after, atol=1e-3), \ 262 | str(abs(out_before - out_after)) 263 | 264 | 265 | @pytest.mark.parametrize('return_attn_weights', [False, True]) 266 | @pytest.mark.parametrize('use_flash_attn', [False, True]) 267 | @pytest.mark.parametrize('use_mask', [False, True]) 268 | def test_return_attn_weights(return_attn_weights, use_flash_attn, use_mask): 269 | model = PerceiverBase( 270 | input_dim=128, 271 | depth=4, 272 | latent_dim=256, 273 | use_flash_attn=use_flash_attn, 274 | ) 275 | 276 | x = torch.randn(32, 64, 128) 277 | 278 | if use_mask: 279 | mask = utils.random_mask(x) 280 | else: 281 | mask = None 282 | 283 | 284 | if return_attn_weights and use_flash_attn: 285 | with pytest.raises(NotImplementedError): 286 | out = model(x, return_attn_weights=return_attn_weights, mask=mask) 287 | else: 288 | if return_attn_weights: 289 | out, all_attn_weights = model(x, return_attn_weights=True, mask=mask) 290 | 291 | assert len(all_attn_weights) == model.num_attention_layers 292 | 293 | for i, attn_weights in enumerate(all_attn_weights): 294 | # Even layers are cross-attention, odd layers are self-attention 295 | if i % 2 == 0: 296 | assert attn_weights.shape == (32, model.cross_heads, model.num_latents, x.shape[1]) 297 | else: 298 | assert attn_weights.shape == (32, model.latent_heads, model.num_latents, model.num_latents) 299 | else: 300 | out = model(x, return_attn_weights=False) 301 | 302 | assert out.shape == (32, model.num_latents, model.latent_dim) 303 | 304 | 305 | @pytest.mark.skip(reason='rotary positional embeddings are not yet supported for cross-attention') 306 | def test_rotary_positional_embeddings(): 307 | model = PerceiverBase( 308 | input_dim=128, 309 | depth=4, 310 | cross_rotary_emb_dim=32, 311 | ) 312 | 313 | x = torch.randn(32, 64, 128) 314 | 315 | out = model(x) 316 | 317 | assert out.shape == (32, model.num_latents, model.latent_dim) 318 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | FlashPerceiver 2 | ========================= 3 | 4 | Fast and memory efficient PyTorch implementation of the Perceiver [1, 2, 3] architecture with FlashAttention [4, 5] as attention backend. 5 | 6 | **Features:** 7 | 8 | * :zap: **More than 2x speedup over naive implementation.** 9 | * :zap: **Sub-linear1 memory usage with respect to input sequence length and linear usage with respect to number of latent vectors.** 10 | * :zap: **Out-of-the-box support for rotary positional embeddings [6]** 11 | * :zap: **Uses the new and improved FlashAttention-2 implementation** 12 | * :zap: **Support for multiple inputs and flexible masking** 13 | 14 | 1 For the attention components. See [Performance](#performance) for more information. 15 | 16 | Installation 17 | ------------ 18 | 19 | **Note:** The `pyproject.toml` has recently been removed from the flash-attn repository and so did the PEP 517 compliance. This means that the flash-attn cannot be declared as dependency for this project anymore and thus needs to be manually until the situation changes in the future: 20 | 21 | ```bash 22 | pip install flash-attn --no-build-isolation 23 | ``` 24 | 25 | Afterwards, install the actual `flash-perceiver` package: 26 | 27 | 28 | ```bash 29 | pip install flash-perceiver 30 | ``` 31 | 32 | Usage 33 | ----- 34 | 35 | ### Perceiver 36 | 37 | ![The Perceiver architecture](./figures/perceiver.png) 38 | 39 | ```python 40 | import torch 41 | 42 | from flash_perceiver import Perceiver, utils 43 | 44 | batch_size, seq_len, in_dim = 32, 128, 256 45 | 46 | latent_dim = 512 47 | num_latents = 512 48 | out_dim = 128 49 | 50 | model = Perceiver( 51 | input_dim=in_dim, 52 | depth=8, 53 | output_dim=out_dim, 54 | num_latents=num_latents, 55 | latent_dim=latent_dim, 56 | cross_heads=1, 57 | cross_head_dim=64, 58 | cross_rotary_emb_dim=0, 59 | cross_attn_dropout=0.0, 60 | latent_heads=8, 61 | latent_head_dim=64, 62 | latent_rotary_emb_dim=0, 63 | latent_attn_dropout=0.0, 64 | weight_tie_layers=False, 65 | gated_mlp=True, 66 | self_per_cross_attn=1, 67 | num_zero_tokens=None, 68 | use_flash_attn=True, 69 | ).cuda() 70 | 71 | data = torch.randn(batch_size, seq_len, in_dim, device='cuda') 72 | 73 | # `out_dim` specified; averages and projects output 74 | # Note: FlashAttention only supports half-precision. 75 | # We need to use `torch.autocast` for the forward-pass 76 | with torch.autocast('cuda'): 77 | out = model(data) 78 | 79 | assert out.shape == (32, out_dim) 80 | ``` 81 | 82 | **Multiple inputs** 83 | 84 | A separate input for each cross-attention block can be used by providing a list of inputs to the `forward` method. The number of inputs must correspond to the `depth` configuration of the model. 85 | 86 | By providing a list of integers to the `input_dim` argument in the constructor, each input can be configured to have a different dimension. 87 | 88 | ```python 89 | input_dims = [256, 512] 90 | 91 | model = Perceiver( 92 | input_dim=input_dims, 93 | depth=2, # must equal len(input_dim) 94 | ).cuda() 95 | 96 | inputs = [ 97 | torch.randn(batch_size, seq_len, in_dim, device='cuda') 98 | for in_dim in input_dims 99 | ] 100 | 101 | with torch.autocast('cuda'): 102 | out = model(inputs) 103 | 104 | assert out.shape == (batch_size, num_latents, latent_dim) 105 | ``` 106 | 107 | **Masking** 108 | 109 | A boolean element-wise mask for the input can be provided. All non-True elements will be masked out within the cross-attention operation. If a list of inputs is provided, a list of masks for each input can be provided as well. This can also include `None` values for inputs without a mask. 110 | 111 | ```python 112 | mask = utils.random_mask(data) # [batch_size, seq_len] 113 | 114 | with torch.autocast('cuda'): 115 | out = model(data, mask=mask) 116 | ``` 117 | 118 | **Extract Embeddings** 119 | 120 | If a value for `output_dim` has been provided to the constructor, the final latent vectors will be averaged and then projected to the desired dimension. To extract the representations prior to the projecting step, set `return_embeddings=True`: 121 | 122 | ```python 123 | with torch.autocast('cuda'): 124 | embeds = model(data, return_embeddings=True) 125 | 126 | assert embeds.shape == (32, num_latents, latent_dim) 127 | ``` 128 | 129 | **Custom Latents** 130 | 131 | For some applications it can be useful to have custom sets of latent vectors. For instance, for a multi-task setting, each task could have a separate set of learned latents. 132 | 133 | The `forward` method supports custom latents via the `latents` argument. If not explicitly provided, the module's latent vectors will be used, otherwise the provided ones. These must have shape `[m, latent_dim]` or `[batch_size, n, latent_dim]` where $m$ can be arbitrary. 134 | 135 | To disable initializing random latent vectors as part of the model construction, pass `num_latents=None` to the constructor. 136 | 137 | **Extract Attention Weights** 138 | 139 | > :warning: This is an experimental feature and requires a modified implementation of FlashAttention [until the changes are eventually merged](https://github.com/Dao-AILab/flash-attention/pull/589). 140 | 141 | `return_attn_weights=True` can be passed to the `forward` method of a model to extract the normalized attention weights of each attention layer. A tuple of `(output, attn_weights)` will be returned in this case, where `attn_weights` is a list with one tensor per attention layer. This list follows the pattern `[cross_attn_0, self_attn_0_0, ..., cross_attn_1, self_attn_1_0]` where attention maps for cross-attention layers will have shape `(batch_size, cross_heads, num_latents, seq_len)` and self-attention maps have shape `(batch_size, latent_heads, num_latents, num_latents)`. 142 | 143 | ```python 144 | with torch.autocast('cuda'): 145 | out, all_attn_weights = model(data, return_attn_weights=True) 146 | 147 | for i, attn_weights in enumerate(all_attn_weights): 148 | if i % model.num_attention_layers_per_block == 0: 149 | print('cross-attention map with shape', attn_weights.shape) 150 | else: 151 | print('self-attention map with shape', attn_weights.shape) 152 | 153 | ``` 154 | 155 | 156 | ### PerceiverIO 157 | 158 | The [PerceiverIO](https://arxiv.org/abs/2107.14795) is a variant of the Perceiver architecture where the encoder tower is followed by a decoder module that allows task specific computation of outputs via sets of queries. 159 | 160 | This makes the architecture more flexible and can be used for cases such position specific decoding of values or multi-task settings. 161 | 162 | ![The PerceiverIO architecture](./figures/perceiver-io.png) 163 | 164 | ```python 165 | import torch 166 | 167 | from flash_perceiver import PerceiverIO, utils 168 | 169 | batch_size, seq_len, in_dim = 32, 128, 256 170 | 171 | depth = 8 172 | latent_dim = 512 173 | num_latents = 512 174 | query_dim = 128 175 | num_queries = 32 176 | proj_dim = 64 177 | 178 | model = PerceiverIO( 179 | input_dim=in_dim, 180 | query_dim=query_dim, 181 | depth=depth, 182 | proj_dim=proj_dim, 183 | num_latents=num_latents, 184 | latent_dim=latent_dim, 185 | cross_heads=1, 186 | cross_head_dim=64, 187 | cross_rotary_emb_dim=0, 188 | cross_attn_dropout=0.0, 189 | latent_heads=8, 190 | latent_head_dim=64, 191 | latent_rotary_emb_dim=0, 192 | latent_attn_dropout=0.0, 193 | latent_drop=0.0, 194 | query_heads=1, 195 | query_head_dim=64, 196 | query_rotary_emb_dim=0, 197 | query_attn_dropout=0.0, 198 | weight_tie_layers=False, 199 | gated_mlp=True, 200 | use_flash_attn=True, 201 | ).cuda() 202 | 203 | data = torch.randn(batch_size, seq_len, in_dim, device='cuda') 204 | 205 | # Can be learned or correspond to positions, tokens, etc. 206 | queries = torch.randn(num_queries, query_dim, device='cuda') 207 | 208 | with torch.autocast('cuda'): 209 | out = model(data, queries=queries) 210 | 211 | assert out.shape == (batch_size, num_queries, proj_dim) 212 | ``` 213 | 214 | Examples 215 | -------- 216 | 217 | Other usage examples are provided in the `examples/` folder. 218 | 219 | Performance 220 | ----------- 221 | 222 | The Perceiver is already designed and intended as an attention architecture with sub-quadratic compute and memory complexity in comparison to the quadratic requirements of a vanilla Transformer. 223 | 224 | A naive implementation will have $\mathcal{O}(nm)$ memory usage for the cross-attention modules and $\mathcal{O}(n^2)$ complexity for the self-attention or _latent_ blocks, where $n$ the number of input elements , $m$ the number of latent vectors (fixed hyperparameter) and $n \gg m$ should generally apply. 225 | 226 | FlashAttention can reduce the memory usage to $\mathcal{O}(\sqrt{nm})$ for the cross-attention layers and $\mathcal{O}(m)$ for the latent self-attention layers. However, this only accounts for the computation of the attention mechanism. The input sequence and corresponding keys and values within the cross-attention modules will still grow with $n$. 227 | 228 | Until the latter starts to dominate memory usage, this implementation allows to greatly scale the input sequence length. For instance, 16x larger input lengths can be achieved in comparison to [perceiver-pytorch](https://github.com/lucidrains/perceiver-pytorch) on a RTX 4090, keeping the other hyperparameters fixed (see `run_benchmarks.py` for the exact configuration). 229 | 230 | ### Benchmarks 231 | 232 | Benchmarks against other implementations (currently only [perceiver-pytorch]([perceiver-pytorch](https://github.com/lucidrains/perceiver-pytorch)) can be performed with: 233 | 234 | ```bash 235 | python run_benchmarks.py 236 | ``` 237 | 238 | The script will create a `benchmark_results.csv`. The `create_plots.py` script can then be used to create plots. 239 | 240 | The following data has been obtained with a RTX 4090 and 24GB of VRAM. 241 | 242 | ![Benchmark results on speedup](figures/benchmark_speedup.png) 243 | 244 | ![Benchmark results on memory usage reduction](figures/benchmark_memory_usage_reduction.png) 245 | 246 | **Note:** The batch size for each configuration corresponds to the smallest value that works for all implementations. Especially for longer sequence lengths, this leads to decreasing GPU utilization and thus a lower speedup than theoretically possible. There are some ways to fix this, but my attempts so far have led to distorted results. 247 | 248 | Acknowledgements 249 | ---------------- 250 | 251 | The implementation is inspired by lucidrain's [Perceiver implementation](https://github.com/lucidrains/perceiver-pytorch) and would not have been possible without Tri Dao's [FlashAttention](https://github.com/Dao-AILab/flash-attention). 252 | 253 | Planned features 254 | --------------- 255 | 256 | These are a few features that are either planned or WIP. If you have urgent demand for some of them, feel free to write an issue: 257 | 258 | - [X] Perceiver IO [2] 259 | - [ ] Perceiver AR [3] (or an AR demo in general) 260 | - [X] Demos 261 | - [X] Tests (see `tests/`) 262 | - [X] Allow more flexible cross-attention configurations 263 | - [ ] Benchmarks against other Perceiver implementations, e.g. [DeepMind's](https://github.com/deepmind/deepmind-research/tree/master/perceiver) or [Krasser's](https://github.com/krasserm/perceiver-io) 264 | - [ ] If FA2 is eventuelly merged into PyTorch, drop the flash-attn dependency 265 | - [ ] Configure and provide multiple inputs as dict 266 | - [ ] TensorDict / tensorclass inputs 267 | - [X] Extract attention weights 268 | - [ ] Add fancy badges in README 269 | - [ ] Use custom attention modules for more flexibility 270 | 271 | References 272 | ---------- 273 | 274 | [1] Jaegle, Andrew, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, and Joao Carreira. “Perceiver: General Perception with Iterative Attention.” arXiv, June 22, 2021. http://arxiv.org/abs/2103.03206. 275 | 276 | [2] Jaegle, Andrew, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, et al. “Perceiver IO: A General Architecture for Structured Inputs & Outputs.” arXiv, March 15, 2022. http://arxiv.org/abs/2107.14795. 277 | 278 | [3] Hawthorne, Curtis, Andrew Jaegle, Cătălina Cangea, Sebastian Borgeaud, Charlie Nash, Mateusz Malinowski, Sander Dieleman, et al. “General-Purpose, Long-Context Autoregressive Modeling with Perceiver AR.” arXiv, June 14, 2022. http://arxiv.org/abs/2202.07765. 279 | 280 | [4] Dao, Tri, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv, June 23, 2022. https://doi.org/10.48550/arXiv.2205.14135. 281 | 282 | [5] Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv, July 17, 2023. https://doi.org/10.48550/arXiv.2307.08691. 283 | 284 | [6] Su, Jianlin, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” arXiv, August 8, 2022. https://doi.org/10.48550/arXiv.2104.09864. 285 | -------------------------------------------------------------------------------- /flash_perceiver/perceiver.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Optional 2 | import torch 3 | 4 | from functools import partial 5 | from torch import nn 6 | 7 | from einops import repeat 8 | 9 | from flash_attn.bert_padding import pad_input, unpad_input 10 | from flash_attn.modules.mha import MHA, ParallelMHA 11 | from flash_attn.modules.block import Block 12 | from flash_attn.modules.mlp import Mlp, GatedMlp 13 | from flash_perceiver.utils import cache_fn 14 | 15 | 16 | def patched_mha(base_mha_cls): 17 | class PatchedMHA(base_mha_cls): 18 | """ 19 | Wrapper around FA's MHA to support separate q and kv dim and more API flexibility. 20 | """ 21 | def __init__( 22 | self, 23 | embed_dim: int, 24 | *args, 25 | kv_dim: Optional[int] = None, 26 | num_heads: Optional[int] = 8, 27 | head_dim: Optional[int] = None, 28 | **kwargs 29 | ): 30 | if num_heads is None: 31 | assert head_dim is not None, 'Must specify either num_heads or head_dim' 32 | num_heads = embed_dim // head_dim 33 | 34 | super().__init__(embed_dim, num_heads, *args, **kwargs) 35 | 36 | # Missing attributes 37 | self.causal = kwargs.get('causal', False) 38 | self.dropout = kwargs.get('dropout', 0.0) 39 | self.softmax_scale = kwargs.get('softmax_scale', None) 40 | 41 | self.kv_dim = kv_dim or self.embed_dim 42 | 43 | if head_dim is not None: 44 | self.head_dim = head_dim 45 | 46 | inner_dim = self.num_heads * self.head_dim 47 | linear_cls = self.out_proj.__class__ 48 | 49 | qkv_proj_bias = kwargs.get('qkv_proj_bias', True) 50 | out_proj_bias = kwargs.get('out_proj_bias', True) 51 | 52 | if self.cross_attn: 53 | self.Wq = linear_cls(self.embed_dim, inner_dim, bias=qkv_proj_bias) 54 | self.Wkv = linear_cls(self.kv_dim, 2 * inner_dim, bias=qkv_proj_bias) 55 | else: 56 | self.Wqkv = linear_cls(self.embed_dim, 3 * inner_dim, bias=qkv_proj_bias) 57 | 58 | self.out_proj = linear_cls(inner_dim, self.embed_dim, bias=out_proj_bias) 59 | 60 | return PatchedMHA 61 | 62 | 63 | PatchedMHA = patched_mha(MHA) 64 | PatchedParallelMHA = patched_mha(ParallelMHA) 65 | 66 | T = torch.Tensor 67 | 68 | 69 | class TokenDrop(nn.Module): 70 | def __init__(self, p: float): 71 | super().__init__() 72 | assert 0 <= p < 1, 'Drop probability p must be 0 <= p < 1' 73 | self.p = p 74 | 75 | def forward(self, x: T): 76 | b, n, *_ = x.shape 77 | device = x.device 78 | 79 | if self.training and self.p > 0: 80 | probas = torch.ones(b, n, device=device) / n 81 | keep_indices = torch.multinomial(probas, int((1 - self.p) * n), replacement=False) 82 | batch_indices = torch.arange(b, device=device).unsqueeze(-1) 83 | x = x[batch_indices, keep_indices] 84 | 85 | return x 86 | 87 | 88 | class PerceiverBase(nn.Module): 89 | """ 90 | Base class for FlashAttention-based implementations of Perceiver and PerceiverIO. 91 | 92 | Fast and memory efficient [Perceiver](https://arxiv.org/abs/2103.03206) implementation in PyTorch 93 | with [FlashAttention](https://arxiv.org/abs/2205.14135) as the underlying attention implementation. 94 | 95 | Args: 96 | input_dim: Number of features of the input data. Can be a single integer or a list of integers 97 | to specify different input dimensions for each cross-attention block. 98 | `len(input_dim)` must be equal to `depth` in that case. 99 | depth: Number of cross-self-attention blocks. One such block corresponds to 100 | a cross-attention module followed by `self_per_cross_attn` self-attention modules. 101 | The number of overall attention modules is therefore `depth * (1 + self_per_cross_attn)`. 102 | num_latents: Number of latent vectors. 103 | latent_dim: Dimension of latent vectors. 104 | cross_heads: Number of heads for cross-attention. Defaults to 1. 105 | cross_head_dim: Dimension of cross-attention heads. 106 | cross_rotary_emb_dim: Dimension of cross-attention rotary embeddings. 107 | Defaults to 0 (no rotary embeddings). 108 | cross_attn_dropout: Dropout for cross-attention. 109 | latent_heads: Number of heads for latent self-attention. Defaults to 8. 110 | latent_head_dim: Dimension of latent self-attention heads. 111 | latent_rotary_emb_dim: Dimension of latent self-attention rotary embeddings. 112 | Defaults to 0 (no rotary embeddings). 113 | latent_attn_dropout: Dropout for latent self-attention. 114 | latent_drop: Dropout rate for the latent vectors. 115 | Defaults to 0 (no latent dropout). 116 | weight_tie_layers: Whether to share the weights of the cross-attention and 117 | latent self-attention blocks. Defaults to False. 118 | gated_mlp: Whether to use gated MLPs. Doubles the number of parameters 119 | in those layers. Defaults to True. 120 | self_per_cross_attn: Number of self-attention blocks per cross-attention block. 121 | Defaults to 1. 122 | num_zero_tokens: Number of learned *zero* tokens to prepend to the inputs. 123 | These zero tokens can be seen as alternate tokens to which to attend if no informative 124 | other tokens are available. 125 | The idea that such a mechanism could be useful has been discussed in 126 | [Attention Is Off By One](https://www.evanmiller.org/attention-is-off-by-one.html) and 127 | [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588). 128 | Defaults to `None` (no zero tokens). 129 | use_flash_attn: Whether to use FlashAttention or a naive and thus less efficient attention 130 | implementation. Defaults to True. 131 | """ 132 | def __init__( 133 | self, 134 | *, 135 | input_dim: List[int] | int, 136 | depth: int, 137 | num_latents: Optional[int] = 512, 138 | latent_dim: int = 512, 139 | cross_heads: int = 1, 140 | cross_head_dim: int = 64, 141 | cross_rotary_emb_dim: int = 0, 142 | cross_attn_dropout: float = 0.0, 143 | latent_heads: int = 8, 144 | latent_head_dim: int = 64, 145 | latent_rotary_emb_dim: int = 0, 146 | latent_attn_dropout: float = 0.0, 147 | latent_drop: float = 0.0, 148 | weight_tie_layers: bool = False, 149 | gated_mlp: bool = True, 150 | self_per_cross_attn: int = 1, 151 | num_zero_tokens: int | None = None, 152 | use_flash_attn: bool = True, 153 | ): 154 | super().__init__() 155 | 156 | if isinstance(input_dim, (tuple, list)): 157 | assert len(input_dim) == depth, 'Must specify input_dim for each layer' 158 | assert not weight_tie_layers, 'Cannot weight tie layers with different input dimensions' 159 | 160 | self.input_dims = input_dim 161 | else: 162 | self.input_dims = [input_dim] * depth 163 | 164 | self.num_latents = num_latents 165 | self.latent_dim = latent_dim 166 | self.cross_heads = 1 167 | self.cross_head_dim = 64 168 | self.latent_heads = 8 169 | self.latent_head_dim = 64 170 | self.depth = depth 171 | self.self_per_cross_attn = self_per_cross_attn 172 | self.use_flash_attn = use_flash_attn 173 | 174 | if self.num_latents is not None: 175 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 176 | else: 177 | self.latents = None 178 | 179 | if gated_mlp: 180 | self.mlp_cls = partial(GatedMlp, hidden_features=latent_dim * 4) 181 | else: 182 | self.mlp_cls = Mlp 183 | 184 | self.mha_cls = PatchedMHA 185 | 186 | get_cross_attn_block = lambda in_dim: Block( 187 | dim=latent_dim, 188 | mixer_cls=partial( 189 | self.mha_cls, 190 | kv_dim=in_dim, 191 | num_heads=cross_heads, 192 | head_dim=cross_head_dim, 193 | cross_attn=True, 194 | dropout=cross_attn_dropout, 195 | qkv_proj_bias=False, 196 | rotary_emb_dim=cross_rotary_emb_dim, 197 | use_flash_attn=use_flash_attn, 198 | ), 199 | mlp_cls=self.mlp_cls 200 | ) 201 | 202 | get_self_attn_block = lambda: Block( 203 | dim=latent_dim, 204 | mixer_cls=partial( 205 | self.mha_cls, 206 | num_heads=latent_heads, 207 | head_dim=latent_head_dim, 208 | dropout=latent_attn_dropout, 209 | rotary_emb_dim=latent_rotary_emb_dim, 210 | use_flash_attn=use_flash_attn 211 | ), 212 | mlp_cls=self.mlp_cls 213 | ) 214 | 215 | get_cross_attn_block, get_self_attn_block = map(cache_fn, (get_cross_attn_block, get_self_attn_block)) 216 | 217 | self.layers = nn.ModuleList([]) 218 | self.zero_tokens = nn.ParameterList([]) 219 | 220 | for i, in_dim in enumerate(self.input_dims): 221 | should_cache = i > 0 and weight_tie_layers 222 | cache_args = {'_cache': should_cache} 223 | 224 | self_attns = nn.ModuleList([]) 225 | 226 | for block_idx in range(self_per_cross_attn): 227 | self_attns.append(get_self_attn_block(**cache_args, key=block_idx)) 228 | 229 | self.layers.append(nn.ModuleList([ 230 | get_cross_attn_block(in_dim=in_dim, **cache_args), 231 | self_attns, 232 | ])) 233 | 234 | if num_zero_tokens: 235 | zero_tokens = nn.Parameter(torch.randn(num_zero_tokens, in_dim)) 236 | else: 237 | zero_tokens = None 238 | 239 | self.zero_tokens.append(zero_tokens) 240 | 241 | self.latent_drop = TokenDrop(latent_drop) 242 | 243 | 244 | def set_flash_attn(self, use_flash_attn: bool): 245 | """ 246 | Enable or disbale use of FlashAttention. 247 | """ 248 | for cross_block, self_attn_blocks in self.layers: 249 | cross_block.mixer.set_flash_attn(use_flash_attn) 250 | 251 | for self_attn_block in self_attn_blocks: 252 | self_attn_block.mixer.set_flash_attn(use_flash_attn) 253 | 254 | self.use_flash_attn = use_flash_attn 255 | 256 | @property 257 | def num_attention_layers_per_block(self): 258 | return 1 + self.self_per_cross_attn 259 | 260 | @property 261 | def num_attention_layers(self): 262 | return self.depth * self.num_attention_layers_per_block 263 | 264 | @property 265 | def num_self_attention_layers(self): 266 | return self.depth * self.self_per_cross_attn 267 | 268 | @property 269 | def num_cross_attention_layers(self): 270 | return self.depth 271 | 272 | def _validate_data(self, data: List[T] | T, mask: List[T] | T | None = None): 273 | if isinstance(data, T): 274 | data = [data] * self.depth 275 | mask = [mask] * self.depth 276 | else: 277 | assert len(data) == self.depth, f'Expected {self.depth} inputs, but found {len(data)}' 278 | 279 | if mask is not None: 280 | assert isinstance(mask, (tuple, list)), \ 281 | 'If a list of data tensors is provided, mask must have the same format' 282 | assert len(mask) == self.depth, f'Expected {self.depth} masks, but found {len(mask)}' 283 | else: 284 | mask = [None] * self.depth 285 | 286 | assert all(d.shape[-1] == in_dim for d, in_dim in zip(data, self.input_dims)), \ 287 | 'Data dimensions do not match cross-attention dimensions' 288 | 289 | assert len(set(d.shape[0] for d in data)) == 1, 'All data tensors must have the same batch size' 290 | 291 | return data, mask 292 | 293 | def forward( 294 | self, 295 | data: List[T] | T, 296 | mask: List[T] | T | None = None, 297 | latents: T | None = None, 298 | return_attn_weights: bool = False 299 | ): 300 | """ 301 | Args: 302 | data: Input data which interacts with the latents via cross-attention. 303 | Must have shape `(batch_size, num_tokens, input_dim)`. 304 | If a single tensor is provided, it will be used for all cross-attention layers. 305 | To provide a separate input for each cross-attention block, a list of tensors with length `depth` 306 | can be provided. The different inputs can have different lengths, optional masking per tensor, 307 | and differ in the feature dimension as configured via `input_dim` during model initialization. 308 | mask: Optional boolean mask for the input data of shape `(batch_size, num_tokens)`. 309 | `False` indicates that a given token should not be attended. 310 | Can be a single tensor or a list of tensors, depending on the format of the `data` argument. 311 | In the multi-input case, `None` masks for some of the inputs are allowed. 312 | latents: Optional custom latent vectors. Must be of shape `([batch_size,] num_latents, latent_dim)`. 313 | If not provided, the model's learned latent vectors will be used. 314 | If `None` has been provided as `num_latents` argument during model initialization, custom latents 315 | must be provided. 316 | return_attn_weights: Whether to return the attention weights of the attention modules. 317 | """ 318 | if self.use_flash_attn and return_attn_weights: 319 | raise NotImplementedError( 320 | 'FlashAttention does not support returning attention weights. ' 321 | 'Please disable use of FA with `set_flash_attention(False)`.' 322 | ) 323 | 324 | is_multi_data = isinstance(data, (tuple, list)) 325 | 326 | data, masks = self._validate_data(data, mask) 327 | 328 | batch_size = data[0].shape[0] 329 | 330 | if latents is None: 331 | assert self.latents is not None, \ 332 | 'Must explicitly provide latents if not initialized with num_latents' 333 | latents = self.latents 334 | else: 335 | assert latents.shape[-1] == self.latent_dim, \ 336 | f'Latents must have {self.latent_dim} dimensions, but found {latents.shape[-1]}' 337 | 338 | if latents.ndim == 2: 339 | latents = repeat(latents, 'n d -> b n d', b=batch_size) 340 | 341 | x = self.latent_drop(latents) 342 | 343 | num_latents = x.shape[1] 344 | 345 | mixer_kwargs = {} 346 | cross_block_mixer_kwargs = {} 347 | attn_weights = [] 348 | 349 | def handle_output(args): 350 | if return_attn_weights: 351 | assert isinstance(args, tuple) and len(args) == 2 352 | out, attn_weight = args 353 | attn_weights.append(attn_weight) 354 | return out 355 | else: 356 | return args 357 | 358 | if return_attn_weights: 359 | mixer_kwargs['return_attn_weights'] = True 360 | 361 | for (cross_block, self_attn_blocks), zero_tokens, datum, mask in zip( 362 | self.layers, self.zero_tokens, data, masks 363 | ): 364 | if zero_tokens is not None: 365 | zero_tokens = repeat(zero_tokens, 'n d -> b n d', b=batch_size) 366 | datum = torch.cat([zero_tokens, datum], dim=1) 367 | 368 | if mask is not None: 369 | zero_token_mask = torch.ones( 370 | zero_tokens.shape[:2], device=zero_tokens.device, dtype=torch.bool 371 | ) 372 | mask = torch.cat([zero_token_mask, mask], dim=1) 373 | 374 | if is_multi_data or not cross_block_mixer_kwargs: 375 | cross_block_mixer_kwargs = {'x_kv': datum} 376 | 377 | if mask is not None: 378 | if self.use_flash_attn: 379 | datum, _, cu_seqlens_k, max_seqlen_in_batch_k = unpad_input(datum, mask) 380 | 381 | cross_block_mixer_kwargs = { 382 | 'x_kv': datum, 383 | 'cu_seqlens_k': cu_seqlens_k, 384 | 'max_seqlen_k': max_seqlen_in_batch_k 385 | } 386 | else: 387 | cross_block_mixer_kwargs = { 388 | 'x_kv': datum, 389 | 'key_padding_mask': mask, 390 | } 391 | 392 | # FlashAttention currently does not support key-value-only padding 393 | # We therefore have to _unpad_ the queries (aka latents) as well. 394 | # In the future, this could be used for a Perceiver AR implementation. 395 | # TODO: We could compute the dummy mask tensors for the queries directly here 396 | # without calling the unpad_input function. 397 | if mask is not None and self.use_flash_attn: 398 | x_mask = torch.ones(x.shape[:2], dtype=torch.bool, device=x.device) 399 | x_cross, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(x, x_mask) 400 | 401 | cross_block_mixer_kwargs.update({ 402 | 'cu_seqlens': cu_seqlens, 403 | 'max_seqlen': max_seqlen_in_batch 404 | }) 405 | else: 406 | x_cross = x 407 | 408 | x = handle_output(cross_block( 409 | x_cross, 410 | mixer_kwargs={**mixer_kwargs, **cross_block_mixer_kwargs} 411 | ))[0] 412 | 413 | if mask is not None and self.use_flash_attn: 414 | x = pad_input(x, indices, batch_size, num_latents) 415 | 416 | for self_attn_block in self_attn_blocks: 417 | x = handle_output(self_attn_block(x, mixer_kwargs=mixer_kwargs))[0] 418 | 419 | if return_attn_weights: 420 | return x, attn_weights 421 | 422 | return x 423 | 424 | 425 | class Perceiver(PerceiverBase): 426 | """ 427 | Implementation of the [Perceiver architecture](https://arxiv.org/abs/2103.03206) with compute and 428 | memory efficient [FlashAttention](https://arxiv.org/abs/2205.14135) as underlying attention implementation 429 | 430 | Args: 431 | input_dim: Number of features of the input data. Can be a single integer or a list of integers 432 | to specify different input dimensions for each cross-attention block. 433 | `len(input_dim)` must be equal to `depth` in that case. 434 | depth: Number of cross-self-attention blocks. One such block corresponds to 435 | a cross-attention module followed by `self_per_cross_attn` self-attention modules. 436 | The number of overall attention modules is therefore `depth * (1 + self_per_cross_attn)`. 437 | output_dim: Dimension of output. If `None`, no output projection is applied 438 | and the final latents are returned. 439 | num_latents: Number of latent vectors. 440 | latent_dim: Dimension of latent vectors. 441 | cross_heads: Number of heads for cross-attention. Defaults to 1. 442 | cross_head_dim: Dimension of cross-attention heads. 443 | cross_rotary_emb_dim: Dimension of cross-attention rotary embeddings. 444 | Defaults to 0 (no rotary embeddings). 445 | cross_attn_dropout: Dropout for cross-attention. 446 | latent_heads: Number of heads for latent self-attention. Defaults to 8. 447 | latent_head_dim: Dimension of latent self-attention heads. 448 | latent_rotary_emb_dim: Dimension of latent self-attention rotary embeddings. 449 | Defaults to 0 (no rotary embeddings). 450 | latent_attn_dropout: Dropout for latent self-attention. 451 | latent_drop: Dropout rate for the latent vectors. 452 | Defaults to 0 (no latent dropout). 453 | weight_tie_layers: Whether to share the weights of the cross-attention and 454 | latent self-attention blocks. Defaults to False. 455 | gated_mlp: Whether to use gated MLPs. Doubles the number of parameters 456 | in those layers. Defaults to True. 457 | self_per_cross_attn: Number of self-attention blocks per cross-attention block. 458 | Defaults to 1. 459 | num_zero_tokens: Number of learned *zero* tokens to prepend to the inputs. 460 | These zero tokens can be seen as alternate tokens to which to attend if no informative 461 | other tokens are available. 462 | The idea that such a mechanism could be useful has been discussed in 463 | [Attention Is Off By One](https://www.evanmiller.org/attention-is-off-by-one.html) and 464 | [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588). 465 | Defaults to `None` (no zero tokens). 466 | use_flash_attn: Whether to use FlashAttention or a naive and thus less efficient attention 467 | implementation. Defaults to True. 468 | """ 469 | def __init__( 470 | self, 471 | *, 472 | input_dim: List[int] | int, 473 | depth: int, 474 | output_dim: int | None = None, 475 | output_mode: Literal['average', 'concat', 'first'] = 'average', 476 | num_latents: Optional[int] = 512, 477 | latent_dim: int = 512, 478 | cross_heads: int = 1, 479 | cross_head_dim: int = 64, 480 | cross_rotary_emb_dim: int = 0, 481 | cross_attn_dropout: float = 0.0, 482 | latent_heads: int = 8, 483 | latent_head_dim: int = 64, 484 | latent_rotary_emb_dim: int = 0, 485 | latent_attn_dropout: float = 0.0, 486 | latent_drop: float = 0.0, 487 | weight_tie_layers: bool = False, 488 | gated_mlp: bool = True, 489 | self_per_cross_attn: int = 1, 490 | num_zero_tokens: int | None = None, 491 | use_flash_attn: bool = True, 492 | ): 493 | if latent_drop > 0 and output_mode == 'concat': 494 | raise ValueError('Cannot use latent dropout with output mode concat') 495 | 496 | super().__init__( 497 | input_dim=input_dim, 498 | depth=depth, 499 | num_latents=num_latents, 500 | latent_dim=latent_dim, 501 | cross_heads=cross_heads, 502 | cross_head_dim=cross_head_dim, 503 | cross_rotary_emb_dim=cross_rotary_emb_dim, 504 | cross_attn_dropout=cross_attn_dropout, 505 | latent_heads=latent_heads, 506 | latent_head_dim=latent_head_dim, 507 | latent_rotary_emb_dim=latent_rotary_emb_dim, 508 | latent_attn_dropout=latent_attn_dropout, 509 | latent_drop=latent_drop, 510 | weight_tie_layers=weight_tie_layers, 511 | gated_mlp=gated_mlp, 512 | self_per_cross_attn=self_per_cross_attn, 513 | num_zero_tokens=num_zero_tokens, 514 | use_flash_attn=use_flash_attn 515 | ) 516 | 517 | self.output_dim = output_dim 518 | self.output_mode = output_mode 519 | 520 | if self.output_dim is not None: 521 | if output_mode == 'concat': 522 | assert self.num_latents is not None, \ 523 | 'Must explicitly provide num_latents if output_mode is concat' 524 | inter_dim = self.num_latents * self.latent_dim 525 | else: 526 | inter_dim = self.latent_dim 527 | 528 | self.out_proj = nn.Sequential( 529 | nn.LayerNorm(inter_dim), 530 | nn.Linear(inter_dim, self.output_dim) 531 | ) 532 | else: 533 | self.out_proj = nn.Identity() 534 | 535 | def forward( 536 | self, 537 | data: List[T] | T, 538 | mask: List[T] | T | None = None, 539 | latents: T | None = None, 540 | return_embeddings: bool = False, 541 | return_attn_weights: bool = False 542 | ): 543 | """ 544 | Args: 545 | data: Input data which interacts with the latents via cross-attention. 546 | Must have shape `(batch_size, num_tokens, input_dim)`. 547 | If a single tensor is provided, it will be used for all cross-attention layers. 548 | To provide a separate input for each cross-attention block, a list of tensors with length `depth` 549 | can be provided. The different inputs can have different lengths, optional masking per tensor, 550 | and differ in the feature dimension as configured via `input_dim` during model initialization. 551 | mask: Optional boolean mask for the input data of shape `(batch_size, num_tokens)`. 552 | `False` indicates that a given token should not be attended. 553 | Can be a single tensor or a list of tensors, depending on the format of the `data` argument. 554 | In the multi-input case, `None` masks for some of the inputs are allowed. 555 | latents: Optional custom latent vectors. Must be of shape `([batch_size,] num_latents, latent_dim)`. 556 | If not provided, the model's learned latent vectors will be used. 557 | If `None` has been provided as `num_latents` argument during model initialization, custom latents 558 | must be provided. 559 | return_embeddings: Whether to return the final latent vectors instead of the output projection. 560 | return_attn_weights: Whether to return the attention weights of the attention modules. 561 | """ 562 | outputs = super().forward(data, mask, latents, return_attn_weights) 563 | 564 | if return_attn_weights: 565 | x, attn_weights = outputs 566 | else: 567 | x = outputs 568 | 569 | def make_output(x): 570 | if return_attn_weights: 571 | return x, attn_weights 572 | else: 573 | return x 574 | 575 | if not return_embeddings and self.output_dim is not None: 576 | if self.output_mode == 'average': 577 | x = x.mean(1) 578 | elif self.output_mode == 'concat': 579 | x = x.flatten(1) 580 | elif self.output_mode == 'first': 581 | x = x[:, 0] 582 | else: 583 | raise ValueError(f'Unknown output mode {self.output_mode}. Valid modes are "average", "concat", "first"') 584 | 585 | x = self.out_proj(x) 586 | 587 | return make_output(x) 588 | 589 | 590 | class PerceiverIO(PerceiverBase): 591 | """ 592 | Implementation of the [PerceiverIO architecture](https://arxiv.org/abs/2103.03206) with compute and 593 | memory efficient [FlashAttention](https://arxiv.org/abs/2205.14135) as underlying attention implementation 594 | 595 | Args: 596 | input_dim: Number of features of the input data. Can be a single integer or a list of integers 597 | to specify different input dimensions for each cross-attention block. 598 | `len(input_dim)` must be equal to `depth` in that case. 599 | query_dim: Dimensionality of the query vectors for decoding from the latents. 600 | depth: Number of self-attention blocks. Following the original paper, there is only one 601 | cross-attention layer at the beginning followed by `depth` self-attention layers. 602 | proj_dim: If provided, the final output from the query attention operation will be projected 603 | to `output_dim`. 604 | num_latents: Number of latent vectors. 605 | latent_dim: Dimension of latent vectors. 606 | cross_heads: Number of heads for cross-attention. Defaults to 1. 607 | cross_head_dim: Dimension of cross-attention heads. 608 | cross_rotary_emb_dim: Dimension of cross-attention rotary embeddings. 609 | Defaults to 0 (no rotary embeddings). 610 | cross_attn_dropout: Dropout for cross-attention. 611 | latent_heads: Number of heads for latent self-attention. Defaults to 8. 612 | latent_head_dim: Dimension of latent self-attention heads. 613 | latent_rotary_emb_dim: Dimension of latent self-attention rotary embeddings. 614 | Defaults to 0 (no rotary embeddings). 615 | latent_attn_dropout: Dropout for latent self-attention. 616 | latent_drop: Dropout rate for the latent vectors. 617 | Defaults to 0 (no latent dropout). 618 | query_heads: Number of heads for the latent-query cross-attention. Defaults to 1. 619 | query_head_dim: Dimension of the latent-query cross-attention heads. 620 | query_rotary_emb_dim: Dimension of the rotary embeddings for the latent-query cross-attention layer. 621 | Defaults to 0 (no rotary embeddings). 622 | query_attn_dropout: Dropout for the latent-query cross-attention. 623 | weight_tie_layers: Whether to share the weights of the cross-attention and 624 | latent self-attention blocks. Defaults to False. 625 | gated_mlp: Whether to use gated MLPs. Doubles the number of parameters 626 | in those layers. Defaults to True. 627 | num_zero_tokens: Number of learned *zero* tokens to prepend to the inputs. 628 | These zero tokens can be seen as alternate tokens to which to attend if no informative 629 | other tokens are available. 630 | The idea that such a mechanism could be useful has been discussed in 631 | [Attention Is Off By One](https://www.evanmiller.org/attention-is-off-by-one.html) and 632 | [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588). 633 | Defaults to `None` (no zero tokens). 634 | use_flash_attn: Whether to use FlashAttention or a naive and thus less efficient attention 635 | implementation. Defaults to True. 636 | """ 637 | def __init__( 638 | self, 639 | *, 640 | input_dim: List[int] | int, 641 | query_dim: int, 642 | depth: int, 643 | proj_dim: int | None = None, 644 | num_latents: Optional[int] = 512, 645 | latent_dim: int = 512, 646 | cross_heads: int = 1, 647 | cross_head_dim: int = 64, 648 | cross_rotary_emb_dim: int = 0, 649 | cross_attn_dropout: float = 0.0, 650 | latent_heads: int = 8, 651 | latent_head_dim: int = 64, 652 | latent_rotary_emb_dim: int = 0, 653 | latent_attn_dropout: float = 0.0, 654 | latent_drop: float = 0.0, 655 | query_heads: int = 1, 656 | query_head_dim: int = 64, 657 | query_rotary_emb_dim: int = 0, 658 | query_attn_dropout: float = 0.0, 659 | weight_tie_layers: bool = False, 660 | gated_mlp: bool = True, 661 | num_zero_tokens: int | None = None, 662 | use_flash_attn: bool = True, 663 | ): 664 | super().__init__( 665 | input_dim=input_dim, 666 | depth=1, 667 | num_latents=num_latents, 668 | latent_dim=latent_dim, 669 | cross_heads=cross_heads, 670 | cross_head_dim=cross_head_dim, 671 | cross_rotary_emb_dim=cross_rotary_emb_dim, 672 | cross_attn_dropout=cross_attn_dropout, 673 | latent_heads=latent_heads, 674 | latent_head_dim=latent_head_dim, 675 | latent_rotary_emb_dim=latent_rotary_emb_dim, 676 | latent_attn_dropout=latent_attn_dropout, 677 | latent_drop=latent_drop, 678 | weight_tie_layers=weight_tie_layers, 679 | gated_mlp=gated_mlp, 680 | self_per_cross_attn=depth, 681 | num_zero_tokens=num_zero_tokens, 682 | use_flash_attn=use_flash_attn 683 | ) 684 | 685 | self.query_dim = query_dim 686 | self.proj_dim = proj_dim 687 | 688 | self.query_block = Block( 689 | dim=query_dim, 690 | mixer_cls=partial( 691 | self.mha_cls, 692 | kv_dim=self.latent_dim, 693 | num_heads=query_heads, 694 | head_dim=query_head_dim, 695 | cross_attn=True, 696 | dropout=query_attn_dropout, 697 | qkv_proj_bias=False, 698 | rotary_emb_dim=query_rotary_emb_dim, 699 | use_flash_attn=self.use_flash_attn, 700 | ), 701 | mlp_cls=self.mlp_cls 702 | ) 703 | 704 | if self.proj_dim is not None: 705 | self.out_proj = nn.Sequential( 706 | nn.LayerNorm(self.query_dim), 707 | nn.Linear(self.query_dim, self.proj_dim) 708 | ) 709 | else: 710 | self.out_proj = nn.Identity() 711 | 712 | def forward( 713 | self, 714 | data: List[T] | T, 715 | mask: List[T] | T | None = None, 716 | latents: T | None = None, 717 | queries: T | None = None, 718 | query_mask: T | None = None, 719 | return_attn_weights: bool = False 720 | ): 721 | """ 722 | Args: 723 | data: Input data which interacts with the latents via cross-attention. 724 | Must have shape `(batch_size, num_tokens, input_dim)`. 725 | If a single tensor is provided, it will be used for all cross-attention layers. 726 | To provide a separate input for each cross-attention block, a list of tensors with length `depth` 727 | can be provided. The different inputs can have different lengths, optional masking per tensor, 728 | and differ in the feature dimension as configured via `input_dim` during model initialization. 729 | mask: Optional boolean mask for the input data of shape `(batch_size, num_tokens)`. 730 | `False` indicates that a given token should not be attended. 731 | Can be a single tensor or a list of tensors, depending on the format of the `data` argument. 732 | In the multi-input case, `None` masks for some of the inputs are allowed. 733 | latents: Optional custom latent vectors. Must be of shape `([batch_size,] num_latents, latent_dim)`. 734 | If not provided, the model's learned latent vectors will be used. 735 | If `None` has been provided as `num_latents` argument during model initialization, custom latents 736 | must be provided. 737 | queries: Optional query vectors which will interact with the latents via cross-attention to produce the output. 738 | Must have shape `(batch_size, num_queries, query_dim)`. 739 | query_mask: Not supported yet. 740 | return_attn_weights: Whether to return the attention weights of the attention modules. 741 | """ 742 | outputs = super().forward(data, mask, latents, return_attn_weights) 743 | 744 | if return_attn_weights: 745 | embeds, attn_weights = outputs 746 | else: 747 | embeds = outputs 748 | 749 | def make_output(x): 750 | if return_attn_weights: 751 | return x, attn_weights 752 | return x 753 | 754 | if queries is None: 755 | return make_output(embeds) 756 | 757 | assert query_mask is None, \ 758 | 'query_mask is not supported yet' 759 | 760 | if queries.ndim == 2: 761 | queries = repeat(queries, 'n d -> b n d', b=embeds.shape[0]) 762 | else: 763 | assert queries.ndim == 3 764 | 765 | out = self.query_block(queries, mixer_kwargs={ 766 | 'x_kv': embeds 767 | })[0] 768 | out = self.out_proj(out) 769 | 770 | return make_output(out) 771 | --------------------------------------------------------------------------------