├── .github └── workflows │ ├── main.yml │ └── upload-to-pip.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── example.png ├── positional_encodings ├── __init__.py ├── tf_encodings.py └── torch_encodings.py ├── requirements.txt ├── setup.py ├── svgs └── cov.svg └── test_suite.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Main CI workflow 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | branches: 8 | - master 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: 3.12 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | pip install pytest-cov coverage-badge black isort 24 | - name: Test with pytest 25 | run: pytest -vv --cov=positional_encodings/ 26 | - name: Run coverage report 27 | run: coverage-badge -o svgs/cov.svg -f 28 | - name: Reformat code to black check 29 | run: black --check . 30 | - name: Sort all imports check 31 | run: isort --check . --skip=__init__.py 32 | -------------------------------------------------------------------------------- /.github/workflows/upload-to-pip.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PIP 2 | 3 | on: 4 | release: 5 | types: [created] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | upload: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.11 17 | 18 | - name: "Installs dependencies" 19 | run: | 20 | python -m pip install --upgrade pip 21 | python -m pip install setuptools wheel twine 22 | - name: "Builds and uploads to PyPI" 23 | run: | 24 | python setup.py sdist bdist_wheel 25 | python -m twine upload dist/* 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | positional_encodings.egg-info/ 4 | dist/ 5 | 6 | .coverage 7 | Pipfile 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - repo: https://github.com/PyCQA/isort 10 | rev: 5.10.1 11 | hooks: 12 | - id: isort 13 | exclude: __init__.py 14 | - repo: https://github.com/psf/black 15 | rev: 19.3b0 16 | hooks: 17 | - id: black 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Peter Tatkowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 1D, 2D, and 3D Sinusoidal Postional Encoding (Pytorch and Tensorflow) 2 | 3 | ![Code Coverage](./svgs/cov.svg) 4 | [![PyPI version](https://badge.fury.io/py/positional-encodings.svg)](https://badge.fury.io/py/positional-encodings) 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 7 | 8 | ![A 2D Example](./example.png) 9 | 10 | This is a practical, easy to download implemenation of 1D, 2D, and 3D 11 | sinusodial positional encodings for PyTorch and Tensorflow. 12 | 13 | It is able to encode on tensors of the form `(batchsize, x, ch)`, `(batchsize, 14 | x, y, ch)`, and `(batchsize, x, y, z, ch)`, where the positional encodings will 15 | be calculated along the `ch` dimension. The [Attention is All You 16 | Need](https://arxiv.org/pdf/1706.03762.pdf) allowed for positional encoding in 17 | only one dimension, however, this works to extend this to 2 and 3 dimensions. 18 | 19 | This also works on tensors of the form `(batchsize, ch, x)`, etc. See the usage for more information. 20 | 21 | **NOTE**: The import syntax has changed as of version `6.0.1`. See the section for details. 22 | 23 | To install, simply run: 24 | 25 | ``` 26 | pip install positional-encodings[pytorch,tensorflow] 27 | ``` 28 | 29 | You can also install the pytorch and tf encodings individually with the following 30 | commands. 31 | 32 | * For a PyTorch only installation, run `pip install positional-encodings[pytorch]` 33 | * For a TensorFlow only installation, run `pip install positional-encodings[tensorflow]` 34 | 35 | ## Usage: 36 | 37 | ### Pytorch 38 | 39 | The repo comes with the three main positional encoding models, 40 | `PositionalEncoding{1,2,3}D`. In addition, there are a `Summer` class that adds 41 | the input tensor to the positional encodings. 42 | 43 | ```python3 44 | import torch 45 | from positional_encodings.torch_encodings import PositionalEncoding1D, PositionalEncoding2D, PositionalEncoding3D, Summer 46 | 47 | # Returns the position encoding only 48 | p_enc_1d_model = PositionalEncoding1D(10) 49 | 50 | # Return the inputs with the position encoding added 51 | p_enc_1d_model_sum = Summer(PositionalEncoding1D(10)) 52 | 53 | x = torch.rand(1,6,10) 54 | penc_no_sum = p_enc_1d_model(x) # penc_no_sum.shape == (1, 6, 10) 55 | penc_sum = p_enc_1d_model_sum(x) 56 | print(penc_no_sum + x == penc_sum) # True 57 | ``` 58 | 59 | ```python3 60 | p_enc_2d = PositionalEncoding2D(8) 61 | y = torch.zeros((1,6,2,8)) 62 | print(p_enc_2d(y).shape) # (1, 6, 2, 8) 63 | 64 | p_enc_3d = PositionalEncoding3D(11) 65 | z = torch.zeros((1,5,6,4,11)) 66 | print(p_enc_3d(z).shape) # (1, 5, 6, 4, 11) 67 | ``` 68 | 69 | And for tensors of the form `(batchsize, ch, x)` or their 2D and 3D 70 | counterparts, include the word `Permute` before the number in the class; e.g. 71 | for a 1D input of size `(batchsize, ch, x)`, do `PositionalEncodingPermute1D` 72 | instead of `PositionalEncoding1D`. 73 | 74 | 75 | ```python3 76 | import torch 77 | from positional_encodings.torch_encodings import PositionalEncodingPermute3D 78 | 79 | p_enc_3d = PositionalEncodingPermute3D(11) 80 | z = torch.zeros((1,11,5,6,4)) 81 | print(p_enc_3d(z).shape) # (1, 11, 5, 6, 4) 82 | ``` 83 | 84 | Note to override the output dtype you can specify it when creating the encoding: 85 | 86 | ```python3 87 | p_enc_3d = PositionalEncodingPermute3D(11, dtype_override=torch.float64) 88 | ``` 89 | 90 | This is particularly useful when the input tensor is of an `int` type since the 91 | output will always be zero (see issue #39). 92 | 93 | ### Tensorflow Keras 94 | 95 | This also supports Tensorflow. Simply prepend all class names with `TF`. 96 | 97 | ```python3 98 | import tensorflow as tf 99 | from positional_encodings.tf_encodings import TFPositionalEncoding2D, TFSummer 100 | 101 | # Returns the position encoding only 102 | p_enc_2d = TFPositionalEncoding2D(170) 103 | y = tf.zeros((1,8,6,2)) 104 | print(p_enc_2d(y).shape) # (1, 8, 6, 2) 105 | 106 | # Return the inputs with the position encoding added 107 | add_p_enc_2d = TFSummer(TFPositionalEncoding2D(170)) 108 | y = tf.ones((1,8,6,2)) 109 | print(add_p_enc_2d(y) - p_enc_2d(y)) # tf.ones((1,8,6,2)) 110 | ``` 111 | 112 | ## Changes as of version `6.0.1` 113 | 114 | Before `6.0.1`, users had to install both the `tensorflow` and the 115 | `torch` packages, both of which are quite large. Now, one can install the 116 | packages individually, but now the code has to be changed: 117 | 118 | If using PyTorch: 119 | 120 | ``` 121 | from positional_encodings import * -> from positional_encodings.torch_encodings import * 122 | ``` 123 | 124 | If using TensorFlow: 125 | 126 | ``` 127 | from positional_encodings import * -> from positional_encodings.tf_encodings import * 128 | ``` 129 | 130 | ## Formulas 131 | 132 | The formula for inserting the positional encoding are as follows: 133 | 134 | 1D: 135 | ``` 136 | PE(x,2i) = sin(x/10000^(2i/D)) 137 | PE(x,2i+1) = cos(x/10000^(2i/D)) 138 | 139 | Where: 140 | x is a point in 2d space 141 | i is an integer in [0, D/2), where D is the size of the ch dimension 142 | ``` 143 | 144 | 2D: 145 | ``` 146 | PE(x,y,2i) = sin(x/10000^(4i/D)) 147 | PE(x,y,2i+1) = cos(x/10000^(4i/D)) 148 | PE(x,y,2j+D/2) = sin(y/10000^(4j/D)) 149 | PE(x,y,2j+1+D/2) = cos(y/10000^(4j/D)) 150 | 151 | Where: 152 | (x,y) is a point in 2d space 153 | i,j is an integer in [0, D/4), where D is the size of the ch dimension 154 | ``` 155 | 156 | 3D: 157 | ``` 158 | PE(x,y,z,2i) = sin(x/10000^(6i/D)) 159 | PE(x,y,z,2i+1) = cos(x/10000^(6i/D)) 160 | PE(x,y,z,2j+D/3) = sin(y/10000^(6j/D)) 161 | PE(x,y,z,2j+1+D/3) = cos(y/10000^(6j/D)) 162 | PE(x,y,z,2k+2D/3) = sin(z/10000^(6k/D)) 163 | PE(x,y,z,2k+1+2D/3) = cos(z/10000^(6k/D)) 164 | 165 | Where: 166 | (x,y,z) is a point in 3d space 167 | i,j,k is an integer in [0, D/6), where D is the size of the ch dimension 168 | ``` 169 | 170 | The 3D formula is just a natural extension of the 2D positional encoding used 171 | in [this](https://arxiv.org/pdf/1908.11415.pdf) paper. 172 | 173 | Don't worry if the input is not divisible by 2 (1D), 4 (2D), or 6 (3D); all the 174 | necessary padding will be taken care of. 175 | 176 | ## Thank you 177 | 178 | Thank you for [this](https://github.com/wzlxjtu/PositionalEncoding2D) repo for inspriration of this method. 179 | 180 | ## Citations 181 | 1D: 182 | ```bibtex 183 | @inproceedings{vaswani2017attention, 184 | title={Attention is all you need}, 185 | author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia}, 186 | booktitle={Advances in neural information processing systems}, 187 | pages={5998--6008}, 188 | year={2017} 189 | } 190 | ``` 191 | 192 | 2D: 193 | ```bibtex 194 | @misc{wang2019translating, 195 | title={Translating Math Formula Images to LaTeX Sequences Using Deep Neural Networks with Sequence-level Training}, 196 | author={Zelun Wang and Jyh-Charn Liu}, 197 | year={2019}, 198 | eprint={1908.11415}, 199 | archivePrefix={arXiv}, 200 | primaryClass={cs.LG} 201 | } 202 | ``` 203 | 204 | 3D: 205 | Coming soon! 206 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tatp22/multidim-positional-encoding/efeb8d9d70e8184da50eae9fddd1bbda10896529/example.png -------------------------------------------------------------------------------- /positional_encodings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tatp22/multidim-positional-encoding/efeb8d9d70e8184da50eae9fddd1bbda10896529/positional_encodings/__init__.py -------------------------------------------------------------------------------- /positional_encodings/tf_encodings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def get_emb(sin_inp): 6 | """ 7 | Gets a base embedding for one dimension with sin and cos intertwined 8 | """ 9 | emb = tf.stack((tf.sin(sin_inp), tf.cos(sin_inp)), -1) 10 | emb = tf.reshape(emb, (*emb.shape[:-2], -1)) 11 | return emb 12 | 13 | 14 | class TFPositionalEncoding1D(tf.keras.layers.Layer): 15 | def __init__(self, channels: int, dtype=tf.float32): 16 | """ 17 | Args: 18 | channels int: The last dimension of the tensor you want to apply pos emb to. 19 | 20 | Keyword Args: 21 | dtype: output type of the encodings. Default is "tf.float32". 22 | 23 | """ 24 | super(TFPositionalEncoding1D, self).__init__() 25 | 26 | self.channels = int(np.ceil(channels / 2) * 2) 27 | self.inv_freq = np.float32( 28 | 1 29 | / np.power( 30 | 10000, np.arange(0, self.channels, 2) / np.float32(self.channels) 31 | ) 32 | ) 33 | 34 | @tf.function 35 | def call(self, inputs): 36 | """ 37 | :param tensor: A 3d tensor of size (batch_size, x, ch) 38 | :return: Positional Encoding Matrix of size (batch_size, x, ch) 39 | """ 40 | if len(inputs.shape) != 3: 41 | raise RuntimeError("The input tensor has to be 3d!") 42 | 43 | _, x, org_channels = inputs.shape 44 | 45 | dtype = self.inv_freq.dtype 46 | pos_x = tf.range(x, dtype=dtype) 47 | sin_inp_x = tf.einsum("i,j->ij", pos_x, self.inv_freq) 48 | emb = tf.expand_dims(get_emb(sin_inp_x), 0) 49 | emb = emb[0] # A bit of a hack 50 | return tf.repeat(emb[None, :, :org_channels], tf.shape(inputs)[0], axis=0) 51 | 52 | 53 | class TFPositionalEncoding2D(tf.keras.layers.Layer): 54 | def __init__(self, channels: int, dtype=tf.float32): 55 | """ 56 | Args: 57 | channels int: The last dimension of the tensor you want to apply pos emb to. 58 | 59 | Keyword Args: 60 | dtype: output type of the encodings. Default is "tf.float32". 61 | 62 | """ 63 | super(TFPositionalEncoding2D, self).__init__() 64 | 65 | self.channels = int(2 * np.ceil(channels / 4)) 66 | self.inv_freq = np.float32( 67 | 1 68 | / np.power( 69 | 10000, np.arange(0, self.channels, 2) / np.float32(self.channels) 70 | ) 71 | ) 72 | 73 | @tf.function 74 | def call(self, inputs): 75 | """ 76 | :param tensor: A 4d tensor of size (batch_size, x, y, ch) 77 | :return: Positional Encoding Matrix of size (batch_size, x, y, ch) 78 | """ 79 | if len(inputs.shape) != 4: 80 | raise RuntimeError("The input tensor has to be 4d!") 81 | 82 | _, x, y, org_channels = inputs.shape 83 | 84 | dtype = self.inv_freq.dtype 85 | 86 | pos_x = tf.range(x, dtype=dtype) 87 | pos_y = tf.range(y, dtype=dtype) 88 | 89 | sin_inp_x = tf.einsum("i,j->ij", pos_x, self.inv_freq) 90 | sin_inp_y = tf.einsum("i,j->ij", pos_y, self.inv_freq) 91 | 92 | emb_x = tf.expand_dims(get_emb(sin_inp_x), 1) 93 | emb_y = tf.expand_dims(get_emb(sin_inp_y), 0) 94 | 95 | emb_x = tf.tile(emb_x, (1, y, 1)) 96 | emb_y = tf.tile(emb_y, (x, 1, 1)) 97 | emb = tf.concat((emb_x, emb_y), -1) 98 | return tf.repeat(emb[None, :, :, :org_channels], tf.shape(inputs)[0], axis=0) 99 | 100 | 101 | class TFPositionalEncoding3D(tf.keras.layers.Layer): 102 | def __init__(self, channels: int, dtype=tf.float32): 103 | """ 104 | Args: 105 | channels int: The last dimension of the tensor you want to apply pos emb to. 106 | 107 | Keyword Args: 108 | dtype: output type of the encodings. Default is "tf.float32". 109 | 110 | """ 111 | super(TFPositionalEncoding3D, self).__init__() 112 | 113 | channels = int(np.ceil(channels / 6) * 2) 114 | if channels % 2: 115 | channels += 1 116 | self.channels = channels 117 | self.inv_freq = np.float32( 118 | 1 119 | / np.power( 120 | 10000, np.arange(0, self.channels, 2) / np.float32(self.channels) 121 | ) 122 | ) 123 | 124 | @tf.function 125 | def call(self, inputs): 126 | """ 127 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) 128 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) 129 | """ 130 | if len(inputs.shape) != 5: 131 | raise RuntimeError("The input tensor has to be 5d!") 132 | 133 | _, x, y, z, org_channels = inputs.shape 134 | 135 | dtype = self.inv_freq.dtype 136 | 137 | pos_x = tf.range(x, dtype=dtype) 138 | pos_y = tf.range(y, dtype=dtype) 139 | pos_z = tf.range(z, dtype=dtype) 140 | 141 | sin_inp_x = tf.einsum("i,j->ij", pos_x, self.inv_freq) 142 | sin_inp_y = tf.einsum("i,j->ij", pos_y, self.inv_freq) 143 | sin_inp_z = tf.einsum("i,j->ij", pos_z, self.inv_freq) 144 | 145 | emb_x = tf.expand_dims(tf.expand_dims(get_emb(sin_inp_x), 1), 1) 146 | emb_y = tf.expand_dims(tf.expand_dims(get_emb(sin_inp_y), 1), 0) 147 | emb_z = tf.expand_dims(tf.expand_dims(get_emb(sin_inp_z), 0), 0) 148 | 149 | emb_x = tf.tile(emb_x, (1, y, z, 1)) 150 | emb_y = tf.tile(emb_y, (x, 1, z, 1)) 151 | emb_z = tf.tile(emb_z, (x, y, 1, 1)) 152 | 153 | emb = tf.concat((emb_x, emb_y, emb_z), -1) 154 | return tf.repeat(emb[None, :, :, :, :org_channels], tf.shape(inputs)[0], axis=0) 155 | 156 | 157 | class TFSummer(tf.keras.layers.Layer): 158 | def __init__(self, penc): 159 | """ 160 | :param model: The type of positional encoding to run the summer on. 161 | """ 162 | super(TFSummer, self).__init__() 163 | self.penc = penc 164 | 165 | @tf.function 166 | def call(self, tensor): 167 | """ 168 | :param tensor: A 3, 4 or 5d tensor that matches the model output size 169 | :return: Positional Encoding Matrix summed to the original tensor 170 | """ 171 | penc = self.penc(tensor) 172 | assert ( 173 | tensor.shape == penc.shape 174 | ), "The original tensor size {} and the positional encoding tensor size {} must match!".format( 175 | tensor.shape, penc.shape 176 | ) 177 | return tensor + penc 178 | -------------------------------------------------------------------------------- /positional_encodings/torch_encodings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def get_emb(sin_inp): 7 | """ 8 | Gets a base embedding for one dimension with sin and cos intertwined 9 | """ 10 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 11 | return torch.flatten(emb, -2, -1) 12 | 13 | 14 | class PositionalEncoding1D(nn.Module): 15 | def __init__(self, channels, dtype_override=None): 16 | """ 17 | :param channels: The last dimension of the tensor you want to apply pos emb to. 18 | :param dtype_override: If set, overrides the dtype of the output embedding. 19 | """ 20 | super(PositionalEncoding1D, self).__init__() 21 | self.org_channels = channels 22 | channels = int(np.ceil(channels / 2) * 2) 23 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 24 | self.register_buffer("inv_freq", inv_freq) 25 | self.register_buffer("cached_penc", None, persistent=False) 26 | self.channels = channels 27 | self.dtype_override = dtype_override 28 | 29 | def forward(self, tensor): 30 | """ 31 | :param tensor: A 3d tensor of size (batch_size, x, ch) 32 | :return: Positional Encoding Matrix of size (batch_size, x, ch) 33 | """ 34 | if len(tensor.shape) != 3: 35 | raise RuntimeError("The input tensor has to be 3d!") 36 | 37 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 38 | return self.cached_penc 39 | 40 | self.cached_penc = None 41 | batch_size, x, orig_ch = tensor.shape 42 | pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype) 43 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 44 | emb_x = get_emb(sin_inp_x) 45 | emb = torch.zeros( 46 | (x, self.channels), 47 | device=tensor.device, 48 | dtype=( 49 | self.dtype_override if self.dtype_override is not None else tensor.dtype 50 | ), 51 | ) 52 | emb[:, : self.channels] = emb_x 53 | 54 | self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1) 55 | return self.cached_penc 56 | 57 | 58 | class PositionalEncodingPermute1D(nn.Module): 59 | def __init__(self, channels, dtype_override=None): 60 | """ 61 | Accepts (batchsize, ch, x) instead of (batchsize, x, ch) 62 | """ 63 | super(PositionalEncodingPermute1D, self).__init__() 64 | self.penc = PositionalEncoding1D(channels, dtype_override) 65 | 66 | def forward(self, tensor): 67 | tensor = tensor.permute(0, 2, 1) 68 | enc = self.penc(tensor) 69 | return enc.permute(0, 2, 1) 70 | 71 | @property 72 | def org_channels(self): 73 | return self.penc.org_channels 74 | 75 | 76 | class PositionalEncoding2D(nn.Module): 77 | def __init__(self, channels, dtype_override=None): 78 | """ 79 | :param channels: The last dimension of the tensor you want to apply pos emb to. 80 | :param dtype_override: If set, overrides the dtype of the output embedding. 81 | """ 82 | super(PositionalEncoding2D, self).__init__() 83 | self.org_channels = channels 84 | channels = int(np.ceil(channels / 4) * 2) 85 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 86 | self.register_buffer("inv_freq", inv_freq) 87 | self.register_buffer("cached_penc", None, persistent=False) 88 | self.dtype_override = dtype_override 89 | self.channels = channels 90 | 91 | def forward(self, tensor): 92 | """ 93 | :param tensor: A 4d tensor of size (batch_size, x, y, ch) 94 | :return: Positional Encoding Matrix of size (batch_size, x, y, ch) 95 | """ 96 | if len(tensor.shape) != 4: 97 | raise RuntimeError("The input tensor has to be 4d!") 98 | 99 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 100 | return self.cached_penc 101 | 102 | self.cached_penc = None 103 | batch_size, x, y, orig_ch = tensor.shape 104 | pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype) 105 | pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype) 106 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 107 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 108 | emb_x = get_emb(sin_inp_x).unsqueeze(1) 109 | emb_y = get_emb(sin_inp_y) 110 | emb = torch.zeros( 111 | (x, y, self.channels * 2), 112 | device=tensor.device, 113 | dtype=( 114 | self.dtype_override if self.dtype_override is not None else tensor.dtype 115 | ), 116 | ) 117 | emb[:, :, : self.channels] = emb_x 118 | emb[:, :, self.channels : 2 * self.channels] = emb_y 119 | 120 | self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) 121 | return self.cached_penc 122 | 123 | 124 | class PositionalEncodingPermute2D(nn.Module): 125 | def __init__(self, channels, dtype_override=None): 126 | """ 127 | Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch) 128 | """ 129 | super(PositionalEncodingPermute2D, self).__init__() 130 | self.penc = PositionalEncoding2D(channels, dtype_override) 131 | 132 | def forward(self, tensor): 133 | tensor = tensor.permute(0, 2, 3, 1) 134 | enc = self.penc(tensor) 135 | return enc.permute(0, 3, 1, 2) 136 | 137 | @property 138 | def org_channels(self): 139 | return self.penc.org_channels 140 | 141 | 142 | class PositionalEncoding3D(nn.Module): 143 | def __init__(self, channels, dtype_override=None): 144 | """ 145 | :param channels: The last dimension of the tensor you want to apply pos emb to. 146 | :param dtype_override: If set, overrides the dtype of the output embedding. 147 | """ 148 | super(PositionalEncoding3D, self).__init__() 149 | self.org_channels = channels 150 | channels = int(np.ceil(channels / 6) * 2) 151 | if channels % 2: 152 | channels += 1 153 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 154 | self.register_buffer("inv_freq", inv_freq) 155 | self.register_buffer("cached_penc", None, persistent=False) 156 | self.dtype_override = dtype_override 157 | self.channels = channels 158 | 159 | def forward(self, tensor): 160 | """ 161 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) 162 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) 163 | """ 164 | if len(tensor.shape) != 5: 165 | raise RuntimeError("The input tensor has to be 5d!") 166 | 167 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 168 | return self.cached_penc 169 | 170 | self.cached_penc = None 171 | batch_size, x, y, z, orig_ch = tensor.shape 172 | pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype) 173 | pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype) 174 | pos_z = torch.arange(z, device=tensor.device, dtype=self.inv_freq.dtype) 175 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 176 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 177 | sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) 178 | emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1) 179 | emb_y = get_emb(sin_inp_y).unsqueeze(1) 180 | emb_z = get_emb(sin_inp_z) 181 | emb = torch.zeros( 182 | (x, y, z, self.channels * 3), 183 | device=tensor.device, 184 | dtype=( 185 | self.dtype_override if self.dtype_override is not None else tensor.dtype 186 | ), 187 | ) 188 | emb[:, :, :, : self.channels] = emb_x 189 | emb[:, :, :, self.channels : 2 * self.channels] = emb_y 190 | emb[:, :, :, 2 * self.channels :] = emb_z 191 | 192 | self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1) 193 | return self.cached_penc 194 | 195 | 196 | class PositionalEncodingPermute3D(nn.Module): 197 | def __init__(self, channels, dtype_override=None): 198 | """ 199 | Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) 200 | """ 201 | super(PositionalEncodingPermute3D, self).__init__() 202 | self.penc = PositionalEncoding3D(channels, dtype_override) 203 | 204 | def forward(self, tensor): 205 | tensor = tensor.permute(0, 2, 3, 4, 1) 206 | enc = self.penc(tensor) 207 | return enc.permute(0, 4, 1, 2, 3) 208 | 209 | @property 210 | def org_channels(self): 211 | return self.penc.org_channels 212 | 213 | 214 | class Summer(nn.Module): 215 | def __init__(self, penc): 216 | """ 217 | :param model: The type of positional encoding to run the summer on. 218 | """ 219 | super(Summer, self).__init__() 220 | self.penc = penc 221 | 222 | def forward(self, tensor): 223 | """ 224 | :param tensor: A 3, 4 or 5d tensor that matches the model output size 225 | :return: Positional Encoding Matrix summed to the original tensor 226 | """ 227 | penc = self.penc(tensor) 228 | assert ( 229 | tensor.size() == penc.size() 230 | ), "The original tensor size {} and the positional encoding tensor size {} must match!".format( 231 | tensor.size(), penc.size() 232 | ) 233 | return tensor + penc.to(tensor.device) 234 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | tensorflow==2.17.0 3 | torch==2.5.0 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="positional_encodings", 8 | version="6.0.4", 9 | author="Peter Tatkowski", 10 | author_email="tatp22@gmail.com", 11 | description="1D, 2D, and 3D Sinusodal Positional Encodings in PyTorch", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/tatp22/multidim-positional-encoding", 15 | packages=setuptools.find_packages(), 16 | keywords=["transformers", "attention"], 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires=">=3.12", 23 | install_requires=["numpy"], 24 | extras_require={ 25 | "pytorch": ["torch"], 26 | "tensorflow": ["tensorflow"], 27 | }, 28 | ) 29 | -------------------------------------------------------------------------------- /svgs/cov.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 97% 19 | 97% 20 | 21 | 22 | -------------------------------------------------------------------------------- /test_suite.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import torch 6 | 7 | from positional_encodings.tf_encodings import * 8 | from positional_encodings.torch_encodings import * 9 | 10 | tf.config.experimental_run_functions_eagerly(True) 11 | 12 | 13 | def test_torch_1d_correct_shape(): 14 | p_enc_1d = PositionalEncoding1D(10) 15 | x = torch.zeros((1, 6, 10)) 16 | assert p_enc_1d(x).shape == (1, 6, 10) 17 | 18 | p_enc_1d = PositionalEncodingPermute1D(10) 19 | x = torch.zeros((1, 10, 6)) 20 | assert p_enc_1d(x).shape == (1, 10, 6) 21 | 22 | 23 | def test_torch_2d_correct_shape(): 24 | p_enc_2d = PositionalEncoding2D(170) 25 | y = torch.zeros((1, 1, 1024, 170)) 26 | assert p_enc_2d(y).shape == (1, 1, 1024, 170) 27 | 28 | p_enc_2d = PositionalEncodingPermute2D(169) 29 | y = torch.zeros((1, 169, 1, 1024)) 30 | assert p_enc_2d(y).shape == (1, 169, 1, 1024) 31 | 32 | 33 | def test_torch_3d_correct_shape(): 34 | p_enc_3d = PositionalEncoding3D(125) 35 | z = torch.zeros((3, 5, 6, 4, 125)) 36 | assert p_enc_3d(z).shape == (3, 5, 6, 4, 125) 37 | 38 | p_enc_3d = PositionalEncodingPermute3D(11) 39 | z = torch.zeros((7, 11, 5, 6, 4)) 40 | assert p_enc_3d(z).shape == (7, 11, 5, 6, 4) 41 | 42 | 43 | def test_tf_1d_correct_shape(): 44 | p_enc_1d = TFPositionalEncoding1D(170) 45 | x = tf.zeros((1, 1024, 170)) 46 | assert p_enc_1d(x).shape == (1, 1024, 170) 47 | 48 | 49 | def test_tf_2d_correct_shape(): 50 | p_enc_2d = TFPositionalEncoding2D(170) 51 | y = tf.zeros((1, 1, 1024, 170)) 52 | assert p_enc_2d(y).shape == (1, 1, 1024, 170) 53 | 54 | 55 | def test_tf_3d_correct_shape(): 56 | p_enc_3d = TFPositionalEncoding3D(170) 57 | z = tf.zeros((1, 4, 1, 1024, 170)) 58 | assert p_enc_3d(z).shape == (1, 4, 1, 1024, 170) 59 | 60 | 61 | def test_torch_tf_1d_same(): 62 | tf_enc_1d = TFPositionalEncoding1D(123) 63 | pt_enc_1d = PositionalEncoding1D(123) 64 | 65 | sample = np.random.randn(2, 15, 123) 66 | 67 | tf_out = tf_enc_1d(sample) 68 | pt_out = pt_enc_1d(torch.tensor(sample)) 69 | 70 | # There is some rounding discrepancy 71 | assert np.sum(np.abs(tf_out.numpy() - pt_out.numpy()) > 0.0001) == 0 72 | 73 | 74 | def test_torch_tf_2d_same(): 75 | tf_enc_2d = TFPositionalEncoding2D(123) 76 | pt_enc_2d = PositionalEncoding2D(123) 77 | 78 | sample = np.random.randn(2, 123, 321, 170) 79 | 80 | tf_out = tf_enc_2d(sample) 81 | pt_out = pt_enc_2d(torch.tensor(sample)) 82 | 83 | # There is some rounding discrepancy 84 | assert np.sum(np.abs(tf_out.numpy() - pt_out.numpy()) > 0.0001) == 0 85 | 86 | 87 | def test_torch_tf_3d_same(): 88 | tf_enc_3d = TFPositionalEncoding3D(123) 89 | pt_enc_3d = PositionalEncoding3D(123) 90 | 91 | sample = np.random.randn(2, 123, 24, 21, 10) 92 | 93 | tf_out = tf_enc_3d(sample) 94 | pt_out = pt_enc_3d(torch.tensor(sample)) 95 | 96 | # There is some rounding discrepancy 97 | assert np.sum(np.abs(tf_out.numpy() - pt_out.numpy()) > 0.0001) == 0 98 | 99 | 100 | def test_torch_summer(): 101 | model_with_sum = Summer(PositionalEncoding2D(125)) 102 | model_wo_sum = PositionalEncoding2D(125) 103 | z = torch.rand(3, 5, 6, 125) 104 | assert ( 105 | np.sum(np.abs((model_wo_sum(z) + z).numpy() - model_with_sum(z).numpy())) 106 | < 0.0001 107 | ), "The summer is not working properly!" 108 | 109 | 110 | def test_torch_1D_cache(): 111 | p_enc_1d = PositionalEncoding1D(10) 112 | x = torch.zeros((1, 6, 10)) 113 | y = torch.zeros((1, 7, 10)) 114 | 115 | assert not p_enc_1d.cached_penc 116 | assert p_enc_1d(x).shape == (1, 6, 10) 117 | assert p_enc_1d.cached_penc.shape == (1, 6, 10) 118 | 119 | assert p_enc_1d(y).shape == (1, 7, 10) 120 | assert p_enc_1d.cached_penc.shape == (1, 7, 10) 121 | 122 | 123 | def test_tf_summer(): 124 | model_with_sum = TFSummer(TFPositionalEncoding2D(125)) 125 | model_wo_sum = TFPositionalEncoding2D(125) 126 | z = tf.random.uniform(shape=(3, 5, 6, 125), name="input_tensor") 127 | assert ( 128 | np.sum(np.abs((model_wo_sum(z) + z).numpy() - model_with_sum(z).numpy())) 129 | < 0.0001 130 | ), "The tf summer is not working properly!" 131 | 132 | 133 | def test_torch_1d_dtype_override(): 134 | penc = PositionalEncoding1D(10, dtype_override=torch.float32) 135 | penc_permute = PositionalEncodingPermute1D(6, dtype_override=torch.float64) 136 | x = torch.zeros((1, 6, 10), dtype=torch.int64) 137 | 138 | assert penc(x).dtype == torch.float32 139 | assert penc_permute(x).dtype == torch.float64 140 | 141 | 142 | def test_torch_2d_dtype_override(): 143 | penc = PositionalEncoding2D(10, dtype_override=torch.float32) 144 | penc_permute = PositionalEncodingPermute2D(6, dtype_override=torch.float64) 145 | x = torch.zeros((1, 6, 6, 10), dtype=torch.int64) 146 | 147 | assert penc(x).dtype == torch.float32 148 | assert penc_permute(x).dtype == torch.float64 149 | 150 | 151 | def test_torch_3d_dtype_override(): 152 | penc = PositionalEncoding3D(10, dtype_override=torch.float32) 153 | penc_permute = PositionalEncodingPermute3D(6, dtype_override=torch.float64) 154 | x = torch.zeros((1, 6, 6, 6, 10), dtype=torch.int64) 155 | 156 | assert penc(x).dtype == torch.float32 157 | assert penc_permute(x).dtype == torch.float64 158 | --------------------------------------------------------------------------------