├── .gitignore ├── .idea ├── inspectionProfiles │ └── profiles_settings.xml └── workspace.xml ├── LICENSE ├── README.md ├── rotary_embedding_tensorflow ├── __init__.py └── rotary_embedding_tensorflow.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 13 | 14 | 15 | 16 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 1630091666021 29 | 33 | 34 | 35 | 36 | 45 | 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Arya Aftab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Rotary Embeddings - Tensorflow 2 | 3 | A standalone library for adding rotary embeddings to transformers in Tesnorflow, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs. 4 | 5 | My gut also tells me there is something more to rotations that can be exploited in artificial neural networks. 6 | 7 | ## Note 8 | An implemented version of Pytorch is available in this repository. 9 | 10 | This version is written by converting to the version of Pytorch. 11 | 12 | The three functions of rearrange, irearrange and repeat have been written due to the incompatibility of the einops library with tensorflow 2.x. 13 | ## Install 14 | 15 | ```bash 16 | $ pip install rotary-embedding-tensorflow 17 | ``` 18 | 19 | ## Usage 20 | 21 | ```python 22 | import tensorflow as tf 23 | from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding 24 | 25 | # instantiate the positional embedding in your transformer and pass to all your attention layers 26 | 27 | pos_emb = RotaryEmbedding(dim = 32) 28 | 29 | # generate the rotations 30 | 31 | freqs = pos_emb(tf.range(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute 32 | 33 | # mock queries and keys 34 | 35 | q = tf.random.normal((1, 1024, 64)) # queries - (batch, seq len, dimension of head) 36 | k = tf.random.normal((1, 1024, 64)) # keys 37 | 38 | # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention) 39 | 40 | freqs = freqs[None, ...] # expand dimension for batch dimension 41 | q = apply_rotary_emb(freqs, q) 42 | k = apply_rotary_emb(freqs, k) 43 | 44 | # then do your attention with your queries (q) and keys (k) 45 | ``` 46 | 47 | If you do all the steps above correctly, you should see a dramatic improvement during training 48 | 49 | ## Axial Rotary Embeddings 50 | 51 | For easy use of 2d axial relative positional embedding, ie. vision transformers 52 | 53 | ```python 54 | import tensorflow as tf 55 | from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding, broadcat 56 | 57 | pos_emb = RotaryEmbedding( 58 | dim = 32, 59 | freqs_for = 'pixel' 60 | ) 61 | 62 | # queries and keys for frequencies to be rotated into 63 | 64 | q = tf.random.normal((1, 256, 256, 64)) 65 | k = tf.random.normal((1, 256, 256, 64)) 66 | 67 | # get frequencies for each axial 68 | # -1 to 1 has been shown to be a good choice for images and audio 69 | 70 | freqs_h = pos_emb(tf.linspace(-1, 1, num = 256), cache_key = 256) 71 | freqs_w = pos_emb(tf.linspace(-1, 1, num = 256), cache_key = 256) 72 | 73 | # concat the frequencies along each axial 74 | # broadcat function makes this easy without a bunch of expands 75 | 76 | freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1) 77 | 78 | # rotate in frequencies 79 | 80 | q = apply_rotary_emb(freqs, q) 81 | k = apply_rotary_emb(freqs, k) 82 | ``` 83 | 84 | ## Learned Rotations 85 | 86 | For injecting learned rotations into a network. Experiments pending 87 | 88 | Update: doesn't seem to do anything -_-, will keep trying... 89 | 90 | ```python 91 | import tensorflow as tf 92 | from tensorflow.keras import layers 93 | from rotary_embedding_tensorflow import apply_learned_rotations 94 | 95 | x = tf.random.normal((1, 1024, 512)) 96 | 97 | # you can only rotate in (dim // 2) values 98 | # ex. for 512, you can only rotate in 256 values 99 | 100 | # say you have two sets of learned rotations of 128 values each 101 | 102 | rots1 = layers.Dense(128)(x) 103 | rots2 = layers.Dense(128)(x) 104 | 105 | # you rotate in 256 (128 x 2) at first 106 | 107 | x = apply_learned_rotations(rots1, x, start_index = 0) 108 | 109 | # then you start at index 256 and rotate in the last (128 x 2) 110 | 111 | x = apply_learned_rotations(rots2, x, start_index = 256) 112 | 113 | # you could also concat the rotations together and pass it in all at once 114 | 115 | rots = tf.concat((rots1, rots2), axis = -1) 116 | 117 | x = apply_learned_rotations(rots, x) 118 | ``` 119 | 120 | ## Citations 121 | 122 | ```bibtex 123 | @misc{su2021roformer, 124 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 125 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 126 | year = {2021}, 127 | eprint = {2104.09864}, 128 | archivePrefix = {arXiv}, 129 | primaryClass = {cs.CL} 130 | } 131 | 132 | @misc{rotary-embedding-torch, 133 | title = {Rotary Embeddings - Pytorch}, 134 | author = {Phil Wang (lucidrains)}, 135 | year = {2021}, 136 | url = {https://github.com/lucidrains/rotary-embedding-torch}, 137 | publisher = {Github}, 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /rotary_embedding_tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | from rotary_embedding_tensorflow.rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding, broadcat, apply_learned_rotations 2 | -------------------------------------------------------------------------------- /rotary_embedding_tensorflow/rotary_embedding_tensorflow.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import numpy as np 3 | 4 | import tensorflow as tf 5 | from tensorflow.keras import layers 6 | 7 | 8 | # helper functions 9 | 10 | #The three functions of rearrange, irearrange and repeat have been written 11 | # due to the incompatibility of the einops library with tensorflow 2.x. 12 | 13 | def rearrange(x, r=2): 14 | b = tf.shape(x) 15 | b1 = b[:-1] 16 | b2 = b[-1, None] 17 | b3 = tf.constant([r], dtype=tf.int32) 18 | b4 = tf.cast(b2/b3, dtype=tf.int32) 19 | b_ = tf.concat([b1, b4, b3], axis=0) 20 | 21 | return tf.reshape(x, b_) 22 | 23 | def irearrange(x): 24 | c = tf.shape(x) 25 | c1 = c[:-2] 26 | c2 = tf.reduce_prod(c[-2:])[None] 27 | c_ = tf.concat([c1, c2], axis=0) 28 | 29 | return tf.reshape(x, c_) 30 | 31 | def repeat(x, r): 32 | c = tf.ones_like(tf.shape(x), dtype=tf.int32) 33 | c1 = c[:-1] 34 | c2 = c[-1][None] * r 35 | c_ = tf.concat([c1, c2], axis=0) 36 | 37 | return tf.tile(x, c_) 38 | 39 | 40 | 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | 46 | def broadcat(tensors, dim = -1): 47 | num_tensors = len(tensors) 48 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 49 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 50 | shape_len = list(shape_lens)[0] 51 | 52 | dim = (dim + shape_len) if dim < 0 else dim 53 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 54 | 55 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 56 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 57 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 58 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 59 | expanded_dims.insert(dim, (dim, dims[dim])) 60 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 61 | tensors = list(map(lambda t: tf.broadcast_to(t[0], t[1]), zip(tensors, expandable_shapes))) 62 | return tf.concat(tensors, axis=dim) 63 | 64 | # rotary embedding helper functions 65 | 66 | def rotate_half(x): 67 | x = rearrange(x, r = 2) 68 | x1, x2 = tf.unstack(x, axis=-1) 69 | x = tf.stack((-x2, x1), axis=-1) 70 | return irearrange(x) 71 | 72 | 73 | def apply_rotary_emb(freqs, t, start_index = 0): 74 | rot_dim = freqs.shape[-1] 75 | end_index = start_index + rot_dim 76 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 78 | t = (t * tf.cos(freqs)) + (rotate_half(t) * tf.sin(freqs)) 79 | return tf.concat((t_left, t, t_right), axis=-1) 80 | 81 | # learned rotation helpers 82 | 83 | def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): 84 | if exists(freq_ranges): 85 | rotations = tf.einsum('..., f -> ... f', rotations, freq_ranges) 86 | rotations = irearrange(rotations) 87 | 88 | rotations = repeat(rotations, r = 2) 89 | return apply_rotary_emb(rotations, t, start_index = start_index) 90 | 91 | 92 | # classes 93 | 94 | class RotaryEmbedding(layers.Layer): 95 | def __init__( 96 | self, 97 | dim, 98 | custom_freqs = None, 99 | freqs_for = 'lang', 100 | theta = 10000, 101 | max_freq = 10, 102 | num_freqs = 1, 103 | learned_freq = False 104 | ): 105 | super(RotaryEmbedding, self).__init__() 106 | if exists(custom_freqs): 107 | freqs = custom_freqs 108 | elif freqs_for == 'lang': 109 | freqs = tf.convert_to_tensor(1. / (theta ** (np.arange(0, dim, 2)[:(dim // 2)] / dim)), dtype=tf.float32) 110 | elif freqs_for == 'pixel': 111 | freqs = tf.convert_to_tensor(np.logspace(0., np.log(max_freq / 2) / np.log(2), dim // 2, base = 2) * np.pi, dtype=tf.float32) 112 | elif freqs_for == 'constant': 113 | freqs = tf.ones(num_freqs, dtype=tf.float32) 114 | else: 115 | raise ValueError(f'unknown modality {freqs_for}') 116 | 117 | self.cache = dict() 118 | 119 | if learned_freq: 120 | self.freqs = tf.Variable(freqs, trainable=True) 121 | else: 122 | # self.register_buffer('freqs', freqs) 123 | self.freqs = freqs 124 | 125 | def call(self, t, cache_key = None): 126 | if exists(cache_key) and cache_key in self.cache: 127 | return self.cache[cache_key] 128 | 129 | if isfunction(t): 130 | t = t() 131 | 132 | freqs = self.freqs 133 | 134 | freqs = tf.einsum('..., f -> ... f', tf.cast(t, dtype=freqs.dtype), freqs) 135 | freqs = repeat(freqs, r = 2) 136 | 137 | if exists(cache_key): 138 | self.cache[cache_key] = freqs 139 | 140 | return freqs -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | from setuptools import setup, find_packages 4 | 5 | 6 | current_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read_file(*parts): 10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: 11 | return reader.read() 12 | 13 | 14 | 15 | setup( 16 | name = 'rotary-embedding-tensorflow', 17 | packages = find_packages(), 18 | version = '0.1.1', 19 | license='MIT', 20 | description = 'Rotary Embedding - Tensorflow', 21 | long_description=read_file('README.md'), 22 | long_description_content_type='text/markdown', 23 | author = 'Arya Aftab', 24 | author_email = 'arya.aftab@gmail.com', 25 | url = 'https://github.com/AryaAftab/rotary-embedding-tensorflow', 26 | keywords = [ 27 | 'deep learning', 28 | 'tensorflow', 29 | 'positional embedding' 30 | ], 31 | install_requires=[ 32 | 'numpy>=1.18.5', 33 | 'tensorflow>=2.2' 34 | ], 35 | classifiers=[ 36 | 'Development Status :: 4 - Beta', 37 | 'Intended Audience :: Developers', 38 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 39 | 'License :: OSI Approved :: MIT License', 40 | 'Programming Language :: Python :: 3.6', 41 | ], 42 | ) 43 | --------------------------------------------------------------------------------