├── perceiver_pytorch ├── __init__.py ├── caching.py ├── common.py ├── modalities.py ├── gated.py ├── experimental.py ├── multi_modality_with_text_perceiver.py ├── perceiver_pytorch.py ├── multi_modality_perceiver.py └── hierarchical_multi_modality_perceiver.py ├── perceiver.png ├── RELEASING.md ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── tests ├── fixtures.py ├── test_perceiver.py ├── compare_params.py ├── test_multimodality_with_text_perceiver.py ├── test_multimodality_perceiver.py └── test_hierarchical_multimodality_perceiver.py ├── LICENSE ├── pyproject.toml ├── .gitignore ├── poetry.lock └── README.md /perceiver_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from perceiver_pytorch.perceiver_pytorch import Perceiver 2 | -------------------------------------------------------------------------------- /perceiver.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fac2003/perceiver-multi-modality-pytorch/HEAD/perceiver.png -------------------------------------------------------------------------------- /RELEASING.md: -------------------------------------------------------------------------------- 1 | # Build package and test deployment to testpypi repo: 2 | ```bash 3 | poetry build 4 | poetry config repositories.testpypi https://test.pypi.org/legacy/ 5 | poetry publish -r testpypi 6 | ``` 7 | 8 | # Build package and deploy to pypi repo: 9 | ```bash 10 | poetry publish 11 | ``` -------------------------------------------------------------------------------- /perceiver_pytorch/caching.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from torch import nn 4 | 5 | 6 | def cache_by_name_fn(f, name=None): 7 | cache = {} 8 | 9 | @wraps(f) 10 | def cached_fn(*args,name, _cache=True, **kwargs): 11 | if not _cache: 12 | return f(*args, **kwargs) 13 | nonlocal cache 14 | if name in cache: 15 | return cache[name] 16 | cache[name] = f(*args, **kwargs) 17 | return cache[name] 18 | 19 | return cached_fn 20 | 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'perceiver-pytorch', 5 | packages = find_packages(), 6 | version = '0.1.19', 7 | license='MIT', 8 | description = 'Perceiver - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/perceiver-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'deep learning', 15 | 'transformer', 16 | 'attention mechanism' 17 | ], 18 | install_requires=[ 19 | 'einops>=0.3', 20 | 'torch>=1.6' 21 | ], 22 | classifiers=[ 23 | 'Development Status :: 4 - Beta', 24 | 'Intended Audience :: Developers', 25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3.6', 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import fixture 3 | 4 | batch_size = 3 5 | num_classes = 32 6 | depth = 2 7 | text_embedding_dim=256 8 | 9 | @fixture() 10 | def targets(): 11 | # batch of 3, 32 frames, 3 channels each frame 260 x 260 12 | targets = torch.randint(high=num_classes, size=(batch_size, 1), requires_grad=False).view(-1) 13 | return targets 14 | 15 | 16 | @fixture() 17 | def image_inputs(): 18 | return torch.rand(size=(3, 260, 260, 3), requires_grad=True) 19 | 20 | 21 | @fixture() 22 | def video_inputs(): 23 | # batch of 3, 32 frames, 3 channels each frame 260 x 260 24 | return torch.rand(size=(3, 32, 260, 260, 3), requires_grad=True) 25 | 26 | @fixture() 27 | def audio_inputs(): 28 | # one second of audio sampled at 44100 (one channel/mono) 29 | return torch.rand(size=(3, 44100,1), requires_grad=True) 30 | 31 | 32 | @fixture() 33 | def text_inputs(): 34 | # text token ids of length 512 (1 channel). 32000 tokens. 35 | return torch.randint(high=32000, size=(3, 512, 1)).long() 36 | 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "perceiver-multi-modality-pytorch" 3 | version = "1.4.0" 4 | description = "A fork of perceiver-pytorch that supports multiple modalities for the Perceiver architecture." 5 | authors = ["Fabien Campagne "] 6 | license = "MIT" 7 | 8 | readme = "README.md" 9 | homepage = "https://github.com/fac2003/perceiver-multi-modality-pytorch" 10 | repository = "https://github.com/fac2003/perceiver-multi-modality-pytorch" 11 | keywords = ["machine learning", "perceiver", "pytorch", "multi-modality", "image", "video", "text", "audio"] 12 | classifiers = [ 13 | "Intended Audience :: Developers", 14 | "License :: OSI Approved :: MIT License", 15 | "Development Status :: 4 - Beta", 16 | "Topic :: Scientific/Engineering", 17 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 18 | "Operating System :: OS Independent", 19 | ] 20 | include = [ 21 | "LICENSE", 22 | ] 23 | 24 | packages = [ 25 | { include = "perceiver_pytorch" }, 26 | ] 27 | 28 | [tool.poetry.dependencies] 29 | python = "^3.7" 30 | torch = "^1.6.0" 31 | einops = '^0.3' 32 | #pytorch-memlab = "^0.2.3" 33 | 34 | 35 | [tool.poetry.dev-dependencies] 36 | 37 | [build-system] 38 | requires = ["poetry-core>=1.0.0"] 39 | build-backend = "poetry.core.masonry.api" 40 | -------------------------------------------------------------------------------- /tests/test_perceiver.py: -------------------------------------------------------------------------------- 1 | from torch.nn import CrossEntropyLoss 2 | from torch.optim import SGD 3 | 4 | from fixtures import * 5 | from perceiver_pytorch import Perceiver 6 | from tests.compare_params import capture_params, compare_parameters 7 | 8 | 9 | def test_all_parameters_change(image_inputs, targets): 10 | model = Perceiver( 11 | input_channels=3, # number of channels for each token of the input 12 | input_axis=2, # number of axis for input data (2 for images, 3 for video) 13 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 14 | max_freq=10., # maximum frequency, hyperparameter depending on how fine the data is 15 | depth=depth, # depth of net 16 | num_latents=256, 17 | # number of latents, or induced set points, or centroids. different papers giving it different names 18 | latent_dim=512, # latent dimension 19 | cross_heads=1, # number of heads for cross attention. paper said 1 20 | latent_heads=8, # number of heads for latent self attention, 8 21 | cross_dim_head=64, 22 | latent_dim_head=64, 23 | num_classes=num_classes, # output number of classes 24 | attn_dropout=0., 25 | ff_dropout=0., 26 | weight_tie_layers=False # whether to weight tie layers (optional, as indicated in the diagram) 27 | ) 28 | 29 | result = model(image_inputs) 30 | 31 | optimizer = SGD( 32 | # Make learning rate large enough that differences in paramerers are clear: 33 | lr=0.1, 34 | params=model.parameters()) 35 | criterion = CrossEntropyLoss() 36 | loss = criterion(result, targets) 37 | loss.backward() 38 | before_params = capture_params(model) 39 | optimizer.step() 40 | after_params = capture_params(model) 41 | compare_parameters(before_params, after_params) 42 | -------------------------------------------------------------------------------- /perceiver_pytorch/common.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import Module 3 | 4 | 5 | class LatentTransformer(Module): 6 | def __init__(self, get_latent_attn, get_latent_ff, num_latent_blocks_per_layer, 7 | weight_tie_layers): 8 | super().__init__() 9 | self.latent_blocks = nn.ModuleList([]) 10 | self.num_latent_blocks_per_layer = num_latent_blocks_per_layer 11 | for latent_block_index in range(num_latent_blocks_per_layer): 12 | should_cache = latent_block_index > 0 and weight_tie_layers 13 | cache_args = {'_cache': should_cache} 14 | self.latent_blocks.append(nn.ModuleList([ 15 | get_latent_attn(**cache_args, name=f"latent_attn_{latent_block_index}"), 16 | get_latent_ff(**cache_args, name=f"latent_ff_{latent_block_index}")])) 17 | 18 | def forward(self, x): 19 | for latent_attn, latent_ff in self.latent_blocks: 20 | x = latent_attn(x) + x 21 | x = latent_ff(x) + x 22 | return x 23 | 24 | 25 | def build_perceiver_layers(layers, depth, get_cross_attn, get_cross_ff, 26 | get_latent_attn, get_latent_ff, 27 | weight_tie_layers, 28 | num_latent_blocks_per_layer=1, 29 | ): 30 | for i in range(depth): 31 | should_cache = i > 0 and weight_tie_layers 32 | cache_args = {'_cache': should_cache} 33 | layers.append(nn.ModuleList([ 34 | get_cross_attn(**cache_args, name="cross_attn"), 35 | get_cross_ff(**cache_args, name="cross_ff"), 36 | LatentTransformer(get_latent_attn, get_latent_ff, 37 | num_latent_blocks_per_layer=num_latent_blocks_per_layer, 38 | weight_tie_layers=weight_tie_layers)])) 39 | -------------------------------------------------------------------------------- /tests/compare_params.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | from torch.nn import Module 4 | 5 | 6 | def capture_params(model: Module) -> List[Tuple[str, List[float]]]: 7 | """ 8 | Take a copy of the model parameters at the time of invocation. 9 | :param model: Pytorch module. 10 | :return: List of named parameters with their value snapshot (list of floats). 11 | """ 12 | return [(n, p.data.tolist()) for n, p in model.named_parameters()] 13 | 14 | 15 | def compare_parameters(before_params: List[Tuple[str, List[float]]], 16 | after_params: List[Tuple[str, List[float]]]): 17 | """ 18 | Check that all model parameters have changed. An assertion will be thrown if any 19 | paramater remains unchanged after the optimization step. 20 | :param before_params: result of capture_params used before optimization step 21 | :param after_params: result of capture_params used after optimization step 22 | :return: None 23 | """ 24 | before_param_values = {} 25 | after_param_values = {} 26 | 27 | for name, p in before_params: 28 | before_param_values[name] = p 29 | for name, p in after_params: 30 | after_param_values[name] = p 31 | 32 | assert list(before_param_values.keys()) == list(after_param_values.keys()) 33 | param_names_did_not_change = [] 34 | param_names_changed = [] 35 | for name in before_param_values.keys(): 36 | after_values = after_param_values[name] 37 | before_values = before_param_values[name] 38 | if after_values == before_values: 39 | param_names_did_not_change.append(name) 40 | else: 41 | param_names_changed.append(name) 42 | assert len( 43 | param_names_did_not_change) == 0, f"some parameters did not change and shoud have: {param_names_did_not_change}," \ 44 | f" these parameters did change: {param_names_changed}" 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /perceiver_pytorch/modalities.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Embedding 6 | 7 | 8 | @dataclass 9 | class InputModality: 10 | name: str 11 | input_channels: int 12 | input_axis: int 13 | num_freq_bands: int 14 | max_freq: float 15 | freq_base: int = 2 16 | 17 | @property 18 | def input_dim(self) -> int: 19 | # Calculate the dimension of this modality. 20 | input_dim = self.input_axis * ((self.num_freq_bands * 2) + 1) + self.input_channels 21 | return input_dim 22 | 23 | 24 | def modality_encoding(batch_size: int, axes, modality_index: int, num_modalities: int, 25 | device=torch.device('cpu')) -> Tensor: 26 | """ 27 | Return one-hot encoding of modality given num_modalities, batch size and axes. 28 | The result need to be compatible with the modality data for concatenation. 29 | :param modality_index: 30 | :param num_modalities: 31 | :return: 32 | """ 33 | one_hot = torch.eye(num_modalities, num_modalities, device=device)[modality_index] 34 | to_expand = [batch_size] 35 | one_hot = one_hot.unsqueeze(0) 36 | for i, axis in enumerate(axes): 37 | one_hot = one_hot.unsqueeze(0) 38 | to_expand.append(axis) 39 | to_expand.append(num_modalities) 40 | 41 | one_hot = one_hot.expand(to_expand) 42 | return one_hot 43 | 44 | 45 | @dataclass 46 | class InputModalityWithEmbedding(InputModality): 47 | embedding: Embedding = None 48 | 49 | def embedding_dim(self, depth: int) -> int: 50 | if not self.embedding: 51 | return self.input_dim 52 | else: 53 | # each layer sees a subset of the embedding output. 54 | pos_encoding_dim = ((self.num_freq_bands * 2) + 1) 55 | return self.embedding.embedding_dim // depth + pos_encoding_dim 56 | 57 | def embedding_for_layer(self, embedded: Tensor, layer_index: int, depth: int): 58 | if not self.embedding: 59 | # This modality does not require embedding, we return the features: 60 | return embedded 61 | assert self.input_axis == 1, "embedding for layer is not supported with axis !=1" 62 | assert embedded.dim()==3, "embedded text tensor must have 3 dimensions: B x L x D" 63 | # embedded has dimension B x L x D, B: batch, L: sequence length, D: full embedding dim. 64 | dim_per_layer = embedded.size(-1) // depth 65 | start_dim_index = layer_index * dim_per_layer 66 | end_dim_index = (layer_index+1) * dim_per_layer 67 | return embedded[:, :, start_dim_index:end_dim_index] 68 | 69 | def maybe_embed(self, data: Tensor): 70 | if self.embedding: 71 | return self.embedding(data.squeeze(2)) 72 | else: 73 | return data 74 | -------------------------------------------------------------------------------- /tests/test_multimodality_with_text_perceiver.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Embedding 2 | import pytest 3 | 4 | from fixtures import * 5 | from perceiver_pytorch.modalities import InputModalityWithEmbedding 6 | from perceiver_pytorch.multi_modality_with_text_perceiver import MultiModalityWithTextPerceiver 7 | 8 | 9 | def test_embedding_for_layer(text_inputs): 10 | text_modality = InputModalityWithEmbedding( 11 | name='text', 12 | input_channels=1, # 1 channel for long ids representing tokens 13 | input_axis=1, # number of axes, 2 for images 14 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 15 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 16 | embedding=Embedding(32000, text_embedding_dim) 17 | ) 18 | assert text_inputs.size() == (3, 512,1) 19 | embedded = text_modality.embedding(text_inputs) 20 | assert embedded.size()==(3, 512,1, 256) 21 | assert text_modality.embedding_for_layer(embedded=embedded.squeeze(2), layer_index=0, depth=4).size() == (3, 512, 256//4) 22 | 23 | 24 | def test_multimodality_forward_image_text(image_inputs, 25 | text_inputs, 26 | targets): 27 | image_modality = InputModalityWithEmbedding( 28 | name='image', 29 | input_channels=3, # number of channels for each token of the input 30 | input_axis=2, # number of axes, 2 for images 31 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 32 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 33 | ) 34 | text_modality = InputModalityWithEmbedding( 35 | name='text', 36 | input_channels=1, # 1 channel for long ids representing tokens 37 | input_axis=1, # number of axes, 2 for images 38 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 39 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 40 | embedding=Embedding(32000, text_embedding_dim) 41 | ) 42 | model = MultiModalityWithTextPerceiver( 43 | modalities=(image_modality, text_modality), 44 | depth=depth, # depth of net 45 | num_latent_blocks_per_layer=2, 46 | num_latents=12, 47 | # number of latents, or induced set points, or centroids. different papers giving it different names 48 | latent_dim=64, # latent dimension 49 | cross_heads=1, # number of heads for cross attention. paper said 1 50 | latent_heads=8, # number of heads for latent self attention, 8 51 | cross_dim_head=64, 52 | latent_dim_head=64, 53 | num_classes=num_classes, # output number of classes 54 | attn_dropout=0., 55 | ff_dropout=0., 56 | weight_tie_layers=True, 57 | # whether to weight tie layers (optional, as indicated in the diagram) 58 | ) 59 | result = model({'image': image_inputs, 60 | 'text': text_inputs}) 61 | assert result is not None 62 | -------------------------------------------------------------------------------- /perceiver_pytorch/gated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | 7 | from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, PreNorm, FeedForward, Attention 8 | 9 | # helpers 10 | 11 | class Residual(nn.Module): 12 | def __init__(self, fn): 13 | super().__init__() 14 | self.fn = fn 15 | 16 | def forward(self, x, **kwargs): 17 | return x + self.fn(x, **kwargs) 18 | 19 | class GRUGating(nn.Module): 20 | def __init__(self, dim, fn): 21 | super().__init__() 22 | self.dim = dim 23 | self.fn = fn 24 | self.gru = nn.GRUCell(dim, dim) 25 | 26 | def forward(self, x, **kwargs): 27 | b, dim = x.shape[0], self.dim 28 | y = self.fn(x, **kwargs) 29 | 30 | gated_output = self.gru( 31 | rearrange(y, '... d -> (...) d'), 32 | rearrange(x, '... d -> (...) d') 33 | ) 34 | 35 | gated_output = rearrange(gated_output, '(b n) d -> b n d', b = b) 36 | return gated_output 37 | 38 | # main class 39 | 40 | class Perceiver(nn.Module): 41 | def __init__( 42 | self, 43 | *, 44 | num_freq_bands, 45 | depth, 46 | max_freq, 47 | freq_base = 2, 48 | input_channels = 3, 49 | input_axis = 2, 50 | num_latents = 512, 51 | cross_dim = 512, 52 | latent_dim = 512, 53 | cross_heads = 1, 54 | latent_heads = 8, 55 | cross_dim_head = 64, 56 | latent_dim_head = 64, 57 | num_classes = 1000, 58 | attn_dropout = 0., 59 | ff_dropout = 0., 60 | weight_tie_layers = False 61 | ): 62 | super().__init__() 63 | self.input_axis = input_axis 64 | self.max_freq = max_freq 65 | self.num_freq_bands = num_freq_bands 66 | self.freq_base = freq_base 67 | 68 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels 69 | 70 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 71 | 72 | get_cross_attn = lambda: GRUGating(latent_dim, PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)) 73 | get_latent_attn = lambda: GRUGating(latent_dim, PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))) 74 | get_cross_ff = lambda: Residual(PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))) 75 | get_latent_ff = lambda: Residual(PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))) 76 | 77 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) 78 | 79 | self.layers = nn.ModuleList([]) 80 | for i in range(depth): 81 | should_cache = i > 0 and weight_tie_layers 82 | cache_args = {'_cache': should_cache} 83 | 84 | self.layers.append(nn.ModuleList([ 85 | get_cross_attn(**cache_args), 86 | get_cross_ff(**cache_args), 87 | get_latent_attn(**cache_args), 88 | get_latent_ff(**cache_args) 89 | ])) 90 | 91 | self.to_logits = nn.Sequential( 92 | nn.LayerNorm(latent_dim), 93 | nn.Linear(latent_dim, num_classes) 94 | ) 95 | 96 | def forward(self, data, mask = None): 97 | b, *axis, _, device = *data.shape, data.device 98 | assert len(axis) == self.input_axis, 'input data must have the right number of axis' 99 | 100 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 101 | 102 | axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) 103 | pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1) 104 | enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) 105 | enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') 106 | enc_pos = repeat(enc_pos, '... -> b ...', b = b) 107 | 108 | # concat to channels of data and flatten axis 109 | 110 | data = torch.cat((data, enc_pos), dim = -1) 111 | data = rearrange(data, 'b ... d -> b (...) d') 112 | 113 | x = repeat(self.latents, 'n d -> b n d', b = b) 114 | 115 | for cross_attn, cross_ff, latent_attn, latent_ff in self.layers: 116 | x = cross_attn(x, context = data, mask = mask) 117 | x = cross_ff(x) 118 | x = latent_attn(x) 119 | x = latent_ff(x) 120 | 121 | x = x.mean(dim = -2) 122 | return self.to_logits(x) 123 | -------------------------------------------------------------------------------- /tests/test_multimodality_perceiver.py: -------------------------------------------------------------------------------- 1 | from fixtures import * 2 | from perceiver_pytorch.modalities import modality_encoding 3 | from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality, \ 4 | MultiModalityPerceiverNoPooling 5 | 6 | 7 | def test_modality_encoding(): 8 | x = modality_encoding(batch_size=3, axes=(32, 12), modality_index=0, num_modalities=2) 9 | assert x.size() == (3, 32, 12, 2) 10 | 11 | 12 | def test_multimodality_forward_image_video(image_inputs, video_inputs, audio_inputs, 13 | targets): 14 | video_modality = InputModality( 15 | name='video', 16 | input_channels=3, # number of channels for each token of the input 17 | input_axis=3, # number of axes, 3 for video) 18 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 19 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 20 | ) 21 | image_modality = InputModality( 22 | name='image', 23 | input_channels=3, # number of channels for each token of the input 24 | input_axis=2, # number of axes, 2 for images 25 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 26 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 27 | ) 28 | audio_modality = InputModality( 29 | name='audio', 30 | input_channels=1, # number of channels for mono audio 31 | input_axis=1, # number of axes, 2 for images 32 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 33 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 34 | ) 35 | model = MultiModalityPerceiver( 36 | modalities=(video_modality, image_modality, audio_modality), 37 | depth=depth, # depth of net 38 | num_latents=12, 39 | # number of latents, or induced set points, or centroids. different papers giving it different names 40 | latent_dim=64, # latent dimension 41 | cross_heads=1, # number of heads for cross attention. paper said 1 42 | latent_heads=8, # number of heads for latent self attention, 8 43 | cross_dim_head=64, 44 | latent_dim_head=64, 45 | num_classes=num_classes, # output number of classes 46 | attn_dropout=0., 47 | ff_dropout=0., 48 | weight_tie_layers=True 49 | # whether to weight tie layers (optional, as indicated in the diagram) 50 | ) 51 | result = model({'image': image_inputs, 52 | 'video': video_inputs, 53 | 'audio': audio_inputs}) 54 | assert result is not None 55 | 56 | 57 | def test_multimodality_forward_image_video_no_pooling(image_inputs, video_inputs, audio_inputs, 58 | targets): 59 | video_modality = InputModality( 60 | name='video', 61 | input_channels=3, # number of channels for each token of the input 62 | input_axis=3, # number of axes, 3 for video) 63 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 64 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 65 | ) 66 | image_modality = InputModality( 67 | name='image', 68 | input_channels=3, # number of channels for each token of the input 69 | input_axis=2, # number of axes, 2 for images 70 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 71 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 72 | ) 73 | audio_modality = InputModality( 74 | name='audio', 75 | input_channels=1, # number of channels for mono audio 76 | input_axis=1, # number of axes, 2 for images 77 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 78 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 79 | ) 80 | num_latents = 12 81 | latent_dim = 17 82 | 83 | model = MultiModalityPerceiverNoPooling( 84 | modalities=(video_modality, image_modality, audio_modality), 85 | depth=depth, # depth of net 86 | num_latents=num_latents, 87 | # number of latents, or induced set points, or centroids. different papers giving it different names 88 | latent_dim=latent_dim, # latent dimension 89 | cross_heads=1, # number of heads for cross attention. paper said 1 90 | latent_heads=8, # number of heads for latent self attention, 8 91 | cross_dim_head=64, 92 | latent_dim_head=64, 93 | attn_dropout=0., 94 | ff_dropout=0., 95 | weight_tie_layers=True 96 | # whether to weight tie layers (optional, as indicated in the diagram) 97 | ) 98 | result = model({'image': image_inputs, 99 | 'video': video_inputs, 100 | 'audio': audio_inputs}) 101 | assert result is not None 102 | assert result.size() == (image_inputs.size()[0], num_latents, latent_dim) 103 | -------------------------------------------------------------------------------- /tests/test_hierarchical_multimodality_perceiver.py: -------------------------------------------------------------------------------- 1 | from fixtures import * 2 | from perceiver_pytorch.hierarchical_multi_modality_perceiver import HierarchicalMultiModalityPerceiver, \ 3 | HierarchicalConfigurator, HierarchicalMultiModalityPerceiverNoPooling 4 | from perceiver_pytorch.modalities import modality_encoding 5 | from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality, \ 6 | MultiModalityPerceiverNoPooling 7 | 8 | 9 | def test_modality_encoding(): 10 | x = modality_encoding(batch_size=3, axes=(32, 12), modality_index=0, num_modalities=2) 11 | assert x.size() == (3, 32, 12, 2) 12 | 13 | 14 | def test_hierarchical_multimodality_forward_image_video(image_inputs, video_inputs, audio_inputs, 15 | targets): 16 | video_modality = InputModality( 17 | name='video', 18 | input_channels=3, # number of channels for each token of the input 19 | input_axis=3, # number of axes, 3 for video) 20 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 21 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 22 | ) 23 | image_modality = InputModality( 24 | name='image', 25 | input_channels=3, # number of channels for each token of the input 26 | input_axis=2, # number of axes, 2 for images 27 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 28 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 29 | ) 30 | audio_modality = InputModality( 31 | name='audio', 32 | input_channels=1, # number of channels for mono audio 33 | input_axis=1, # number of axes, 2 for images 34 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 35 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 36 | ) 37 | model = HierarchicalMultiModalityPerceiver( 38 | modalities=(video_modality, image_modality, audio_modality), 39 | depth=depth, # depth of net 40 | 41 | # number of latents, or induced set points, or centroids. different papers giving it different names 42 | 43 | cross_heads=1, # number of heads for cross attention. paper said 1 44 | latent_heads=8, # number of heads for latent self attention, 8 45 | cross_dim_head=64, 46 | latent_dim_head=64, 47 | num_classes=num_classes, # output number of classes 48 | attn_dropout=0., 49 | ff_dropout=0., 50 | weight_tie_layers=False, 51 | num_latent_blocks_per_layer=1, 52 | configurator=HierarchicalConfigurator(num_latents_begin=12,latent_dim_begin=24, 53 | depth=depth) 54 | # whether to weight tie layers (optional, as indicated in the diagram) 55 | ) 56 | print(model) 57 | result = model({'image': image_inputs, 58 | 'video': video_inputs, 59 | 'audio': audio_inputs}) 60 | assert result is not None 61 | 62 | 63 | def test_multimodality_forward_image_video_no_pooling(image_inputs, video_inputs, audio_inputs, 64 | targets): 65 | video_modality = InputModality( 66 | name='video', 67 | input_channels=3, # number of channels for each token of the input 68 | input_axis=3, # number of axes, 3 for video) 69 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 70 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 71 | ) 72 | image_modality = InputModality( 73 | name='image', 74 | input_channels=3, # number of channels for each token of the input 75 | input_axis=2, # number of axes, 2 for images 76 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 77 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 78 | ) 79 | audio_modality = InputModality( 80 | name='audio', 81 | input_channels=1, # number of channels for mono audio 82 | input_axis=1, # number of axes, 2 for images 83 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 84 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 85 | ) 86 | num_latents = 12 87 | latent_dim = 17 88 | 89 | configurator = HierarchicalConfigurator(num_latents_begin=num_latents, latent_dim_begin=latent_dim, depth=depth) 90 | assert configurator.get_num_latents(depth - 1) == 6 91 | assert configurator.get_latent_dim(depth - 1) == 34 92 | model = HierarchicalMultiModalityPerceiverNoPooling( 93 | modalities=(video_modality, image_modality, audio_modality), 94 | depth=depth, # depth of net 95 | num_latents=num_latents, 96 | # number of latents, or induced set points, or centroids. different papers giving it different names 97 | latent_dim=latent_dim, # latent dimension 98 | cross_heads=1, # number of heads for cross attention. paper said 1 99 | latent_heads=8, # number of heads for latent self attention, 8 100 | cross_dim_head=64, 101 | latent_dim_head=64, 102 | attn_dropout=0., 103 | ff_dropout=0., 104 | weight_tie_layers=False, 105 | configurator=configurator 106 | # whether to weight tie layers (optional, as indicated in the diagram) 107 | ) 108 | 109 | result = model({'image': image_inputs, 110 | 'video': video_inputs, 111 | 'audio': audio_inputs}) 112 | assert result is not None 113 | 114 | assert result.size() == (image_inputs.size()[0], configurator.get_num_latents(depth-1), configurator.get_latent_dim(depth-1)) 115 | -------------------------------------------------------------------------------- /perceiver_pytorch/experimental.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | 7 | from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, PreNorm, FeedForward, Attention 8 | 9 | # linear attention 10 | 11 | class LinearAttention(nn.Module): 12 | def __init__( 13 | self, 14 | dim, 15 | *, 16 | heads = 4, 17 | dim_head = 64, 18 | dropout = 0. 19 | ): 20 | super().__init__() 21 | inner_dim = heads * dim_head 22 | self.heads = heads 23 | self.scale = dim_head ** -0.5 24 | 25 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 26 | self.to_out = nn.Sequential( 27 | nn.Linear(inner_dim, dim), 28 | nn.Dropout(dropout) 29 | ) 30 | 31 | def forward(self, x, mask = None): 32 | h = self.heads 33 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 34 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) 35 | 36 | q *= self.scale 37 | q, k = q.softmax(dim = -1), k.softmax(dim = -2) 38 | 39 | if exists(mask): 40 | k.masked_fill_(mask, 0.) 41 | 42 | context = einsum('b n d, b n e -> b d e', q, k) 43 | out = einsum('b d e, b n d -> b n e', context, v) 44 | out = rearrange(out, ' (b h) n d -> b n (h d)', h = h) 45 | return self.to_out(out) 46 | 47 | # main class 48 | 49 | class Perceiver(nn.Module): 50 | def __init__( 51 | self, 52 | *, 53 | num_freq_bands, 54 | depth, 55 | max_freq, 56 | freq_base = 2, 57 | input_channels = 3, 58 | input_axis = 2, 59 | num_latents = 512, 60 | cross_dim = 512, 61 | latent_dim = 512, 62 | cross_heads = 1, 63 | latent_heads = 8, 64 | cross_dim_head = 64, 65 | latent_dim_head = 64, 66 | num_classes = 1000, 67 | attn_dropout = 0., 68 | ff_dropout = 0., 69 | weight_tie_layers = False 70 | ): 71 | super().__init__() 72 | self.input_axis = input_axis 73 | self.max_freq = max_freq 74 | self.num_freq_bands = num_freq_bands 75 | self.freq_base = freq_base 76 | 77 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels 78 | 79 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 80 | 81 | self.data_proj = nn.Linear(input_dim, input_dim) 82 | 83 | get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim) 84 | get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) 85 | 86 | get_input_attn = lambda: PreNorm(input_dim, LinearAttention(input_dim, dropout = attn_dropout)) 87 | get_rev_cross_attn = lambda: PreNorm(input_dim, Attention(input_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = latent_dim) 88 | get_rev_cross_ff = lambda: PreNorm(input_dim, FeedForward(input_dim, dropout = ff_dropout)) 89 | 90 | get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout)) 91 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) 92 | 93 | get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff)) 94 | 95 | self.layers = nn.ModuleList([]) 96 | for i in range(depth): 97 | should_cache = i > 0 and weight_tie_layers 98 | cache_args = {'_cache': should_cache} 99 | 100 | self.layers.append(nn.ModuleList([ 101 | get_cross_attn(**cache_args), 102 | get_cross_ff(**cache_args), 103 | get_rev_cross_attn(**cache_args), 104 | get_rev_cross_ff(**cache_args), 105 | get_input_attn(**cache_args), 106 | get_latent_attn(**cache_args), 107 | get_latent_ff(**cache_args) 108 | ])) 109 | 110 | self.to_logits = nn.Sequential( 111 | RMSNorm(latent_dim), 112 | nn.Linear(latent_dim, num_classes) 113 | ) 114 | 115 | def forward(self, data, mask = None): 116 | b, *axis, _, device = *data.shape, data.device 117 | assert len(axis) == self.input_axis, 'input data must have the right number of axis' 118 | 119 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 120 | 121 | axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) 122 | pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1) 123 | enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) 124 | enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') 125 | enc_pos = repeat(enc_pos, '... -> b ...', b = b) 126 | 127 | # concat to channels of data and flatten axis 128 | 129 | data = torch.cat((data, enc_pos), dim = -1) 130 | data = rearrange(data, 'b ... d -> b (...) d') 131 | 132 | data = self.data_proj(data) 133 | 134 | x = repeat(self.latents, 'n d -> b n d', b = b) 135 | 136 | for i, (cross_attn, cross_ff, rev_cross_attn, rev_cross_ff, input_attn, latent_attn, latent_ff) in enumerate(self.layers): 137 | is_last = i == (len(self.layers) - 1) 138 | 139 | x = cross_attn(x, context = data, mask = mask) + x 140 | x = cross_ff(x) + x 141 | 142 | if not is_last: 143 | data = input_attn(data, mask = mask) + data 144 | data = rev_cross_attn(data, context = x) + data 145 | data = rev_cross_ff(data) + data 146 | 147 | x = latent_attn(x) + x 148 | x = latent_ff(x) + x 149 | 150 | x = x.mean(dim = -2) 151 | return self.to_logits(x) 152 | -------------------------------------------------------------------------------- /perceiver_pytorch/multi_modality_with_text_perceiver.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Dict, List 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | from torch import Tensor 6 | from torch import nn 7 | from torch.nn import ModuleDict 8 | 9 | from perceiver_pytorch.caching import cache_by_name_fn 10 | from perceiver_pytorch.common import build_perceiver_layers 11 | from perceiver_pytorch.modalities import InputModalityWithEmbedding, modality_encoding 12 | from perceiver_pytorch.perceiver_pytorch import PreNorm, Attention, FeedForward, fourier_encode 13 | 14 | 15 | # An implementation of Perceiver that can accept multiple data modalities in the same forward, including 16 | # modalities which requires embedding. 17 | class MultiModalityWithTextPerceiver(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | modalities: Iterable[InputModalityWithEmbedding], 22 | depth, 23 | num_latents=512, 24 | latent_dim=512, 25 | cross_heads=1, 26 | latent_heads=8, 27 | cross_dim_head=64, 28 | latent_dim_head=64, 29 | num_classes=1000, 30 | attn_dropout=0., 31 | ff_dropout=0., 32 | weight_tie_layers=False, 33 | num_latent_blocks_per_layer=6 34 | ): 35 | super().__init__() 36 | self.depth = depth 37 | self.modalities = {modality.name: modality for modality in modalities} 38 | # we encode modality with one hot encoding, so need one dim per modality: 39 | modality_encoding_dim = sum([1 for _ in modalities]) 40 | # Register any embeddings inside this torch module: 41 | self.embeddings = ModuleDict({modality.name: modality.embedding for modality 42 | in modalities if hasattr(modality, 'embedding') and 43 | modality.embedding}) 44 | 45 | # input_dim is the maximum dimension over all input modalities: 46 | input_dim = max(modality.embedding_dim(self.depth) for modality in modalities) + modality_encoding_dim 47 | self.max_modality_dim = input_dim 48 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 49 | 50 | get_cross_attn = lambda: PreNorm(latent_dim, 51 | Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, 52 | dropout=attn_dropout), context_dim=input_dim) 53 | get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 54 | get_latent_attn = lambda: PreNorm(latent_dim, 55 | Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, 56 | dropout=attn_dropout)) 57 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 58 | 59 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_by_name_fn, ( 60 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) 61 | 62 | self.layers = nn.ModuleList([]) 63 | build_perceiver_layers(self.layers, depth, get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff, 64 | weight_tie_layers, num_latent_blocks_per_layer=num_latent_blocks_per_layer) 65 | 66 | self.to_logits = nn.Sequential( 67 | nn.LayerNorm(latent_dim), 68 | nn.Linear(latent_dim, num_classes) 69 | ) 70 | 71 | def forward(self, multi_modality_data: Dict[str, Tensor], mask=None): 72 | """ 73 | 74 | :param data: a dictionary where keys are modality names and Tensor contain a batch 75 | of modality input data. 76 | :param mask: 77 | :return: 78 | """ 79 | batch_sizes = set() 80 | num_modalities = len(multi_modality_data) 81 | 82 | linearized_data_per_layer: Dict[int, List[Tensor]] = {} 83 | 84 | for modality_index, modality_name in enumerate(sorted(multi_modality_data.keys())): 85 | assert modality_name in self.modalities, f"modality {modality_name} was not defined in constructor" 86 | data = multi_modality_data[modality_name] 87 | modality = self.modalities[modality_name] 88 | b, *axis, _, device = *data.shape, data.device 89 | assert len(axis) == \ 90 | modality.input_axis, f'input data must have the right number of axes for modality {modality_name}. ' \ 91 | f'Expected {modality.input_axis} while forward argument offered {len(axis)}' 92 | batch_sizes.add(b) 93 | assert len(batch_sizes) == 1, "batch size must be the same across all modalities" 94 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 95 | 96 | axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis)) 97 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 98 | enc_pos = fourier_encode(pos, 99 | modality.max_freq, modality.num_freq_bands, modality.freq_base) 100 | enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') 101 | enc_pos = repeat(enc_pos, '... -> b ...', b=b) 102 | 103 | # Figure out padding for this modality, given max dimension across all modalities: 104 | padding_size = self.max_modality_dim - modality.embedding_dim(self.depth) - num_modalities 105 | current_data_modality_shape_without_channels = data.size()[0:-1] 106 | padding = torch.zeros(size=current_data_modality_shape_without_channels + (padding_size,), 107 | device=data.device) 108 | # concat to channels of data and flatten axis 109 | modality_encodings = modality_encoding(b, axis, modality_index, num_modalities, device=device) 110 | 111 | if modality_name in self.embeddings: 112 | # restore modality embedding from this torch module: 113 | modality.embedding = self.embeddings[modality_name] 114 | data = modality.maybe_embed(data) 115 | 116 | for i in range(self.depth): 117 | layer_data = modality.embedding_for_layer(data, i, self.depth) 118 | to_concat = (layer_data, padding, enc_pos, modality_encodings) 119 | 120 | layer_data = torch.cat(to_concat, dim=-1) 121 | layer_data = rearrange(layer_data, 'b ... d -> b (...) d') 122 | 123 | if i not in linearized_data_per_layer: 124 | linearized_data_per_layer[i] = [] 125 | linearized_data_per_layer[i].append(layer_data) 126 | 127 | b = batch_sizes.pop() 128 | x = repeat(self.latents, 'n d -> b n d', b=b) 129 | 130 | for i, (cross_attn, cross_ff, latent_transformer) in enumerate(self.layers): 131 | # Concatenate all the modalities: 132 | data = torch.cat(linearized_data_per_layer[i], dim=1) 133 | x = cross_attn(x, context=data, mask=mask) + x 134 | x = cross_ff(x) + x 135 | x = latent_transformer(x) + x 136 | 137 | x = x.mean(dim=-2) 138 | return self.to_logits(x) 139 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "einops" 3 | version = "0.3.0" 4 | description = "A new flavour of deep learning operations" 5 | category = "main" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [[package]] 10 | name = "numpy" 11 | version = "1.20.2" 12 | description = "NumPy is the fundamental package for array computing with Python." 13 | category = "main" 14 | optional = false 15 | python-versions = ">=3.7" 16 | 17 | [[package]] 18 | name = "torch" 19 | version = "1.8.1" 20 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 21 | category = "main" 22 | optional = false 23 | python-versions = ">=3.6.2" 24 | 25 | [package.dependencies] 26 | numpy = "*" 27 | typing-extensions = "*" 28 | 29 | [[package]] 30 | name = "typing-extensions" 31 | version = "3.7.4.3" 32 | description = "Backported and Experimental Type Hints for Python 3.5+" 33 | category = "main" 34 | optional = false 35 | python-versions = "*" 36 | 37 | [metadata] 38 | lock-version = "1.1" 39 | python-versions = "^3.7" 40 | content-hash = "182ef14637f9332ab6f05c78a95d0c32e23d308f3c65939c7fb373402a0b5303" 41 | 42 | [metadata.files] 43 | einops = [ 44 | {file = "einops-0.3.0-py2.py3-none-any.whl", hash = "sha256:a91c6190ceff7d513d74ca9fd701dfa6a1ffcdd98ea0ced14350197c07f75c73"}, 45 | {file = "einops-0.3.0.tar.gz", hash = "sha256:a3b0935a4556f012cd5fa1851373f63366890a3f6698d117afea55fd2a40c1fc"}, 46 | ] 47 | numpy = [ 48 | {file = "numpy-1.20.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e9459f40244bb02b2f14f6af0cd0732791d72232bbb0dc4bab57ef88e75f6935"}, 49 | {file = "numpy-1.20.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a8e6859913ec8eeef3dbe9aed3bf475347642d1cdd6217c30f28dee8903528e6"}, 50 | {file = "numpy-1.20.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:9cab23439eb1ebfed1aaec9cd42b7dc50fc96d5cd3147da348d9161f0501ada5"}, 51 | {file = "numpy-1.20.2-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:9c0fab855ae790ca74b27e55240fe4f2a36a364a3f1ebcfd1fb5ac4088f1cec3"}, 52 | {file = "numpy-1.20.2-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:61d5b4cf73622e4d0c6b83408a16631b670fc045afd6540679aa35591a17fe6d"}, 53 | {file = "numpy-1.20.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:d15007f857d6995db15195217afdbddfcd203dfaa0ba6878a2f580eaf810ecd6"}, 54 | {file = "numpy-1.20.2-cp37-cp37m-win32.whl", hash = "sha256:d76061ae5cab49b83a8cf3feacefc2053fac672728802ac137dd8c4123397677"}, 55 | {file = "numpy-1.20.2-cp37-cp37m-win_amd64.whl", hash = "sha256:bad70051de2c50b1a6259a6df1daaafe8c480ca98132da98976d8591c412e737"}, 56 | {file = "numpy-1.20.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:719656636c48be22c23641859ff2419b27b6bdf844b36a2447cb39caceb00935"}, 57 | {file = "numpy-1.20.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:aa046527c04688af680217fffac61eec2350ef3f3d7320c07fd33f5c6e7b4d5f"}, 58 | {file = "numpy-1.20.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:2428b109306075d89d21135bdd6b785f132a1f5a3260c371cee1fae427e12727"}, 59 | {file = "numpy-1.20.2-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:e8e4fbbb7e7634f263c5b0150a629342cc19b47c5eba8d1cd4363ab3455ab576"}, 60 | {file = "numpy-1.20.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:edb1f041a9146dcf02cd7df7187db46ab524b9af2515f392f337c7cbbf5b52cd"}, 61 | {file = "numpy-1.20.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:c73a7975d77f15f7f68dacfb2bca3d3f479f158313642e8ea9058eea06637931"}, 62 | {file = "numpy-1.20.2-cp38-cp38-win32.whl", hash = "sha256:6c915ee7dba1071554e70a3664a839fbc033e1d6528199d4621eeaaa5487ccd2"}, 63 | {file = "numpy-1.20.2-cp38-cp38-win_amd64.whl", hash = "sha256:471c0571d0895c68da309dacee4e95a0811d0a9f9f532a48dc1bea5f3b7ad2b7"}, 64 | {file = "numpy-1.20.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4703b9e937df83f5b6b7447ca5912b5f5f297aba45f91dbbbc63ff9278c7aa98"}, 65 | {file = "numpy-1.20.2-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:abc81829c4039e7e4c30f7897938fa5d4916a09c2c7eb9b244b7a35ddc9656f4"}, 66 | {file = "numpy-1.20.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:377751954da04d4a6950191b20539066b4e19e3b559d4695399c5e8e3e683bf6"}, 67 | {file = "numpy-1.20.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6e51e417d9ae2e7848314994e6fc3832c9d426abce9328cf7571eefceb43e6c9"}, 68 | {file = "numpy-1.20.2-cp39-cp39-win32.whl", hash = "sha256:780ae5284cb770ade51d4b4a7dce4faa554eb1d88a56d0e8b9f35fca9b0270ff"}, 69 | {file = "numpy-1.20.2-cp39-cp39-win_amd64.whl", hash = "sha256:924dc3f83de20437de95a73516f36e09918e9c9c18d5eac520062c49191025fb"}, 70 | {file = "numpy-1.20.2-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:97ce8b8ace7d3b9288d88177e66ee75480fb79b9cf745e91ecfe65d91a856042"}, 71 | {file = "numpy-1.20.2.zip", hash = "sha256:878922bf5ad7550aa044aa9301d417e2d3ae50f0f577de92051d739ac6096cee"}, 72 | ] 73 | torch = [ 74 | {file = "torch-1.8.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:f23eeb1a48cc39209d986c418ad7e02227eee973da45c0c42d36b1aec72f4940"}, 75 | {file = "torch-1.8.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4ace9c5bb94d5a7b9582cd089993201658466e9c59ff88bd4e9e08f6f072d1cf"}, 76 | {file = "torch-1.8.1-cp36-cp36m-win_amd64.whl", hash = "sha256:6ffa1e7ae079c7cb828712cb0cdaae5cc4fb87c16a607e6d14526b62c20bcc17"}, 77 | {file = "torch-1.8.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:16f2630d9604c4ee28ea7d6e388e2264cd7bc6031c6ecd796bae3f56b5efa9a3"}, 78 | {file = "torch-1.8.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95b7bbbacc3f28fe438f418392ceeae146a01adc03b29d44917d55214ac234c9"}, 79 | {file = "torch-1.8.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:55137feb2f5a0dc7aced5bba690dcdb7652054ad3452b09a2bbb59f02a11e9ff"}, 80 | {file = "torch-1.8.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8ad2252bf09833dcf46a536a78544e349b8256a370e03a98627ebfb118d9555b"}, 81 | {file = "torch-1.8.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:1388b30fbd262c1a053d6c9ace73bb0bd8f5871b4892b6f3e02d1d7bc9768563"}, 82 | {file = "torch-1.8.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:e7ad1649adb7dc2a450e70a3e51240b84fa4746c69c8f98989ce0c254f9fba3a"}, 83 | {file = "torch-1.8.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:3e4190c04dfd89c59bad06d5fe451446643a65e6d2607cc989eb1001ee76e12f"}, 84 | {file = "torch-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:5c2e9a33d44cdb93ebd739b127ffd7da786bf5f740539539195195b186a05f6c"}, 85 | {file = "torch-1.8.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:c6ede2ae4dcd8214b63e047efabafa92493605205a947574cf358216ca4e440a"}, 86 | {file = "torch-1.8.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ce7d435426f3dd14f95710d779aa46e9cd5e077d512488e813f7589fdc024f78"}, 87 | {file = "torch-1.8.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a50ea8ed900927fb30cadb63aa7a32fdd59c7d7abe5012348dfbe35a8355c083"}, 88 | {file = "torch-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:dac4d10494e74f7e553c92d7263e19ea501742c4825ddd26c4decfa27be95981"}, 89 | {file = "torch-1.8.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:225ee4238c019b28369c71977327deeeb2bd1c6b8557e6fcf631b8866bdc5447"}, 90 | ] 91 | typing-extensions = [ 92 | {file = "typing_extensions-3.7.4.3-py2-none-any.whl", hash = "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"}, 93 | {file = "typing_extensions-3.7.4.3-py3-none-any.whl", hash = "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918"}, 94 | {file = "typing_extensions-3.7.4.3.tar.gz", hash = "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c"}, 95 | ] 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![PyPI](https://img.shields.io/pypi/v/perceiver-multi-modality-pytorch.svg) 2 | ![PyPI](https://img.shields.io/pypi/pyversions/perceiver-multi-modality-pytorch.svg) 3 | ![PyPI](https://img.shields.io/github/license/fac2003/perceiver-mutli-modality-pytorch.svg) 4 | 5 | 6 | ## Multi Modality Perceiver - Pytorch 7 | 8 | Implementation of Perceiver, with support for multi-modality inputs. Fork 9 | of (lucidrains repo)[https://github.com/lucidrains/perceiver-pytorch] extended for multi-modality and support for text 10 | embedding splits chunking across layers. This repo also is closer to the Perceiver preprint because you can use GELU 11 | activation in feedforward, while Lucidrains' repo substitutes GEGLU instead. Set use_gelu to true in the 12 | MultiModalityPerceiver constructor. 13 | 14 | MultiModalityPerceiver also provides means to customize pooling method. You can subclass 15 | perceiver_pytorch.multi_modality_perceiver.MultiModalityPerceiver and override the pool() method, or use the 16 | perceiver_pytorch.multi_modality_perceiver.MultiModalityPerceiverNoPooling implementation that returns the hidden 17 | representation without any pooling. This is useful if you need to train multitask models and want to experiment with, 18 | say, using the first 3 latent outputs to predict each a different task. 19 | 20 | ## Install 21 | 22 | To install the Perceiver implementation with multi-modality (also includes without multi-modality): 23 | 24 | ```bash 25 | $ pip install perceiver-multi-modality-pytorch 26 | ``` 27 | 28 | Import with: 29 | 30 | ```python 31 | from perceiver_pytorch.modalities import modality_encoding 32 | from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality 33 | ``` 34 | 35 | See tests/test_multimodality_perceiver.py or 36 | 37 | ```python 38 | from perceiver_pytorch.modalities import InputModalityWithEmbedding 39 | from perceiver_pytorch.multi_modality_with_text_perceiver import MultiModalityWithTextPerceiver 40 | ``` 41 | 42 | See tests/test_multimodality_with_text_perceiver.py 43 | 44 | To install the Perceiver implementation, follow instructions at the 45 | (lucidrains repo)[https://github.com/lucidrains/perceiver-pytorch]: 46 | 47 | ## Usage 48 | 49 | ```python 50 | import torch 51 | from perceiver_pytorch import Perceiver 52 | 53 | model = Perceiver( 54 | input_channels=3, # number of channels for each token of the input 55 | input_axis=2, # number of axis for input data (2 for images, 3 for video) 56 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 57 | max_freq=10., # maximum frequency, hyperparameter depending on how fine the data is 58 | depth=6, # depth of net 59 | num_latents=256, 60 | # number of latents, or induced set points, or centroids. different papers giving it different names 61 | latent_dim=512, # latent dimension 62 | cross_heads=1, # number of heads for cross attention. paper said 1 63 | latent_heads=8, # number of heads for latent self attention, 8 64 | cross_dim_head=64, 65 | latent_dim_head=64, 66 | num_classes=1000, # output number of classes 67 | attn_dropout=0., 68 | ff_dropout=0., 69 | weight_tie_layers=False # whether to weight tie layers (optional, as indicated in the diagram) 70 | ) 71 | 72 | img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized 73 | 74 | model(img) # (1, 1000) 75 | ``` 76 | 77 | ## Multi-modality perceiver 78 | 79 | An attractive feature of the perceiver architecture is that it can process multiple modalities of data in the same 80 | batch. This is not obvious from the perceiver forward signature shown above, but a relatively modest change can support 81 | processing video, images and audio with a single model, in one forward. This feature is demonstrated by the 82 | MultiModalityPerceiver, contributed by Fabien Campagne. 83 | 84 | ```python 85 | from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality 86 | 87 | image_inputs = torch.rand(size=(3, 260, 260, 3), requires_grad=True) 88 | video_inputs = torch.rand(size=(3, 32, 260, 260, 3), requires_grad=True) 89 | audio_inputs = torch.rand(size=(3, 44100, 1), requires_grad=True) 90 | 91 | video_modality = InputModality( 92 | name='video', 93 | input_channels=3, # number of channels for each token of the input 94 | input_axis=3, # number of axes, 3 for video) 95 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 96 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 97 | ) 98 | image_modality = InputModality( 99 | name='image', 100 | input_channels=3, # number of channels for each token of the input 101 | input_axis=2, # number of axes, 2 for images 102 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 103 | max_freq=4., # maximum frequency, hyperparameter depending on how fine the data is 104 | ) 105 | audio_modality = InputModality( 106 | name='audio', 107 | input_channels=1, # number of channels for mono audio 108 | input_axis=1, # number of axes, 2 for images 109 | num_freq_bands=6, # number of freq bands, with original value (2 * K + 1) 110 | max_freq=8., # maximum frequency, hyperparameter depending on how fine the data is 111 | ) 112 | model = MultiModalityPerceiver( 113 | modalities=(video_modality, image_modality, audio_modality), 114 | depth=8, # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver 115 | num_latents=12, 116 | # number of latents, or induced set points, or centroids. different papers giving it different names 117 | latent_dim=64, # latent dimension 118 | cross_heads=1, # number of heads for cross attention. paper said 1 119 | latent_heads=8, # number of heads for latent self attention, 8 120 | cross_dim_head=64, 121 | latent_dim_head=64, 122 | num_classes=1000, # output number of classes 123 | attn_dropout=0., 124 | ff_dropout=0., 125 | weight_tie_layers=True, 126 | num_latent_blocks_per_layer=6 # Note that this parameter is 1 in the original Lucidrain implementation 127 | # whether to weight tie layers (optional, as indicated in the diagram) 128 | ) 129 | result = model({'image': image_inputs, 130 | 'video': video_inputs, 131 | 'audio': audio_inputs}) 132 | ``` 133 | 134 | ### Text perceiver 135 | 136 | While the Perceiver architecture described by [jaegle2021perceiver] could support text if text was embedded and each 137 | dimension of the embedding provided as a channel in the input, this introduces a mismatch between the text embedding 138 | dimension (typically large, 512/768 or more) and the number of channels used for video and images (typically 3 channels, 139 | one for red, green and blue), or audio 140 | (1 for mono or 2 for stereo channels). When training text embeddings from scratch, this creates an opportunity, because 141 | there should be no need for the perceiver to attend to the entire text embedding in each layer. If we split the text 142 | embedding into as many chunks as there are layers in a perceiver, we reduce how much we need to pad other modalities, 143 | and introduce a structure to the learned embeddings, were parts of the text embedding can specialize according to the 144 | needs of each layer. The perceiver implementation provided in this repo can be used to explore the question of whether 145 | splitting text embeddings across layers is beneficial (you would compare the performance of 146 | MultiModalityWithTextPerceiver with that of MultiModalityPerceiver). 147 | 148 | ## Citations 149 | 150 | ```bibtex 151 | @misc{jaegle2021perceiver, 152 | title = {Perceiver: General Perception with Iterative Attention}, 153 | author = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira}, 154 | year = {2021}, 155 | eprint = {2103.03206}, 156 | archivePrefix = {arXiv}, 157 | primaryClass = {cs.CV} 158 | } 159 | @misc{campagne2021textperceiver, 160 | title = {Adapting Perceiver for learning with text modalities}, 161 | author = {Fabien Campagne}, 162 | year = {2021}, 163 | eprint = {unpublished results}, 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /perceiver_pytorch/perceiver_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import pi, log 2 | from functools import wraps, partial 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | 10 | # helpers 11 | from torch.nn import GELU 12 | from torch.utils.checkpoint import checkpoint 13 | 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | 19 | def default(val, d): 20 | return val if exists(val) else d 21 | 22 | 23 | def cache_fn(f): 24 | cache = None 25 | 26 | @wraps(f) 27 | def cached_fn(*args, _cache=True, **kwargs): 28 | if not _cache: 29 | return f(*args, **kwargs) 30 | nonlocal cache 31 | if cache is not None: 32 | return cache 33 | cache = f(*args, **kwargs) 34 | return cache 35 | 36 | return cached_fn 37 | 38 | 39 | def fourier_encode(x, max_freq, num_bands=4, base=2): 40 | x = x.unsqueeze(-1) 41 | device, dtype, orig_x = x.device, x.dtype, x 42 | 43 | scales = torch.logspace(1., log(max_freq / 2) / log(base), num_bands, base=base, device=device, dtype=dtype) 44 | scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] 45 | 46 | x = x * scales * pi 47 | x = torch.cat([x.sin(), x.cos()], dim=-1) 48 | x = torch.cat((x, orig_x), dim=-1) 49 | return x 50 | 51 | 52 | # helper classes 53 | 54 | class PreNorm(nn.Module): 55 | def __init__(self, dim, fn, context_dim=None): 56 | super().__init__() 57 | self.fn = fn 58 | self.norm = nn.LayerNorm(dim) 59 | self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None 60 | 61 | def forward(self, x, **kwargs): 62 | x = self.norm(x) 63 | 64 | if exists(self.norm_context): 65 | context = kwargs['context'] 66 | normed_context = self.norm_context(context) 67 | kwargs.update(context=normed_context) 68 | 69 | return self.fn(x, **kwargs) 70 | 71 | 72 | class GEGLU(nn.Module): 73 | def forward(self, x): 74 | x, gates = x.chunk(2, dim=-1) 75 | return x * F.gelu(gates) 76 | 77 | 78 | class FeedForward(nn.Module): 79 | def __init__(self, dim, mult=4, dropout=0.): 80 | super().__init__() 81 | self.net = nn.Sequential( 82 | nn.Linear(dim, dim * mult * 2), 83 | GEGLU(), 84 | nn.Dropout(dropout), 85 | nn.Linear(dim * mult, dim) 86 | ) 87 | 88 | def forward(self, x): 89 | return self.net(x) 90 | 91 | 92 | class FeedForwardGELU(nn.Module): 93 | def __init__(self, dim, mult=4, dropout=0.): 94 | super().__init__() 95 | self.net = nn.Sequential( 96 | nn.Linear(dim, dim * mult), 97 | GELU(), 98 | nn.Dropout(dropout), 99 | nn.Linear(dim * mult, dim) 100 | ) 101 | 102 | def forward(self, x): 103 | return self.net(x) 104 | 105 | 106 | class Attention(nn.Module): 107 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 108 | super().__init__() 109 | inner_dim = dim_head * heads 110 | context_dim = default(context_dim, query_dim) 111 | 112 | self.scale = dim_head ** -0.5 113 | self.heads = heads 114 | 115 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 116 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 117 | 118 | self.to_out = nn.Sequential( 119 | nn.Linear(inner_dim, query_dim), 120 | nn.Dropout(dropout) 121 | ) 122 | 123 | def forward(self, x, context=None, mask=None): 124 | h = self.heads 125 | 126 | q = self.to_q(x) 127 | context = default(context, x) 128 | k, v = self.to_kv(context).chunk(2, dim=-1) 129 | # Cast query and keys to float 32 to avoid instability as attention weights grow 130 | # during training, per https://twitter.com/tsuname/status/1430653484827697155?s=20 131 | k = k.float() 132 | q = q.float() 133 | 134 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 135 | 136 | sim = checkpoint(partial(einsum,'b i d, b j d -> b i j'), q, k) * self.scale 137 | 138 | if exists(mask): 139 | mask = rearrange(mask, 'b ... -> b (...)') 140 | max_neg_value = -torch.finfo(sim.dtype).max 141 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 142 | sim.masked_fill_(~mask, max_neg_value) 143 | 144 | # attention, what we cannot get enough of 145 | attn = sim.softmax(dim=-1) 146 | 147 | out = checkpoint(partial(einsum,'b i j, b j d -> b i d'), attn, v) 148 | # cast back to input type: 149 | out = out.type(x.dtype) 150 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 151 | return self.to_out(out) 152 | 153 | 154 | # main class 155 | 156 | class Perceiver(nn.Module): 157 | def __init__( 158 | self, 159 | *, 160 | num_freq_bands, 161 | depth, 162 | max_freq, 163 | freq_base=2, 164 | input_channels=3, 165 | input_axis=2, 166 | num_latents=512, 167 | latent_dim=512, 168 | cross_heads=1, 169 | latent_heads=8, 170 | cross_dim_head=64, 171 | latent_dim_head=64, 172 | num_classes=1000, 173 | attn_dropout=0., 174 | ff_dropout=0., 175 | weight_tie_layers=False 176 | ): 177 | super().__init__() 178 | self.input_axis = input_axis 179 | self.max_freq = max_freq 180 | self.num_freq_bands = num_freq_bands 181 | self.freq_base = freq_base 182 | 183 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels 184 | 185 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 186 | 187 | get_cross_attn = lambda: PreNorm(latent_dim, 188 | Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, 189 | dropout=attn_dropout), context_dim=input_dim) 190 | get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 191 | get_latent_attn = lambda: PreNorm(latent_dim, 192 | Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, 193 | dropout=attn_dropout)) 194 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 195 | 196 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, ( 197 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) 198 | 199 | self.layers = nn.ModuleList([]) 200 | for i in range(depth): 201 | should_cache = i > 0 and weight_tie_layers 202 | cache_args = {'_cache': should_cache} 203 | 204 | self.layers.append(nn.ModuleList([ 205 | get_cross_attn(**cache_args), 206 | get_cross_ff(**cache_args), 207 | get_latent_attn(**cache_args), 208 | get_latent_ff(**cache_args) 209 | ])) 210 | 211 | self.to_logits = nn.Sequential( 212 | nn.LayerNorm(latent_dim), 213 | nn.Linear(latent_dim, num_classes) 214 | ) 215 | 216 | def forward(self, data, mask=None): 217 | b, *axis, _, device = *data.shape, data.device 218 | assert len(axis) == self.input_axis, 'input data must have the right number of axis' 219 | 220 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 221 | 222 | axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis)) 223 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 224 | enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base=self.freq_base) 225 | enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') 226 | enc_pos = repeat(enc_pos, '... -> b ...', b=b) 227 | 228 | # concat to channels of data and flatten axis 229 | 230 | data = torch.cat((data, enc_pos), dim=-1) 231 | data = rearrange(data, 'b ... d -> b (...) d') 232 | 233 | x = repeat(self.latents, 'n d -> b n d', b=b) 234 | 235 | for cross_attn, cross_ff, latent_attn, latent_ff in self.layers: 236 | x = cross_attn(x, context=data, mask=mask) + x 237 | x = cross_ff(x) + x 238 | x = latent_attn(x) + x 239 | x = latent_ff(x) + x 240 | 241 | x = x.mean(dim=-2) 242 | return self.to_logits(x) 243 | -------------------------------------------------------------------------------- /perceiver_pytorch/multi_modality_perceiver.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Dict, List 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | from torch import Tensor 6 | from torch import nn 7 | from torch.nn import Identity 8 | 9 | from perceiver_pytorch.caching import cache_by_name_fn 10 | from perceiver_pytorch.modalities import InputModality, modality_encoding 11 | from perceiver_pytorch.perceiver_pytorch import PreNorm, Attention, FeedForward, cache_fn, fourier_encode, \ 12 | FeedForwardGELU 13 | from perceiver_pytorch.common import build_perceiver_layers 14 | 15 | 16 | # An implementation of Perceiver that can accept multiple data modalities in the same forward. 17 | class MultiModalityPerceiver(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | modalities: Iterable[InputModality], 22 | depth, 23 | num_latents=512, 24 | latent_dim=512, 25 | cross_heads=1, 26 | latent_heads=8, 27 | cross_dim_head=64, 28 | latent_dim_head=64, 29 | num_classes=None, 30 | attn_dropout=0., 31 | ff_dropout=0., 32 | weight_tie_layers=False, 33 | num_latent_blocks_per_layer=1, 34 | use_gelu: bool = False, 35 | ): 36 | """ 37 | 38 | :param modalities: 39 | :param depth: Number of times the perceiver will perform cross-attention between latent and input. 40 | :param num_latents: 41 | :param latent_dim: 42 | :param cross_heads: 43 | :param latent_heads: 44 | :param cross_dim_head: 45 | :param latent_dim_head: 46 | :param num_classes: Number of classes to predict, or if None, return the hidden state (num latents x hidden_dim) 47 | :param attn_dropout: 48 | :param ff_dropout: 49 | :param weight_tie_layers: True: share weights across layers, False no shared weights. 50 | :param num_latent_blocks_per_layer: Number of blocks in the latent transformer. 51 | :param use_gelu: Use GELU activation like the Perceiver preprint indicates. False, 52 | with Lucidrains' GEGLU activation in feed forward instead. 53 | 54 | """ 55 | super().__init__() 56 | self.modalities = {modality.name: modality for modality in modalities} 57 | # we encode modality with one hot encoding, so need one dim per modality: 58 | modality_encoding_dim = sum([1 for _ in modalities]) 59 | # input_dim is the maximum dimension over all input modalities: 60 | input_dim = max(modality.input_dim for modality in modalities) + modality_encoding_dim 61 | self.max_modality_dim = input_dim 62 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 63 | ff_type = FeedForwardGELU if use_gelu else FeedForward 64 | get_cross_attn = lambda: PreNorm(latent_dim, 65 | Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, 66 | dropout=attn_dropout), context_dim=input_dim) 67 | get_cross_ff = lambda: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout)) 68 | get_latent_attn = lambda: PreNorm(latent_dim, 69 | Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, 70 | dropout=attn_dropout)) 71 | get_latent_ff = lambda: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout)) 72 | 73 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_by_name_fn, ( 74 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) 75 | 76 | self.layers = nn.ModuleList([]) 77 | build_perceiver_layers(self.layers, depth, get_cross_attn, get_cross_ff, 78 | get_latent_attn, get_latent_ff, 79 | weight_tie_layers, 80 | num_latent_blocks_per_layer=num_latent_blocks_per_layer) 81 | 82 | self.to_logits = nn.Sequential( 83 | nn.LayerNorm(latent_dim), 84 | nn.Linear(latent_dim, num_classes) 85 | ) 86 | 87 | def forward(self, multi_modality_data: Dict[str, Tensor], mask=None): 88 | """ 89 | 90 | :param data: a dictionary where keys are modality names and Tensor contain a batch 91 | of modality input data. 92 | :param mask: 93 | :return: 94 | """ 95 | batch_sizes = set() 96 | num_modalities = len(multi_modality_data) 97 | linearized_data = [] 98 | linearized_data_per_layer: Dict[int, List[Tensor]] = {} 99 | 100 | for modality_index, modality_name in enumerate(sorted(multi_modality_data.keys())): 101 | assert modality_name in self.modalities, f"modality {modality_name} was not defined in constructor" 102 | data = multi_modality_data[modality_name] 103 | modality = self.modalities[modality_name] 104 | b, *axis, _, device = *data.shape, data.device 105 | assert len( 106 | axis) == modality.input_axis, f'input data must have the right number of for modality {modality_name}. ' \ 107 | f'Expected {modality.input_axis} while forward argument offered {len(axis)}' 108 | batch_sizes.add(b) 109 | assert len(batch_sizes) == 1, "batch size must be the same across all modalities" 110 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 111 | 112 | axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis)) 113 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 114 | enc_pos = fourier_encode(pos, 115 | modality.max_freq, modality.num_freq_bands, modality.freq_base) 116 | enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') 117 | enc_pos = repeat(enc_pos, '... -> b ...', b=b) 118 | 119 | # Figure out padding for this modality, given max dimension across all modalities: 120 | padding_size = self.max_modality_dim - modality.input_dim - num_modalities 121 | 122 | padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).to(device) 123 | # concat to channels of data and flatten axis 124 | modality_encodings = modality_encoding(b, axis, modality_index, num_modalities, device=device) 125 | 126 | to_concat = (data, padding, enc_pos, modality_encodings) 127 | 128 | data = torch.cat(to_concat, dim=-1) 129 | data = rearrange(data, 'b ... d -> b (...) d') 130 | linearized_data.append(data) 131 | b = batch_sizes.pop() 132 | x = repeat(self.latents, 'n d -> b n d', b=b) 133 | 134 | # Concatenate all the modalities: 135 | data = torch.cat(linearized_data, dim=1) 136 | 137 | for cross_attn, cross_ff, latent_transformer in self.layers: 138 | x = cross_attn(x, context=data, mask=mask) + x 139 | x = cross_ff(x) + x 140 | x = latent_transformer(x) + x 141 | x = self.pool(x) 142 | 143 | return self.to_logits(x) 144 | 145 | def pool(self, x): 146 | """ 147 | Perform pooling over latents. 148 | :param x: batch x num_latents x latent_dim 149 | :return: pooled x 150 | """ 151 | # implement global pooling 152 | return x.mean(dim=-2) 153 | 154 | 155 | class MultiModalityPerceiverNoPooling(MultiModalityPerceiver): 156 | def __init__(self, *, modalities: Iterable[InputModality], depth, 157 | num_latents=512, latent_dim=512, cross_heads=1, 158 | latent_heads=8, cross_dim_head=64, latent_dim_head=64, 159 | attn_dropout=0., ff_dropout=0., 160 | weight_tie_layers=False, num_latent_blocks_per_layer=1, 161 | use_gelu: bool = True): 162 | """ 163 | Perceiver that returns hidden state. Makes it possible to configure pooling with 164 | the result of forward. 165 | :param modalities: 166 | :param depth: Number of times the perceiver will perform cross-attention between latent and input. 167 | :param num_latents: 168 | :param latent_dim: 169 | :param cross_heads: 170 | :param latent_heads: 171 | :param cross_dim_head: 172 | :param latent_dim_head: 173 | :param attn_dropout: 174 | :param ff_dropout: 175 | :param weight_tie_layers: True: share weights across layers, False no shared weights. 176 | :param num_latent_blocks_per_layer: Number of blocks in the latent transformer. 177 | :param use_gelu: Use GELU activation like the Perceiver preprint indicates. False, 178 | with Lucidrains' GEGLU activation in feed forward instead. 179 | 180 | """ 181 | 182 | super().__init__(modalities=modalities, depth=depth, num_latents=num_latents, latent_dim=latent_dim, 183 | cross_heads=cross_heads, latent_heads=latent_heads, cross_dim_head=cross_dim_head, 184 | latent_dim_head=latent_dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, 185 | weight_tie_layers=weight_tie_layers, num_latent_blocks_per_layer=num_latent_blocks_per_layer, 186 | use_gelu=use_gelu, num_classes=1) 187 | self.to_logits = Identity() 188 | 189 | def pool(self, x): 190 | """ 191 | Do not pool. 192 | :param x: batch x num_latents x latent_dim 193 | :return: pooled x 194 | """ 195 | # no pooling 196 | return x 197 | -------------------------------------------------------------------------------- /perceiver_pytorch/hierarchical_multi_modality_perceiver.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Dict, List 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | from torch import Tensor 6 | from torch import nn 7 | from torch.nn import Identity, Module 8 | from torch.nn.functional import pad 9 | 10 | from perceiver_pytorch.caching import cache_by_name_fn 11 | from perceiver_pytorch.modalities import InputModality, modality_encoding 12 | from perceiver_pytorch.perceiver_pytorch import PreNorm, Attention, FeedForward, cache_fn, fourier_encode, \ 13 | FeedForwardGELU 14 | from perceiver_pytorch.common import build_perceiver_layers, LatentTransformer 15 | from torch.nn.modules.container import ParameterList 16 | 17 | 18 | class HierarchicalConfigurator(): 19 | def __init__(self, depth: int, num_latents_begin: int, 20 | latent_dim_begin: int): 21 | self.depth = depth 22 | self.num_latents_begin = num_latents_begin 23 | self.latent_dim_begin = latent_dim_begin 24 | 25 | def get_num_latents(self, layer_index: int) -> int: 26 | assert layer_index < self.depth, 'layer_index cannot be larger than depth' 27 | num_latents = self.num_latents_begin 28 | 29 | for layer_index in range(layer_index): 30 | num_latents = num_latents // 2 31 | assert num_latents > 0, f"num_latents_begin is too small, no remaining latents at layer {layer_index}" 32 | return num_latents 33 | 34 | def get_latent_dim(self, layer_index: int) -> int: 35 | assert layer_index < self.depth, 'layer_index cannot be larger than depth' 36 | latent_dim = self.latent_dim_begin 37 | 38 | for layer_index in range(layer_index): 39 | latent_dim = latent_dim * 2 40 | 41 | return latent_dim 42 | 43 | 44 | class HierarchicalLatentTransformer(Module): 45 | def __init__(self, get_latent_attn, get_latent_ff, num_latent_blocks_per_layer, 46 | weight_tie_layers, latent_dim: int): 47 | super().__init__() 48 | self.latent_blocks = nn.ModuleList([]) 49 | self.num_latent_blocks_per_layer = num_latent_blocks_per_layer 50 | for latent_block_index in range(num_latent_blocks_per_layer): 51 | should_cache = latent_block_index > 0 and weight_tie_layers 52 | cache_args = {'_cache': should_cache} 53 | self.latent_blocks.append(nn.ModuleList([ 54 | get_latent_attn(**cache_args, name=f"latent_attn_{latent_block_index}", latent_dim=latent_dim), 55 | get_latent_ff(**cache_args, name=f"latent_ff_{latent_block_index}", latent_dim=latent_dim)])) 56 | 57 | def forward(self, x): 58 | for latent_attn, latent_ff in self.latent_blocks: 59 | x = latent_attn(x) + x 60 | x = latent_ff(x) + x 61 | return x 62 | 63 | 64 | def build_perceiver_layers_hierarchical(layers, depth, get_cross_attn, get_cross_ff, 65 | get_latent_attn, get_latent_ff, 66 | weight_tie_layers, 67 | configurator: HierarchicalConfigurator, 68 | num_latent_blocks_per_layer=1, 69 | 70 | ): 71 | for i in range(depth): 72 | should_cache = i > 0 and weight_tie_layers 73 | cache_args = {'_cache': should_cache} 74 | latent_dim = configurator.get_latent_dim(i) 75 | layers.append(nn.ModuleList([ 76 | get_cross_attn(**cache_args, name=f"cross_attn_{latent_dim}", latent_dim=latent_dim), 77 | get_cross_ff(**cache_args, name=f"cross_ff_{latent_dim}", latent_dim=latent_dim), 78 | HierarchicalLatentTransformer(get_latent_attn, get_latent_ff, 79 | num_latent_blocks_per_layer=num_latent_blocks_per_layer, 80 | weight_tie_layers=weight_tie_layers, 81 | latent_dim=latent_dim)])) 82 | 83 | 84 | # An implementation of Perceiver that can accept multiple data modalities in the same forward. 85 | # Can be configured with different numbers of latents and latent_dim at each layer. Initial 86 | # Implementation supports increasing latent_dim, while reducing the number of latents. 87 | class HierarchicalMultiModalityPerceiver(nn.Module): 88 | def __init__( 89 | self, 90 | *, 91 | modalities: Iterable[InputModality], 92 | depth, 93 | cross_heads=1, 94 | latent_heads=8, 95 | cross_dim_head=64, 96 | latent_dim_head=64, 97 | num_classes=None, 98 | attn_dropout=0., 99 | ff_dropout=0., 100 | weight_tie_layers=False, 101 | num_latent_blocks_per_layer=1, 102 | use_gelu: bool = False, 103 | configurator: HierarchicalConfigurator 104 | ): 105 | """ 106 | 107 | :param modalities: 108 | :param depth: Number of times the perceiver will perform cross-attention between latent and input. 109 | :param cross_heads: 110 | :param latent_heads: 111 | :param cross_dim_head: 112 | :param latent_dim_head: 113 | :param num_classes: Number of classes to predict, or if None, return the hidden state (num latents x hidden_dim) 114 | :param attn_dropout: 115 | :param ff_dropout: 116 | :param weight_tie_layers: True: share weights across layers, False no shared weights. 117 | :param num_latent_blocks_per_layer: Number of blocks in the latent transformer. 118 | :param use_gelu: Use GELU activation like the Perceiver preprint indicates. False, 119 | with Lucidrains' GEGLU activation in feed forward instead. 120 | :param configurator: instance of HierarchicalConfigurator that determines how many latents and latent_dim 121 | to setup at each layer. 122 | 123 | """ 124 | super().__init__() 125 | assert not weight_tie_layers, "HierarchicalMultiModalityPerceiver does not support tied weights" 126 | self.modalities = {modality.name: modality for modality in modalities} 127 | # we encode modality with one hot encoding, so need one dim per modality: 128 | modality_encoding_dim = sum([1 for _ in modalities]) 129 | # input_dim is the maximum dimension over all input modalities: 130 | input_dim = max(modality.input_dim for modality in modalities) + modality_encoding_dim 131 | self.max_modality_dim = input_dim 132 | self.latents = ParameterList([nn.Parameter(torch.randn(configurator.get_num_latents(layer_index), 133 | configurator.get_latent_dim(layer_index))) for 134 | layer_index in range(depth)]) 135 | ff_type = FeedForwardGELU if use_gelu else FeedForward 136 | get_cross_attn = lambda latent_dim: PreNorm(latent_dim, 137 | Attention(latent_dim, input_dim, heads=cross_heads, 138 | dim_head=cross_dim_head, 139 | dropout=attn_dropout), context_dim=input_dim) 140 | get_cross_ff = lambda latent_dim: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout)) 141 | get_latent_attn = lambda latent_dim: PreNorm(latent_dim, 142 | Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, 143 | dropout=attn_dropout)) 144 | get_latent_ff = lambda latent_dim: PreNorm(latent_dim, ff_type(latent_dim, dropout=ff_dropout)) 145 | 146 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_by_name_fn, ( 147 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) 148 | 149 | self.layers = nn.ModuleList([]) 150 | build_perceiver_layers_hierarchical(self.layers, depth, get_cross_attn, get_cross_ff, 151 | get_latent_attn, get_latent_ff, 152 | weight_tie_layers, 153 | num_latent_blocks_per_layer=num_latent_blocks_per_layer, 154 | configurator=configurator) 155 | 156 | last_layer_latent_dim = configurator.get_latent_dim(depth - 1) 157 | self.to_logits = nn.Sequential( 158 | nn.LayerNorm(last_layer_latent_dim), 159 | nn.Linear(last_layer_latent_dim, num_classes) 160 | ) 161 | 162 | def forward(self, multi_modality_data: Dict[str, Tensor], mask=None): 163 | """ 164 | 165 | :param data: a dictionary where keys are modality names and Tensor contain a batch 166 | of modality input data. 167 | :param mask: 168 | :return: 169 | """ 170 | batch_sizes = set() 171 | num_modalities = len(multi_modality_data) 172 | linearized_data = [] 173 | linearized_data_per_layer: Dict[int, List[Tensor]] = {} 174 | 175 | for modality_index, modality_name in enumerate(sorted(multi_modality_data.keys())): 176 | assert modality_name in self.modalities, f"modality {modality_name} was not defined in constructor" 177 | data = multi_modality_data[modality_name] 178 | modality = self.modalities[modality_name] 179 | b, *axis, _, device = *data.shape, data.device 180 | assert len( 181 | axis) == modality.input_axis, f'input data must have the right number of for modality {modality_name}. ' \ 182 | f'Expected {modality.input_axis} while forward argument offered {len(axis)}' 183 | batch_sizes.add(b) 184 | assert len(batch_sizes) == 1, "batch size must be the same across all modalities" 185 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 186 | 187 | axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis)) 188 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 189 | enc_pos = fourier_encode(pos, 190 | modality.max_freq, modality.num_freq_bands, modality.freq_base) 191 | enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') 192 | enc_pos = repeat(enc_pos, '... -> b ...', b=b) 193 | 194 | # Figure out padding for this modality, given max dimension across all modalities: 195 | padding_size = self.max_modality_dim - modality.input_dim - num_modalities 196 | 197 | padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).to(device) 198 | # concat to channels of data and flatten axis 199 | modality_encodings = modality_encoding(b, axis, modality_index, num_modalities, device=device) 200 | 201 | to_concat = (data, padding, enc_pos, modality_encodings) 202 | 203 | data = torch.cat(to_concat, dim=-1) 204 | data = rearrange(data, 'b ... d -> b (...) d') 205 | linearized_data.append(data) 206 | b = batch_sizes.pop() 207 | 208 | # Concatenate all the modalities: 209 | data = torch.cat(linearized_data, dim=1) 210 | x = None 211 | for layer_index, (cross_attn, cross_ff, latent_transformer) in enumerate(self.layers): 212 | 213 | latents = repeat(self.latents[layer_index], 'n d -> b n d', b=b) 214 | num_latents_in_layer = self.latents[layer_index].size(0) 215 | if x is None: 216 | x = latents 217 | else: 218 | pad_right_size = latents.size(2) - x.size(2) 219 | # x produced by the prior layer has more latents, and less dimention in each one. We pad x with zero in the 220 | # latent_dim and keep only the first num_latents_in_layer latents as input to this new layer. 221 | # Choosing the first latents, or last, or any other combination should be equivalent since all latents 222 | # are optimized during training. 223 | prior_layer_x_padded = pad(x, (0, pad_right_size), mode="constant", value=0)[:, 0:num_latents_in_layer, :] 224 | x = latents + prior_layer_x_padded 225 | x = cross_attn(x, context=data, mask=mask) + x 226 | x = cross_ff(x) + x 227 | x = latent_transformer(x) + x 228 | x = self.pool(x) 229 | 230 | return self.to_logits(x) 231 | 232 | def pool(self, x): 233 | """ 234 | Perform pooling over latents. 235 | :param x: batch x num_latents x latent_dim 236 | :return: pooled x 237 | """ 238 | # implement global pooling 239 | return x.mean(dim=-2) 240 | 241 | 242 | class HierarchicalMultiModalityPerceiverNoPooling(HierarchicalMultiModalityPerceiver): 243 | def __init__(self, *, modalities: Iterable[InputModality], depth, 244 | num_latents=512, latent_dim=512, cross_heads=1, 245 | latent_heads=8, cross_dim_head=64, latent_dim_head=64, 246 | attn_dropout=0., ff_dropout=0., 247 | weight_tie_layers=False, num_latent_blocks_per_layer=1, 248 | use_gelu: bool = True, 249 | configurator: HierarchicalConfigurator): 250 | """ 251 | Perceiver that returns hidden state. Makes it possible to configure pooling with 252 | the result of forward. 253 | :param modalities: 254 | :param depth: Number of times the perceiver will perform cross-attention between latent and input. 255 | :param num_latents: 256 | :param latent_dim: 257 | :param cross_heads: 258 | :param latent_heads: 259 | :param cross_dim_head: 260 | :param latent_dim_head: 261 | :param attn_dropout: 262 | :param ff_dropout: 263 | :param weight_tie_layers: True: share weights across layers, False no shared weights. 264 | :param num_latent_blocks_per_layer: Number of blocks in the latent transformer. 265 | :param use_gelu: Use GELU activation like the Perceiver preprint indicates. False, 266 | with Lucidrains' GEGLU activation in feed forward instead. 267 | 268 | """ 269 | 270 | super().__init__(modalities=modalities, depth=depth, 271 | cross_heads=cross_heads, latent_heads=latent_heads, cross_dim_head=cross_dim_head, 272 | latent_dim_head=latent_dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, 273 | weight_tie_layers=weight_tie_layers, num_latent_blocks_per_layer=num_latent_blocks_per_layer, 274 | use_gelu=use_gelu, num_classes=1, 275 | configurator=configurator) 276 | self.to_logits = Identity() 277 | 278 | def pool(self, x): 279 | """ 280 | Do not pool. 281 | :param x: batch x num_latents x latent_dim 282 | :return: pooled x 283 | """ 284 | # no pooling 285 | return x 286 | --------------------------------------------------------------------------------