├── .gitignore
├── .idea
├── inspectionProfiles
│ └── profiles_settings.xml
└── workspace.xml
├── LICENSE
├── README.md
├── rotary_embedding_tensorflow
├── __init__.py
└── rotary_embedding_tensorflow.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | 1630091666021
29 |
30 |
31 | 1630091666021
32 |
33 |
34 |
35 |
36 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Arya Aftab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Rotary Embeddings - Tensorflow
2 |
3 | A standalone library for adding rotary embeddings to transformers in Tesnorflow, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.
4 |
5 | My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.
6 |
7 | ## Note
8 | An implemented version of Pytorch is available in this repository.
9 |
10 | This version is written by converting to the version of Pytorch.
11 |
12 | The three functions of rearrange, irearrange and repeat have been written due to the incompatibility of the einops library with tensorflow 2.x.
13 | ## Install
14 |
15 | ```bash
16 | $ pip install rotary-embedding-tensorflow
17 | ```
18 |
19 | ## Usage
20 |
21 | ```python
22 | import tensorflow as tf
23 | from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding
24 |
25 | # instantiate the positional embedding in your transformer and pass to all your attention layers
26 |
27 | pos_emb = RotaryEmbedding(dim = 32)
28 |
29 | # generate the rotations
30 |
31 | freqs = pos_emb(tf.range(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute
32 |
33 | # mock queries and keys
34 |
35 | q = tf.random.normal((1, 1024, 64)) # queries - (batch, seq len, dimension of head)
36 | k = tf.random.normal((1, 1024, 64)) # keys
37 |
38 | # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)
39 |
40 | freqs = freqs[None, ...] # expand dimension for batch dimension
41 | q = apply_rotary_emb(freqs, q)
42 | k = apply_rotary_emb(freqs, k)
43 |
44 | # then do your attention with your queries (q) and keys (k)
45 | ```
46 |
47 | If you do all the steps above correctly, you should see a dramatic improvement during training
48 |
49 | ## Axial Rotary Embeddings
50 |
51 | For easy use of 2d axial relative positional embedding, ie. vision transformers
52 |
53 | ```python
54 | import tensorflow as tf
55 | from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding, broadcat
56 |
57 | pos_emb = RotaryEmbedding(
58 | dim = 32,
59 | freqs_for = 'pixel'
60 | )
61 |
62 | # queries and keys for frequencies to be rotated into
63 |
64 | q = tf.random.normal((1, 256, 256, 64))
65 | k = tf.random.normal((1, 256, 256, 64))
66 |
67 | # get frequencies for each axial
68 | # -1 to 1 has been shown to be a good choice for images and audio
69 |
70 | freqs_h = pos_emb(tf.linspace(-1, 1, num = 256), cache_key = 256)
71 | freqs_w = pos_emb(tf.linspace(-1, 1, num = 256), cache_key = 256)
72 |
73 | # concat the frequencies along each axial
74 | # broadcat function makes this easy without a bunch of expands
75 |
76 | freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)
77 |
78 | # rotate in frequencies
79 |
80 | q = apply_rotary_emb(freqs, q)
81 | k = apply_rotary_emb(freqs, k)
82 | ```
83 |
84 | ## Learned Rotations
85 |
86 | For injecting learned rotations into a network. Experiments pending
87 |
88 | Update: doesn't seem to do anything -_-, will keep trying...
89 |
90 | ```python
91 | import tensorflow as tf
92 | from tensorflow.keras import layers
93 | from rotary_embedding_tensorflow import apply_learned_rotations
94 |
95 | x = tf.random.normal((1, 1024, 512))
96 |
97 | # you can only rotate in (dim // 2) values
98 | # ex. for 512, you can only rotate in 256 values
99 |
100 | # say you have two sets of learned rotations of 128 values each
101 |
102 | rots1 = layers.Dense(128)(x)
103 | rots2 = layers.Dense(128)(x)
104 |
105 | # you rotate in 256 (128 x 2) at first
106 |
107 | x = apply_learned_rotations(rots1, x, start_index = 0)
108 |
109 | # then you start at index 256 and rotate in the last (128 x 2)
110 |
111 | x = apply_learned_rotations(rots2, x, start_index = 256)
112 |
113 | # you could also concat the rotations together and pass it in all at once
114 |
115 | rots = tf.concat((rots1, rots2), axis = -1)
116 |
117 | x = apply_learned_rotations(rots, x)
118 | ```
119 |
120 | ## Citations
121 |
122 | ```bibtex
123 | @misc{su2021roformer,
124 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
125 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
126 | year = {2021},
127 | eprint = {2104.09864},
128 | archivePrefix = {arXiv},
129 | primaryClass = {cs.CL}
130 | }
131 |
132 | @misc{rotary-embedding-torch,
133 | title = {Rotary Embeddings - Pytorch},
134 | author = {Phil Wang (lucidrains)},
135 | year = {2021},
136 | url = {https://github.com/lucidrains/rotary-embedding-torch},
137 | publisher = {Github},
138 | }
139 | ```
140 |
--------------------------------------------------------------------------------
/rotary_embedding_tensorflow/__init__.py:
--------------------------------------------------------------------------------
1 | from rotary_embedding_tensorflow.rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding, broadcat, apply_learned_rotations
2 |
--------------------------------------------------------------------------------
/rotary_embedding_tensorflow/rotary_embedding_tensorflow.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import numpy as np
3 |
4 | import tensorflow as tf
5 | from tensorflow.keras import layers
6 |
7 |
8 | # helper functions
9 |
10 | #The three functions of rearrange, irearrange and repeat have been written
11 | # due to the incompatibility of the einops library with tensorflow 2.x.
12 |
13 | def rearrange(x, r=2):
14 | b = tf.shape(x)
15 | b1 = b[:-1]
16 | b2 = b[-1, None]
17 | b3 = tf.constant([r], dtype=tf.int32)
18 | b4 = tf.cast(b2/b3, dtype=tf.int32)
19 | b_ = tf.concat([b1, b4, b3], axis=0)
20 |
21 | return tf.reshape(x, b_)
22 |
23 | def irearrange(x):
24 | c = tf.shape(x)
25 | c1 = c[:-2]
26 | c2 = tf.reduce_prod(c[-2:])[None]
27 | c_ = tf.concat([c1, c2], axis=0)
28 |
29 | return tf.reshape(x, c_)
30 |
31 | def repeat(x, r):
32 | c = tf.ones_like(tf.shape(x), dtype=tf.int32)
33 | c1 = c[:-1]
34 | c2 = c[-1][None] * r
35 | c_ = tf.concat([c1, c2], axis=0)
36 |
37 | return tf.tile(x, c_)
38 |
39 |
40 |
41 |
42 | def exists(val):
43 | return val is not None
44 |
45 |
46 | def broadcat(tensors, dim = -1):
47 | num_tensors = len(tensors)
48 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
49 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
50 | shape_len = list(shape_lens)[0]
51 |
52 | dim = (dim + shape_len) if dim < 0 else dim
53 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
54 |
55 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
56 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
57 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
58 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
59 | expanded_dims.insert(dim, (dim, dims[dim]))
60 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
61 | tensors = list(map(lambda t: tf.broadcast_to(t[0], t[1]), zip(tensors, expandable_shapes)))
62 | return tf.concat(tensors, axis=dim)
63 |
64 | # rotary embedding helper functions
65 |
66 | def rotate_half(x):
67 | x = rearrange(x, r = 2)
68 | x1, x2 = tf.unstack(x, axis=-1)
69 | x = tf.stack((-x2, x1), axis=-1)
70 | return irearrange(x)
71 |
72 |
73 | def apply_rotary_emb(freqs, t, start_index = 0):
74 | rot_dim = freqs.shape[-1]
75 | end_index = start_index + rot_dim
76 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
78 | t = (t * tf.cos(freqs)) + (rotate_half(t) * tf.sin(freqs))
79 | return tf.concat((t_left, t, t_right), axis=-1)
80 |
81 | # learned rotation helpers
82 |
83 | def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
84 | if exists(freq_ranges):
85 | rotations = tf.einsum('..., f -> ... f', rotations, freq_ranges)
86 | rotations = irearrange(rotations)
87 |
88 | rotations = repeat(rotations, r = 2)
89 | return apply_rotary_emb(rotations, t, start_index = start_index)
90 |
91 |
92 | # classes
93 |
94 | class RotaryEmbedding(layers.Layer):
95 | def __init__(
96 | self,
97 | dim,
98 | custom_freqs = None,
99 | freqs_for = 'lang',
100 | theta = 10000,
101 | max_freq = 10,
102 | num_freqs = 1,
103 | learned_freq = False
104 | ):
105 | super(RotaryEmbedding, self).__init__()
106 | if exists(custom_freqs):
107 | freqs = custom_freqs
108 | elif freqs_for == 'lang':
109 | freqs = tf.convert_to_tensor(1. / (theta ** (np.arange(0, dim, 2)[:(dim // 2)] / dim)), dtype=tf.float32)
110 | elif freqs_for == 'pixel':
111 | freqs = tf.convert_to_tensor(np.logspace(0., np.log(max_freq / 2) / np.log(2), dim // 2, base = 2) * np.pi, dtype=tf.float32)
112 | elif freqs_for == 'constant':
113 | freqs = tf.ones(num_freqs, dtype=tf.float32)
114 | else:
115 | raise ValueError(f'unknown modality {freqs_for}')
116 |
117 | self.cache = dict()
118 |
119 | if learned_freq:
120 | self.freqs = tf.Variable(freqs, trainable=True)
121 | else:
122 | # self.register_buffer('freqs', freqs)
123 | self.freqs = freqs
124 |
125 | def call(self, t, cache_key = None):
126 | if exists(cache_key) and cache_key in self.cache:
127 | return self.cache[cache_key]
128 |
129 | if isfunction(t):
130 | t = t()
131 |
132 | freqs = self.freqs
133 |
134 | freqs = tf.einsum('..., f -> ... f', tf.cast(t, dtype=freqs.dtype), freqs)
135 | freqs = repeat(freqs, r = 2)
136 |
137 | if exists(cache_key):
138 | self.cache[cache_key] = freqs
139 |
140 | return freqs
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import codecs
3 | from setuptools import setup, find_packages
4 |
5 |
6 | current_path = os.path.abspath(os.path.dirname(__file__))
7 |
8 |
9 | def read_file(*parts):
10 | with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader:
11 | return reader.read()
12 |
13 |
14 |
15 | setup(
16 | name = 'rotary-embedding-tensorflow',
17 | packages = find_packages(),
18 | version = '0.1.1',
19 | license='MIT',
20 | description = 'Rotary Embedding - Tensorflow',
21 | long_description=read_file('README.md'),
22 | long_description_content_type='text/markdown',
23 | author = 'Arya Aftab',
24 | author_email = 'arya.aftab@gmail.com',
25 | url = 'https://github.com/AryaAftab/rotary-embedding-tensorflow',
26 | keywords = [
27 | 'deep learning',
28 | 'tensorflow',
29 | 'positional embedding'
30 | ],
31 | install_requires=[
32 | 'numpy>=1.18.5',
33 | 'tensorflow>=2.2'
34 | ],
35 | classifiers=[
36 | 'Development Status :: 4 - Beta',
37 | 'Intended Audience :: Developers',
38 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
39 | 'License :: OSI Approved :: MIT License',
40 | 'Programming Language :: Python :: 3.6',
41 | ],
42 | )
43 |
--------------------------------------------------------------------------------