├── .bumpversion.cfg
├── .github
└── workflows
│ ├── linters.yaml
│ ├── release.yaml
│ └── workflows.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── perceiver.png
├── perceiver_pytorch
├── __init__.py
├── convolutions.py
├── decoders.py
├── encoders.py
├── gated.py
├── layers.py
├── mixed_latents.py
├── modalities.py
├── multi_perceiver_pytorch.py
├── perceiver_io.py
├── perceiver_pytorch.py
├── queries.py
├── rotary.py
└── utils.py
├── requirements.txt
├── setup.py
└── tests
├── test_decoders.py
├── test_encoders.py
├── test_model.py
├── test_perceiver_pytorch.py
├── test_queries.py
└── test_rotary.py
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | commit = True
3 | tag = True
4 | current_version = 0.7.6
5 |
6 | [bumpversion:file:setup.py]
7 | search = version="{current_version}"
8 | replace = version="{new_version}"
9 |
--------------------------------------------------------------------------------
/.github/workflows/linters.yaml:
--------------------------------------------------------------------------------
1 | name: Lint Python
2 |
3 | on: [push]
4 |
5 | jobs:
6 | call-run-python-linters:
7 | uses: openclimatefix/.github/.github/workflows/python-lint.yml@e67a64b086a5662c39f6b4523a97dd0641904279
8 | with:
9 | folder: "perceiver_pytorch"
10 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Bump version and auto-release
2 | on:
3 | push:
4 | branches:
5 | - main
6 | jobs:
7 | call-run-python-release:
8 | uses: openclimatefix/.github/.github/workflows/python-release.yml@e67a64b086a5662c39f6b4523a97dd0641904279
9 | secrets:
10 | token: ${{ secrets.PYPI_API_TOKEN }}
11 |
--------------------------------------------------------------------------------
/.github/workflows/workflows.yaml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push]
4 | jobs:
5 | call-run-python-tests:
6 | uses: openclimatefix/.github/.github/workflows/python-test.yml@e67a64b086a5662c39f6b4523a97dd0641904279
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | /.idea/
131 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3.9
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v3.4.0
7 | hooks:
8 | # list of supported hooks: https://pre-commit.com/hooks.html
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-yaml
12 | - id: debug-statements
13 | - id: detect-private-key
14 |
15 | # python code formatting/linting
16 | - repo: https://github.com/PyCQA/pydocstyle
17 | rev: 6.1.1
18 | hooks:
19 | - id: pydocstyle
20 | args:
21 | [
22 | --convention=google,
23 | "--add-ignore=D200,D202,D210,D212,D415",
24 | "perceiver_pytorch",
25 | ]
26 | - repo: https://github.com/PyCQA/flake8
27 | rev: 4.0.1
28 | hooks:
29 | - id: flake8
30 | args:
31 | [
32 | --max-line-length,
33 | "100",
34 | --extend-ignore=E203,
35 | --per-file-ignores,
36 | "__init__.py:F401",
37 | "perceiver_pytorch",
38 | ]
39 | - repo: https://github.com/PyCQA/isort
40 | rev: 5.9.3
41 | hooks:
42 | - id: isort
43 | args: [--profile, black, --line-length, "100", "perceiver_pytorch"]
44 | - repo: https://github.com/psf/black
45 | rev: 20.8b1
46 | hooks:
47 | - id: black
48 | args: [--line-length, "100"]
49 |
50 | # yaml formatting
51 | - repo: https://github.com/pre-commit/mirrors-prettier
52 | rev: v2.3.0
53 | hooks:
54 | - id: prettier
55 | types: [yaml]
56 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Perceiver - Pytorch
4 |
5 | Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch.
6 | Extended from Phil Wang's perceiver-pytorch
7 |
8 | Yannic Kilcher explanation!
9 |
10 | ## Install
11 |
12 | ```bash
13 | $ pip install perceiver-model
14 | ```
15 |
16 | ## Usage
17 |
18 | ```python
19 | import torch
20 | from perceiver_pytorch import Perceiver
21 |
22 | model = Perceiver(
23 | input_channels = 3, # number of channels for each token of the input
24 | input_axis = 2, # number of axis for input data (2 for images, 3 for video)
25 | num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1)
26 | max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is
27 | depth = 6, # depth of net. The shape of the final attention mechanism will be:
28 | # depth * (cross attention -> self_per_cross_attn * self attention)
29 | num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
30 | latent_dim = 512, # latent dimension
31 | cross_heads = 1, # number of heads for cross attention. paper said 1
32 | latent_heads = 8, # number of heads for latent self attention, 8
33 | cross_dim_head = 64, # number of dimensions per cross attention head
34 | latent_dim_head = 64, # number of dimensions per latent self attention head
35 | num_classes = 1000, # output number of classes
36 | attn_dropout = 0.,
37 | ff_dropout = 0.,
38 | weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
39 | fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
40 | self_per_cross_attn = 2 # number of self attention blocks per cross attention
41 | )
42 |
43 | img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized
44 |
45 | model(img) # (1, 1000)
46 | ```
47 |
48 | For the backbone of Perceiver IO, the follow up paper that allows for flexible number of output sequence length, just import `PerceiverIO` instead
49 |
50 | ```python
51 | import torch
52 | from perceiver_pytorch import PerceiverIO
53 |
54 | model = PerceiverIO(
55 | dim = 32, # dimension of sequence to be encoded
56 | queries_dim = 32, # dimension of decoder queries
57 | logits_dim = 100, # dimension of final logits
58 | depth = 6, # depth of net
59 | num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names
60 | latent_dim = 512, # latent dimension
61 | cross_heads = 1, # number of heads for cross attention. paper said 1
62 | latent_heads = 8, # number of heads for latent self attention, 8
63 | cross_dim_head = 64, # number of dimensions per cross attention head
64 | latent_dim_head = 64, # number of dimensions per latent self attention head
65 | weight_tie_layers = False # whether to weight tie layers (optional, as indicated in the diagram)
66 | )
67 |
68 | seq = torch.randn(1, 512, 32)
69 | queries = torch.randn(1, 128, 32)
70 |
71 | logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)
72 | ```
73 |
74 | ## Citations
75 |
76 | ```bibtex
77 | @misc{jaegle2021perceiver,
78 | title = {Perceiver: General Perception with Iterative Attention},
79 | author = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
80 | year = {2021},
81 | eprint = {2103.03206},
82 | archivePrefix = {arXiv},
83 | primaryClass = {cs.CV}
84 | }
85 | ```
86 |
87 | ```bibtex
88 | @misc{jaegle2021perceiver,
89 | title = {Perceiver IO: A General Architecture for Structured Inputs & Outputs},
90 | author = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
91 | year = {2021},
92 | eprint = {2107.14795},
93 | archivePrefix = {arXiv},
94 | primaryClass = {cs.LG}
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/perceiver.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/perceiver-pytorch/fbab157fffe7d68d123a1ccf45f4131372873183/perceiver.png
--------------------------------------------------------------------------------
/perceiver_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from perceiver_pytorch.perceiver_pytorch import Perceiver
2 | from perceiver_pytorch.perceiver_io import PerceiverIO
3 | from perceiver_pytorch.multi_perceiver_pytorch import MultiPerceiver
4 |
--------------------------------------------------------------------------------
/perceiver_pytorch/convolutions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 |
5 | class Conv2DDownsample(torch.nn.Sequential):
6 | def __init__(
7 | self,
8 | num_layers: int = 1,
9 | input_channels: int = 12,
10 | output_channels: int = 64,
11 | use_batchnorm: bool = True,
12 | ):
13 | """
14 | Constructs a Conv2DDownsample model
15 | Args:
16 | num_layers: Number of conv -> maxpool layers
17 | output_channels: Number of output channels
18 | input_channels: Number of input channels to first layer
19 | use_batchnorm: Whether to use Batch Norm
20 | """
21 |
22 | layers = [self.make_layer(input_channels, output_channels, batch=use_batchnorm)]
23 | for _ in range(num_layers - 1):
24 | layers += [
25 | self.make_layer(output_channels, output_channels, batch=use_batchnorm)
26 | ]
27 |
28 | super().__init__(*layers)
29 |
30 | def make_layer(self, c_in, c_out, ks=7, stride=2, padding=7 // 2, batch=True):
31 | "Make Conv->Batch->Relu->MaxPool stack"
32 | layers = [torch.nn.Conv2d(c_in, c_out, ks, stride, padding, bias=False)]
33 | if batch:
34 | layers += [torch.nn.BatchNorm2d(c_out)]
35 | layers += [torch.nn.ReLU(), torch.nn.MaxPool2d(3, stride=2, padding=3 // 2)]
36 | return torch.nn.Sequential(*layers)
37 |
38 |
39 | class Conv2DUpsample(torch.nn.Module):
40 | def __init__(self, input_channels: int = 12, output_channels: int = 12):
41 | """
42 | Upsamples 4x using 2 2D transposed convolutions
43 | Args:
44 | input_channels: Input channels to the first layer
45 | output_channels: Number of output channels
46 | """
47 |
48 | super().__init__()
49 | self.transpose_conv1 = torch.nn.ConvTranspose2d(
50 | in_channels=input_channels,
51 | out_channels=output_channels * 2,
52 | kernel_size=(4, 4),
53 | stride=(2, 2),
54 | padding=(1, 1),
55 | )
56 | self.transpose_conv2 = torch.nn.ConvTranspose2d(
57 | in_channels=output_channels * 2,
58 | out_channels=output_channels,
59 | kernel_size=(4, 4),
60 | stride=(2, 2),
61 | padding=(1, 1),
62 | )
63 |
64 | def forward(self, x):
65 | x = self.transpose_conv1(x)
66 | x = F.relu(x)
67 | x = self.transpose_conv2(x)
68 | return x
69 |
70 |
71 | class Conv3DUpsample(torch.nn.Module):
72 | def __init__(
73 | self,
74 | input_channels: int = 12,
75 | output_channels: int = 12,
76 | num_temporal_upsamples: int = 2,
77 | num_space_upsamples: int = 4,
78 | ):
79 | """
80 | Simple convolutional auto-encoder
81 | Args:
82 | output_channels: Final output channels
83 | num_temporal_upsamples: Number of temporal upsamples to perform
84 | num_space_upsamples: Number of spatial upsamples to perform
85 | """
86 |
87 | super().__init__()
88 | temporal_stride = 2
89 | space_stride = 2
90 | num_upsamples = max(num_space_upsamples, num_temporal_upsamples)
91 |
92 | # create the input and output changesl for the different layers
93 | # The intermediate channels are (output channels) * 2^(number of upsamples - 1 - index of the current upsample)
94 | # The decoder sets the number of upsamples as log2(upsample_value), and this changes the number of channels
95 | # in a similar way, so it all scales together.
96 | intermediate_output_channels = [
97 | output_channels * pow(2, num_upsamples - 1 - i)
98 | for i in range(0, num_upsamples)
99 | ]
100 | intermediate_input_channels = [input_channels] + intermediate_output_channels
101 |
102 | self.layers = torch.nn.ModuleList()
103 | for i in range(num_upsamples):
104 | if i >= num_temporal_upsamples:
105 | temporal_stride = 1
106 | if i >= num_space_upsamples:
107 | space_stride = 1
108 |
109 | input_channels = input_channels if i == 0 else output_channels
110 | stride = (temporal_stride, space_stride, space_stride)
111 | conv = torch.nn.ConvTranspose3d(
112 | in_channels=intermediate_input_channels[i],
113 | out_channels=intermediate_output_channels[i],
114 | stride=stride,
115 | kernel_size=stride,
116 | )
117 | # see output dims calculations - https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
118 | # if kernel=stride, and dilation -1
119 | # D = (D-1) * stride + (kernel - 1) + 2 = D * stride
120 |
121 | self.layers.append(conv)
122 | if i != num_upsamples - 1:
123 | self.layers.append(torch.nn.ReLU())
124 |
125 | def forward(self, x):
126 | # Move around to Channels first
127 | x = x.permute(0, 2, 1, 3, 4)
128 |
129 | # loop over layers
130 | for layer in self.layers:
131 | x = layer(x)
132 |
133 | x = F.relu(x)
134 | # Move back
135 | x = x.permute(0, 2, 1, 3, 4)
136 | return x
137 |
--------------------------------------------------------------------------------
/perceiver_pytorch/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from perceiver_pytorch.utils import reverse_space_to_depth
4 | from perceiver_pytorch.convolutions import Conv2DUpsample, Conv3DUpsample
5 |
6 |
7 | class ImageDecoder(torch.nn.Module):
8 | def __init__(
9 | self,
10 | postprocess_type: str = "pixels",
11 | spatial_upsample: int = 1,
12 | temporal_upsample: int = 1,
13 | output_channels: int = -1,
14 | input_channels: int = 12,
15 | input_reshape_size=None,
16 | ):
17 | """
18 | ImageDecoder modeled after JAX version here
19 | https://github.com/deepmind/deepmind-research/blob/769bfdbeafbcb472cb8e2c6cfa746b53ac82efc2/perceiver/io_processors.py#L441-L510
20 |
21 | Args:
22 | postprocess_type: Type of postprocessing, one of conv, patches, pixels, raft, or conv1x1
23 | spatial_upsample: How much to spatially upsample
24 | temporal_upsample: How much to temporally upsample
25 | output_channels: Number of output channels, should be the final desired number of channels
26 | Has to explicitly set for conv and conv1x1 options, otherwise an error will be raised.
27 | Ignored for patches and pixels options.
28 | input_channels: Number of input channels to decoder
29 | input_reshape_size: The size to reshape the input to
30 | """
31 |
32 | super().__init__()
33 |
34 | if postprocess_type not in ("conv", "patches", "pixels", "conv1x1"):
35 | # TODO Add Raft
36 | raise ValueError("Invalid postprocess_type!")
37 |
38 | # Architecture parameters:
39 | self.postprocess_type = postprocess_type
40 |
41 | self.temporal_upsample = temporal_upsample
42 | self.spatial_upsample = spatial_upsample
43 | self.input_reshape_size = input_reshape_size
44 |
45 | if postprocess_type == "pixels":
46 | # No postprocessing for pixels
47 | self.decoder = torch.nn.Identity()
48 | elif postprocess_type == "patches":
49 | self.decoder = ImageDecoderPatches(
50 | spatial_upsample=spatial_upsample, temporal_upsample=temporal_upsample
51 | )
52 | elif postprocess_type == "conv":
53 | self.decoder = ImageDecoderConv(
54 | spatial_upsample=spatial_upsample,
55 | temporal_upsample=temporal_upsample,
56 | output_channels=output_channels,
57 | input_channels=input_channels,
58 | )
59 | elif postprocess_type == "conv1x1":
60 | self.decoder = ImageDecoderConv1x1(
61 | spatial_upsample=spatial_upsample,
62 | output_channels=output_channels,
63 | input_channels=input_channels,
64 | )
65 |
66 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
67 | if self.input_reshape_size is not None:
68 | inputs = torch.reshape(
69 | inputs,
70 | [inputs.shape[0]] + list(self.input_reshape_size) + [inputs.shape[-1]],
71 | )
72 | return self.decoder(inputs)
73 |
74 |
75 | class ImageDecoderConv(torch.nn.Module):
76 | def __init__(
77 | self,
78 | spatial_upsample: int = 1,
79 | temporal_upsample: int = 1,
80 | output_channels: int = -1,
81 | input_channels: int = 12,
82 | ):
83 | """
84 | Convolutional image decoder that can upsample temporally and spatially
85 |
86 | Args:
87 | spatial_upsample: How much to spatially upsample
88 | temporal_upsample: How much to temporally upsample
89 | output_channels: Number of output channels, should be the final desired number of channels
90 | Has to explicitly set for conv and conv1x1 options, otherwise an error will be raised.
91 | Ignored for patches and pixels options.
92 | input_channels: Number of input channels to decoder
93 | """
94 |
95 | super().__init__()
96 |
97 | self.temporal_upsample = temporal_upsample
98 | self.spatial_upsample = spatial_upsample
99 |
100 | if output_channels == -1:
101 | raise ValueError("Expected value for output_channels")
102 | if self.temporal_upsample != 1:
103 |
104 | def int_log2(x):
105 | return int(np.round(np.log(x) / np.log(2)))
106 |
107 | self.convnet = Conv3DUpsample(
108 | input_channels=input_channels,
109 | output_channels=output_channels,
110 | num_temporal_upsamples=int_log2(temporal_upsample),
111 | num_space_upsamples=int_log2(spatial_upsample),
112 | )
113 | else:
114 | assert (
115 | self.spatial_upsample == 4
116 | ), "Conv2DUpsample only support 4x spatial upsample right now"
117 | self.convnet = Conv2DUpsample(
118 | input_channels=input_channels, output_channels=output_channels
119 | )
120 |
121 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
122 | # Convnet image featurization.
123 | if len(inputs.shape) == 5 and self.temporal_upsample == 1:
124 | # Timeseries, do it to each timestep independently
125 | outs = []
126 | for i in range(inputs.shape[1]):
127 | outs.append(self.convnet(inputs[:, i, :, :, :]))
128 | inputs = torch.stack(outs, dim=1)
129 | else:
130 | inputs = self.convnet(inputs)
131 |
132 | return inputs
133 |
134 |
135 | class ImageDecoderConv1x1(torch.nn.Module):
136 | def __init__(
137 | self,
138 | spatial_upsample: int = 1,
139 | output_channels: int = -1,
140 | input_channels: int = 12,
141 | ):
142 | """
143 | Convolutional 1x1 image decoder
144 |
145 | Args:
146 | spatial_upsample: How much to spatially upsample
147 | output_channels: Number of output channels, should be the final desired number of channels
148 | Has to explicitly set for conv and conv1x1 options, otherwise an error will be raised.
149 | Ignored for patches and pixels options.
150 | input_channels: Number of input channels to decoder
151 | """
152 |
153 | super().__init__()
154 |
155 | self.spatial_upsample = spatial_upsample
156 |
157 | if output_channels == -1:
158 | raise ValueError("Expected value for output_channels")
159 | self.conv1x1 = torch.nn.Conv2d(
160 | in_channels=input_channels,
161 | out_channels=output_channels,
162 | kernel_size=(1, 1),
163 | # spatial_downsample is unconstrained for 1x1 convolutions.
164 | stride=(self.spatial_upsample, self.spatial_upsample),
165 | )
166 |
167 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
168 | # Convnet image featurization.
169 | if len(inputs.shape) == 5:
170 | # Timeseries, do it to each timestep independently
171 | outs = []
172 | for i in range(inputs.shape[1]):
173 | outs.append(self.conv1x1(inputs[:, i, :, :, :]))
174 | inputs = torch.stack(outs, dim=1)
175 | else:
176 | inputs = self.conv1x1(inputs)
177 |
178 | return inputs
179 |
180 |
181 | class ImageDecoderPatches(torch.nn.Module):
182 | def __init__(
183 | self, spatial_upsample: int = 1, temporal_upsample: int = 1,
184 | ):
185 | """
186 | Patch-based image decoder
187 |
188 | Args:
189 | spatial_upsample: How much to spatially upsample
190 | temporal_upsample: How much to temporally upsample
191 | """
192 |
193 | super().__init__()
194 |
195 | self.temporal_upsample = temporal_upsample
196 | self.spatial_upsample = spatial_upsample
197 |
198 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
199 | inputs = reverse_space_to_depth(
200 | inputs, self.temporal_upsample, self.spatial_upsample
201 | )
202 | return inputs
203 |
--------------------------------------------------------------------------------
/perceiver_pytorch/encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torchvision
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import math
7 |
8 | from perceiver_pytorch.convolutions import Conv2DDownsample
9 | from perceiver_pytorch.utils import space_to_depth
10 |
11 |
12 | class ImageEncoder(torch.nn.Module):
13 | def __init__(
14 | self,
15 | input_channels: int = 12,
16 | prep_type: str = "conv",
17 | spatial_downsample: int = 4,
18 | temporal_downsample: int = 1,
19 | output_channels: int = 64,
20 | conv2d_use_batchnorm: bool = True,
21 | crop_size: int = 256,
22 | use_space2depth: bool = True,
23 | ):
24 | """
25 | Image encoder class, modeled off the JAX version
26 | https://github.com/deepmind/deepmind-research/blob/769bfdbeafbcb472cb8e2c6cfa746b53ac82efc2/perceiver/io_processors.py#L291-L438
27 |
28 | Args:
29 | input_channels: Number of input channels of the original image/video
30 | prep_type: How to encode the images, one of conv, patches, pixels, or conv1x1
31 | spatial_downsample: How much to downsample spatially
32 | temporal_downsample: How much to downsample temporally
33 | output_channels: Number of output channels to send to Perceiver
34 | conv2d_use_batchnorm: Whether to use batch norm
35 | crop_size: Only for MetNet preprocessor, the center crop size
36 | use_space2depth: Only for MetNet preprocessor, whether to use average pooling, or space2depth for downsampling
37 | """
38 | super().__init__()
39 | self.prep_type = prep_type
40 |
41 | if prep_type not in ("conv", "patches", "pixels", "conv1x1", "metnet"):
42 | raise ValueError("Invalid prep_type!")
43 |
44 | if self.prep_type == "conv":
45 | self.encoder = ImageEncoderConv(
46 | input_channels=input_channels,
47 | temporal_downsample=temporal_downsample,
48 | spatial_downsample=spatial_downsample,
49 | output_channels=output_channels,
50 | conv2d_use_batchnorm=conv2d_use_batchnorm,
51 | )
52 | elif self.prep_type == "conv1x1":
53 | self.encoder = ImageEncoderConv1x1(
54 | input_channels=input_channels,
55 | spatial_downsample=spatial_downsample,
56 | output_channels=output_channels,
57 | )
58 | elif self.prep_type == "patches":
59 | self.encoder = ImageEncoderPatches(
60 | temporal_downsample=temporal_downsample,
61 | spatial_downsample=spatial_downsample,
62 | )
63 | elif self.prep_type == "pixels":
64 | self.encoder = ImageEncoderPixel(
65 | temporal_downsample=temporal_downsample,
66 | spatial_downsample=spatial_downsample,
67 | )
68 | elif self.prep_type == "metnet":
69 | self.encoder = ImageEncoderMetNet(
70 | crop_size=crop_size, use_space2depth=use_space2depth
71 | )
72 |
73 | def forward(self, x: torch.Tensor) -> torch.Tensor:
74 | return self.encoder(x)
75 |
76 |
77 | class ImageEncoderConv(torch.nn.Module):
78 | def __init__(
79 | self,
80 | input_channels: int = 12,
81 | spatial_downsample: int = 4,
82 | temporal_downsample: int = 1,
83 | output_channels: int = 64,
84 | conv2d_use_batchnorm: bool = True,
85 | ):
86 | """
87 | Convolutional image encoder that can spatially and temporally downsample
88 |
89 | Args:
90 | input_channels: Number of input channels of the original image/video
91 | spatial_downsample: How much to downsample spatially
92 | temporal_downsample: How much to downsample temporally
93 | output_channels: Number of output channels to send to Perceiver
94 | conv2d_use_batchnorm: Whether to use batch norm
95 | """
96 | super().__init__()
97 | self.temporal_downsample = temporal_downsample
98 | self.spatial_downsample = spatial_downsample
99 | self.output_channels = output_channels
100 |
101 | # Downsampling with conv is currently restricted
102 | convnet_num_layers = math.log(spatial_downsample, 4)
103 | convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers)
104 | if not convnet_num_layers_is_int or temporal_downsample != 1:
105 | raise ValueError(
106 | "Only powers of 4 expected for spatial "
107 | "and 1 expected for temporal "
108 | "downsampling with conv."
109 | )
110 |
111 | self.convnet = Conv2DDownsample(
112 | num_layers=int(convnet_num_layers),
113 | output_channels=output_channels,
114 | input_channels=input_channels,
115 | use_batchnorm=conv2d_use_batchnorm,
116 | )
117 |
118 | def forward(self, x: torch.Tensor) -> torch.Tensor:
119 | if len(x.shape) == 5:
120 | # Timeseries, do it to each timestep independently
121 | outs = []
122 | for i in range(x.shape[1]):
123 | outs.append(self.convnet(x[:, i, :, :, :]))
124 | x = torch.stack(outs, dim=1)
125 | else:
126 | x = self.convnet(x)
127 | return x
128 |
129 |
130 | class ImageEncoderConv1x1(torch.nn.Module):
131 | def __init__(
132 | self,
133 | input_channels: int = 12,
134 | spatial_downsample: int = 4,
135 | output_channels: int = 64,
136 | ):
137 | """
138 | Convolutional 1x1 encoder that can spatially downsample
139 |
140 | Args:
141 | input_channels: Number of input channels of the original image/video
142 | spatial_downsample: How much to downsample spatially
143 | output_channels: Number of output channels to send to Perceiver
144 | """
145 | super().__init__()
146 | self.spatial_downsample = spatial_downsample
147 | self.output_channels = output_channels
148 |
149 | self.convnet_1x1 = torch.nn.Conv2d(
150 | in_channels=input_channels,
151 | out_channels=output_channels,
152 | kernel_size=(1, 1),
153 | # spatial_downsample is unconstrained for 1x1 convolutions.
154 | stride=(spatial_downsample, spatial_downsample),
155 | )
156 |
157 | def forward(self, x: torch.Tensor) -> torch.Tensor:
158 | if len(x.shape) == 5:
159 | # Timeseries, do it to each timestep independently
160 | outs = []
161 | for i in range(x.shape[1]):
162 | outs.append(self.convnet_1x1(x[:, i, :, :, :]))
163 | x = torch.stack(outs, dim=1)
164 | else:
165 | x = self.convnet_1x1(x)
166 |
167 | return x
168 |
169 |
170 | class ImageEncoderPatches(torch.nn.Module):
171 | def __init__(
172 | self, spatial_downsample: int = 4, temporal_downsample: int = 1,
173 | ):
174 | """
175 | Image encoder that uses patches
176 |
177 | Args:
178 | spatial_downsample: How much to downsample spatially
179 | temporal_downsample: How much to downsample temporally
180 | """
181 | super().__init__()
182 | self.temporal_downsample = temporal_downsample
183 | self.spatial_downsample = spatial_downsample
184 |
185 | def forward(self, x: torch.Tensor) -> torch.Tensor:
186 | x = space_to_depth(
187 | x,
188 | temporal_block_size=self.temporal_downsample,
189 | spatial_block_size=self.spatial_downsample,
190 | )
191 |
192 | # For flow
193 | if x.ndim == 5 and x.shape[1] == 1:
194 | x = x.squeeze(axis=1)
195 |
196 | return x
197 |
198 |
199 | class ImageEncoderPixel(torch.nn.Module):
200 | def __init__(
201 | self, spatial_downsample: int = 4, temporal_downsample: int = 1,
202 | ):
203 | """
204 | Image encoder class for simple downsampling with pixels
205 |
206 | Args:
207 | spatial_downsample: How much to downsample spatially
208 | temporal_downsample: How much to downsample temporally
209 | """
210 | super().__init__()
211 | self.temporal_downsample = temporal_downsample
212 | self.spatial_downsample = spatial_downsample
213 |
214 | def forward(self, x: torch.Tensor) -> torch.Tensor:
215 | # If requested, will downsample in simplest way
216 | if x.ndim == 4:
217 | x = x[:, :, :: self.spatial_downsample, :: self.spatial_downsample]
218 | elif x.ndim == 5:
219 | x = x[
220 | :,
221 | :: self.temporal_downsample,
222 | :,
223 | :: self.spatial_downsample,
224 | :: self.spatial_downsample,
225 | ]
226 | else:
227 | raise ValueError("Unsupported data format for pixels")
228 |
229 | return x
230 |
231 |
232 | class ImageEncoderMetNet(nn.Module):
233 | def __init__(
234 | self, crop_size: int = 256, use_space2depth: bool = True,
235 | ):
236 | """
237 | Performs the MetNet preprocessing of mean pooling Sat channels, followed by
238 | concatenating the center crop and mean pool
239 |
240 | In the paper, the radar data is space2depth'd, while satellite channel is mean pooled, but for this different
241 | task, we choose to do either option for satellites
242 |
243 | Args:
244 | sat_channels: Number of satellite channels
245 | crop_size: Center crop size
246 | use_space2depth: Whether to use space2depth on satellite channels, or mean pooling, like in paper
247 | """
248 | super().__init__()
249 | # Split off sat + mask channels into own image, and the rest, which we just take a center crop
250 | # For this,
251 | self.sat_downsample = (
252 | torch.nn.PixelUnshuffle(downscale_factor=2)
253 | if use_space2depth
254 | else torch.nn.AvgPool3d(kernel_size=(1, 2, 2))
255 | )
256 | self.center_crop = torchvision.transforms.CenterCrop(size=crop_size)
257 |
258 | def forward(self, x: torch.Tensor) -> torch.Tensor:
259 | x = self.sat_downsample(x)
260 | # In paper, satellite and radar data is concatenated here
261 | # We are just going to skip that bit
262 | sat_center = self.center_crop(x)
263 | sat_mean = F.avg_pool3d(x, (1, 2, 2))
264 | # All the same size now, so concatenate together, already have time, lat/long, and elevation image
265 | x = torch.cat([sat_center, sat_mean], dim=2)
266 | return x
267 |
--------------------------------------------------------------------------------
/perceiver_pytorch/gated.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, repeat
6 |
7 | from perceiver_pytorch.layers import exists, default, cache_fn, PreNorm, FeedForward, Attention
8 | from perceiver_pytorch.utils import fourier_encode
9 |
10 |
11 | # helpers
12 |
13 |
14 | class Residual(nn.Module):
15 | def __init__(self, fn):
16 | super().__init__()
17 | self.fn = fn
18 |
19 | def forward(self, x, **kwargs):
20 | return x + self.fn(x, **kwargs)
21 |
22 |
23 | class GRUGating(nn.Module):
24 | def __init__(self, dim, fn):
25 | super().__init__()
26 | self.dim = dim
27 | self.fn = fn
28 | self.gru = nn.GRUCell(dim, dim)
29 |
30 | def forward(self, x, **kwargs):
31 | b, dim = x.shape[0], self.dim
32 | y = self.fn(x, **kwargs)
33 |
34 | gated_output = self.gru(
35 | rearrange(y, "... d -> (...) d"), rearrange(x, "... d -> (...) d")
36 | )
37 |
38 | gated_output = rearrange(gated_output, "(b n) d -> b n d", b=b)
39 | return gated_output
40 |
41 |
42 | # main class
43 |
44 |
45 | class Perceiver(nn.Module):
46 | def __init__(
47 | self,
48 | *,
49 | num_freq_bands,
50 | depth,
51 | max_freq,
52 | freq_base=2,
53 | input_channels=3,
54 | input_axis=2,
55 | num_latents=512,
56 | latent_dim=512,
57 | cross_heads=1,
58 | latent_heads=8,
59 | cross_dim_head=64,
60 | latent_dim_head=64,
61 | num_classes=1000,
62 | attn_dropout=0.0,
63 | ff_dropout=0.0,
64 | weight_tie_layers=False
65 | ):
66 | super().__init__()
67 | self.input_axis = input_axis
68 | self.max_freq = max_freq
69 | self.num_freq_bands = num_freq_bands
70 | self.freq_base = freq_base
71 |
72 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels
73 |
74 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
75 |
76 | get_cross_attn = lambda: GRUGating(
77 | latent_dim,
78 | PreNorm(
79 | latent_dim,
80 | Attention(
81 | latent_dim,
82 | input_dim,
83 | heads=cross_heads,
84 | dim_head=cross_dim_head,
85 | dropout=attn_dropout,
86 | ),
87 | context_dim=input_dim,
88 | ),
89 | )
90 | get_latent_attn = lambda: GRUGating(
91 | latent_dim,
92 | PreNorm(
93 | latent_dim,
94 | Attention(
95 | latent_dim,
96 | heads=latent_heads,
97 | dim_head=latent_dim_head,
98 | dropout=attn_dropout,
99 | ),
100 | ),
101 | )
102 | get_cross_ff = lambda: Residual(
103 | PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
104 | )
105 | get_latent_ff = lambda: Residual(
106 | PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
107 | )
108 |
109 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
110 | cache_fn,
111 | (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff),
112 | )
113 |
114 | self.layers = nn.ModuleList([])
115 | for i in range(depth):
116 | should_cache = i > 0 and weight_tie_layers
117 | cache_args = {"_cache": should_cache}
118 |
119 | self.layers.append(
120 | nn.ModuleList(
121 | [
122 | get_cross_attn(**cache_args),
123 | get_cross_ff(**cache_args),
124 | get_latent_attn(**cache_args),
125 | get_latent_ff(**cache_args),
126 | ]
127 | )
128 | )
129 |
130 | self.to_logits = nn.Sequential(
131 | nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes)
132 | )
133 |
134 | def forward(self, data, mask=None):
135 | b, *axis, _, device = *data.shape, data.device
136 | assert (
137 | len(axis) == self.input_axis
138 | ), "input data must have the right number of axis"
139 |
140 | # calculate fourier encoded positions in the range of [-1, 1], for all axis
141 |
142 | axis_pos = list(
143 | map(
144 | lambda size: torch.linspace(
145 | -1.0, 1.0, steps=size, device=device
146 | ),
147 | axis,
148 | )
149 | )
150 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
151 | enc_pos = fourier_encode(
152 | pos, self.max_freq, self.num_freq_bands, base=self.freq_base
153 | )
154 | enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
155 | enc_pos = repeat(enc_pos, "... -> b ...", b=b)
156 |
157 | # concat to channels of data and flatten axis
158 |
159 | data = torch.cat((data, enc_pos), dim=-1)
160 | data = rearrange(data, "b ... d -> b (...) d")
161 |
162 | x = repeat(self.latents, "n d -> b n d", b=b)
163 |
164 | for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
165 | x = cross_attn(x, context=data, mask=mask)
166 | x = cross_ff(x)
167 | x = latent_attn(x)
168 | x = latent_ff(x)
169 |
170 | x = x.mean(dim=-2)
171 | return self.to_logits(x)
172 |
--------------------------------------------------------------------------------
/perceiver_pytorch/layers.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 |
3 | import torch
4 | from einops import rearrange, repeat
5 | from torch import nn, einsum
6 | from torch.nn import functional as F
7 |
8 | from perceiver_pytorch.rotary import apply_rotary_emb
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def default(val, d):
16 | return val if exists(val) else d
17 |
18 |
19 | def cache_fn(f):
20 | cache = None
21 |
22 | @wraps(f)
23 | def cached_fn(*args, _cache=True, **kwargs):
24 | if not _cache:
25 | return f(*args, **kwargs)
26 | nonlocal cache
27 | if cache is not None:
28 | return cache
29 | cache = f(*args, **kwargs)
30 | return cache
31 |
32 | return cached_fn
33 |
34 |
35 | class PreNorm(nn.Module):
36 | def __init__(self, dim, fn, context_dim=None):
37 | super().__init__()
38 | self.fn = fn
39 | self.norm = nn.LayerNorm(dim)
40 | self.norm_context = (
41 | nn.LayerNorm(context_dim) if exists(context_dim) else None
42 | )
43 |
44 | def forward(self, x, **kwargs):
45 | x = self.norm(x)
46 |
47 | if exists(self.norm_context):
48 | context = kwargs["context"]
49 | normed_context = self.norm_context(context)
50 | kwargs.update(context=normed_context)
51 |
52 | return self.fn(x, **kwargs)
53 |
54 |
55 | class GEGLU(nn.Module):
56 | """
57 | Gaussian Error Gated Linear Unit.
58 | See Shazer 2020: https://arxiv.org/abs/2002.05202
59 | """
60 | def forward(self, x):
61 | x, gates = x.chunk(2, dim=-1)
62 | return x * F.gelu(gates)
63 |
64 |
65 | class FeedForward(nn.Module):
66 | """Feed forward neural net with GEGLU activation."""
67 |
68 | def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
69 | """
70 | Args:
71 | dim: Input & Output size.
72 | mult: The inner dimension of the FF net will be dim * mult.
73 | dropout: Proportion to dropout after the GEGLU.
74 | """
75 | super().__init__()
76 | self.net = nn.Sequential(
77 | nn.Linear(dim, dim * mult * 2),
78 | GEGLU(),
79 | nn.Dropout(dropout),
80 | nn.Linear(dim * mult, dim),
81 | )
82 |
83 | def forward(self, x):
84 | return self.net(x)
85 |
86 |
87 | class Attention(nn.Module):
88 | def __init__(
89 | self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
90 | ):
91 | """
92 | Args:
93 | query_dim: Size of the queries.
94 | context_dim: Size of the 'context' (the 'byte array' in the paper).
95 | If None, will default to the query_dim.
96 | heads: Number of attention heads.
97 | dim_head: Number of dimensions per head.
98 | dropout: Proportion to dropout (in the final linear layer).
99 | """
100 |
101 | super().__init__()
102 | inner_dim = dim_head * heads
103 | context_dim = default(context_dim, query_dim)
104 |
105 | self.scale = dim_head ** -0.5
106 | self.heads = heads
107 |
108 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
109 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
110 |
111 | self.to_out = nn.Sequential(
112 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
113 | )
114 |
115 | def forward(self, x, context=None, mask=None, pos_emb=None):
116 | """
117 |
118 | Args:
119 | x: The 'latent array' in the Perceiver paper.
120 | context: The 'byte array' in the Perceiver paper (the input data).
121 | mask:
122 | pos_emb:
123 |
124 | Returns:
125 |
126 | """
127 |
128 | h = self.heads
129 |
130 | q = self.to_q(x)
131 | context = default(context, x)
132 | k, v = self.to_kv(context).chunk(2, dim=-1)
133 |
134 | # Rearrange the query, key and value tensors.
135 | # b = batch size; n = TODO (PD-2021-09-13)
136 | # h = number of heads; d = number of dims per head.
137 | q, k, v = map(
138 | lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
139 | )
140 |
141 | if exists(pos_emb):
142 | q, k = apply_rotary_emb(q, k, pos_emb)
143 |
144 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
145 |
146 | if exists(mask):
147 | mask = rearrange(mask, "b ... -> b (...)")
148 | max_neg_value = -torch.finfo(sim.dtype).max
149 | mask = repeat(mask, "b j -> (b h) () j", h=h)
150 | sim.masked_fill_(~mask, max_neg_value)
151 |
152 | # attention, what we cannot get enough of
153 | attn = sim.softmax(dim=-1)
154 |
155 | out = einsum("b i j, b j d -> b i d", attn, v)
156 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
157 | return self.to_out(out)
158 |
--------------------------------------------------------------------------------
/perceiver_pytorch/mixed_latents.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, repeat
6 |
7 | from perceiver_pytorch.layers import exists, default, cache_fn, PreNorm, FeedForward, Attention
8 | from perceiver_pytorch.utils import fourier_encode
9 |
10 |
11 | # latent mixer
12 |
13 |
14 | def Mixer(seq_len, mult=4, dropout=0.0):
15 | return nn.Sequential(
16 | nn.Conv1d(seq_len, seq_len * mult, 1),
17 | nn.GELU(),
18 | nn.Dropout(dropout),
19 | nn.Conv1d(seq_len * mult, seq_len, 1),
20 | )
21 |
22 |
23 | # main class
24 |
25 |
26 | class Perceiver(nn.Module):
27 | def __init__(
28 | self,
29 | *,
30 | num_freq_bands,
31 | depth,
32 | max_freq,
33 | input_channels=3,
34 | input_axis=2,
35 | num_latents=512,
36 | latent_dim=512,
37 | cross_heads=1,
38 | latent_heads=8,
39 | cross_dim_head=64,
40 | latent_dim_head=64,
41 | num_classes=1000,
42 | attn_dropout=0.0,
43 | ff_dropout=0.0,
44 | weight_tie_layers=False,
45 | **kwargs
46 | ):
47 | super().__init__()
48 | self.input_axis = input_axis
49 | self.max_freq = max_freq
50 | self.num_freq_bands = num_freq_bands
51 |
52 | input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels
53 |
54 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
55 |
56 | get_cross_attn = lambda: PreNorm(
57 | latent_dim,
58 | Attention(
59 | latent_dim,
60 | input_dim,
61 | heads=cross_heads,
62 | dim_head=cross_dim_head,
63 | dropout=attn_dropout,
64 | ),
65 | context_dim=input_dim,
66 | )
67 | get_latent_attn = lambda: PreNorm(latent_dim, Mixer(num_latents, dropout=ff_dropout))
68 | get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
69 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
70 |
71 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
72 | cache_fn,
73 | (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff),
74 | )
75 |
76 | self.layers = nn.ModuleList([])
77 | for i in range(depth):
78 | should_cache = i > 0 and weight_tie_layers
79 | cache_args = {"_cache": should_cache}
80 |
81 | self.layers.append(
82 | nn.ModuleList(
83 | [
84 | get_cross_attn(**cache_args),
85 | get_cross_ff(**cache_args),
86 | get_latent_attn(**cache_args),
87 | get_latent_ff(**cache_args),
88 | ]
89 | )
90 | )
91 |
92 | self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes))
93 |
94 | def forward(self, data, mask=None):
95 | b, *axis, _, device = *data.shape, data.device
96 | assert len(axis) == self.input_axis, "input data must have the right number of axis"
97 |
98 | # calculate fourier encoded positions in the range of [-1, 1], for all axis
99 |
100 | axis_pos = list(
101 | map(
102 | lambda size: torch.linspace(-1.0, 1.0, steps=size, device=device),
103 | axis,
104 | )
105 | )
106 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
107 | enc_pos = fourier_encode(
108 | x=pos,
109 | max_freq=self.max_freq,
110 | num_bands=self.num_freq_bands,
111 | )
112 | enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
113 | enc_pos = repeat(enc_pos, "... -> b ...", b=b)
114 |
115 | # concat to channels of data and flatten axis
116 |
117 | data = torch.cat((data, enc_pos), dim=-1)
118 | data = rearrange(data, "b ... d -> b (...) d")
119 |
120 | x = repeat(self.latents, "n d -> b n d", b=b)
121 |
122 | for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
123 | x = cross_attn(x, context=data, mask=mask) + x
124 | x = cross_ff(x) + x
125 | x = latent_attn(x) + x
126 | x = latent_ff(x) + x
127 |
128 | x = x.mean(dim=-2)
129 | return self.to_logits(x)
130 |
--------------------------------------------------------------------------------
/perceiver_pytorch/modalities.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataclasses import dataclass
3 |
4 |
5 | @dataclass
6 | class InputModality:
7 | name: str
8 | input_channels: int
9 | input_axis: int
10 | num_freq_bands: int
11 | max_freq: float
12 | sin_only: bool = False
13 | fourier_encode: bool = True
14 |
15 | @property
16 | def input_dim(self) -> int:
17 | # Calculate the dimension of this modality.
18 | if self.fourier_encode:
19 | fourier_channels = self.input_axis * ((self.num_freq_bands * 2) + 1)
20 | fourier_channels = fourier_channels // 2 if self.sin_only else fourier_channels
21 | input_dim = fourier_channels + self.input_channels
22 | return input_dim
23 | else:
24 | return self.input_channels
25 |
26 |
27 | def modality_encoding(
28 | batch_size: int, axes, modality_index: int, num_modalities: int
29 | ) -> torch.Tensor:
30 | """
31 | Return one-hot encoding of modality given num_modalities, batch size and axes.
32 | The result need to be compatible with the modality data for concatenation.
33 |
34 | Args:
35 | batch_size: Batch size of the input
36 | axes: The size of each axis, other than batch size, of the input
37 | modality_index: The index of this modality i.e. if there are 3 modalities, this would be 0, 1, or 2
38 | num_modalities: Total number of modalities
39 |
40 | Returns:
41 | One hot encoding of which modality the input is
42 |
43 | """
44 | one_hot = torch.eye(num_modalities, num_modalities)[modality_index]
45 | to_expand = [batch_size]
46 | one_hot = one_hot.unsqueeze(0)
47 | for i, axis in enumerate(axes):
48 | one_hot = one_hot.unsqueeze(0)
49 | to_expand.append(axis)
50 | to_expand.append(num_modalities)
51 |
52 | one_hot = one_hot.expand(to_expand)
53 | return one_hot
54 |
--------------------------------------------------------------------------------
/perceiver_pytorch/multi_perceiver_pytorch.py:
--------------------------------------------------------------------------------
1 | from perceiver_pytorch.perceiver_io import PerceiverIO
2 | from perceiver_pytorch.modalities import InputModality, modality_encoding
3 | from perceiver_pytorch.utils import encode_position, fourier_encode
4 | import torch
5 | from typing import List, Iterable, Dict, Optional, Any, Union, Tuple
6 | from einops import rearrange, repeat
7 | from math import prod
8 |
9 |
10 | class MultiPerceiver(torch.nn.Module):
11 | def __init__(
12 | self,
13 | modalities: Iterable[InputModality],
14 | fourier_encode_data: bool = True,
15 | input_channels: int = 3,
16 | output_channels: int = 12,
17 | forecast_steps: int = 48,
18 | sine_only: bool = False,
19 | output_shape: Union[int, Tuple[int, ...]] = 32,
20 | **kwargs,
21 | ):
22 | """
23 | PerceiverIO made to work more specifically with timeseries images and multimodal inputs https://arxiv.org/abs/2107.14795
24 | This is a wrapper around the PerceiverIO implementation to encode the inputs correctly
25 |
26 | Args:
27 | input_channels: Number of input channels (int)
28 | forecast_steps: Number of forecast steps to make (int)
29 | fourier_encode_data: Whether to add Fourier Features to the input data, if this is false, inputs should be have some type of positional encoding added beforehand
30 | output_channels: Number of output channels per image (int)
31 | sine_only: Only use Sine part of Fourier features (bool)
32 | output_shape: Int or Tuple of ints, giving the desired output shape of the model
33 | **kwargs: Extra kwargs to pass through to PerceiverIO
34 | """
35 | super(MultiPerceiver, self).__init__()
36 | self.fourier_encode_data = fourier_encode_data
37 | self.forecast_steps = forecast_steps
38 | self.input_channels = input_channels
39 | self.sine_only = sine_only
40 | self.output_channels = output_channels
41 | self.modalities = {modality.name: modality for modality in modalities}
42 | # we encode modality with one hot encoding, so need one dim per modality:
43 | modality_encoding_dim = len(modalities)
44 | # input_dim is the maximum dimension over all input modalities:
45 | input_dim = max(modality.input_dim for modality in modalities) + modality_encoding_dim
46 | # Pop dim
47 | self.max_modality_dim = input_dim
48 | kwargs.pop("dim", None)
49 | # Want toe logit_dim to be the same as the channels * width or height
50 | if isinstance(output_shape, int):
51 | kwargs["logits_dim"] = output_shape * self.output_channels
52 | else:
53 | kwargs["logits_dim"] = prod(output_shape)
54 | self.perceiver = PerceiverIO(dim=input_dim, **kwargs)
55 |
56 | def decode_output(self, data):
57 | pass
58 |
59 | def forward(self, multi_modality_data: Dict[str, torch.Tensor], mask=None, queries=None):
60 | batch_sizes = set()
61 | num_modalities = len(multi_modality_data)
62 | linearized_data = []
63 |
64 | for modality_index, modality_name in enumerate(sorted(multi_modality_data.keys())):
65 | assert (
66 | modality_name in self.modalities
67 | ), f"modality {modality_name} was not defined in constructor"
68 | data = multi_modality_data[modality_name]
69 | modality = self.modalities[modality_name]
70 | b, *axis, _ = data.size()
71 | assert len(axis) == modality.input_axis, (
72 | f"input data must have the right number of axes for modality {modality_name}. "
73 | f"Expected {modality.input_axis} while forward argument offered {len(axis)}"
74 | )
75 | batch_sizes.add(b)
76 | assert len(batch_sizes) == 1, "batch size must be the same across all modalities"
77 | enc_pos = []
78 | if self.fourier_encode_data:
79 | # calculate fourier encoded positions in the range of [-1, 1], for all axis
80 | enc_pos = encode_position(
81 | batch_size=b,
82 | axis=axis,
83 | max_frequency=modality.max_freq,
84 | num_frequency_bands=modality.num_freq_bands,
85 | sine_only=self.sine_only,
86 | ).type_as(data)
87 |
88 | # Figure out padding for this modality, given max dimension across all modalities:
89 | padding_size = self.max_modality_dim - modality.input_dim - num_modalities
90 |
91 | padding = torch.zeros(size=data.size()[0:-1] + (padding_size,)).type_as(data)
92 | # concat to channels of data and flatten axis
93 | modality_encodings = modality_encoding(b, axis, modality_index, num_modalities).type_as(
94 | data
95 | )
96 | to_concat = (
97 | (data, padding, enc_pos, modality_encodings)
98 | if len(enc_pos) > 0
99 | else (data, padding, modality_encodings)
100 | )
101 | data = torch.cat(to_concat, dim=-1)
102 | # concat to channels of data and flatten axis
103 | data = rearrange(data, "b ... d -> b (...) d")
104 | linearized_data.append(data)
105 |
106 | # Concatenate all the modalities:
107 | data = torch.cat(linearized_data, dim=1)
108 |
109 | perceiver_output = self.perceiver.forward(data, mask, queries)
110 |
111 | # To keep this more general, leave the reshaping to postprocessing outside the model
112 | return perceiver_output
113 |
--------------------------------------------------------------------------------
/perceiver_pytorch/perceiver_io.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from einops import repeat
5 | from perceiver_pytorch.layers import exists, cache_fn, PreNorm, FeedForward, Attention
6 |
7 | # main class
8 |
9 | class PerceiverIO(nn.Module):
10 | def __init__(
11 | self,
12 | *,
13 | depth,
14 | dim,
15 | queries_dim,
16 | logits_dim = None,
17 | num_latents = 512,
18 | latent_dim = 512,
19 | cross_heads = 1,
20 | latent_heads = 8,
21 | cross_dim_head = 64,
22 | latent_dim_head = 64,
23 | weight_tie_layers = False,
24 | decoder_ff = False
25 | ):
26 | """
27 | PerceiverIO implementation from https://arxiv.org/abs/2107.14795
28 |
29 | Args:
30 | depth: Depth of the network
31 | dim: dimension of sequence to be encoded
32 | queries_dim: Dimension of the decoder queries
33 | logits_dim: Dimension of final logits
34 | num_latents: Number of latents
35 | latent_dim: Latent dimension
36 | cross_heads: Number of heads for cross attention
37 | latent_heads: Number of heads for latent self-attention
38 | cross_dim_head: Number of dimensions per cross attention head
39 | latent_dim_head: Number of dimensions per latent self-attention head
40 | weight_tie_layers: Whether to weight tie layers
41 | decoder_ff: Whether to use a feed forward network on the decoder queries
42 | """
43 | super().__init__()
44 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
45 |
46 | self.cross_attend_blocks = nn.ModuleList([
47 | PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim),
48 | PreNorm(latent_dim, FeedForward(latent_dim))
49 | ])
50 |
51 | get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head))
52 | get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
53 | get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
54 |
55 | self.layers = nn.ModuleList([])
56 | cache_args = {'_cache': weight_tie_layers}
57 |
58 | for i in range(depth):
59 | self.layers.append(nn.ModuleList([
60 | get_latent_attn(**cache_args),
61 | get_latent_ff(**cache_args)
62 | ]))
63 |
64 | self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim)
65 | self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
66 |
67 | self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()
68 |
69 | def forward(
70 | self,
71 | data,
72 | mask = None,
73 | queries = None
74 | ):
75 | b, *_, device = *data.shape, data.device
76 |
77 | x = repeat(self.latents, 'n d -> b n d', b = b)
78 |
79 | cross_attn, cross_ff = self.cross_attend_blocks
80 |
81 | # cross attention only happens once for Perceiver IO
82 |
83 | x = cross_attn(x, context = data, mask = mask) + x
84 | x = cross_ff(x) + x
85 |
86 | # layers
87 |
88 | for self_attn, self_ff in self.layers:
89 | x = self_attn(x) + x
90 | x = self_ff(x) + x
91 |
92 | if not exists(queries):
93 | return x
94 |
95 | # cross attend from decoder queries to latents
96 |
97 | latents = self.decoder_cross_attn(queries, context = x)
98 |
99 | # optional decoder feedforward
100 |
101 | if exists(self.decoder_ff):
102 | latents = latents + self.decoder_ff(latents)
103 |
104 | # final linear out
105 |
106 | return self.to_logits(latents)
107 |
108 |
--------------------------------------------------------------------------------
/perceiver_pytorch/perceiver_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange, repeat
3 | from torch import nn
4 |
5 | from perceiver_pytorch.layers import exists, cache_fn, PreNorm, FeedForward, Attention
6 | from perceiver_pytorch.rotary import SinusoidalEmbeddings
7 | from perceiver_pytorch.utils import encode_position
8 |
9 | # main class
10 |
11 |
12 | class Perceiver(nn.Module):
13 | def __init__(
14 | self,
15 | *,
16 | num_freq_bands,
17 | depth,
18 | max_freq,
19 | input_channels=3,
20 | input_axis=2,
21 | num_latents=512,
22 | latent_dim=512,
23 | cross_heads=1,
24 | latent_heads=8,
25 | cross_dim_head=64,
26 | latent_dim_head=64,
27 | num_classes=1000,
28 | attn_dropout=0.0,
29 | ff_dropout=0.0,
30 | weight_tie_layers=False,
31 | fourier_encode_data=True,
32 | sine_only: bool = False,
33 | self_per_cross_attn=1,
34 | self_attn_rel_pos=True,
35 | ):
36 | """
37 | Perceiver: https://arxiv.org/abs/2103.03206
38 | The shape of the final attention mechanism will be:
39 | depth * (cross attention -> self_per_cross_attn * self attention)
40 |
41 | Args:
42 | num_freq_bands: Number of freq bands, with original value (2 * K + 1)
43 | depth: Depth of net.
44 | max_freq: Maximum frequency, hyperparameter depending on how
45 | fine the data is.
46 | input_channels: Number of channels for each token of the input.
47 | input_axis: Number of axes for input data (2 for images, 3 for video)
48 | num_latents: Number of latents, or induced set points, or centroids.
49 | Different papers giving it different names.
50 | latent_dim: Latent dimension.
51 | cross_heads: Number of heads for cross attention. Paper said 1.
52 | latent_heads: Number of heads for latent self attention, 8.
53 | cross_dim_head: Number of dimensions per cross attention head.
54 | latent_dim_head: Number of dimensions per latent self attention head.
55 | num_classes: Output number of classes.
56 | attn_dropout: Attention dropout
57 | ff_dropout: Feedforward dropout
58 | weight_tie_layers: Whether to weight tie layers (optional).
59 | fourier_encode_data: Whether to auto-fourier encode the data, using
60 | the input_axis given. defaults to True, but can be turned off
61 | if you are fourier encoding the data yourself.
62 | sine_only: Use only sine encoding in fourier encoding, compared to using sine and cos
63 | self_per_cross_attn: Number of self attention blocks per cross attn.
64 | self_attn_rel_pos:
65 | """
66 | super().__init__()
67 | self.input_axis = input_axis
68 | self.max_freq = max_freq
69 | self.num_freq_bands = num_freq_bands
70 |
71 | self.fourier_encode_data = fourier_encode_data
72 | fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0
73 | self.sine_only = sine_only
74 | input_dim = fourier_channels + input_channels
75 |
76 | # Randomly initialise the 'latent array'.
77 | self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
78 |
79 | def get_cross_attn():
80 | return PreNorm(
81 | latent_dim,
82 | Attention(
83 | latent_dim,
84 | input_dim,
85 | heads=cross_heads,
86 | dim_head=cross_dim_head,
87 | dropout=attn_dropout,
88 | ),
89 | context_dim=input_dim,
90 | )
91 |
92 | def get_cross_ff():
93 | return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
94 |
95 | def get_latent_attn():
96 | return PreNorm(
97 | latent_dim,
98 | Attention(
99 | latent_dim,
100 | heads=latent_heads,
101 | dim_head=latent_dim_head,
102 | dropout=attn_dropout,
103 | ),
104 | )
105 |
106 | def get_latent_ff():
107 | return PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout))
108 |
109 | # Cache all the above functions.
110 | get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
111 | cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)
112 | )
113 |
114 | self.layers = nn.ModuleList([])
115 | for i in range(depth):
116 | should_cache = i > 0 and weight_tie_layers
117 | cache_args = {"_cache": should_cache}
118 |
119 | self_attns = nn.ModuleList([])
120 |
121 | for _ in range(self_per_cross_attn):
122 | self_attns.append(
123 | nn.ModuleList(
124 | [
125 | get_latent_attn(**cache_args),
126 | get_latent_ff(**cache_args),
127 | ]
128 | )
129 | )
130 |
131 | self.layers.append(
132 | nn.ModuleList(
133 | [
134 | get_cross_attn(**cache_args),
135 | get_cross_ff(**cache_args),
136 | self_attns,
137 | ]
138 | )
139 | )
140 |
141 | self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes))
142 |
143 | self.sinu_emb = None
144 | if self_attn_rel_pos:
145 | self.sinu_emb = SinusoidalEmbeddings(latent_dim_head)
146 |
147 | def forward(self, data, mask=None):
148 | """
149 | Args:
150 | data: If sequential is True, then data must be of shape:
151 | (batch size, sequence length, *axes) where axes would be width
152 | and height for images.
153 | """
154 |
155 | b, *axis, _ = data.shape
156 | device = data.device
157 |
158 | assert (
159 | len(axis) == self.input_axis
160 | ), f"Input data must have {self.input_axis} axes, not {len(axis)}!"
161 |
162 | if self.fourier_encode_data:
163 | # Calculate Fourier encoded positions in the range of [-1, 1],
164 | # for all axes.
165 | enc_pos = encode_position(
166 | b,
167 | axis,
168 | self.max_freq,
169 | self.num_freq_bands,
170 | sine_only=self.sine_only,
171 | ).type_as(data)
172 |
173 | data = torch.cat((data, enc_pos), dim=-1)
174 |
175 | # Concat to channels of data and flatten axes.
176 | # b = batch size; d = last dimension of data
177 | data = rearrange(data, "b ... d -> b (...) d", b=b)
178 |
179 | # x is the 'latent array' in the paper.
180 | # b = batch size; n = number of latents; d = latent dimensions.
181 | x = repeat(self.latents, "n d -> b n d", b=b)
182 |
183 | # Rotary embeddings for latents, if specified.
184 | pos_emb = self.sinu_emb(x) if exists(self.sinu_emb) else None
185 |
186 | # Layers.
187 | for cross_attn, cross_ff, self_attns in self.layers:
188 | x = cross_attn(x, context=data, mask=mask) + x
189 | x = cross_ff(x) + x
190 |
191 | for self_attn, self_ff in self_attns:
192 | x = self_attn(x, pos_emb=pos_emb) + x
193 | x = self_ff(x) + x
194 |
195 | x = x.mean(dim=-2)
196 | return self.to_logits(x)
197 |
--------------------------------------------------------------------------------
/perceiver_pytorch/queries.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.distributions import uniform
3 | from typing import List, Union, Tuple, Optional
4 | from perceiver_pytorch.utils import encode_position
5 | from math import prod
6 | import einops
7 | import logging
8 |
9 | _LOG = logging.getLogger("perceiver.queries")
10 | _LOG.setLevel(logging.WARN)
11 |
12 |
13 | class LearnableQuery(torch.nn.Module):
14 | """
15 | Module that constructs a learnable query of query_shape for the Perceiver
16 | """
17 |
18 | def __init__(
19 | self,
20 | channel_dim: int,
21 | query_shape: Union[Tuple[int], List[int]],
22 | conv_layer: str = "3d",
23 | max_frequency: float = 16.0,
24 | num_frequency_bands: int = 64,
25 | sine_only: bool = False,
26 | precomputed_fourier: Optional[torch.Tensor] = None,
27 | generate_fourier_features: bool = False,
28 | ):
29 | """
30 | Learnable Query with some inbuilt randomness to help with ensembling
31 |
32 | Args:
33 | channel_dim: Channel dimension for the output of the network
34 | query_shape: The final shape of the query, generally, the (T, H, W) of the output
35 | conv_layer: The type of convolutional layer to use, either 3d or 2d
36 | max_frequency: Max frequency for the Fourier Features
37 | num_frequency_bands: Number of frequency bands for the Fourier Features
38 | sine_only: Whether to use only the sine Fourier features
39 | precomputed_fourier: Fourier features to use instead of computing them here,
40 | useful for having temporally consistent features from history timesteps to future predictions
41 | These features will be concatenated directly to the query, so should be compatible, and made in the
42 | same way as in encode_position
43 | generate_fourier_features: Whether to use generated Fourier features giving the relative
44 | position within the predictions.
45 | """
46 | super().__init__()
47 | self.query_shape = query_shape
48 | self.generate_fourier_features = generate_fourier_features
49 | # Need to get Fourier Features once and then just append to the output
50 | fourier_features = []
51 | if precomputed_fourier is not None:
52 | fourier_features += [precomputed_fourier]
53 | if self.generate_fourier_features:
54 | generated_features = encode_position(
55 | 1, # Batch size, 1 for this as it will be adapted in forward
56 | axis=query_shape,
57 | max_frequency=max_frequency,
58 | num_frequency_bands=num_frequency_bands,
59 | sine_only=sine_only,
60 | )
61 | fourier_features += [generated_features]
62 | if len(fourier_features) > 1:
63 | self.fourier_features = torch.cat(fourier_features, dim=-1)
64 | elif fourier_features:
65 | # Only have one of the two options
66 | self.fourier_features = fourier_features[0]
67 | else:
68 | # None are set
69 | self.fourier_features = None
70 |
71 | self.channel_dim = channel_dim
72 | if (
73 | conv_layer == "3d" and len(self.query_shape) == 3
74 | ): # If Query shape is for an image, then 3D conv won't work
75 | conv = torch.nn.Conv3d
76 | elif conv_layer == "2d":
77 | conv = torch.nn.Conv2d
78 | else:
79 | raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'")
80 | self.conv_layer = conv_layer
81 | self.layer = conv(
82 | in_channels=channel_dim, out_channels=channel_dim, kernel_size=3, padding=1
83 | )
84 | # Linear layer to compress channels down to query_dim size?
85 | self.fc = torch.nn.Linear(self.channel_dim, self.channel_dim)
86 | self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0]))
87 |
88 | def output_shape(self) -> Tuple[int, int]:
89 | """
90 | Gives the output shape from the query, useful for setting the correct
91 | query_dim in the Perceiver
92 |
93 | Returns:
94 | The shape of the resulting query, excluding the batch size
95 | """
96 |
97 | # The shape is the query_dim + Fourier Feature channels
98 | if self.fourier_features is not None:
99 | channels = self.fourier_features.shape[-1] + self.channel_dim
100 | else:
101 | channels = self.channel_dim
102 | return prod(self.query_shape), channels
103 |
104 | def forward(
105 | self, x: torch.Tensor, fourier_features: Optional[torch.Tensor] = None
106 | ) -> torch.Tensor:
107 | """
108 | Samples the uniform distribution and creates the query by passing the
109 | sample through the model and appending Fourier features
110 |
111 | Args:
112 | x: The input tensor to the model, used to batch the batch size
113 | fourier_features: Fourier features to append to the input, if used, the output_shape will be incorrect,
114 | and to get the correct output shape, the fourier_features channels have to be added to output_shape
115 |
116 | Returns:
117 | Torch tensor used to query the output of the PerceiverIO model
118 | """
119 | _LOG.debug(f"Batch: {x.shape[0]} Query: {self.query_shape} Dim: {self.channel_dim}")
120 | z = self.distribution.sample((x.shape[0], self.channel_dim, *self.query_shape)).type_as(
121 | x
122 | ) # [B, Query, T, H, W, 1] or [B, Query, H, W, 1]
123 | z = torch.squeeze(z, dim=-1) # Extra 1 for some reason
124 | _LOG.debug(f"Z: {z.shape}")
125 | # Do 3D or 2D CNN to keep same spatial size, concat, then linearize
126 | if self.conv_layer == "2d" and len(self.query_shape) == 3:
127 | # Iterate through time dimension
128 | outs = []
129 | for i in range(x.shape[1]):
130 | outs.append(self.layer(z[:, :, i, :, :]))
131 | query = torch.stack(outs, dim=2)
132 | else:
133 | query = self.layer(z)
134 | # Move channels to correct location
135 | query = einops.rearrange(query, "b c ... -> b ... c")
136 | to_concat = [query]
137 | if self.fourier_features is not None:
138 | ff = einops.repeat(
139 | self.fourier_features, "b ... -> (repeat b) ...", repeat=x.shape[0]
140 | ) # Match batches
141 | to_concat = to_concat + [ff]
142 | if fourier_features is not None:
143 | to_concat = to_concat + [fourier_features]
144 | if len(to_concat) > 1:
145 | query = torch.cat(to_concat, dim=-1)
146 | # concat to channels of data and flatten axis
147 | query = einops.rearrange(query, "b ... d -> b (...) d")
148 | _LOG.debug(f"Final Query Shape: {query.shape}")
149 | return query
150 |
--------------------------------------------------------------------------------
/perceiver_pytorch/rotary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange, repeat
4 |
5 |
6 | class SinusoidalEmbeddings(nn.Module):
7 | def __init__(self, dim):
8 | super().__init__()
9 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
10 | self.register_buffer("inv_freq", inv_freq)
11 |
12 | def forward(self, x):
13 | n = x.shape[-2]
14 | t = torch.arange(n, device=x.device).type_as(self.inv_freq)
15 | sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
16 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
17 | return emb[None, :, :]
18 |
19 |
20 | def rotate_every_two(x):
21 | x = rearrange(x, "... (d j) -> ... d j", j=2)
22 | x1, x2 = x.unbind(dim=-1)
23 | x = torch.stack((-x2, x1), dim=-1)
24 | return rearrange(x, "... d j -> ... (d j)")
25 |
26 |
27 | def apply_rotary_emb(q, k, sinu_pos):
28 | sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2)
29 | sin, cos = sinu_pos.unbind(dim=-2)
30 | sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos))
31 | q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
32 | return q, k
33 |
--------------------------------------------------------------------------------
/perceiver_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | from math import log, pi
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import math
7 | import einops
8 |
9 |
10 | def extract_image_patches(
11 | x: torch.Tensor, kernel: int, stride: int = 1, dilation: int = 1
12 | ) -> torch.Tensor:
13 | """
14 | Extract image patches in a way similar to TensorFlow extract_image_patches
15 | Taken from https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8
16 |
17 | In the Perceiver JAX implementation they extract image patches matching TensorFlow's SAME padding.
18 | PyTorch doesn't have that same kind of option, so this is a way to do that.
19 |
20 | Args:
21 | x: Input Torch Tensor
22 | kernel: Size of kernel
23 | stride: Stride of patch
24 | dilation: Dilation rate
25 |
26 | Returns:
27 | Tensor of size [Batch, Height, Width, Channels*kernel*stride]
28 |
29 | """
30 | # Do TF 'SAME' Padding
31 | b, c, h, w = x.shape
32 | h2 = math.ceil(h / stride)
33 | w2 = math.ceil(w / stride)
34 | pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
35 | pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
36 | x = F.pad(x, (pad_row // 2, pad_row - pad_row // 2, pad_col // 2, pad_col - pad_col // 2))
37 |
38 | # Extract patches
39 | # get all image windows of size (kernel, stride) and stride (kernel, stride)
40 | patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
41 | # Permute so that channels are next to patch dimension
42 | patches = patches.permute(0, 4, 5, 1, 2, 3).contiguous()
43 | # View as [batch_size, height, width, channels*kh*kw]
44 | return patches.view(b, -1, patches.shape[-2], patches.shape[-1])
45 |
46 |
47 | def reverse_space_to_depth(
48 | frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1
49 | ) -> torch.Tensor:
50 | """Reverse space to depth transform.
51 | Works for images (dim = 4) and videos (dim = 5)"""
52 | if len(frames.shape) == 4:
53 | return einops.rearrange(
54 | frames,
55 | "b (dh dw c) h w -> b c (h dh) (w dw)",
56 | dh=spatial_block_size,
57 | dw=spatial_block_size,
58 | )
59 | elif len(frames.shape) == 5:
60 | return einops.rearrange(
61 | frames,
62 | "b t (dt dh dw c) h w -> b (t dt) c (h dh) (w dw)",
63 | dt=temporal_block_size,
64 | dh=spatial_block_size,
65 | dw=spatial_block_size,
66 | )
67 | else:
68 | raise ValueError(
69 | "Frames should be of rank 4 (batch, height, width, channels)"
70 | " or rank 5 (batch, time, height, width, channels)"
71 | )
72 |
73 |
74 | def space_to_depth(
75 | frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1
76 | ) -> torch.Tensor:
77 | """Space to depth transform.
78 | Works for images (dim = 4) and videos (dim = 5)"""
79 | if len(frames.shape) == 4:
80 | return einops.rearrange(
81 | frames,
82 | "b c (h dh) (w dw) -> b (dh dw c) h w",
83 | dh=spatial_block_size,
84 | dw=spatial_block_size,
85 | )
86 | elif len(frames.shape) == 5:
87 | return einops.rearrange(
88 | frames,
89 | "b (t dt) c (h dh) (w dw) -> b t (dt dh dw c) h w ",
90 | dt=temporal_block_size,
91 | dh=spatial_block_size,
92 | dw=spatial_block_size,
93 | )
94 | else:
95 | raise ValueError(
96 | "Frames should be of rank 4 (batch, height, width, channels)"
97 | " or rank 5 (batch, time, height, width, channels)"
98 | )
99 |
100 |
101 | def encode_position(
102 | batch_size: int,
103 | axis: list,
104 | max_frequency: float,
105 | num_frequency_bands: int,
106 | sine_only: bool = False,
107 | ) -> torch.Tensor:
108 | """
109 | Encode the Fourier Features and return them
110 |
111 | Args:
112 | batch_size: Batch size
113 | axis: List containing the size of each axis
114 | max_frequency: Max frequency
115 | num_frequency_bands: Number of frequency bands to use
116 | sine_only: (bool) Whether to only use Sine features or both Sine and Cosine, defaults to both
117 |
118 | Returns:
119 | Torch tensor containing the Fourier Features of shape [Batch, *axis]
120 | """
121 | axis_pos = list(
122 | map(
123 | lambda size: torch.linspace(-1.0, 1.0, steps=size),
124 | axis,
125 | )
126 | )
127 | pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
128 | enc_pos = fourier_encode(
129 | pos,
130 | max_frequency,
131 | num_frequency_bands,
132 | sine_only=sine_only,
133 | )
134 | enc_pos = einops.rearrange(enc_pos, "... n d -> ... (n d)")
135 | enc_pos = einops.repeat(enc_pos, "... -> b ...", b=batch_size)
136 | return enc_pos
137 |
138 |
139 | def fourier_encode(
140 | x: torch.Tensor,
141 | max_freq: float,
142 | num_bands: int = 4,
143 | sine_only: bool = False,
144 | ) -> torch.Tensor:
145 | """
146 | Create Fourier Encoding
147 |
148 | Args:
149 | x: Input Torch Tensor
150 | max_freq: Maximum frequency for the Fourier features
151 | num_bands: Number of frequency bands
152 | sine_only: Whether to only use sine or both sine and cosine features
153 |
154 | Returns:
155 | Torch Tensor with the fourier position encoded concatenated
156 | """
157 | x = x.unsqueeze(-1)
158 | device, dtype, orig_x = x.device, x.dtype, x
159 |
160 | scales = torch.linspace(
161 | 1.0,
162 | max_freq / 2,
163 | num_bands,
164 | device=device,
165 | dtype=dtype,
166 | )
167 | scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
168 |
169 | x = x * scales * pi
170 | x = x.sin() if sine_only else torch.cat([x.sin(), x.cos()], dim=-1)
171 | x = torch.cat((x, orig_x), dim=-1)
172 | return x
173 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops>=0.3
2 | torch>=1.6
3 | numpy
4 | torchvision
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from pathlib import Path
3 |
4 | this_directory = Path(__file__).parent
5 | install_requires = (this_directory / "requirements.txt").read_text().splitlines()
6 | long_description = (this_directory / "README.md").read_text()
7 |
8 |
9 | setup(
10 | name="perceiver-model",
11 | packages=find_packages(),
12 | version="0.7.6",
13 | license="MIT",
14 | description="Multimodal Perceiver - Pytorch",
15 | author="Jacob Bieker, Jack Kelly, Peter Dudfield",
16 | author_email="jacob@openclimatefix.org",
17 | company="Open Climate Fix Ltd",
18 | url="https://github.com/openclimatefix/perceiver-pytorch",
19 | keywords=[
20 | "artificial intelligence",
21 | "deep learning",
22 | "transformer",
23 | "attention mechanism",
24 | ],
25 | long_description=long_description,
26 | long_description_content_type="text/markdown",
27 | install_requires=install_requires,
28 | classifiers=[
29 | "Development Status :: 4 - Beta",
30 | "Intended Audience :: Developers",
31 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
32 | "License :: OSI Approved :: MIT License",
33 | "Programming Language :: Python :: 3.6",
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/tests/test_decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from perceiver_pytorch.decoders import ImageDecoder
4 | import pytest
5 |
6 |
7 | def test_conv_image_decoder():
8 | decoder = ImageDecoder(
9 | postprocess_type="conv",
10 | output_channels=12,
11 | input_channels=48,
12 | spatial_upsample=4,
13 | )
14 | inputs = torch.randn(2, 48, 64, 64)
15 | with torch.no_grad():
16 | out = decoder(inputs)
17 | assert not torch.isnan(out).any(), "Output included NaNs"
18 | assert out.size() == (2, 12, 256, 256)
19 |
20 |
21 | def test_conv1x1_image_decoder():
22 | decoder = ImageDecoder(
23 | postprocess_type="conv1x1",
24 | output_channels=12,
25 | input_channels=48,
26 | spatial_upsample=4,
27 | )
28 | inputs = torch.randn(2, 48, 64, 64)
29 | with torch.no_grad():
30 | out = decoder(inputs)
31 | assert not torch.isnan(out).any(), "Output included NaNs"
32 | # Conv1x1 downsample if spatial_upsample > 1
33 | assert out.size() == (2, 12, 16, 16)
34 |
35 |
36 | def test_patches_image_decoder():
37 | decoder = ImageDecoder(postprocess_type="patches", spatial_upsample=4)
38 | inputs = torch.randn(2, 192, 64, 64)
39 | with torch.no_grad():
40 | out = decoder(inputs)
41 | assert not torch.isnan(out).any(), "Output included NaNs"
42 | assert out.size() == (2, 12, 256, 256)
43 |
44 |
45 | def test_pixel_image_decoder():
46 | decoder = ImageDecoder(postprocess_type="pixels")
47 | inputs = torch.randn(2, 192, 64, 64)
48 | with torch.no_grad():
49 | out = decoder(inputs)
50 | assert not torch.isnan(out).any(), "Output included NaNs"
51 | assert out.size() == (2, 192, 64, 64)
52 | assert pytest.approx(inputs, out)
53 |
54 |
55 | def test_conv_video_decoder():
56 | decoder = ImageDecoder(
57 | postprocess_type="conv",
58 | output_channels=12,
59 | input_channels=48,
60 | spatial_upsample=4,
61 | )
62 | inputs = torch.randn(2, 3, 48, 64, 64)
63 | with torch.no_grad():
64 | out = decoder(inputs)
65 | assert not torch.isnan(out).any(), "Output included NaNs"
66 | assert out.size() == (2, 3, 12, 256, 256)
67 |
68 |
69 | def test_conv1x1_video_decoder():
70 | decoder = ImageDecoder(
71 | postprocess_type="conv1x1",
72 | output_channels=12,
73 | input_channels=48,
74 | spatial_upsample=4,
75 | )
76 | inputs = torch.randn(2, 3, 48, 64, 64)
77 | with torch.no_grad():
78 | out = decoder(inputs)
79 | assert not torch.isnan(out).any(), "Output included NaNs"
80 | # Conv1x1 downsample if spatial_upsample > 1
81 | assert out.size() == (2, 3, 12, 16, 16)
82 |
83 |
84 | def test_conv3d_video_decoder():
85 | decoder = ImageDecoder(
86 | postprocess_type="conv",
87 | output_channels=12,
88 | input_channels=48,
89 | spatial_upsample=4,
90 | temporal_upsample=2,
91 | )
92 | inputs = torch.randn(2, 1, 48, 64, 64)
93 | with torch.no_grad():
94 | out = decoder(inputs)
95 | assert not torch.isnan(out).any(), "Output included NaNs"
96 | assert out.size() == (2, 2, 12, 256, 256)
97 |
98 |
99 | def test_patches_video_decoder():
100 | decoder = ImageDecoder(postprocess_type="patches", spatial_upsample=4)
101 | inputs = torch.randn(2, 3, 192, 64, 64)
102 | with torch.no_grad():
103 | out = decoder(inputs)
104 | assert not torch.isnan(out).any(), "Output included NaNs"
105 | assert out.size() == (2, 3, 12, 256, 256)
106 |
107 |
108 | def test_pixel_video_decoder():
109 | decoder = ImageDecoder(postprocess_type="pixels")
110 | inputs = torch.randn(2, 3, 192, 64, 64)
111 | with torch.no_grad():
112 | out = decoder(inputs)
113 | assert not torch.isnan(out).any(), "Output included NaNs"
114 | assert out.size() == (2, 3, 192, 64, 64)
115 | assert pytest.approx(inputs, out)
116 |
--------------------------------------------------------------------------------
/tests/test_encoders.py:
--------------------------------------------------------------------------------
1 | from perceiver_pytorch.encoders import ImageEncoder
2 | import torch
3 | import pytest
4 |
5 |
6 | @pytest.mark.parametrize("prep_type", ["conv", "conv1x1"])
7 | def test_conv_image_encoder(prep_type):
8 | encoder = ImageEncoder(prep_type=prep_type, output_channels=48)
9 | image = torch.randn(2, 12, 256, 256)
10 | with torch.no_grad():
11 | out = encoder(image)
12 | assert not torch.isnan(out).any(), "Output included NaNs"
13 | assert out.size() == (2, 48, 64, 64)
14 |
15 |
16 | def test_patches_image_encoder():
17 | encoder = ImageEncoder(prep_type="patches", output_channels=48)
18 | image = torch.randn(2, 12, 256, 256)
19 | with torch.no_grad():
20 | out = encoder(image)
21 | assert not torch.isnan(out).any(), "Output included NaNs"
22 | assert out.size() == (2, 192, 64, 64)
23 |
24 |
25 | def test_pixels_image_encoder():
26 | encoder = ImageEncoder(prep_type="pixels", output_channels=48)
27 | image = torch.randn(2, 12, 256, 256)
28 | with torch.no_grad():
29 | out = encoder(image)
30 | assert not torch.isnan(out).any(), "Output included NaNs"
31 | assert out.size() == (2, 12, 64, 64)
32 |
33 |
34 | @pytest.mark.parametrize("prep_type", ["conv", "conv1x1"])
35 | def test_conv_video_encoder(prep_type):
36 | encoder = ImageEncoder(prep_type=prep_type, output_channels=48)
37 | image = torch.randn(2, 6, 12, 256, 256)
38 | with torch.no_grad():
39 | out = encoder(image)
40 | assert not torch.isnan(out).any(), "Output included NaNs"
41 | assert out.size() == (2, 6, 48, 64, 64)
42 |
43 |
44 | def test_patches_video_encoder():
45 | encoder = ImageEncoder(prep_type="patches", output_channels=48)
46 | image = torch.randn(2, 6, 12, 256, 256)
47 | with torch.no_grad():
48 | out = encoder(image)
49 | assert not torch.isnan(out).any(), "Output included NaNs"
50 | assert out.size() == (2, 6, 192, 64, 64)
51 |
52 |
53 | def test_pixels_video_encoder():
54 | encoder = ImageEncoder(prep_type="pixels", output_channels=48)
55 | image = torch.randn(2, 6, 12, 256, 256)
56 | with torch.no_grad():
57 | out = encoder(image)
58 | assert not torch.isnan(out).any(), "Output included NaNs"
59 | assert out.size() == (2, 6, 12, 64, 64)
60 |
61 |
62 | def test_pixels_video_downsample_encoder():
63 | encoder = ImageEncoder(
64 | prep_type="pixels", output_channels=48, temporal_downsample=2
65 | )
66 | image = torch.randn(2, 6, 12, 256, 256)
67 | with torch.no_grad():
68 | out = encoder(image)
69 | assert not torch.isnan(out).any(), "Output included NaNs"
70 | assert out.size() == (2, 3, 12, 64, 64)
71 |
72 |
73 | def test_metnet_video_encoder():
74 | encoder = ImageEncoder(prep_type="metnet", crop_size=128)
75 | image = torch.randn(2, 6, 12, 512, 512)
76 | with torch.no_grad():
77 | out = encoder(image)
78 | assert not torch.isnan(out).any(), "Output included NaNs"
79 | assert out.size() == (2, 6, 96, 128, 128)
80 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from perceiver_pytorch.multi_perceiver_pytorch import MultiPerceiver
4 | from perceiver_pytorch.modalities import InputModality
5 | from perceiver_pytorch.decoders import ImageDecoder
6 |
7 |
8 | def test_multiperceiver_creation():
9 | # Timeseries input
10 | input_size = 64
11 | max_frequency = 16.0
12 | video_modality = InputModality(
13 | name="timeseries",
14 | input_channels=12,
15 | input_axis=3, # number of axes, 3 for video
16 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1)
17 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is, should be Nyquist frequency (i.e. 112 for 224 input image)
18 | sin_only=False, # Whether if sine only for Fourier encoding, TODO test more
19 | fourier_encode=True, # Whether to encode position with Fourier features
20 | )
21 | # Use image modality for latlon, elevation, other base data?
22 | image_modality = InputModality(
23 | name="base",
24 | input_channels=4,
25 | input_axis=2, # number of axes, 2 for images
26 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1)
27 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is
28 | sin_only=False,
29 | fourier_encode=True,
30 | )
31 | # Sort audio for timestep one-hot encode? Or include under other modality?
32 | timestep_modality = InputModality(
33 | name="forecast_time",
34 | input_channels=1, # number of channels for mono audio
35 | input_axis=1, # number of axes, 2 for images
36 | num_freq_bands=24, # number of freq bands, with original value (2 * K + 1)
37 | max_freq=16.0, # maximum frequency, hyperparameter depending on how fine the data is
38 | sin_only=False,
39 | fourier_encode=True,
40 | )
41 | model = MultiPerceiver(
42 | modalities=[video_modality, image_modality, timestep_modality],
43 | queries_dim=input_size,
44 | depth=6,
45 | forecast_steps=12,
46 | output_shape=input_size,
47 | )
48 | x = {
49 | "timeseries": torch.randn((2, 6, input_size, input_size, 12)),
50 | "base": torch.randn((2, input_size, input_size, 4)),
51 | "forecast_time": torch.randn(2, 24, 1),
52 | }
53 | query = torch.randn((2, input_size * 12, input_size))
54 | model.eval()
55 | with torch.no_grad():
56 | out = model(x, queries=query)
57 | out = rearrange(
58 | out, "b h (w c) -> b c h w", c=12
59 | )
60 | # MetNet creates predictions for the center 1/4th
61 | assert out.size() == (
62 | 2,
63 | 12,
64 | 12 * input_size,
65 | input_size,
66 | )
67 | assert not torch.isnan(out).any(), "Output included NaNs"
68 |
69 |
70 | def test_multiperceiver_decoder():
71 | # Timeseries input
72 | input_size = 64
73 | max_frequency = 16.0
74 | video_modality = InputModality(
75 | name="timeseries",
76 | input_channels=12,
77 | input_axis=3, # number of axes, 3 for video
78 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1)
79 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is, should be Nyquist frequency (i.e. 112 for 224 input image)
80 | sin_only=False, # Whether if sine only for Fourier encoding, TODO test more
81 | fourier_encode=True, # Whether to encode position with Fourier features
82 | )
83 | # Use image modality for latlon, elevation, other base data?
84 | image_modality = InputModality(
85 | name="base",
86 | input_channels=4,
87 | input_axis=2, # number of axes, 2 for images
88 | num_freq_bands=input_size, # number of freq bands, with original value (2 * K + 1)
89 | max_freq=max_frequency, # maximum frequency, hyperparameter depending on how fine the data is
90 | sin_only=False,
91 | fourier_encode=True,
92 | )
93 | # Sort audio for timestep one-hot encode? Or include under other modality?
94 | timestep_modality = InputModality(
95 | name="forecast_time",
96 | input_channels=1, # number of channels for mono audio
97 | input_axis=1, # number of axes, 2 for images
98 | num_freq_bands=24, # number of freq bands, with original value (2 * K + 1)
99 | max_freq=16.0, # maximum frequency, hyperparameter depending on how fine the data is
100 | sin_only=False,
101 | fourier_encode=True,
102 | )
103 | model = MultiPerceiver(
104 | modalities=[video_modality, image_modality, timestep_modality],
105 | queries_dim=input_size,
106 | depth=6,
107 | forecast_steps=12,
108 | output_shape=(24,input_size,input_size),
109 | )
110 |
111 | x = {
112 | "timeseries": torch.randn((2, 6, input_size, input_size, 12)),
113 | "base": torch.randn((2, input_size, input_size, 4)),
114 | "forecast_time": torch.randn(2, 24, 1),
115 | }
116 | query = torch.randn((2, input_size * 12, input_size))
117 | model.eval()
118 | decoder = ImageDecoder(postprocess_type='conv1x1', input_channels=768, output_channels=12, spatial_upsample=1, temporal_upsample=1)
119 | decoder.eval()
120 | with torch.no_grad():
121 | out = model(x, queries=query)
122 | out = rearrange(
123 | out, "b c (t w h) -> b t c h w", t=24, h=input_size, w=input_size
124 | )
125 | out = decoder(out)
126 | # MetNet creates predictions for the center 1/4th
127 | assert out.size() == (
128 | 2,
129 | 24,
130 | 12,
131 | input_size,
132 | input_size,
133 | )
134 | assert not torch.isnan(out).any(), "Output included NaNs"
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/tests/test_perceiver_pytorch.py:
--------------------------------------------------------------------------------
1 | from perceiver_pytorch.perceiver_pytorch import Perceiver
2 | import torch
3 |
4 |
5 | def test_init_model():
6 |
7 | _ = Perceiver(
8 | input_channels=16,
9 | input_axis=2,
10 | num_freq_bands=6,
11 | max_freq=10,
12 | depth=13,
13 | num_latents=16,
14 | latent_dim=17,
15 | num_classes=7,
16 | weight_tie_layers=True,
17 | fourier_encode_data=False,
18 | )
19 |
20 |
21 | def test_model_forward():
22 |
23 | model = Perceiver(
24 | input_channels=16,
25 | input_axis=2,
26 | num_freq_bands=6,
27 | max_freq=10,
28 | depth=13,
29 | num_latents=16,
30 | latent_dim=17,
31 | num_classes=7,
32 | weight_tie_layers=True,
33 | fourier_encode_data=False,
34 | )
35 |
36 | x = torch.randn(8 * 13, 32, 32, 16)
37 | y = model(x)
38 |
39 |
40 | def test_model_forward_fourier():
41 |
42 | model = Perceiver(
43 | input_channels=16,
44 | input_axis=2,
45 | num_freq_bands=6,
46 | max_freq=10,
47 | depth=13,
48 | num_latents=16,
49 | latent_dim=17,
50 | num_classes=7,
51 | weight_tie_layers=True,
52 | fourier_encode_data=True,
53 | )
54 |
55 | x = torch.randn(8 * 13, 32, 32, 16)
56 | y = model(x)
57 |
--------------------------------------------------------------------------------
/tests/test_queries.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from perceiver_pytorch.queries import LearnableQuery
4 | from perceiver_pytorch.perceiver_io import PerceiverIO
5 | from perceiver_pytorch.utils import encode_position
6 | import einops
7 |
8 |
9 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
10 | def test_learnable_query(layer_shape):
11 | query_creator = LearnableQuery(
12 | channel_dim=32,
13 | query_shape=(6, 16, 16),
14 | conv_layer=layer_shape,
15 | max_frequency=64.0,
16 | num_frequency_bands=128,
17 | sine_only=False,
18 | generate_fourier_features=True,
19 | )
20 | x = torch.randn((4, 6, 12, 16, 16))
21 | out = query_creator(x)
22 | # Output is flattened, so should be [B, T*H*W, C]
23 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1)
24 | # 32 + 3*(257) = 771 + 32 = 803
25 | assert out.shape == (4, 16 * 16 * 6, 803)
26 |
27 |
28 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
29 | def test_learnable_query_no_fourier(layer_shape):
30 | query_creator = LearnableQuery(
31 | channel_dim=32,
32 | query_shape=(6, 16, 16),
33 | conv_layer=layer_shape,
34 | max_frequency=64.0,
35 | num_frequency_bands=128,
36 | sine_only=False,
37 | generate_fourier_features=False,
38 | )
39 | x = torch.randn((4, 6, 12, 16, 16))
40 | out = query_creator(x)
41 | assert out.shape == (4, 16 * 16 * 6, 32)
42 |
43 |
44 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
45 | def test_learnable_query_qpplication(layer_shape):
46 | output_shape = (6, 16, 16)
47 | query_creator = LearnableQuery(
48 | channel_dim=32,
49 | query_shape=output_shape,
50 | conv_layer=layer_shape,
51 | max_frequency=64.0,
52 | num_frequency_bands=32,
53 | sine_only=False,
54 | generate_fourier_features=True,
55 | )
56 | with torch.no_grad():
57 | query_creator.eval()
58 | x = torch.randn((2, 6, 12, 16, 16))
59 | out = query_creator(x)
60 |
61 | model = PerceiverIO(depth=2, dim=100, queries_dim=query_creator.output_shape()[-1])
62 | model.eval()
63 | model_input = torch.randn((2, 256, 100))
64 | model_out = model(model_input, queries=out)
65 | # Reshape back to correct shape
66 | model_out = einops.rearrange(
67 | model_out,
68 | "b (t h w) c -> b t c h w",
69 | t=output_shape[0],
70 | h=output_shape[1],
71 | w=output_shape[2],
72 | )
73 | assert model_out.shape == (2, 6, 227, 16, 16)
74 |
75 |
76 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
77 | def test_learnable_query_precomputed_fourier_only(layer_shape):
78 | precomputed_features = encode_position(
79 | 1, # Batch size, 1 for this as it will be adapted in forward
80 | axis=(10, 16, 16), # 4 history + 6 future steps
81 | max_frequency=16.0,
82 | num_frequency_bands=128,
83 | sine_only=False,
84 | )
85 | # Only take future ones
86 | precomputed_features = precomputed_features[:, 4:]
87 | query_creator = LearnableQuery(
88 | channel_dim=32,
89 | query_shape=(6, 16, 16),
90 | conv_layer=layer_shape,
91 | max_frequency=64.0,
92 | num_frequency_bands=16,
93 | sine_only=False,
94 | precomputed_fourier=precomputed_features,
95 | generate_fourier_features=False,
96 | )
97 | x = torch.randn((4, 6, 12, 16, 16))
98 | out = query_creator(x)
99 | # Output is flattened, so should be [B, T*H*W, C]
100 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1)
101 | # 32 + 3*(257) = 771 + 32 = 803
102 | assert out.shape == (4, 16 * 16 * 6, 803)
103 |
104 |
105 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
106 | def test_learnable_query_precomputed_and_generated_fourer(layer_shape):
107 | precomputed_features = encode_position(
108 | 1, # Batch size, 1 for this as it will be adapted in forward
109 | axis=(10, 16, 16), # 4 history + 6 future steps
110 | max_frequency=16.0,
111 | num_frequency_bands=128,
112 | sine_only=False,
113 | )
114 | # Only take future ones
115 | precomputed_features = precomputed_features[:, 4:]
116 | query_creator = LearnableQuery(
117 | channel_dim=32,
118 | query_shape=(6, 16, 16),
119 | conv_layer=layer_shape,
120 | max_frequency=64.0,
121 | num_frequency_bands=128,
122 | sine_only=False,
123 | precomputed_fourier=precomputed_features,
124 | generate_fourier_features=True,
125 | )
126 | x = torch.randn((4, 6, 12, 16, 16))
127 | out = query_creator(x)
128 | # Output is flattened, so should be [B, T*H*W, C]
129 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1)
130 | # 32 + 3*(257) = 771 + 32 = 803
131 | # Then add 771 from the precomputed features, to get 803 + 771
132 | assert out.shape == (4, 16 * 16 * 6, 803 + 771)
133 |
134 |
135 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
136 | def test_learnable_query_pass_in_fourier(layer_shape):
137 | precomputed_features = encode_position(
138 | 4,
139 | axis=(10, 16, 16), # 4 history + 6 future steps
140 | max_frequency=16.0,
141 | num_frequency_bands=64,
142 | sine_only=False,
143 | )
144 | # Only take future ones
145 | precomputed_features = precomputed_features[:, 4:]
146 | query_creator = LearnableQuery(
147 | channel_dim=32,
148 | query_shape=(6, 16, 16),
149 | conv_layer=layer_shape,
150 | max_frequency=64.0,
151 | num_frequency_bands=128,
152 | sine_only=False,
153 | generate_fourier_features=False,
154 | )
155 | x = torch.randn((4, 6, 12, 16, 16))
156 | out = query_creator(x, precomputed_features)
157 | # Output is flattened, so should be [B, T*H*W, C]
158 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1)
159 | # 3*(129) = 389 + 32 = 419
160 | # Since this is less than what is passed to LearnableQuery, we know its using the passed in features
161 | assert out.shape == (4, 16 * 16 * 6, 419)
162 |
163 |
164 | @pytest.mark.parametrize("layer_shape", ["2d", "3d"])
165 | def test_learnable_query_all_fouriers(layer_shape):
166 | batch_ff = encode_position(
167 | 4,
168 | axis=(10, 16, 16), # 4 history + 6 future steps
169 | max_frequency=16.0,
170 | num_frequency_bands=32,
171 | sine_only=False,
172 | )
173 | # Only take future ones
174 | batch_ff = batch_ff[:, 4:]
175 | precomputed_features = encode_position(
176 | 1,
177 | axis=(10, 16, 16), # 4 history + 6 future steps
178 | max_frequency=16.0,
179 | num_frequency_bands=64,
180 | sine_only=False,
181 | )
182 | # Only take future ones
183 | precomputed_features = precomputed_features[:, 4:]
184 | query_creator = LearnableQuery(
185 | channel_dim=32,
186 | query_shape=(6, 16, 16),
187 | conv_layer=layer_shape,
188 | max_frequency=64.0,
189 | num_frequency_bands=128,
190 | sine_only=False,
191 | precomputed_fourier=precomputed_features,
192 | generate_fourier_features=True,
193 | )
194 | x = torch.randn((4, 6, 12, 16, 16))
195 | out = query_creator(x, batch_ff)
196 | # Output is flattened, so should be [B, T*H*W, C]
197 | # Channels is from channel_dim + 3*(num_frequency_bands * 2 + 1)
198 | # 3*(129) = 389 + 32 = 419 + 771 from the generated ones + 195 from the batch features
199 | # Since this is less than what is passed to LearnableQuery, we know its using the passed in features
200 | assert out.shape == (4, 16 * 16 * 6, 1385)
201 |
--------------------------------------------------------------------------------
/tests/test_rotary.py:
--------------------------------------------------------------------------------
1 | from perceiver_pytorch.rotary import (
2 | rotate_every_two,
3 | apply_rotary_emb,
4 | SinusoidalEmbeddings,
5 | )
6 | import torch
7 |
8 |
9 | def test_rotate_every_two():
10 | """
11 | Test for rotate every two
12 | :return:
13 | """
14 |
15 | x = torch.randn(5, 4, 4)
16 | y = rotate_every_two(x)
17 |
18 | assert y.shape == torch.Size([5, 4, 4])
19 | assert y[0, 0, 0] == -x[0, 0, 1]
20 | assert y[0, 0, 1] == x[0, 0, 0]
21 |
22 |
23 | def test_apply_rotary_emb():
24 | """
25 | Check that 'apply_rotary_emb' works correctly
26 | :return:
27 | """
28 |
29 | sinu_pos = torch.randn(1, 4, 10)
30 | q = torch.randn(5, 4, 10)
31 | k = torch.randn(5, 4, 10)
32 |
33 | q, k = apply_rotary_emb(q, k, sinu_pos=sinu_pos)
34 |
35 |
36 | def test_torch_sinusoidal_mbeddings():
37 | model = SinusoidalEmbeddings(dim=128)
38 |
39 | y = model(torch.randn(4, 10))
40 | assert y.shape[-1] == 128
41 | assert y.shape[-2] == 4
42 |
--------------------------------------------------------------------------------