├── .bumpversion.cfg ├── .github └── workflows │ ├── linters.yaml │ ├── release.yaml │ └── workflows.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── perceiver.png ├── perceiver_pytorch ├── __init__.py ├── convolutions.py ├── decoders.py ├── encoders.py ├── gated.py ├── layers.py ├── mixed_latents.py ├── modalities.py ├── multi_perceiver_pytorch.py ├── perceiver_io.py ├── perceiver_pytorch.py ├── queries.py ├── rotary.py └── utils.py ├── requirements.txt ├── setup.py └── tests ├── test_decoders.py ├── test_encoders.py ├── test_model.py ├── test_perceiver_pytorch.py ├── test_queries.py └── test_rotary.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | commit = True 3 | tag = True 4 | current_version = 0.7.6 5 | 6 | [bumpversion:file:setup.py] 7 | search = version="{current_version}" 8 | replace = version="{new_version}" 9 | -------------------------------------------------------------------------------- /.github/workflows/linters.yaml: -------------------------------------------------------------------------------- 1 | name: Lint Python 2 | 3 | on: [push] 4 | 5 | jobs: 6 | call-run-python-linters: 7 | uses: openclimatefix/.github/.github/workflows/python-lint.yml@e67a64b086a5662c39f6b4523a97dd0641904279 8 | with: 9 | folder: "perceiver_pytorch" 10 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Bump version and auto-release 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | call-run-python-release: 8 | uses: openclimatefix/.github/.github/workflows/python-release.yml@e67a64b086a5662c39f6b4523a97dd0641904279 9 | secrets: 10 | token: ${{ secrets.PYPI_API_TOKEN }} 11 | -------------------------------------------------------------------------------- /.github/workflows/workflows.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | jobs: 5 | call-run-python-tests: 6 | uses: openclimatefix/.github/.github/workflows/python-test.yml@e67a64b086a5662c39f6b4523a97dd0641904279 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /.idea/ 131 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.9 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v3.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | - id: debug-statements 13 | - id: detect-private-key 14 | 15 | # python code formatting/linting 16 | - repo: https://github.com/PyCQA/pydocstyle 17 | rev: 6.1.1 18 | hooks: 19 | - id: pydocstyle 20 | args: 21 | [ 22 | --convention=google, 23 | "--add-ignore=D200,D202,D210,D212,D415", 24 | "perceiver_pytorch", 25 | ] 26 | - repo: https://github.com/PyCQA/flake8 27 | rev: 4.0.1 28 | hooks: 29 | - id: flake8 30 | args: 31 | [ 32 | --max-line-length, 33 | "100", 34 | --extend-ignore=E203, 35 | --per-file-ignores, 36 | "__init__.py:F401", 37 | "perceiver_pytorch", 38 | ] 39 | - repo: https://github.com/PyCQA/isort 40 | rev: 5.9.3 41 | hooks: 42 | - id: isort 43 | args: [--profile, black, --line-length, "100", "perceiver_pytorch"] 44 | - repo: https://github.com/psf/black 45 | rev: 20.8b1 46 | hooks: 47 | - id: black 48 | args: [--line-length, "100"] 49 | 50 | # yaml formatting 51 | - repo: https://github.com/pre-commit/mirrors-prettier 52 | rev: v2.3.0 53 | hooks: 54 | - id: prettier 55 | types: [yaml] 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Perceiver - Pytorch 4 | 5 | Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch. 6 | Extended from Phil Wang's perceiver-pytorch 7 | 8 | Yannic Kilcher explanation! 9 | 10 | ## Install 11 | 12 | ```bash 13 | $ pip install perceiver-model 14 | ``` 15 | 16 | ## Usage 17 | 18 | ```python 19 | import torch 20 | from perceiver_pytorch import Perceiver 21 | 22 | model = Perceiver( 23 | input_channels = 3, # number of channels for each token of the input 24 | input_axis = 2, # number of axis for input data (2 for images, 3 for video) 25 | num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) 26 | max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is 27 | depth = 6, # depth of net. The shape of the final attention mechanism will be: 28 | # depth * (cross attention -> self_per_cross_attn * self attention) 29 | num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names 30 | latent_dim = 512, # latent dimension 31 | cross_heads = 1, # number of heads for cross attention. paper said 1 32 | latent_heads = 8, # number of heads for latent self attention, 8 33 | cross_dim_head = 64, # number of dimensions per cross attention head 34 | latent_dim_head = 64, # number of dimensions per latent self attention head 35 | num_classes = 1000, # output number of classes 36 | attn_dropout = 0., 37 | ff_dropout = 0., 38 | weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram) 39 | fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself 40 | self_per_cross_attn = 2 # number of self attention blocks per cross attention 41 | ) 42 | 43 | img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized 44 | 45 | model(img) # (1, 1000) 46 | ``` 47 | 48 | For the backbone of Perceiver IO, the follow up paper that allows for flexible number of output sequence length, just import `PerceiverIO` instead 49 | 50 | ```python 51 | import torch 52 | from perceiver_pytorch import PerceiverIO 53 | 54 | model = PerceiverIO( 55 | dim = 32, # dimension of sequence to be encoded 56 | queries_dim = 32, # dimension of decoder queries 57 | logits_dim = 100, # dimension of final logits 58 | depth = 6, # depth of net 59 | num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names 60 | latent_dim = 512, # latent dimension 61 | cross_heads = 1, # number of heads for cross attention. paper said 1 62 | latent_heads = 8, # number of heads for latent self attention, 8 63 | cross_dim_head = 64, # number of dimensions per cross attention head 64 | latent_dim_head = 64, # number of dimensions per latent self attention head 65 | weight_tie_layers = False # whether to weight tie layers (optional, as indicated in the diagram) 66 | ) 67 | 68 | seq = torch.randn(1, 512, 32) 69 | queries = torch.randn(1, 128, 32) 70 | 71 | logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim) 72 | ``` 73 | 74 | ## Citations 75 | 76 | ```bibtex 77 | @misc{jaegle2021perceiver, 78 | title = {Perceiver: General Perception with Iterative Attention}, 79 | author = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira}, 80 | year = {2021}, 81 | eprint = {2103.03206}, 82 | archivePrefix = {arXiv}, 83 | primaryClass = {cs.CV} 84 | } 85 | ``` 86 | 87 | ```bibtex 88 | @misc{jaegle2021perceiver, 89 | title = {Perceiver IO: A General Architecture for Structured Inputs & Outputs}, 90 | author = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira}, 91 | year = {2021}, 92 | eprint = {2107.14795}, 93 | archivePrefix = {arXiv}, 94 | primaryClass = {cs.LG} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /perceiver.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/perceiver-pytorch/fbab157fffe7d68d123a1ccf45f4131372873183/perceiver.png -------------------------------------------------------------------------------- /perceiver_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from perceiver_pytorch.perceiver_pytorch import Perceiver 2 | from perceiver_pytorch.perceiver_io import PerceiverIO 3 | from perceiver_pytorch.multi_perceiver_pytorch import MultiPerceiver 4 | -------------------------------------------------------------------------------- /perceiver_pytorch/convolutions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | class Conv2DDownsample(torch.nn.Sequential): 6 | def __init__( 7 | self, 8 | num_layers: int = 1, 9 | input_channels: int = 12, 10 | output_channels: int = 64, 11 | use_batchnorm: bool = True, 12 | ): 13 | """ 14 | Constructs a Conv2DDownsample model 15 | Args: 16 | num_layers: Number of conv -> maxpool layers 17 | output_channels: Number of output channels 18 | input_channels: Number of input channels to first layer 19 | use_batchnorm: Whether to use Batch Norm 20 | """ 21 | 22 | layers = [self.make_layer(input_channels, output_channels, batch=use_batchnorm)] 23 | for _ in range(num_layers - 1): 24 | layers += [ 25 | self.make_layer(output_channels, output_channels, batch=use_batchnorm) 26 | ] 27 | 28 | super().__init__(*layers) 29 | 30 | def make_layer(self, c_in, c_out, ks=7, stride=2, padding=7 // 2, batch=True): 31 | "Make Conv->Batch->Relu->MaxPool stack" 32 | layers = [torch.nn.Conv2d(c_in, c_out, ks, stride, padding, bias=False)] 33 | if batch: 34 | layers += [torch.nn.BatchNorm2d(c_out)] 35 | layers += [torch.nn.ReLU(), torch.nn.MaxPool2d(3, stride=2, padding=3 // 2)] 36 | return torch.nn.Sequential(*layers) 37 | 38 | 39 | class Conv2DUpsample(torch.nn.Module): 40 | def __init__(self, input_channels: int = 12, output_channels: int = 12): 41 | """ 42 | Upsamples 4x using 2 2D transposed convolutions 43 | Args: 44 | input_channels: Input channels to the first layer 45 | output_channels: Number of output channels 46 | """ 47 | 48 | super().__init__() 49 | self.transpose_conv1 = torch.nn.ConvTranspose2d( 50 | in_channels=input_channels, 51 | out_channels=output_channels * 2, 52 | kernel_size=(4, 4), 53 | stride=(2, 2), 54 | padding=(1, 1), 55 | ) 56 | self.transpose_conv2 = torch.nn.ConvTranspose2d( 57 | in_channels=output_channels * 2, 58 | out_channels=output_channels, 59 | kernel_size=(4, 4), 60 | stride=(2, 2), 61 | padding=(1, 1), 62 | ) 63 | 64 | def forward(self, x): 65 | x = self.transpose_conv1(x) 66 | x = F.relu(x) 67 | x = self.transpose_conv2(x) 68 | return x 69 | 70 | 71 | class Conv3DUpsample(torch.nn.Module): 72 | def __init__( 73 | self, 74 | input_channels: int = 12, 75 | output_channels: int = 12, 76 | num_temporal_upsamples: int = 2, 77 | num_space_upsamples: int = 4, 78 | ): 79 | """ 80 | Simple convolutional auto-encoder 81 | Args: 82 | output_channels: Final output channels 83 | num_temporal_upsamples: Number of temporal upsamples to perform 84 | num_space_upsamples: Number of spatial upsamples to perform 85 | """ 86 | 87 | super().__init__() 88 | temporal_stride = 2 89 | space_stride = 2 90 | num_upsamples = max(num_space_upsamples, num_temporal_upsamples) 91 | 92 | # create the input and output changesl for the different layers 93 | # The intermediate channels are (output channels) * 2^(number of upsamples - 1 - index of the current upsample) 94 | # The decoder sets the number of upsamples as log2(upsample_value), and this changes the number of channels 95 | # in a similar way, so it all scales together. 96 | intermediate_output_channels = [ 97 | output_channels * pow(2, num_upsamples - 1 - i) 98 | for i in range(0, num_upsamples) 99 | ] 100 | intermediate_input_channels = [input_channels] + intermediate_output_channels 101 | 102 | self.layers = torch.nn.ModuleList() 103 | for i in range(num_upsamples): 104 | if i >= num_temporal_upsamples: 105 | temporal_stride = 1 106 | if i >= num_space_upsamples: 107 | space_stride = 1 108 | 109 | input_channels = input_channels if i == 0 else output_channels 110 | stride = (temporal_stride, space_stride, space_stride) 111 | conv = torch.nn.ConvTranspose3d( 112 | in_channels=intermediate_input_channels[i], 113 | out_channels=intermediate_output_channels[i], 114 | stride=stride, 115 | kernel_size=stride, 116 | ) 117 | # see output dims calculations - https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html 118 | # if kernel=stride, and dilation -1 119 | # D = (D-1) * stride + (kernel - 1) + 2 = D * stride 120 | 121 | self.layers.append(conv) 122 | if i != num_upsamples - 1: 123 | self.layers.append(torch.nn.ReLU()) 124 | 125 | def forward(self, x): 126 | # Move around to Channels first 127 | x = x.permute(0, 2, 1, 3, 4) 128 | 129 | # loop over layers 130 | for layer in self.layers: 131 | x = layer(x) 132 | 133 | x = F.relu(x) 134 | # Move back 135 | x = x.permute(0, 2, 1, 3, 4) 136 | return x 137 | -------------------------------------------------------------------------------- /perceiver_pytorch/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from perceiver_pytorch.utils import reverse_space_to_depth 4 | from perceiver_pytorch.convolutions import Conv2DUpsample, Conv3DUpsample 5 | 6 | 7 | class ImageDecoder(torch.nn.Module): 8 | def __init__( 9 | self, 10 | postprocess_type: str = "pixels", 11 | spatial_upsample: int = 1, 12 | temporal_upsample: int = 1, 13 | output_channels: int = -1, 14 | input_channels: int = 12, 15 | input_reshape_size=None, 16 | ): 17 | """ 18 | ImageDecoder modeled after JAX version here 19 | https://github.com/deepmind/deepmind-research/blob/769bfdbeafbcb472cb8e2c6cfa746b53ac82efc2/perceiver/io_processors.py#L441-L510 20 | 21 | Args: 22 | postprocess_type: Type of postprocessing, one of conv, patches, pixels, raft, or conv1x1 23 | spatial_upsample: How much to spatially upsample 24 | temporal_upsample: How much to temporally upsample 25 | output_channels: Number of output channels, should be the final desired number of channels 26 | Has to explicitly set for conv and conv1x1 options, otherwise an error will be raised. 27 | Ignored for patches and pixels options. 28 | input_channels: Number of input channels to decoder 29 | input_reshape_size: The size to reshape the input to 30 | """ 31 | 32 | super().__init__() 33 | 34 | if postprocess_type not in ("conv", "patches", "pixels", "conv1x1"): 35 | # TODO Add Raft 36 | raise ValueError("Invalid postprocess_type!") 37 | 38 | # Architecture parameters: 39 | self.postprocess_type = postprocess_type 40 | 41 | self.temporal_upsample = temporal_upsample 42 | self.spatial_upsample = spatial_upsample 43 | self.input_reshape_size = input_reshape_size 44 | 45 | if postprocess_type == "pixels": 46 | # No postprocessing for pixels 47 | self.decoder = torch.nn.Identity() 48 | elif postprocess_type == "patches": 49 | self.decoder = ImageDecoderPatches( 50 | spatial_upsample=spatial_upsample, temporal_upsample=temporal_upsample 51 | ) 52 | elif postprocess_type == "conv": 53 | self.decoder = ImageDecoderConv( 54 | spatial_upsample=spatial_upsample, 55 | temporal_upsample=temporal_upsample, 56 | output_channels=output_channels, 57 | input_channels=input_channels, 58 | ) 59 | elif postprocess_type == "conv1x1": 60 | self.decoder = ImageDecoderConv1x1( 61 | spatial_upsample=spatial_upsample, 62 | output_channels=output_channels, 63 | input_channels=input_channels, 64 | ) 65 | 66 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 67 | if self.input_reshape_size is not None: 68 | inputs = torch.reshape( 69 | inputs, 70 | [inputs.shape[0]] + list(self.input_reshape_size) + [inputs.shape[-1]], 71 | ) 72 | return self.decoder(inputs) 73 | 74 | 75 | class ImageDecoderConv(torch.nn.Module): 76 | def __init__( 77 | self, 78 | spatial_upsample: int = 1, 79 | temporal_upsample: int = 1, 80 | output_channels: int = -1, 81 | input_channels: int = 12, 82 | ): 83 | """ 84 | Convolutional image decoder that can upsample temporally and spatially 85 | 86 | Args: 87 | spatial_upsample: How much to spatially upsample 88 | temporal_upsample: How much to temporally upsample 89 | output_channels: Number of output channels, should be the final desired number of channels 90 | Has to explicitly set for conv and conv1x1 options, otherwise an error will be raised. 91 | Ignored for patches and pixels options. 92 | input_channels: Number of input channels to decoder 93 | """ 94 | 95 | super().__init__() 96 | 97 | self.temporal_upsample = temporal_upsample 98 | self.spatial_upsample = spatial_upsample 99 | 100 | if output_channels == -1: 101 | raise ValueError("Expected value for output_channels") 102 | if self.temporal_upsample != 1: 103 | 104 | def int_log2(x): 105 | return int(np.round(np.log(x) / np.log(2))) 106 | 107 | self.convnet = Conv3DUpsample( 108 | input_channels=input_channels, 109 | output_channels=output_channels, 110 | num_temporal_upsamples=int_log2(temporal_upsample), 111 | num_space_upsamples=int_log2(spatial_upsample), 112 | ) 113 | else: 114 | assert ( 115 | self.spatial_upsample == 4 116 | ), "Conv2DUpsample only support 4x spatial upsample right now" 117 | self.convnet = Conv2DUpsample( 118 | input_channels=input_channels, output_channels=output_channels 119 | ) 120 | 121 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 122 | # Convnet image featurization. 123 | if len(inputs.shape) == 5 and self.temporal_upsample == 1: 124 | # Timeseries, do it to each timestep independently 125 | outs = [] 126 | for i in range(inputs.shape[1]): 127 | outs.append(self.convnet(inputs[:, i, :, :, :])) 128 | inputs = torch.stack(outs, dim=1) 129 | else: 130 | inputs = self.convnet(inputs) 131 | 132 | return inputs 133 | 134 | 135 | class ImageDecoderConv1x1(torch.nn.Module): 136 | def __init__( 137 | self, 138 | spatial_upsample: int = 1, 139 | output_channels: int = -1, 140 | input_channels: int = 12, 141 | ): 142 | """ 143 | Convolutional 1x1 image decoder 144 | 145 | Args: 146 | spatial_upsample: How much to spatially upsample 147 | output_channels: Number of output channels, should be the final desired number of channels 148 | Has to explicitly set for conv and conv1x1 options, otherwise an error will be raised. 149 | Ignored for patches and pixels options. 150 | input_channels: Number of input channels to decoder 151 | """ 152 | 153 | super().__init__() 154 | 155 | self.spatial_upsample = spatial_upsample 156 | 157 | if output_channels == -1: 158 | raise ValueError("Expected value for output_channels") 159 | self.conv1x1 = torch.nn.Conv2d( 160 | in_channels=input_channels, 161 | out_channels=output_channels, 162 | kernel_size=(1, 1), 163 | # spatial_downsample is unconstrained for 1x1 convolutions. 164 | stride=(self.spatial_upsample, self.spatial_upsample), 165 | ) 166 | 167 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 168 | # Convnet image featurization. 169 | if len(inputs.shape) == 5: 170 | # Timeseries, do it to each timestep independently 171 | outs = [] 172 | for i in range(inputs.shape[1]): 173 | outs.append(self.conv1x1(inputs[:, i, :, :, :])) 174 | inputs = torch.stack(outs, dim=1) 175 | else: 176 | inputs = self.conv1x1(inputs) 177 | 178 | return inputs 179 | 180 | 181 | class ImageDecoderPatches(torch.nn.Module): 182 | def __init__( 183 | self, spatial_upsample: int = 1, temporal_upsample: int = 1, 184 | ): 185 | """ 186 | Patch-based image decoder 187 | 188 | Args: 189 | spatial_upsample: How much to spatially upsample 190 | temporal_upsample: How much to temporally upsample 191 | """ 192 | 193 | super().__init__() 194 | 195 | self.temporal_upsample = temporal_upsample 196 | self.spatial_upsample = spatial_upsample 197 | 198 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 199 | inputs = reverse_space_to_depth( 200 | inputs, self.temporal_upsample, self.spatial_upsample 201 | ) 202 | return inputs 203 | -------------------------------------------------------------------------------- /perceiver_pytorch/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | 8 | from perceiver_pytorch.convolutions import Conv2DDownsample 9 | from perceiver_pytorch.utils import space_to_depth 10 | 11 | 12 | class ImageEncoder(torch.nn.Module): 13 | def __init__( 14 | self, 15 | input_channels: int = 12, 16 | prep_type: str = "conv", 17 | spatial_downsample: int = 4, 18 | temporal_downsample: int = 1, 19 | output_channels: int = 64, 20 | conv2d_use_batchnorm: bool = True, 21 | crop_size: int = 256, 22 | use_space2depth: bool = True, 23 | ): 24 | """ 25 | Image encoder class, modeled off the JAX version 26 | https://github.com/deepmind/deepmind-research/blob/769bfdbeafbcb472cb8e2c6cfa746b53ac82efc2/perceiver/io_processors.py#L291-L438 27 | 28 | Args: 29 | input_channels: Number of input channels of the original image/video 30 | prep_type: How to encode the images, one of conv, patches, pixels, or conv1x1 31 | spatial_downsample: How much to downsample spatially 32 | temporal_downsample: How much to downsample temporally 33 | output_channels: Number of output channels to send to Perceiver 34 | conv2d_use_batchnorm: Whether to use batch norm 35 | crop_size: Only for MetNet preprocessor, the center crop size 36 | use_space2depth: Only for MetNet preprocessor, whether to use average pooling, or space2depth for downsampling 37 | """ 38 | super().__init__() 39 | self.prep_type = prep_type 40 | 41 | if prep_type not in ("conv", "patches", "pixels", "conv1x1", "metnet"): 42 | raise ValueError("Invalid prep_type!") 43 | 44 | if self.prep_type == "conv": 45 | self.encoder = ImageEncoderConv( 46 | input_channels=input_channels, 47 | temporal_downsample=temporal_downsample, 48 | spatial_downsample=spatial_downsample, 49 | output_channels=output_channels, 50 | conv2d_use_batchnorm=conv2d_use_batchnorm, 51 | ) 52 | elif self.prep_type == "conv1x1": 53 | self.encoder = ImageEncoderConv1x1( 54 | input_channels=input_channels, 55 | spatial_downsample=spatial_downsample, 56 | output_channels=output_channels, 57 | ) 58 | elif self.prep_type == "patches": 59 | self.encoder = ImageEncoderPatches( 60 | temporal_downsample=temporal_downsample, 61 | spatial_downsample=spatial_downsample, 62 | ) 63 | elif self.prep_type == "pixels": 64 | self.encoder = ImageEncoderPixel( 65 | temporal_downsample=temporal_downsample, 66 | spatial_downsample=spatial_downsample, 67 | ) 68 | elif self.prep_type == "metnet": 69 | self.encoder = ImageEncoderMetNet( 70 | crop_size=crop_size, use_space2depth=use_space2depth 71 | ) 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | return self.encoder(x) 75 | 76 | 77 | class ImageEncoderConv(torch.nn.Module): 78 | def __init__( 79 | self, 80 | input_channels: int = 12, 81 | spatial_downsample: int = 4, 82 | temporal_downsample: int = 1, 83 | output_channels: int = 64, 84 | conv2d_use_batchnorm: bool = True, 85 | ): 86 | """ 87 | Convolutional image encoder that can spatially and temporally downsample 88 | 89 | Args: 90 | input_channels: Number of input channels of the original image/video 91 | spatial_downsample: How much to downsample spatially 92 | temporal_downsample: How much to downsample temporally 93 | output_channels: Number of output channels to send to Perceiver 94 | conv2d_use_batchnorm: Whether to use batch norm 95 | """ 96 | super().__init__() 97 | self.temporal_downsample = temporal_downsample 98 | self.spatial_downsample = spatial_downsample 99 | self.output_channels = output_channels 100 | 101 | # Downsampling with conv is currently restricted 102 | convnet_num_layers = math.log(spatial_downsample, 4) 103 | convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers) 104 | if not convnet_num_layers_is_int or temporal_downsample != 1: 105 | raise ValueError( 106 | "Only powers of 4 expected for spatial " 107 | "and 1 expected for temporal " 108 | "downsampling with conv." 109 | ) 110 | 111 | self.convnet = Conv2DDownsample( 112 | num_layers=int(convnet_num_layers), 113 | output_channels=output_channels, 114 | input_channels=input_channels, 115 | use_batchnorm=conv2d_use_batchnorm, 116 | ) 117 | 118 | def forward(self, x: torch.Tensor) -> torch.Tensor: 119 | if len(x.shape) == 5: 120 | # Timeseries, do it to each timestep independently 121 | outs = [] 122 | for i in range(x.shape[1]): 123 | outs.append(self.convnet(x[:, i, :, :, :])) 124 | x = torch.stack(outs, dim=1) 125 | else: 126 | x = self.convnet(x) 127 | return x 128 | 129 | 130 | class ImageEncoderConv1x1(torch.nn.Module): 131 | def __init__( 132 | self, 133 | input_channels: int = 12, 134 | spatial_downsample: int = 4, 135 | output_channels: int = 64, 136 | ): 137 | """ 138 | Convolutional 1x1 encoder that can spatially downsample 139 | 140 | Args: 141 | input_channels: Number of input channels of the original image/video 142 | spatial_downsample: How much to downsample spatially 143 | output_channels: Number of output channels to send to Perceiver 144 | """ 145 | super().__init__() 146 | self.spatial_downsample = spatial_downsample 147 | self.output_channels = output_channels 148 | 149 | self.convnet_1x1 = torch.nn.Conv2d( 150 | in_channels=input_channels, 151 | out_channels=output_channels, 152 | kernel_size=(1, 1), 153 | # spatial_downsample is unconstrained for 1x1 convolutions. 154 | stride=(spatial_downsample, spatial_downsample), 155 | ) 156 | 157 | def forward(self, x: torch.Tensor) -> torch.Tensor: 158 | if len(x.shape) == 5: 159 | # Timeseries, do it to each timestep independently 160 | outs = [] 161 | for i in range(x.shape[1]): 162 | outs.append(self.convnet_1x1(x[:, i, :, :, :])) 163 | x = torch.stack(outs, dim=1) 164 | else: 165 | x = self.convnet_1x1(x) 166 | 167 | return x 168 | 169 | 170 | class ImageEncoderPatches(torch.nn.Module): 171 | def __init__( 172 | self, spatial_downsample: int = 4, temporal_downsample: int = 1, 173 | ): 174 | """ 175 | Image encoder that uses patches 176 | 177 | Args: 178 | spatial_downsample: How much to downsample spatially 179 | temporal_downsample: How much to downsample temporally 180 | """ 181 | super().__init__() 182 | self.temporal_downsample = temporal_downsample 183 | self.spatial_downsample = spatial_downsample 184 | 185 | def forward(self, x: torch.Tensor) -> torch.Tensor: 186 | x = space_to_depth( 187 | x, 188 | temporal_block_size=self.temporal_downsample, 189 | spatial_block_size=self.spatial_downsample, 190 | ) 191 | 192 | # For flow 193 | if x.ndim == 5 and x.shape[1] == 1: 194 | x = x.squeeze(axis=1) 195 | 196 | return x 197 | 198 | 199 | class ImageEncoderPixel(torch.nn.Module): 200 | def __init__( 201 | self, spatial_downsample: int = 4, temporal_downsample: int = 1, 202 | ): 203 | """ 204 | Image encoder class for simple downsampling with pixels 205 | 206 | Args: 207 | spatial_downsample: How much to downsample spatially 208 | temporal_downsample: How much to downsample temporally 209 | """ 210 | super().__init__() 211 | self.temporal_downsample = temporal_downsample 212 | self.spatial_downsample = spatial_downsample 213 | 214 | def forward(self, x: torch.Tensor) -> torch.Tensor: 215 | # If requested, will downsample in simplest way 216 | if x.ndim == 4: 217 | x = x[:, :, :: self.spatial_downsample, :: self.spatial_downsample] 218 | elif x.ndim == 5: 219 | x = x[ 220 | :, 221 | :: self.temporal_downsample, 222 | :, 223 | :: self.spatial_downsample, 224 | :: self.spatial_downsample, 225 | ] 226 | else: 227 | raise ValueError("Unsupported data format for pixels") 228 | 229 | return x 230 | 231 | 232 | class ImageEncoderMetNet(nn.Module): 233 | def __init__( 234 | self, crop_size: int = 256, use_space2depth: bool = True, 235 | ): 236 | """ 237 | Performs the MetNet preprocessing of mean pooling Sat channels, followed by 238 | concatenating the center crop and mean pool 239 | 240 | In the paper, the radar data is space2depth'd, while satellite channel is mean pooled, but for this different 241 | task, we choose to do either option for satellites 242 | 243 | Args: 244 | sat_channels: Number of satellite channels 245 | crop_size: Center crop size 246 | use_space2depth: Whether to use space2depth on satellite channels, or mean pooling, like in paper 247 | """ 248 | super().__init__() 249 | # Split off sat + mask channels into own image, and the rest, which we just take a center crop 250 | # For this, 251 | self.sat_downsample = ( 252 | torch.nn.PixelUnshuffle(downscale_factor=2) 253 | if use_space2depth 254 | else torch.nn.AvgPool3d(kernel_size=(1, 2, 2)) 255 | ) 256 | self.center_crop = torchvision.transforms.CenterCrop(size=crop_size) 257 | 258 | def forward(self, x: torch.Tensor) -> torch.Tensor: 259 | x = self.sat_downsample(x) 260 | # In paper, satellite and radar data is concatenated here 261 | # We are just going to skip that bit 262 | sat_center = self.center_crop(x) 263 | sat_mean = F.avg_pool3d(x, (1, 2, 2)) 264 | # All the same size now, so concatenate together, already have time, lat/long, and elevation image 265 | x = torch.cat([sat_center, sat_mean], dim=2) 266 | return x 267 | -------------------------------------------------------------------------------- /perceiver_pytorch/gated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | 7 | from perceiver_pytorch.layers import exists, default, cache_fn, PreNorm, FeedForward, Attention 8 | from perceiver_pytorch.utils import fourier_encode 9 | 10 | 11 | # helpers 12 | 13 | 14 | class Residual(nn.Module): 15 | def __init__(self, fn): 16 | super().__init__() 17 | self.fn = fn 18 | 19 | def forward(self, x, **kwargs): 20 | return x + self.fn(x, **kwargs) 21 | 22 | 23 | class GRUGating(nn.Module): 24 | def __init__(self, dim, fn): 25 | super().__init__() 26 | self.dim = dim 27 | self.fn = fn 28 | self.gru = nn.GRUCell(dim, dim) 29 | 30 | def forward(self, x, **kwargs): 31 | b, dim = x.shape[0], self.dim 32 | y = self.fn(x, **kwargs) 33 | 34 | gated_output = self.gru( 35 | rearrange(y, "... d -> (...) d"), rearrange(x, "... d -> (...) d") 36 | ) 37 | 38 | gated_output = rearrange(gated_output, "(b n) d -> b n d", b=b) 39 | return gated_output 40 | 41 | 42 | # main class 43 | 44 | 45 | class Perceiver(nn.Module): 46 | def __init__( 47 | self, 48 | *, 49 | num_freq_bands, 50 | depth, 51 | max_freq, 52 | freq_base=2, 53 | input_channels=3, 54 | input_axis=2, 55 | num_latents=512, 56 | latent_dim=512, 57 | cross_heads=1, 58 | latent_heads=8, 59 | cross_dim_head=64, 60 | latent_dim_head=64, 61 | num_classes=1000, 62 | attn_dropout=0.0, 63 | ff_dropout=0.0, 64 | weight_tie_layers=False 65 | ): 66 | super().__init__() 67 | self.input_axis = input_axis 68 | self.max_freq = max_freq 69 | self.num_freq_bands = num_freq_bands 70 | self.freq_base = freq_base 71 | 72 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels 73 | 74 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 75 | 76 | get_cross_attn = lambda: GRUGating( 77 | latent_dim, 78 | PreNorm( 79 | latent_dim, 80 | Attention( 81 | latent_dim, 82 | input_dim, 83 | heads=cross_heads, 84 | dim_head=cross_dim_head, 85 | dropout=attn_dropout, 86 | ), 87 | context_dim=input_dim, 88 | ), 89 | ) 90 | get_latent_attn = lambda: GRUGating( 91 | latent_dim, 92 | PreNorm( 93 | latent_dim, 94 | Attention( 95 | latent_dim, 96 | heads=latent_heads, 97 | dim_head=latent_dim_head, 98 | dropout=attn_dropout, 99 | ), 100 | ), 101 | ) 102 | get_cross_ff = lambda: Residual( 103 | PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 104 | ) 105 | get_latent_ff = lambda: Residual( 106 | PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 107 | ) 108 | 109 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map( 110 | cache_fn, 111 | (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff), 112 | ) 113 | 114 | self.layers = nn.ModuleList([]) 115 | for i in range(depth): 116 | should_cache = i > 0 and weight_tie_layers 117 | cache_args = {"_cache": should_cache} 118 | 119 | self.layers.append( 120 | nn.ModuleList( 121 | [ 122 | get_cross_attn(**cache_args), 123 | get_cross_ff(**cache_args), 124 | get_latent_attn(**cache_args), 125 | get_latent_ff(**cache_args), 126 | ] 127 | ) 128 | ) 129 | 130 | self.to_logits = nn.Sequential( 131 | nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes) 132 | ) 133 | 134 | def forward(self, data, mask=None): 135 | b, *axis, _, device = *data.shape, data.device 136 | assert ( 137 | len(axis) == self.input_axis 138 | ), "input data must have the right number of axis" 139 | 140 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 141 | 142 | axis_pos = list( 143 | map( 144 | lambda size: torch.linspace( 145 | -1.0, 1.0, steps=size, device=device 146 | ), 147 | axis, 148 | ) 149 | ) 150 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 151 | enc_pos = fourier_encode( 152 | pos, self.max_freq, self.num_freq_bands, base=self.freq_base 153 | ) 154 | enc_pos = rearrange(enc_pos, "... n d -> ... (n d)") 155 | enc_pos = repeat(enc_pos, "... -> b ...", b=b) 156 | 157 | # concat to channels of data and flatten axis 158 | 159 | data = torch.cat((data, enc_pos), dim=-1) 160 | data = rearrange(data, "b ... d -> b (...) d") 161 | 162 | x = repeat(self.latents, "n d -> b n d", b=b) 163 | 164 | for cross_attn, cross_ff, latent_attn, latent_ff in self.layers: 165 | x = cross_attn(x, context=data, mask=mask) 166 | x = cross_ff(x) 167 | x = latent_attn(x) 168 | x = latent_ff(x) 169 | 170 | x = x.mean(dim=-2) 171 | return self.to_logits(x) 172 | -------------------------------------------------------------------------------- /perceiver_pytorch/layers.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | from torch import nn, einsum 6 | from torch.nn import functional as F 7 | 8 | from perceiver_pytorch.rotary import apply_rotary_emb 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | 19 | def cache_fn(f): 20 | cache = None 21 | 22 | @wraps(f) 23 | def cached_fn(*args, _cache=True, **kwargs): 24 | if not _cache: 25 | return f(*args, **kwargs) 26 | nonlocal cache 27 | if cache is not None: 28 | return cache 29 | cache = f(*args, **kwargs) 30 | return cache 31 | 32 | return cached_fn 33 | 34 | 35 | class PreNorm(nn.Module): 36 | def __init__(self, dim, fn, context_dim=None): 37 | super().__init__() 38 | self.fn = fn 39 | self.norm = nn.LayerNorm(dim) 40 | self.norm_context = ( 41 | nn.LayerNorm(context_dim) if exists(context_dim) else None 42 | ) 43 | 44 | def forward(self, x, **kwargs): 45 | x = self.norm(x) 46 | 47 | if exists(self.norm_context): 48 | context = kwargs["context"] 49 | normed_context = self.norm_context(context) 50 | kwargs.update(context=normed_context) 51 | 52 | return self.fn(x, **kwargs) 53 | 54 | 55 | class GEGLU(nn.Module): 56 | """ 57 | Gaussian Error Gated Linear Unit. 58 | See Shazer 2020: https://arxiv.org/abs/2002.05202 59 | """ 60 | def forward(self, x): 61 | x, gates = x.chunk(2, dim=-1) 62 | return x * F.gelu(gates) 63 | 64 | 65 | class FeedForward(nn.Module): 66 | """Feed forward neural net with GEGLU activation.""" 67 | 68 | def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0): 69 | """ 70 | Args: 71 | dim: Input & Output size. 72 | mult: The inner dimension of the FF net will be dim * mult. 73 | dropout: Proportion to dropout after the GEGLU. 74 | """ 75 | super().__init__() 76 | self.net = nn.Sequential( 77 | nn.Linear(dim, dim * mult * 2), 78 | GEGLU(), 79 | nn.Dropout(dropout), 80 | nn.Linear(dim * mult, dim), 81 | ) 82 | 83 | def forward(self, x): 84 | return self.net(x) 85 | 86 | 87 | class Attention(nn.Module): 88 | def __init__( 89 | self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0 90 | ): 91 | """ 92 | Args: 93 | query_dim: Size of the queries. 94 | context_dim: Size of the 'context' (the 'byte array' in the paper). 95 | If None, will default to the query_dim. 96 | heads: Number of attention heads. 97 | dim_head: Number of dimensions per head. 98 | dropout: Proportion to dropout (in the final linear layer). 99 | """ 100 | 101 | super().__init__() 102 | inner_dim = dim_head * heads 103 | context_dim = default(context_dim, query_dim) 104 | 105 | self.scale = dim_head ** -0.5 106 | self.heads = heads 107 | 108 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 109 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 110 | 111 | self.to_out = nn.Sequential( 112 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 113 | ) 114 | 115 | def forward(self, x, context=None, mask=None, pos_emb=None): 116 | """ 117 | 118 | Args: 119 | x: The 'latent array' in the Perceiver paper. 120 | context: The 'byte array' in the Perceiver paper (the input data). 121 | mask: 122 | pos_emb: 123 | 124 | Returns: 125 | 126 | """ 127 | 128 | h = self.heads 129 | 130 | q = self.to_q(x) 131 | context = default(context, x) 132 | k, v = self.to_kv(context).chunk(2, dim=-1) 133 | 134 | # Rearrange the query, key and value tensors. 135 | # b = batch size; n = TODO (PD-2021-09-13) 136 | # h = number of heads; d = number of dims per head. 137 | q, k, v = map( 138 | lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v) 139 | ) 140 | 141 | if exists(pos_emb): 142 | q, k = apply_rotary_emb(q, k, pos_emb) 143 | 144 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 145 | 146 | if exists(mask): 147 | mask = rearrange(mask, "b ... -> b (...)") 148 | max_neg_value = -torch.finfo(sim.dtype).max 149 | mask = repeat(mask, "b j -> (b h) () j", h=h) 150 | sim.masked_fill_(~mask, max_neg_value) 151 | 152 | # attention, what we cannot get enough of 153 | attn = sim.softmax(dim=-1) 154 | 155 | out = einsum("b i j, b j d -> b i d", attn, v) 156 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 157 | return self.to_out(out) 158 | -------------------------------------------------------------------------------- /perceiver_pytorch/mixed_latents.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | 7 | from perceiver_pytorch.layers import exists, default, cache_fn, PreNorm, FeedForward, Attention 8 | from perceiver_pytorch.utils import fourier_encode 9 | 10 | 11 | # latent mixer 12 | 13 | 14 | def Mixer(seq_len, mult=4, dropout=0.0): 15 | return nn.Sequential( 16 | nn.Conv1d(seq_len, seq_len * mult, 1), 17 | nn.GELU(), 18 | nn.Dropout(dropout), 19 | nn.Conv1d(seq_len * mult, seq_len, 1), 20 | ) 21 | 22 | 23 | # main class 24 | 25 | 26 | class Perceiver(nn.Module): 27 | def __init__( 28 | self, 29 | *, 30 | num_freq_bands, 31 | depth, 32 | max_freq, 33 | input_channels=3, 34 | input_axis=2, 35 | num_latents=512, 36 | latent_dim=512, 37 | cross_heads=1, 38 | latent_heads=8, 39 | cross_dim_head=64, 40 | latent_dim_head=64, 41 | num_classes=1000, 42 | attn_dropout=0.0, 43 | ff_dropout=0.0, 44 | weight_tie_layers=False, 45 | **kwargs 46 | ): 47 | super().__init__() 48 | self.input_axis = input_axis 49 | self.max_freq = max_freq 50 | self.num_freq_bands = num_freq_bands 51 | 52 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels 53 | 54 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 55 | 56 | get_cross_attn = lambda: PreNorm( 57 | latent_dim, 58 | Attention( 59 | latent_dim, 60 | input_dim, 61 | heads=cross_heads, 62 | dim_head=cross_dim_head, 63 | dropout=attn_dropout, 64 | ), 65 | context_dim=input_dim, 66 | ) 67 | get_latent_attn = lambda: PreNorm(latent_dim, Mixer(num_latents, dropout=ff_dropout)) 68 | get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 69 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 70 | 71 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map( 72 | cache_fn, 73 | (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff), 74 | ) 75 | 76 | self.layers = nn.ModuleList([]) 77 | for i in range(depth): 78 | should_cache = i > 0 and weight_tie_layers 79 | cache_args = {"_cache": should_cache} 80 | 81 | self.layers.append( 82 | nn.ModuleList( 83 | [ 84 | get_cross_attn(**cache_args), 85 | get_cross_ff(**cache_args), 86 | get_latent_attn(**cache_args), 87 | get_latent_ff(**cache_args), 88 | ] 89 | ) 90 | ) 91 | 92 | self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes)) 93 | 94 | def forward(self, data, mask=None): 95 | b, *axis, _, device = *data.shape, data.device 96 | assert len(axis) == self.input_axis, "input data must have the right number of axis" 97 | 98 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 99 | 100 | axis_pos = list( 101 | map( 102 | lambda size: torch.linspace(-1.0, 1.0, steps=size, device=device), 103 | axis, 104 | ) 105 | ) 106 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 107 | enc_pos = fourier_encode( 108 | x=pos, 109 | max_freq=self.max_freq, 110 | num_bands=self.num_freq_bands, 111 | ) 112 | enc_pos = rearrange(enc_pos, "... n d -> ... (n d)") 113 | enc_pos = repeat(enc_pos, "... -> b ...", b=b) 114 | 115 | # concat to channels of data and flatten axis 116 | 117 | data = torch.cat((data, enc_pos), dim=-1) 118 | data = rearrange(data, "b ... d -> b (...) d") 119 | 120 | x = repeat(self.latents, "n d -> b n d", b=b) 121 | 122 | for cross_attn, cross_ff, latent_attn, latent_ff in self.layers: 123 | x = cross_attn(x, context=data, mask=mask) + x 124 | x = cross_ff(x) + x 125 | x = latent_attn(x) + x 126 | x = latent_ff(x) + x 127 | 128 | x = x.mean(dim=-2) 129 | return self.to_logits(x) 130 | -------------------------------------------------------------------------------- /perceiver_pytorch/modalities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class InputModality: 7 | name: str 8 | input_channels: int 9 | input_axis: int 10 | num_freq_bands: int 11 | max_freq: float 12 | sin_only: bool = False 13 | fourier_encode: bool = True 14 | 15 | @property 16 | def input_dim(self) -> int: 17 | # Calculate the dimension of this modality. 18 | if self.fourier_encode: 19 | fourier_channels = self.input_axis * ((self.num_freq_bands * 2) + 1) 20 | fourier_channels = fourier_channels // 2 if self.sin_only else fourier_channels 21 | input_dim = fourier_channels + self.input_channels 22 | return input_dim 23 | else: 24 | return self.input_channels 25 | 26 | 27 | def modality_encoding( 28 | batch_size: int, axes, modality_index: int, num_modalities: int 29 | ) -> torch.Tensor: 30 | """ 31 | Return one-hot encoding of modality given num_modalities, batch size and axes. 32 | The result need to be compatible with the modality data for concatenation. 33 | 34 | Args: 35 | batch_size: Batch size of the input 36 | axes: The size of each axis, other than batch size, of the input 37 | modality_index: The index of this modality i.e. if there are 3 modalities, this would be 0, 1, or 2 38 | num_modalities: Total number of modalities 39 | 40 | Returns: 41 | One hot encoding of which modality the input is 42 | 43 | """ 44 | one_hot = torch.eye(num_modalities, num_modalities)[modality_index] 45 | to_expand = [batch_size] 46 | one_hot = one_hot.unsqueeze(0) 47 | for i, axis in enumerate(axes): 48 | one_hot = one_hot.unsqueeze(0) 49 | to_expand.append(axis) 50 | to_expand.append(num_modalities) 51 | 52 | one_hot = one_hot.expand(to_expand) 53 | return one_hot 54 | -------------------------------------------------------------------------------- /perceiver_pytorch/multi_perceiver_pytorch.py: -------------------------------------------------------------------------------- 1 | from perceiver_pytorch.perceiver_io import PerceiverIO 2 | from perceiver_pytorch.modalities import InputModality, modality_encoding 3 | from perceiver_pytorch.utils import encode_position, fourier_encode 4 | import torch 5 | from typing import List, Iterable, Dict, Optional, Any, Union, Tuple 6 | from einops import rearrange, repeat 7 | from math import prod 8 | 9 | 10 | class MultiPerceiver(torch.nn.Module): 11 | def __init__( 12 | self, 13 | modalities: Iterable[InputModality], 14 | fourier_encode_data: bool = True, 15 | input_channels: int = 3, 16 | output_channels: int = 12, 17 | forecast_steps: int = 48, 18 | sine_only: bool = False, 19 | output_shape: Union[int, Tuple[int, ...]] = 32, 20 | **kwargs, 21 | ): 22 | """ 23 | PerceiverIO made to work more specifically with timeseries images and multimodal inputs https://arxiv.org/abs/2107.14795 24 | This is a wrapper around the PerceiverIO implementation to encode the inputs correctly 25 | 26 | Args: 27 | input_channels: Number of input channels (int) 28 | forecast_steps: Number of forecast steps to make (int) 29 | fourier_encode_data: Whether to add Fourier Features to the input data, if this is false, inputs should be have some type of positional encoding added beforehand 30 | output_channels: Number of output channels per image (int) 31 | sine_only: Only use Sine part of Fourier features (bool) 32 | output_shape: Int or Tuple of ints, giving the desired output shape of the model 33 | **kwargs: Extra kwargs to pass through to PerceiverIO 34 | """ 35 | super(MultiPerceiver, self).__init__() 36 | self.fourier_encode_data = fourier_encode_data 37 | self.forecast_steps = forecast_steps 38 | self.input_channels = input_channels 39 | self.sine_only = sine_only 40 | self.output_channels = output_channels 41 | self.modalities = {modality.name: modality for modality in modalities} 42 | # we encode modality with one hot encoding, so need one dim per modality: 43 | modality_encoding_dim = len(modalities) 44 | # input_dim is the maximum dimension over all input modalities: 45 | input_dim = max(modality.input_dim for modality in modalities) + modality_encoding_dim 46 | # Pop dim 47 | self.max_modality_dim = input_dim 48 | kwargs.pop("dim", None) 49 | # Want toe logit_dim to be the same as the channels * width or height 50 | if isinstance(output_shape, int): 51 | kwargs["logits_dim"] = output_shape * self.output_channels 52 | else: 53 | kwargs["logits_dim"] = prod(output_shape) 54 | self.perceiver = PerceiverIO(dim=input_dim, **kwargs) 55 | 56 | def decode_output(self, data): 57 | pass 58 | 59 | def forward(self, multi_modality_data: Dict[str, torch.Tensor], mask=None, queries=None): 60 | batch_sizes = set() 61 | num_modalities = len(multi_modality_data) 62 | linearized_data = [] 63 | 64 | for modality_index, modality_name in enumerate(sorted(multi_modality_data.keys())): 65 | assert ( 66 | modality_name in self.modalities 67 | ), f"modality {modality_name} was not defined in constructor" 68 | data = multi_modality_data[modality_name] 69 | modality = self.modalities[modality_name] 70 | b, *axis, _ = data.size() 71 | assert len(axis) == modality.input_axis, ( 72 | f"input data must have the right number of axes for modality {modality_name}. " 73 | f"Expected {modality.input_axis} while forward argument offered {len(axis)}" 74 | ) 75 | batch_sizes.add(b) 76 | assert len(batch_sizes) == 1, "batch size must be the same across all modalities" 77 | enc_pos = [] 78 | if self.fourier_encode_data: 79 | # calculate fourier encoded positions in the range of [-1, 1], for all axis 80 | enc_pos = encode_position( 81 | batch_size=b, 82 | axis=axis, 83 | max_frequency=modality.max_freq, 84 | num_frequency_bands=modality.num_freq_bands, 85 | sine_only=self.sine_only, 86 | ).type_as(data) 87 | 88 | # Figure out padding for this modality, given max dimension across all modalities: 89 | padding_size = self.max_modality_dim - modality.input_dim - num_modalities 90 | 91 | padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).type_as(data) 92 | # concat to channels of data and flatten axis 93 | modality_encodings = modality_encoding(b, axis, modality_index, num_modalities).type_as( 94 | data 95 | ) 96 | to_concat = ( 97 | (data, padding, enc_pos, modality_encodings) 98 | if len(enc_pos) > 0 99 | else (data, padding, modality_encodings) 100 | ) 101 | data = torch.cat(to_concat, dim=-1) 102 | # concat to channels of data and flatten axis 103 | data = rearrange(data, "b ... d -> b (...) d") 104 | linearized_data.append(data) 105 | 106 | # Concatenate all the modalities: 107 | data = torch.cat(linearized_data, dim=1) 108 | 109 | perceiver_output = self.perceiver.forward(data, mask, queries) 110 | 111 | # To keep this more general, leave the reshaping to postprocessing outside the model 112 | return perceiver_output 113 | -------------------------------------------------------------------------------- /perceiver_pytorch/perceiver_io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import repeat 5 | from perceiver_pytorch.layers import exists, cache_fn, PreNorm, FeedForward, Attention 6 | 7 | # main class 8 | 9 | class PerceiverIO(nn.Module): 10 | def __init__( 11 | self, 12 | *, 13 | depth, 14 | dim, 15 | queries_dim, 16 | logits_dim = None, 17 | num_latents = 512, 18 | latent_dim = 512, 19 | cross_heads = 1, 20 | latent_heads = 8, 21 | cross_dim_head = 64, 22 | latent_dim_head = 64, 23 | weight_tie_layers = False, 24 | decoder_ff = False 25 | ): 26 | """ 27 | PerceiverIO implementation from https://arxiv.org/abs/2107.14795 28 | 29 | Args: 30 | depth: Depth of the network 31 | dim: dimension of sequence to be encoded 32 | queries_dim: Dimension of the decoder queries 33 | logits_dim: Dimension of final logits 34 | num_latents: Number of latents 35 | latent_dim: Latent dimension 36 | cross_heads: Number of heads for cross attention 37 | latent_heads: Number of heads for latent self-attention 38 | cross_dim_head: Number of dimensions per cross attention head 39 | latent_dim_head: Number of dimensions per latent self-attention head 40 | weight_tie_layers: Whether to weight tie layers 41 | decoder_ff: Whether to use a feed forward network on the decoder queries 42 | """ 43 | super().__init__() 44 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 45 | 46 | self.cross_attend_blocks = nn.ModuleList([ 47 | PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim), 48 | PreNorm(latent_dim, FeedForward(latent_dim)) 49 | ]) 50 | 51 | get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head)) 52 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim)) 53 | get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff)) 54 | 55 | self.layers = nn.ModuleList([]) 56 | cache_args = {'_cache': weight_tie_layers} 57 | 58 | for i in range(depth): 59 | self.layers.append(nn.ModuleList([ 60 | get_latent_attn(**cache_args), 61 | get_latent_ff(**cache_args) 62 | ])) 63 | 64 | self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim) 65 | self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None 66 | 67 | self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity() 68 | 69 | def forward( 70 | self, 71 | data, 72 | mask = None, 73 | queries = None 74 | ): 75 | b, *_, device = *data.shape, data.device 76 | 77 | x = repeat(self.latents, 'n d -> b n d', b = b) 78 | 79 | cross_attn, cross_ff = self.cross_attend_blocks 80 | 81 | # cross attention only happens once for Perceiver IO 82 | 83 | x = cross_attn(x, context = data, mask = mask) + x 84 | x = cross_ff(x) + x 85 | 86 | # layers 87 | 88 | for self_attn, self_ff in self.layers: 89 | x = self_attn(x) + x 90 | x = self_ff(x) + x 91 | 92 | if not exists(queries): 93 | return x 94 | 95 | # cross attend from decoder queries to latents 96 | 97 | latents = self.decoder_cross_attn(queries, context = x) 98 | 99 | # optional decoder feedforward 100 | 101 | if exists(self.decoder_ff): 102 | latents = latents + self.decoder_ff(latents) 103 | 104 | # final linear out 105 | 106 | return self.to_logits(latents) 107 | 108 | -------------------------------------------------------------------------------- /perceiver_pytorch/perceiver_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from torch import nn 4 | 5 | from perceiver_pytorch.layers import exists, cache_fn, PreNorm, FeedForward, Attention 6 | from perceiver_pytorch.rotary import SinusoidalEmbeddings 7 | from perceiver_pytorch.utils import encode_position 8 | 9 | # main class 10 | 11 | 12 | class Perceiver(nn.Module): 13 | def __init__( 14 | self, 15 | *, 16 | num_freq_bands, 17 | depth, 18 | max_freq, 19 | input_channels=3, 20 | input_axis=2, 21 | num_latents=512, 22 | latent_dim=512, 23 | cross_heads=1, 24 | latent_heads=8, 25 | cross_dim_head=64, 26 | latent_dim_head=64, 27 | num_classes=1000, 28 | attn_dropout=0.0, 29 | ff_dropout=0.0, 30 | weight_tie_layers=False, 31 | fourier_encode_data=True, 32 | sine_only: bool = False, 33 | self_per_cross_attn=1, 34 | self_attn_rel_pos=True, 35 | ): 36 | """ 37 | Perceiver: https://arxiv.org/abs/2103.03206 38 | The shape of the final attention mechanism will be: 39 | depth * (cross attention -> self_per_cross_attn * self attention) 40 | 41 | Args: 42 | num_freq_bands: Number of freq bands, with original value (2 * K + 1) 43 | depth: Depth of net. 44 | max_freq: Maximum frequency, hyperparameter depending on how 45 | fine the data is. 46 | input_channels: Number of channels for each token of the input. 47 | input_axis: Number of axes for input data (2 for images, 3 for video) 48 | num_latents: Number of latents, or induced set points, or centroids. 49 | Different papers giving it different names. 50 | latent_dim: Latent dimension. 51 | cross_heads: Number of heads for cross attention. Paper said 1. 52 | latent_heads: Number of heads for latent self attention, 8. 53 | cross_dim_head: Number of dimensions per cross attention head. 54 | latent_dim_head: Number of dimensions per latent self attention head. 55 | num_classes: Output number of classes. 56 | attn_dropout: Attention dropout 57 | ff_dropout: Feedforward dropout 58 | weight_tie_layers: Whether to weight tie layers (optional). 59 | fourier_encode_data: Whether to auto-fourier encode the data, using 60 | the input_axis given. defaults to True, but can be turned off 61 | if you are fourier encoding the data yourself. 62 | sine_only: Use only sine encoding in fourier encoding, compared to using sine and cos 63 | self_per_cross_attn: Number of self attention blocks per cross attn. 64 | self_attn_rel_pos: 65 | """ 66 | super().__init__() 67 | self.input_axis = input_axis 68 | self.max_freq = max_freq 69 | self.num_freq_bands = num_freq_bands 70 | 71 | self.fourier_encode_data = fourier_encode_data 72 | fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 73 | self.sine_only = sine_only 74 | input_dim = fourier_channels + input_channels 75 | 76 | # Randomly initialise the 'latent array'. 77 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) 78 | 79 | def get_cross_attn(): 80 | return PreNorm( 81 | latent_dim, 82 | Attention( 83 | latent_dim, 84 | input_dim, 85 | heads=cross_heads, 86 | dim_head=cross_dim_head, 87 | dropout=attn_dropout, 88 | ), 89 | context_dim=input_dim, 90 | ) 91 | 92 | def get_cross_ff(): 93 | return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 94 | 95 | def get_latent_attn(): 96 | return PreNorm( 97 | latent_dim, 98 | Attention( 99 | latent_dim, 100 | heads=latent_heads, 101 | dim_head=latent_dim_head, 102 | dropout=attn_dropout, 103 | ), 104 | ) 105 | 106 | def get_latent_ff(): 107 | return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) 108 | 109 | # Cache all the above functions. 110 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map( 111 | cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff) 112 | ) 113 | 114 | self.layers = nn.ModuleList([]) 115 | for i in range(depth): 116 | should_cache = i > 0 and weight_tie_layers 117 | cache_args = {"_cache": should_cache} 118 | 119 | self_attns = nn.ModuleList([]) 120 | 121 | for _ in range(self_per_cross_attn): 122 | self_attns.append( 123 | nn.ModuleList( 124 | [ 125 | get_latent_attn(**cache_args), 126 | get_latent_ff(**cache_args), 127 | ] 128 | ) 129 | ) 130 | 131 | self.layers.append( 132 | nn.ModuleList( 133 | [ 134 | get_cross_attn(**cache_args), 135 | get_cross_ff(**cache_args), 136 | self_attns, 137 | ] 138 | ) 139 | ) 140 | 141 | self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes)) 142 | 143 | self.sinu_emb = None 144 | if self_attn_rel_pos: 145 | self.sinu_emb = SinusoidalEmbeddings(latent_dim_head) 146 | 147 | def forward(self, data, mask=None): 148 | """ 149 | Args: 150 | data: If sequential is True, then data must be of shape: 151 | (batch size, sequence length, *axes) where axes would be width 152 | and height for images. 153 | """ 154 | 155 | b, *axis, _ = data.shape 156 | device = data.device 157 | 158 | assert ( 159 | len(axis) == self.input_axis 160 | ), f"Input data must have {self.input_axis} axes, not {len(axis)}!" 161 | 162 | if self.fourier_encode_data: 163 | # Calculate Fourier encoded positions in the range of [-1, 1], 164 | # for all axes. 165 | enc_pos = encode_position( 166 | b, 167 | axis, 168 | self.max_freq, 169 | self.num_freq_bands, 170 | sine_only=self.sine_only, 171 | ).type_as(data) 172 | 173 | data = torch.cat((data, enc_pos), dim=-1) 174 | 175 | # Concat to channels of data and flatten axes. 176 | # b = batch size; d = last dimension of data 177 | data = rearrange(data, "b ... d -> b (...) d", b=b) 178 | 179 | # x is the 'latent array' in the paper. 180 | # b = batch size; n = number of latents; d = latent dimensions. 181 | x = repeat(self.latents, "n d -> b n d", b=b) 182 | 183 | # Rotary embeddings for latents, if specified. 184 | pos_emb = self.sinu_emb(x) if exists(self.sinu_emb) else None 185 | 186 | # Layers. 187 | for cross_attn, cross_ff, self_attns in self.layers: 188 | x = cross_attn(x, context=data, mask=mask) + x 189 | x = cross_ff(x) + x 190 | 191 | for self_attn, self_ff in self_attns: 192 | x = self_attn(x, pos_emb=pos_emb) + x 193 | x = self_ff(x) + x 194 | 195 | x = x.mean(dim=-2) 196 | return self.to_logits(x) 197 | -------------------------------------------------------------------------------- /perceiver_pytorch/queries.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import uniform 3 | from typing import List, Union, Tuple, Optional 4 | from perceiver_pytorch.utils import encode_position 5 | from math import prod 6 | import einops 7 | import logging 8 | 9 | _LOG = logging.getLogger("perceiver.queries") 10 | _LOG.setLevel(logging.WARN) 11 | 12 | 13 | class LearnableQuery(torch.nn.Module): 14 | """ 15 | Module that constructs a learnable query of query_shape for the Perceiver 16 | """ 17 | 18 | def __init__( 19 | self, 20 | channel_dim: int, 21 | query_shape: Union[Tuple[int], List[int]], 22 | conv_layer: str = "3d", 23 | max_frequency: float = 16.0, 24 | num_frequency_bands: int = 64, 25 | sine_only: bool = False, 26 | precomputed_fourier: Optional[torch.Tensor] = None, 27 | generate_fourier_features: bool = False, 28 | ): 29 | """ 30 | Learnable Query with some inbuilt randomness to help with ensembling 31 | 32 | Args: 33 | channel_dim: Channel dimension for the output of the network 34 | query_shape: The final shape of the query, generally, the (T, H, W) of the output 35 | conv_layer: The type of convolutional layer to use, either 3d or 2d 36 | max_frequency: Max frequency for the Fourier Features 37 | num_frequency_bands: Number of frequency bands for the Fourier Features 38 | sine_only: Whether to use only the sine Fourier features 39 | precomputed_fourier: Fourier features to use instead of computing them here, 40 | useful for having temporally consistent features from history timesteps to future predictions 41 | These features will be concatenated directly to the query, so should be compatible, and made in the 42 | same way as in encode_position 43 | generate_fourier_features: Whether to use generated Fourier features giving the relative 44 | position within the predictions. 45 | """ 46 | super().__init__() 47 | self.query_shape = query_shape 48 | self.generate_fourier_features = generate_fourier_features 49 | # Need to get Fourier Features once and then just append to the output 50 | fourier_features = [] 51 | if precomputed_fourier is not None: 52 | fourier_features += [precomputed_fourier] 53 | if self.generate_fourier_features: 54 | generated_features = encode_position( 55 | 1, # Batch size, 1 for this as it will be adapted in forward 56 | axis=query_shape, 57 | max_frequency=max_frequency, 58 | num_frequency_bands=num_frequency_bands, 59 | sine_only=sine_only, 60 | ) 61 | fourier_features += [generated_features] 62 | if len(fourier_features) > 1: 63 | self.fourier_features = torch.cat(fourier_features, dim=-1) 64 | elif fourier_features: 65 | # Only have one of the two options 66 | self.fourier_features = fourier_features[0] 67 | else: 68 | # None are set 69 | self.fourier_features = None 70 | 71 | self.channel_dim = channel_dim 72 | if ( 73 | conv_layer == "3d" and len(self.query_shape) == 3 74 | ): # If Query shape is for an image, then 3D conv won't work 75 | conv = torch.nn.Conv3d 76 | elif conv_layer == "2d": 77 | conv = torch.nn.Conv2d 78 | else: 79 | raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'") 80 | self.conv_layer = conv_layer 81 | self.layer = conv( 82 | in_channels=channel_dim, out_channels=channel_dim, kernel_size=3, padding=1 83 | ) 84 | # Linear layer to compress channels down to query_dim size? 85 | self.fc = torch.nn.Linear(self.channel_dim, self.channel_dim) 86 | self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) 87 | 88 | def output_shape(self) -> Tuple[int, int]: 89 | """ 90 | Gives the output shape from the query, useful for setting the correct 91 | query_dim in the Perceiver 92 | 93 | Returns: 94 | The shape of the resulting query, excluding the batch size 95 | """ 96 | 97 | # The shape is the query_dim + Fourier Feature channels 98 | if self.fourier_features is not None: 99 | channels = self.fourier_features.shape[-1] + self.channel_dim 100 | else: 101 | channels = self.channel_dim 102 | return prod(self.query_shape), channels 103 | 104 | def forward( 105 | self, x: torch.Tensor, fourier_features: Optional[torch.Tensor] = None 106 | ) -> torch.Tensor: 107 | """ 108 | Samples the uniform distribution and creates the query by passing the 109 | sample through the model and appending Fourier features 110 | 111 | Args: 112 | x: The input tensor to the model, used to batch the batch size 113 | fourier_features: Fourier features to append to the input, if used, the output_shape will be incorrect, 114 | and to get the correct output shape, the fourier_features channels have to be added to output_shape 115 | 116 | Returns: 117 | Torch tensor used to query the output of the PerceiverIO model 118 | """ 119 | _LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.channel_dim}") 120 | z = self.distribution.sample((x.shape[0], self.channel_dim, *self.query_shape)).type_as( 121 | x 122 | ) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1] 123 | z = torch.squeeze(z, dim=-1) # Extra 1 for some reason 124 | _LOG.debug(f"Z: {z.shape}") 125 | # Do 3D or 2D CNN to keep same spatial size, concat, then linearize 126 | if self.conv_layer == "2d" and len(self.query_shape) == 3: 127 | # Iterate through time dimension 128 | outs = [] 129 | for i in range(x.shape[1]): 130 | outs.append(self.layer(z[:, :, i, :, :])) 131 | query = torch.stack(outs, dim=2) 132 | else: 133 | query = self.layer(z) 134 | # Move channels to correct location 135 | query = einops.rearrange(query, "b c ... -> b ... c") 136 | to_concat = [query] 137 | if self.fourier_features is not None: 138 | ff = einops.repeat( 139 | self.fourier_features, "b ... -> (repeat b) ...", repeat=x.shape[0] 140 | ) # Match batches 141 | to_concat = to_concat + [ff] 142 | if fourier_features is not None: 143 | to_concat = to_concat + [fourier_features] 144 | if len(to_concat) > 1: 145 | query = torch.cat(to_concat, dim=-1) 146 | # concat to channels of data and flatten axis 147 | query = einops.rearrange(query, "b ... d -> b (...) d") 148 | _LOG.debug(f"Final Query Shape: {query.shape}") 149 | return query 150 | -------------------------------------------------------------------------------- /perceiver_pytorch/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | 5 | 6 | class SinusoidalEmbeddings(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 10 | self.register_buffer("inv_freq", inv_freq) 11 | 12 | def forward(self, x): 13 | n = x.shape[-2] 14 | t = torch.arange(n, device=x.device).type_as(self.inv_freq) 15 | sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) 16 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 17 | return emb[None, :, :] 18 | 19 | 20 | def rotate_every_two(x): 21 | x = rearrange(x, "... (d j) -> ... d j", j=2) 22 | x1, x2 = x.unbind(dim=-1) 23 | x = torch.stack((-x2, x1), dim=-1) 24 | return rearrange(x, "... d j -> ... (d j)") 25 | 26 | 27 | def apply_rotary_emb(q, k, sinu_pos): 28 | sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2) 29 | sin, cos = sinu_pos.unbind(dim=-2) 30 | sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos)) 31 | q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) 32 | return q, k 33 | -------------------------------------------------------------------------------- /perceiver_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | from math import log, pi 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | import einops 8 | 9 | 10 | def extract_image_patches( 11 | x: torch.Tensor, kernel: int, stride: int = 1, dilation: int = 1 12 | ) -> torch.Tensor: 13 | """ 14 | Extract image patches in a way similar to TensorFlow extract_image_patches 15 | Taken from https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8 16 | 17 | In the Perceiver JAX implementation they extract image patches matching TensorFlow's SAME padding. 18 | PyTorch doesn't have that same kind of option, so this is a way to do that. 19 | 20 | Args: 21 | x: Input Torch Tensor 22 | kernel: Size of kernel 23 | stride: Stride of patch 24 | dilation: Dilation rate 25 | 26 | Returns: 27 | Tensor of size [Batch, Height, Width, Channels*kernel*stride] 28 | 29 | """ 30 | # Do TF 'SAME' Padding 31 | b, c, h, w = x.shape 32 | h2 = math.ceil(h / stride) 33 | w2 = math.ceil(w / stride) 34 | pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h 35 | pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w 36 | x = F.pad(x, (pad_row // 2, pad_row - pad_row // 2, pad_col // 2, pad_col - pad_col // 2)) 37 | 38 | # Extract patches 39 | # get all image windows of size (kernel, stride) and stride (kernel, stride) 40 | patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride) 41 | # Permute so that channels are next to patch dimension 42 | patches = patches.permute(0, 4, 5, 1, 2, 3).contiguous() 43 | # View as [batch_size, height, width, channels*kh*kw] 44 | return patches.view(b, -1, patches.shape[-2], patches.shape[-1]) 45 | 46 | 47 | def reverse_space_to_depth( 48 | frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1 49 | ) -> torch.Tensor: 50 | """Reverse space to depth transform. 51 | Works for images (dim = 4) and videos (dim = 5)""" 52 | if len(frames.shape) == 4: 53 | return einops.rearrange( 54 | frames, 55 | "b (dh dw c) h w -> b c (h dh) (w dw)", 56 | dh=spatial_block_size, 57 | dw=spatial_block_size, 58 | ) 59 | elif len(frames.shape) == 5: 60 | return einops.rearrange( 61 | frames, 62 | "b t (dt dh dw c) h w -> b (t dt) c (h dh) (w dw)", 63 | dt=temporal_block_size, 64 | dh=spatial_block_size, 65 | dw=spatial_block_size, 66 | ) 67 | else: 68 | raise ValueError( 69 | "Frames should be of rank 4 (batch, height, width, channels)" 70 | " or rank 5 (batch, time, height, width, channels)" 71 | ) 72 | 73 | 74 | def space_to_depth( 75 | frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1 76 | ) -> torch.Tensor: 77 | """Space to depth transform. 78 | Works for images (dim = 4) and videos (dim = 5)""" 79 | if len(frames.shape) == 4: 80 | return einops.rearrange( 81 | frames, 82 | "b c (h dh) (w dw) -> b (dh dw c) h w", 83 | dh=spatial_block_size, 84 | dw=spatial_block_size, 85 | ) 86 | elif len(frames.shape) == 5: 87 | return einops.rearrange( 88 | frames, 89 | "b (t dt) c (h dh) (w dw) -> b t (dt dh dw c) h w ", 90 | dt=temporal_block_size, 91 | dh=spatial_block_size, 92 | dw=spatial_block_size, 93 | ) 94 | else: 95 | raise ValueError( 96 | "Frames should be of rank 4 (batch, height, width, channels)" 97 | " or rank 5 (batch, time, height, width, channels)" 98 | ) 99 | 100 | 101 | def encode_position( 102 | batch_size: int, 103 | axis: list, 104 | max_frequency: float, 105 | num_frequency_bands: int, 106 | sine_only: bool = False, 107 | ) -> torch.Tensor: 108 | """ 109 | Encode the Fourier Features and return them 110 | 111 | Args: 112 | batch_size: Batch size 113 | axis: List containing the size of each axis 114 | max_frequency: Max frequency 115 | num_frequency_bands: Number of frequency bands to use 116 | sine_only: (bool) Whether to only use Sine features or both Sine and Cosine, defaults to both 117 | 118 | Returns: 119 | Torch tensor containing the Fourier Features of shape [Batch, *axis] 120 | """ 121 | axis_pos = list( 122 | map( 123 | lambda size: torch.linspace(-1.0, 1.0, steps=size), 124 | axis, 125 | ) 126 | ) 127 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) 128 | enc_pos = fourier_encode( 129 | pos, 130 | max_frequency, 131 | num_frequency_bands, 132 | sine_only=sine_only, 133 | ) 134 | enc_pos = einops.rearrange(enc_pos, "... n d -> ... (n d)") 135 | enc_pos = einops.repeat(enc_pos, "... -> b ...", b=batch_size) 136 | return enc_pos 137 | 138 | 139 | def fourier_encode( 140 | x: torch.Tensor, 141 | max_freq: float, 142 | num_bands: int = 4, 143 | sine_only: bool = False, 144 | ) -> torch.Tensor: 145 | """ 146 | Create Fourier Encoding 147 | 148 | Args: 149 | x: Input Torch Tensor 150 | max_freq: Maximum frequency for the Fourier features 151 | num_bands: Number of frequency bands 152 | sine_only: Whether to only use sine or both sine and cosine features 153 | 154 | Returns: 155 | Torch Tensor with the fourier position encoded concatenated 156 | """ 157 | x = x.unsqueeze(-1) 158 | device, dtype, orig_x = x.device, x.dtype, x 159 | 160 | scales = torch.linspace( 161 | 1.0, 162 | max_freq / 2, 163 | num_bands, 164 | device=device, 165 | dtype=dtype, 166 | ) 167 | scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] 168 | 169 | x = x * scales * pi 170 | x = x.sin() if sine_only else torch.cat([x.sin(), x.cos()], dim=-1) 171 | x = torch.cat((x, orig_x), dim=-1) 172 | return x 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops>=0.3 2 | torch>=1.6 3 | numpy 4 | torchvision -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | 4 | this_directory = Path(__file__).parent 5 | install_requires = (this_directory / "requirements.txt").read_text().splitlines() 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | 9 | setup( 10 | name="perceiver-model", 11 | packages=find_packages(), 12 | version="0.7.6", 13 | license="MIT", 14 | description="Multimodal Perceiver - Pytorch", 15 | author="Jacob Bieker, Jack Kelly, Peter Dudfield", 16 | author_email="jacob@openclimatefix.org", 17 | company="Open Climate Fix Ltd", 18 | url="https://github.com/openclimatefix/perceiver-pytorch", 19 | keywords=[ 20 | "artificial intelligence", 21 | "deep learning", 22 | "transformer", 23 | "attention mechanism", 24 | ], 25 | long_description=long_description, 26 | long_description_content_type="text/markdown", 27 | install_requires=install_requires, 28 | classifiers=[ 29 | "Development Status :: 4 - Beta", 30 | "Intended Audience :: Developers", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | "License :: OSI Approved :: MIT License", 33 | "Programming Language :: Python :: 3.6", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test_decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from perceiver_pytorch.decoders import ImageDecoder 4 | import pytest 5 | 6 | 7 | def test_conv_image_decoder(): 8 | decoder = ImageDecoder( 9 | postprocess_type="conv", 10 | output_channels=12, 11 | input_channels=48, 12 | spatial_upsample=4, 13 | ) 14 | inputs = torch.randn(2, 48, 64, 64) 15 | with torch.no_grad(): 16 | out = decoder(inputs) 17 | assert not torch.isnan(out).any(), "Output included NaNs" 18 | assert out.size() == (2, 12, 256, 256) 19 | 20 | 21 | def test_conv1x1_image_decoder(): 22 | decoder = ImageDecoder( 23 | postprocess_type="conv1x1", 24 | output_channels=12, 25 | input_channels=48, 26 | spatial_upsample=4, 27 | ) 28 | inputs = torch.randn(2, 48, 64, 64) 29 | with torch.no_grad(): 30 | out = decoder(inputs) 31 | assert not torch.isnan(out).any(), "Output included NaNs" 32 | # Conv1x1 downsample if spatial_upsample > 1 33 | assert out.size() == (2, 12, 16, 16) 34 | 35 | 36 | def test_patches_image_decoder(): 37 | decoder = ImageDecoder(postprocess_type="patches", spatial_upsample=4) 38 | inputs = torch.randn(2, 192, 64, 64) 39 | with torch.no_grad(): 40 | out = decoder(inputs) 41 | assert not torch.isnan(out).any(), "Output included NaNs" 42 | assert out.size() == (2, 12, 256, 256) 43 | 44 | 45 | def test_pixel_image_decoder(): 46 | decoder = ImageDecoder(postprocess_type="pixels") 47 | inputs = torch.randn(2, 192, 64, 64) 48 | with torch.no_grad(): 49 | out = decoder(inputs) 50 | assert not torch.isnan(out).any(), "Output included NaNs" 51 | assert out.size() == (2, 192, 64, 64) 52 | assert pytest.approx(inputs, out) 53 | 54 | 55 | def test_conv_video_decoder(): 56 | decoder = ImageDecoder( 57 | postprocess_type="conv", 58 | output_channels=12, 59 | input_channels=48, 60 | spatial_upsample=4, 61 | ) 62 | inputs = torch.randn(2, 3, 48, 64, 64) 63 | with torch.no_grad(): 64 | out = decoder(inputs) 65 | assert not torch.isnan(out).any(), "Output included NaNs" 66 | assert out.size() == (2, 3, 12, 256, 256) 67 | 68 | 69 | def test_conv1x1_video_decoder(): 70 | decoder = ImageDecoder( 71 | postprocess_type="conv1x1", 72 | output_channels=12, 73 | input_channels=48, 74 | spatial_upsample=4, 75 | ) 76 | inputs = torch.randn(2, 3, 48, 64, 64) 77 | with torch.no_grad(): 78 | out = decoder(inputs) 79 | assert not torch.isnan(out).any(), "Output included NaNs" 80 | # Conv1x1 downsample if spatial_upsample > 1 81 | assert out.size() == (2, 3, 12, 16, 16) 82 | 83 | 84 | def test_conv3d_video_decoder(): 85 | decoder = ImageDecoder( 86 | postprocess_type="conv", 87 | output_channels=12, 88 | input_channels=48, 89 | spatial_upsample=4, 90 | temporal_upsample=2, 91 | ) 92 | inputs = torch.randn(2, 1, 48, 64, 64) 93 | with torch.no_grad(): 94 | out = decoder(inputs) 95 | assert not torch.isnan(out).any(), "Output included NaNs" 96 | assert out.size() == (2, 2, 12, 256, 256) 97 | 98 | 99 | def test_patches_video_decoder(): 100 | decoder = ImageDecoder(postprocess_type="patches", spatial_upsample=4) 101 | inputs = torch.randn(2, 3, 192, 64, 64) 102 | with torch.no_grad(): 103 | out = decoder(inputs) 104 | assert not torch.isnan(out).any(), "Output included NaNs" 105 | assert out.size() == (2, 3, 12, 256, 256) 106 | 107 | 108 | def test_pixel_video_decoder(): 109 | decoder = ImageDecoder(postprocess_type="pixels") 110 | inputs = torch.randn(2, 3, 192, 64, 64) 111 | with torch.no_grad(): 112 | out = decoder(inputs) 113 | assert not torch.isnan(out).any(), "Output included NaNs" 114 | assert out.size() == (2, 3, 192, 64, 64) 115 | assert pytest.approx(inputs, out) 116 | -------------------------------------------------------------------------------- /tests/test_encoders.py: -------------------------------------------------------------------------------- 1 | from perceiver_pytorch.encoders import ImageEncoder 2 | import torch 3 | import pytest 4 | 5 | 6 | @pytest.mark.parametrize("prep_type", ["conv", "conv1x1"]) 7 | def test_conv_image_encoder(prep_type): 8 | encoder = ImageEncoder(prep_type=prep_type, output_channels=48) 9 | image = torch.randn(2, 12, 256, 256) 10 | with torch.no_grad(): 11 | out = encoder(image) 12 | assert not torch.isnan(out).any(), "Output included NaNs" 13 | assert out.size() == (2, 48, 64, 64) 14 | 15 | 16 | def test_patches_image_encoder(): 17 | encoder = ImageEncoder(prep_type="patches", output_channels=48) 18 | image = torch.randn(2, 12, 256, 256) 19 | with torch.no_grad(): 20 | out = encoder(image) 21 | assert not torch.isnan(out).any(), "Output included NaNs" 22 | assert out.size() == (2, 192, 64, 64) 23 | 24 | 25 | def test_pixels_image_encoder(): 26 | encoder = ImageEncoder(prep_type="pixels", output_channels=48) 27 | image = torch.randn(2, 12, 256, 256) 28 | with torch.no_grad(): 29 | out = encoder(image) 30 | assert not torch.isnan(out).any(), "Output included NaNs" 31 | assert out.size() == (2, 12, 64, 64) 32 | 33 | 34 | @pytest.mark.parametrize("prep_type", ["conv", "conv1x1"]) 35 | def test_conv_video_encoder(prep_type): 36 | encoder = ImageEncoder(prep_type=prep_type, output_channels=48) 37 | image = torch.randn(2, 6, 12, 256, 256) 38 | with torch.no_grad(): 39 | out = encoder(image) 40 | assert not torch.isnan(out).any(), "Output included NaNs" 41 | assert out.size() == (2, 6, 48, 64, 64) 42 | 43 | 44 | def test_patches_video_encoder(): 45 | encoder = ImageEncoder(prep_type="patches", output_channels=48) 46 | image = torch.randn(2, 6, 12, 256, 256) 47 | with torch.no_grad(): 48 | out = encoder(image) 49 | assert not torch.isnan(out).any(), "Output included NaNs" 50 | assert out.size() == (2, 6, 192, 64, 64) 51 | 52 | 53 | def test_pixels_video_encoder(): 54 | encoder = ImageEncoder(prep_type="pixels", output_channels=48) 55 | image = torch.randn(2, 6, 12, 256, 256) 56 | with torch.no_grad(): 57 | out = encoder(image) 58 | assert not torch.isnan(out).any(), "Output included NaNs" 59 | assert out.size() == (2, 6, 12, 64, 64) 60 | 61 | 62 | def test_pixels_video_downsample_encoder(): 63 | encoder = ImageEncoder( 64 | prep_type="pixels", output_channels=48, temporal_downsample=2 65 | ) 66 | image = torch.randn(2, 6, 12, 256, 256) 67 | with torch.no_grad(): 68 | out = encoder(image) 69 | assert not torch.isnan(out).any(), "Output included NaNs" 70 | assert out.size() == (2, 3, 12, 64, 64) 71 | 72 | 73 | def test_metnet_video_encoder(): 74 | encoder = ImageEncoder(prep_type="metnet", crop_size=128) 75 | image = torch.randn(2, 6, 12, 512, 512) 76 | with torch.no_grad(): 77 | out = encoder(image) 78 | assert not torch.isnan(out).any(), "Output included NaNs" 79 | assert out.size() == (2, 6, 96, 128, 128) 80 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from perceiver_pytorch.multi_perceiver_pytorch import MultiPerceiver 4 | from perceiver_pytorch.modalities import InputModality 5 | from perceiver_pytorch.decoders import ImageDecoder 6 | 7 | 8 | def test_multiperceiver_creation(): 9 | # Timeseries input 10 | input_size = 64 11 | max_frequency = 16.0 12 | video_modality = InputModality( 13 | name="timeseries", 14 | input_channels=12, 15 | input_axis=3, # number of axes, 3 for video 16 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1) 17 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is, should be Nyquist frequency (i.e. 112 for 224 input image) 18 | sin_only=False, # Whether if sine only for Fourier encoding, TODO test more 19 | fourier_encode=True, # Whether to encode position with Fourier features 20 | ) 21 | # Use image modality for latlon, elevation, other base data? 22 | image_modality = InputModality( 23 | name="base", 24 | input_channels=4, 25 | input_axis=2, # number of axes, 2 for images 26 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1) 27 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is 28 | sin_only=False, 29 | fourier_encode=True, 30 | ) 31 | # Sort audio for timestep one-hot encode? Or include under other modality? 32 | timestep_modality = InputModality( 33 | name="forecast_time", 34 | input_channels=1, # number of channels for mono audio 35 | input_axis=1, # number of axes, 2 for images 36 | num_freq_bands=24, # number of freq bands, with original value (2 * K + 1) 37 | max_freq=16.0, # maximum frequency, hyperparameter depending on how fine the data is 38 | sin_only=False, 39 | fourier_encode=True, 40 | ) 41 | model = MultiPerceiver( 42 | modalities=[video_modality, image_modality, timestep_modality], 43 | queries_dim=input_size, 44 | depth=6, 45 | forecast_steps=12, 46 | output_shape=input_size, 47 | ) 48 | x = { 49 | "timeseries": torch.randn((2, 6, input_size, input_size, 12)), 50 | "base": torch.randn((2, input_size, input_size, 4)), 51 | "forecast_time": torch.randn(2, 24, 1), 52 | } 53 | query = torch.randn((2, input_size * 12, input_size)) 54 | model.eval() 55 | with torch.no_grad(): 56 | out = model(x, queries=query) 57 | out = rearrange( 58 | out, "b h (w c) -> b c h w", c=12 59 | ) 60 | # MetNet creates predictions for the center 1/4th 61 | assert out.size() == ( 62 | 2, 63 | 12, 64 | 12 * input_size, 65 | input_size, 66 | ) 67 | assert not torch.isnan(out).any(), "Output included NaNs" 68 | 69 | 70 | def test_multiperceiver_decoder(): 71 | # Timeseries input 72 | input_size = 64 73 | max_frequency = 16.0 74 | video_modality = InputModality( 75 | name="timeseries", 76 | input_channels=12, 77 | input_axis=3, # number of axes, 3 for video 78 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1) 79 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is, should be Nyquist frequency (i.e. 112 for 224 input image) 80 | sin_only=False, # Whether if sine only for Fourier encoding, TODO test more 81 | fourier_encode=True, # Whether to encode position with Fourier features 82 | ) 83 | # Use image modality for latlon, elevation, other base data? 84 | image_modality = InputModality( 85 | name="base", 86 | input_channels=4, 87 | input_axis=2, # number of axes, 2 for images 88 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1) 89 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is 90 | sin_only=False, 91 | fourier_encode=True, 92 | ) 93 | # Sort audio for timestep one-hot encode? Or include under other modality? 94 | timestep_modality = InputModality( 95 | name="forecast_time", 96 | input_channels=1, # number of channels for mono audio 97 | input_axis=1, # number of axes, 2 for images 98 | num_freq_bands=24, # number of freq bands, with original value (2 * K + 1) 99 | max_freq=16.0, # maximum frequency, hyperparameter depending on how fine the data is 100 | sin_only=False, 101 | fourier_encode=True, 102 | ) 103 | model = MultiPerceiver( 104 | modalities=[video_modality, image_modality, timestep_modality], 105 | queries_dim=input_size, 106 | depth=6, 107 | forecast_steps=12, 108 | output_shape=(24,input_size,input_size), 109 | ) 110 | 111 | x = { 112 | "timeseries": torch.randn((2, 6, input_size, input_size, 12)), 113 | "base": torch.randn((2, input_size, input_size, 4)), 114 | "forecast_time": torch.randn(2, 24, 1), 115 | } 116 | query = torch.randn((2, input_size * 12, input_size)) 117 | model.eval() 118 | decoder = ImageDecoder(postprocess_type='conv1x1', input_channels=768, output_channels=12, spatial_upsample=1, temporal_upsample=1) 119 | decoder.eval() 120 | with torch.no_grad(): 121 | out = model(x, queries=query) 122 | out = rearrange( 123 | out, "b c (t w h) -> b t c h w", t=24, h=input_size, w=input_size 124 | ) 125 | out = decoder(out) 126 | # MetNet creates predictions for the center 1/4th 127 | assert out.size() == ( 128 | 2, 129 | 24, 130 | 12, 131 | input_size, 132 | input_size, 133 | ) 134 | assert not torch.isnan(out).any(), "Output included NaNs" 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /tests/test_perceiver_pytorch.py: -------------------------------------------------------------------------------- 1 | from perceiver_pytorch.perceiver_pytorch import Perceiver 2 | import torch 3 | 4 | 5 | def test_init_model(): 6 | 7 | _ = Perceiver( 8 | input_channels=16, 9 | input_axis=2, 10 | num_freq_bands=6, 11 | max_freq=10, 12 | depth=13, 13 | num_latents=16, 14 | latent_dim=17, 15 | num_classes=7, 16 | weight_tie_layers=True, 17 | fourier_encode_data=False, 18 | ) 19 | 20 | 21 | def test_model_forward(): 22 | 23 | model = Perceiver( 24 | input_channels=16, 25 | input_axis=2, 26 | num_freq_bands=6, 27 | max_freq=10, 28 | depth=13, 29 | num_latents=16, 30 | latent_dim=17, 31 | num_classes=7, 32 | weight_tie_layers=True, 33 | fourier_encode_data=False, 34 | ) 35 | 36 | x = torch.randn(8 * 13, 32, 32, 16) 37 | y = model(x) 38 | 39 | 40 | def test_model_forward_fourier(): 41 | 42 | model = Perceiver( 43 | input_channels=16, 44 | input_axis=2, 45 | num_freq_bands=6, 46 | max_freq=10, 47 | depth=13, 48 | num_latents=16, 49 | latent_dim=17, 50 | num_classes=7, 51 | weight_tie_layers=True, 52 | fourier_encode_data=True, 53 | ) 54 | 55 | x = torch.randn(8 * 13, 32, 32, 16) 56 | y = model(x) 57 | -------------------------------------------------------------------------------- /tests/test_queries.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from perceiver_pytorch.queries import LearnableQuery 4 | from perceiver_pytorch.perceiver_io import PerceiverIO 5 | from perceiver_pytorch.utils import encode_position 6 | import einops 7 | 8 | 9 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 10 | def test_learnable_query(layer_shape): 11 | query_creator = LearnableQuery( 12 | channel_dim=32, 13 | query_shape=(6, 16, 16), 14 | conv_layer=layer_shape, 15 | max_frequency=64.0, 16 | num_frequency_bands=128, 17 | sine_only=False, 18 | generate_fourier_features=True, 19 | ) 20 | x = torch.randn((4, 6, 12, 16, 16)) 21 | out = query_creator(x) 22 | # Output is flattened, so should be [B, T*H*W, C] 23 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) 24 | # 32 + 3*(257) = 771 + 32 = 803 25 | assert out.shape == (4, 16 * 16 * 6, 803) 26 | 27 | 28 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 29 | def test_learnable_query_no_fourier(layer_shape): 30 | query_creator = LearnableQuery( 31 | channel_dim=32, 32 | query_shape=(6, 16, 16), 33 | conv_layer=layer_shape, 34 | max_frequency=64.0, 35 | num_frequency_bands=128, 36 | sine_only=False, 37 | generate_fourier_features=False, 38 | ) 39 | x = torch.randn((4, 6, 12, 16, 16)) 40 | out = query_creator(x) 41 | assert out.shape == (4, 16 * 16 * 6, 32) 42 | 43 | 44 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 45 | def test_learnable_query_qpplication(layer_shape): 46 | output_shape = (6, 16, 16) 47 | query_creator = LearnableQuery( 48 | channel_dim=32, 49 | query_shape=output_shape, 50 | conv_layer=layer_shape, 51 | max_frequency=64.0, 52 | num_frequency_bands=32, 53 | sine_only=False, 54 | generate_fourier_features=True, 55 | ) 56 | with torch.no_grad(): 57 | query_creator.eval() 58 | x = torch.randn((2, 6, 12, 16, 16)) 59 | out = query_creator(x) 60 | 61 | model = PerceiverIO(depth=2, dim=100, queries_dim=query_creator.output_shape()[-1]) 62 | model.eval() 63 | model_input = torch.randn((2, 256, 100)) 64 | model_out = model(model_input, queries=out) 65 | # Reshape back to correct shape 66 | model_out = einops.rearrange( 67 | model_out, 68 | "b (t h w) c -> b t c h w", 69 | t=output_shape[0], 70 | h=output_shape[1], 71 | w=output_shape[2], 72 | ) 73 | assert model_out.shape == (2, 6, 227, 16, 16) 74 | 75 | 76 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 77 | def test_learnable_query_precomputed_fourier_only(layer_shape): 78 | precomputed_features = encode_position( 79 | 1, # Batch size, 1 for this as it will be adapted in forward 80 | axis=(10, 16, 16), # 4 history + 6 future steps 81 | max_frequency=16.0, 82 | num_frequency_bands=128, 83 | sine_only=False, 84 | ) 85 | # Only take future ones 86 | precomputed_features = precomputed_features[:, 4:] 87 | query_creator = LearnableQuery( 88 | channel_dim=32, 89 | query_shape=(6, 16, 16), 90 | conv_layer=layer_shape, 91 | max_frequency=64.0, 92 | num_frequency_bands=16, 93 | sine_only=False, 94 | precomputed_fourier=precomputed_features, 95 | generate_fourier_features=False, 96 | ) 97 | x = torch.randn((4, 6, 12, 16, 16)) 98 | out = query_creator(x) 99 | # Output is flattened, so should be [B, T*H*W, C] 100 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) 101 | # 32 + 3*(257) = 771 + 32 = 803 102 | assert out.shape == (4, 16 * 16 * 6, 803) 103 | 104 | 105 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 106 | def test_learnable_query_precomputed_and_generated_fourer(layer_shape): 107 | precomputed_features = encode_position( 108 | 1, # Batch size, 1 for this as it will be adapted in forward 109 | axis=(10, 16, 16), # 4 history + 6 future steps 110 | max_frequency=16.0, 111 | num_frequency_bands=128, 112 | sine_only=False, 113 | ) 114 | # Only take future ones 115 | precomputed_features = precomputed_features[:, 4:] 116 | query_creator = LearnableQuery( 117 | channel_dim=32, 118 | query_shape=(6, 16, 16), 119 | conv_layer=layer_shape, 120 | max_frequency=64.0, 121 | num_frequency_bands=128, 122 | sine_only=False, 123 | precomputed_fourier=precomputed_features, 124 | generate_fourier_features=True, 125 | ) 126 | x = torch.randn((4, 6, 12, 16, 16)) 127 | out = query_creator(x) 128 | # Output is flattened, so should be [B, T*H*W, C] 129 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) 130 | # 32 + 3*(257) = 771 + 32 = 803 131 | # Then add 771 from the precomputed features, to get 803 + 771 132 | assert out.shape == (4, 16 * 16 * 6, 803 + 771) 133 | 134 | 135 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 136 | def test_learnable_query_pass_in_fourier(layer_shape): 137 | precomputed_features = encode_position( 138 | 4, 139 | axis=(10, 16, 16), # 4 history + 6 future steps 140 | max_frequency=16.0, 141 | num_frequency_bands=64, 142 | sine_only=False, 143 | ) 144 | # Only take future ones 145 | precomputed_features = precomputed_features[:, 4:] 146 | query_creator = LearnableQuery( 147 | channel_dim=32, 148 | query_shape=(6, 16, 16), 149 | conv_layer=layer_shape, 150 | max_frequency=64.0, 151 | num_frequency_bands=128, 152 | sine_only=False, 153 | generate_fourier_features=False, 154 | ) 155 | x = torch.randn((4, 6, 12, 16, 16)) 156 | out = query_creator(x, precomputed_features) 157 | # Output is flattened, so should be [B, T*H*W, C] 158 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) 159 | # 3*(129) = 389 + 32 = 419 160 | # Since this is less than what is passed to LearnableQuery, we know its using the passed in features 161 | assert out.shape == (4, 16 * 16 * 6, 419) 162 | 163 | 164 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"]) 165 | def test_learnable_query_all_fouriers(layer_shape): 166 | batch_ff = encode_position( 167 | 4, 168 | axis=(10, 16, 16), # 4 history + 6 future steps 169 | max_frequency=16.0, 170 | num_frequency_bands=32, 171 | sine_only=False, 172 | ) 173 | # Only take future ones 174 | batch_ff = batch_ff[:, 4:] 175 | precomputed_features = encode_position( 176 | 1, 177 | axis=(10, 16, 16), # 4 history + 6 future steps 178 | max_frequency=16.0, 179 | num_frequency_bands=64, 180 | sine_only=False, 181 | ) 182 | # Only take future ones 183 | precomputed_features = precomputed_features[:, 4:] 184 | query_creator = LearnableQuery( 185 | channel_dim=32, 186 | query_shape=(6, 16, 16), 187 | conv_layer=layer_shape, 188 | max_frequency=64.0, 189 | num_frequency_bands=128, 190 | sine_only=False, 191 | precomputed_fourier=precomputed_features, 192 | generate_fourier_features=True, 193 | ) 194 | x = torch.randn((4, 6, 12, 16, 16)) 195 | out = query_creator(x, batch_ff) 196 | # Output is flattened, so should be [B, T*H*W, C] 197 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1) 198 | # 3*(129) = 389 + 32 = 419 + 771 from the generated ones + 195 from the batch features 199 | # Since this is less than what is passed to LearnableQuery, we know its using the passed in features 200 | assert out.shape == (4, 16 * 16 * 6, 1385) 201 | -------------------------------------------------------------------------------- /tests/test_rotary.py: -------------------------------------------------------------------------------- 1 | from perceiver_pytorch.rotary import ( 2 | rotate_every_two, 3 | apply_rotary_emb, 4 | SinusoidalEmbeddings, 5 | ) 6 | import torch 7 | 8 | 9 | def test_rotate_every_two(): 10 | """ 11 | Test for rotate every two 12 | :return: 13 | """ 14 | 15 | x = torch.randn(5, 4, 4) 16 | y = rotate_every_two(x) 17 | 18 | assert y.shape == torch.Size([5, 4, 4]) 19 | assert y[0, 0, 0] == -x[0, 0, 1] 20 | assert y[0, 0, 1] == x[0, 0, 0] 21 | 22 | 23 | def test_apply_rotary_emb(): 24 | """ 25 | Check that 'apply_rotary_emb' works correctly 26 | :return: 27 | """ 28 | 29 | sinu_pos = torch.randn(1, 4, 10) 30 | q = torch.randn(5, 4, 10) 31 | k = torch.randn(5, 4, 10) 32 | 33 | q, k = apply_rotary_emb(q, k, sinu_pos=sinu_pos) 34 | 35 | 36 | def test_torch_sinusoidal_mbeddings(): 37 | model = SinusoidalEmbeddings(dim=128) 38 | 39 | y = model(torch.randn(4, 10)) 40 | assert y.shape[-1] == 128 41 | assert y.shape[-2] == 4 42 | --------------------------------------------------------------------------------