├── rope.png
├── rotary_embedding_torch
├── __init__.py
└── rotary_embedding_torch.py
├── setup.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
└── README.md
/rope.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/rotary-embedding-torch/HEAD/rope.png
--------------------------------------------------------------------------------
/rotary_embedding_torch/__init__.py:
--------------------------------------------------------------------------------
1 | from rotary_embedding_torch.rotary_embedding_torch import (
2 | apply_rotary_emb,
3 | RotaryEmbedding,
4 | apply_learned_rotations,
5 | broadcat
6 | )
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'rotary-embedding-torch',
5 | packages = find_packages(),
6 | version = '0.8.9',
7 | license='MIT',
8 | description = 'Rotary Embedding - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/rotary-embedding-torch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'positional embedding'
17 | ],
18 | install_requires=[
19 | 'einops>=0.7',
20 | 'torch>=2.0'
21 | ],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Programming Language :: Python :: 3.6',
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Rotary Embeddings - Pytorch
4 |
5 | A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.
6 |
7 | My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install rotary-embedding-torch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from rotary_embedding_torch import RotaryEmbedding
20 |
21 | # instantiate the positional embedding in your transformer and pass to all your attention layers
22 |
23 | rotary_emb = RotaryEmbedding(dim = 32)
24 |
25 | # mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)
26 |
27 | q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
28 | k = torch.randn(1, 8, 1024, 64) # keys
29 |
30 | # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)
31 |
32 | q = rotary_emb.rotate_queries_or_keys(q)
33 | k = rotary_emb.rotate_queries_or_keys(k)
34 |
35 | # then do your attention with your queries (q) and keys (k) as usual
36 | ```
37 |
38 | If you do all the steps above correctly, you should see a dramatic improvement during training
39 |
40 | ## Inference Key-Value Cache
41 |
42 | When dealing with key / value caches at inference, the query position needs to be offset with the `key_value_seq_length - query_seq_length`
43 |
44 | To make this easy, use the `rotate_queries_with_cached_keys` method
45 |
46 | ```python
47 | q = torch.randn(1, 8, 1, 64) # only one query at a time
48 | k = torch.randn(1, 8, 1024, 64) # key / values with cache concatted
49 |
50 | q, k = rotary_emb.rotate_queries_with_cached_keys(q, k)
51 | ```
52 |
53 | You can also do this manually like so
54 |
55 | ```python
56 | q = rotary_emb.rotate_queries_or_keys(q, offset = k.shape[-2] - q.shape[-2])
57 | ```
58 |
59 | ## Axial Rotary Embeddings
60 |
61 | For easy use of n-dimensional axial relative positional embedding, ie. video transformers
62 |
63 | ```python
64 | import torch
65 |
66 | from rotary_embedding_torch import (
67 | RotaryEmbedding,
68 | apply_rotary_emb
69 | )
70 |
71 | pos_emb = RotaryEmbedding(
72 | dim = 16,
73 | freqs_for = 'pixel',
74 | max_freq = 256
75 | )
76 |
77 | # queries and keys for frequencies to be rotated into
78 | # say for a video with 8 frames, and rectangular image (feature dimension comes last)
79 |
80 | q = torch.randn(1, 8, 64, 32, 64)
81 | k = torch.randn(1, 8, 64, 32, 64)
82 |
83 | # get axial frequencies - (8, 64, 32, 16 * 3 = 48)
84 | # will automatically do partial rotary
85 |
86 | freqs = pos_emb.get_axial_freqs(8, 64, 32)
87 |
88 | # rotate in frequencies
89 |
90 | q = apply_rotary_emb(freqs, q)
91 | k = apply_rotary_emb(freqs, k)
92 | ```
93 |
94 | ## Length Extrapolatable Rotary Embeddings
95 |
96 | In this paper, they were able to fix length extrapolation issue with rotary embeddings by giving it a decay similar to ALiBi. They named this technique XPos, and you can use it by setting `use_xpos = True` on initialization.
97 |
98 | This can only be used for autoregressive transformers
99 |
100 | ```python
101 | import torch
102 | from rotary_embedding_torch import RotaryEmbedding
103 |
104 | # instantiate the positional embedding in your transformer and pass to all your attention layers
105 |
106 | rotary_emb = RotaryEmbedding(
107 | dim = 32,
108 | use_xpos = True # set this to True to make rotary embeddings extrapolate better to sequence lengths greater than the one used at training time
109 | )
110 |
111 | # mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)
112 |
113 | q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
114 | k = torch.randn(1, 8, 1024, 64) # keys
115 |
116 | # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)
117 |
118 | # instead of using `rotate_queries_or_keys`, you will use `rotate_queries_and_keys`, the rest is taken care of
119 |
120 | q, k = rotary_emb.rotate_queries_and_keys(q, k)
121 | ```
122 |
123 | ## Interpolating Sequence Positions
124 |
125 | This MetaAI paper proposes simply fine-tuning on interpolations of the sequence positions for extending to longer context length for pretrained models. They show this performs much better than simply fine-tuning on the same sequence positions but extended further.
126 |
127 | You can use this by setting the `interpolate_factor` on initialization to a value greater than `1.` (ex. if pretrained model was trained on 2048, setting `interpolate_factor = 2.` would allow fine-tuning to `2048 x 2. = 4096`)
128 |
129 | Update: someone in the community has reported that it does not work well. please email me if you see either a positive or negative result
130 |
131 | ```python
132 | import torch
133 | from rotary_embedding_torch import RotaryEmbedding
134 |
135 | rotary_emb = RotaryEmbedding(
136 | dim = 32,
137 | interpolate_factor = 2. # add this line of code to pretrained model and fine-tune for ~1000 steps, as shown in paper
138 | )
139 | ```
140 |
141 | ## Citations
142 |
143 | ```bibtex
144 | @misc{su2021roformer,
145 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
146 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
147 | year = {2021},
148 | eprint = {2104.09864},
149 | archivePrefix = {arXiv},
150 | primaryClass = {cs.CL}
151 | }
152 | ```
153 |
154 | ```bibtex
155 | @inproceedings{Sun2022ALT,
156 | title = {A Length-Extrapolatable Transformer},
157 | author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
158 | year = {2022}
159 | }
160 | ```
161 |
162 | ```bibtex
163 | @inproceedings{Chen2023ExtendingCW,
164 | title = {Extending Context Window of Large Language Models via Positional Interpolation},
165 | author = {Shouyuan Chen and Sherman Wong and Liangjian Chen and Yuandong Tian},
166 | year = {2023}
167 | }
168 | ```
169 |
170 | ```bibtex
171 | @misc{bloc97-2023
172 | title = {NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.},
173 | author = {/u/bloc97},
174 | url = {https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/}
175 | }
176 | ```
177 |
--------------------------------------------------------------------------------
/rotary_embedding_torch/rotary_embedding_torch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from math import pi, log
3 |
4 | import torch
5 | from torch.amp import autocast
6 | from torch.nn import Module, ModuleList
7 | from torch import nn, einsum, broadcast_tensors, is_tensor, tensor, Tensor
8 |
9 | from einops import rearrange, repeat
10 |
11 | from typing import Literal
12 |
13 | # helper functions
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def default(val, d):
19 | return val if exists(val) else d
20 |
21 | # broadcat, as tortoise-tts was using it
22 |
23 | def broadcat(tensors, dim = -1):
24 | broadcasted_tensors = broadcast_tensors(*tensors)
25 | return torch.cat(broadcasted_tensors, dim = dim)
26 |
27 | def slice_at_dim(t, dim_slice: slice, *, dim):
28 | dim += (t.ndim if dim < 0 else 0)
29 | colons = [slice(None)] * t.ndim
30 | colons[dim] = dim_slice
31 | return t[tuple(colons)]
32 |
33 | # rotary embedding helper functions
34 |
35 | def rotate_half(x):
36 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
37 | x1, x2 = x.unbind(dim = -1)
38 | x = torch.stack((-x2, x1), dim = -1)
39 | return rearrange(x, '... d r -> ... (d r)')
40 |
41 | @autocast('cuda', enabled = False)
42 | def apply_rotary_emb(
43 | freqs,
44 | t,
45 | start_index = 0,
46 | scale = 1.,
47 | seq_dim = -2,
48 | freqs_seq_dim = None
49 | ):
50 | dtype = t.dtype
51 |
52 | if not exists(freqs_seq_dim):
53 | if freqs.ndim == 2 or t.ndim == 3:
54 | freqs_seq_dim = 0
55 |
56 | if t.ndim == 3 or exists(freqs_seq_dim):
57 | seq_len = t.shape[seq_dim]
58 | freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
59 |
60 | rot_dim = freqs.shape[-1]
61 | end_index = start_index + rot_dim
62 |
63 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
64 |
65 | # Split t into three parts: left, middle (to be transformed), and right
66 | t_left = t[..., :start_index]
67 | t_middle = t[..., start_index:end_index]
68 | t_right = t[..., end_index:]
69 |
70 | # Apply rotary embeddings without modifying t in place
71 | t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
72 |
73 | out = torch.cat((t_left, t_transformed, t_right), dim=-1)
74 |
75 | return out.type(dtype)
76 |
77 | # learned rotation helpers
78 |
79 | def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
80 | if exists(freq_ranges):
81 | rotations = einsum('..., f -> ... f', rotations, freq_ranges)
82 | rotations = rearrange(rotations, '... r f -> ... (r f)')
83 |
84 | rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
85 | return apply_rotary_emb(rotations, t, start_index = start_index)
86 |
87 | # classes
88 |
89 | class RotaryEmbedding(Module):
90 | def __init__(
91 | self,
92 | dim,
93 | custom_freqs: Tensor | None = None,
94 | freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang',
95 | theta = 10000,
96 | max_freq = 10,
97 | num_freqs = 1,
98 | learned_freq = False,
99 | use_xpos = False,
100 | xpos_scale_base = 512,
101 | interpolate_factor = 1.,
102 | theta_rescale_factor = 1.,
103 | seq_before_head_dim = False,
104 | cache_if_possible = True,
105 | cache_max_seq_len = 8192
106 | ):
107 | super().__init__()
108 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
109 | # has some connection to NTK literature
110 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
111 |
112 | theta *= theta_rescale_factor ** (dim / (dim - 2))
113 |
114 | self.freqs_for = freqs_for
115 |
116 | if exists(custom_freqs):
117 | freqs = custom_freqs
118 | elif freqs_for == 'lang':
119 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
120 | elif freqs_for == 'pixel':
121 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
122 | elif freqs_for == 'constant':
123 | freqs = torch.ones(num_freqs).float()
124 |
125 | self.cache_if_possible = cache_if_possible
126 | self.cache_max_seq_len = cache_max_seq_len
127 |
128 | self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
129 | self.cached_freqs_seq_len = 0
130 |
131 | self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
132 |
133 | self.learned_freq = learned_freq
134 |
135 | # dummy for device
136 |
137 | self.register_buffer('dummy', torch.tensor(0), persistent = False)
138 |
139 | # default sequence dimension
140 |
141 | self.seq_before_head_dim = seq_before_head_dim
142 | self.default_seq_dim = -3 if seq_before_head_dim else -2
143 |
144 | # interpolation factors
145 |
146 | assert interpolate_factor >= 1.
147 | self.interpolate_factor = interpolate_factor
148 |
149 | # xpos
150 |
151 | self.use_xpos = use_xpos
152 |
153 | if not use_xpos:
154 | return
155 |
156 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
157 | self.scale_base = xpos_scale_base
158 |
159 | self.register_buffer('scale', scale, persistent = False)
160 | self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False)
161 | self.cached_scales_seq_len = 0
162 |
163 | # add apply_rotary_emb as static method
164 |
165 | self.apply_rotary_emb = staticmethod(apply_rotary_emb)
166 |
167 | @property
168 | def device(self):
169 | return self.dummy.device
170 |
171 | def get_seq_pos(self, seq_len, device = None, dtype = None, offset = 0):
172 | device = default(device, self.device)
173 | dtype = default(dtype, self.cached_freqs.dtype)
174 |
175 | return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
176 |
177 | def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None):
178 | seq_dim = default(seq_dim, self.default_seq_dim)
179 |
180 | assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
181 |
182 | device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
183 |
184 | seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
185 |
186 | freqs = self.forward(seq, seq_len = seq_len, offset = offset)
187 |
188 | if seq_dim == -3:
189 | freqs = rearrange(freqs, 'n d -> n 1 d')
190 |
191 | return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim)
192 |
193 | def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
194 | dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
195 |
196 | q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
197 | assert q_len <= k_len
198 |
199 | q_scale = k_scale = 1.
200 |
201 | if self.use_xpos:
202 | seq = self.get_seq_pos(k_len, dtype = dtype, device = device)
203 |
204 | q_scale = self.get_scale(seq[-q_len:]).type(dtype)
205 | k_scale = self.get_scale(seq).type(dtype)
206 |
207 | rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset)
208 | rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1)
209 |
210 | rotated_q = rotated_q.type(q.dtype)
211 | rotated_k = rotated_k.type(k.dtype)
212 |
213 | return rotated_q, rotated_k
214 |
215 | def rotate_queries_and_keys(self, q, k, seq_dim = None):
216 | seq_dim = default(seq_dim, self.default_seq_dim)
217 |
218 | assert self.use_xpos
219 | device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
220 |
221 | seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
222 |
223 | freqs = self.forward(seq, seq_len = seq_len)
224 | scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
225 |
226 | if seq_dim == -3:
227 | freqs = rearrange(freqs, 'n d -> n 1 d')
228 | scale = rearrange(scale, 'n d -> n 1 d')
229 |
230 | rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
231 | rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
232 |
233 | rotated_q = rotated_q.type(q.dtype)
234 | rotated_k = rotated_k.type(k.dtype)
235 |
236 | return rotated_q, rotated_k
237 |
238 | def get_scale(
239 | self,
240 | t: Tensor,
241 | seq_len: int | None = None,
242 | offset = 0
243 | ):
244 | assert self.use_xpos
245 |
246 | should_cache = (
247 | self.cache_if_possible and
248 | exists(seq_len) and
249 | (offset + seq_len) <= self.cache_max_seq_len
250 | )
251 |
252 | if (
253 | should_cache and \
254 | exists(self.cached_scales) and \
255 | (seq_len + offset) <= self.cached_scales_seq_len
256 | ):
257 | return self.cached_scales[offset:(offset + seq_len)]
258 |
259 | scale = 1.
260 | if self.use_xpos:
261 | power = (t - len(t) // 2) / self.scale_base
262 | scale = self.scale ** rearrange(power, 'n -> n 1')
263 | scale = repeat(scale, 'n d -> n (d r)', r = 2)
264 |
265 | if should_cache and offset == 0:
266 | self.cached_scales[:seq_len] = scale.detach()
267 | self.cached_scales_seq_len = seq_len
268 |
269 | return scale
270 |
271 | def get_axial_freqs(
272 | self,
273 | *dims,
274 | offsets: (
275 | tuple[int | float, ...] |
276 | Tensor |
277 | None
278 | ) = None
279 | ):
280 | Colon = slice(None)
281 | all_freqs = []
282 |
283 | # handle offset
284 |
285 | if exists(offsets):
286 | if not is_tensor(offsets):
287 | offsets = tensor(offsets)
288 |
289 | assert len(offsets) == len(dims)
290 |
291 | # get frequencies for each axis
292 |
293 | for ind, dim in enumerate(dims):
294 |
295 | offset = 0
296 | if exists(offsets):
297 | offset = offsets[ind]
298 |
299 | if self.freqs_for == 'pixel':
300 | pos = torch.linspace(-1, 1, steps = dim, device = self.device)
301 | else:
302 | pos = torch.arange(dim, device = self.device)
303 |
304 | pos = pos + offset
305 |
306 | freqs = self.forward(pos, seq_len = dim)
307 |
308 | all_axis = [None] * len(dims)
309 | all_axis[ind] = Colon
310 |
311 | new_axis_slice = (Ellipsis, *all_axis, Colon)
312 | all_freqs.append(freqs[new_axis_slice])
313 |
314 | # concat all freqs
315 |
316 | all_freqs = broadcast_tensors(*all_freqs)
317 | return torch.cat(all_freqs, dim = -1)
318 |
319 | @autocast('cuda', enabled = False)
320 | def forward(
321 | self,
322 | t: Tensor,
323 | seq_len: int | None = None,
324 | offset = 0
325 | ):
326 | should_cache = (
327 | self.cache_if_possible and
328 | not self.learned_freq and
329 | exists(seq_len) and
330 | self.freqs_for != 'pixel' and
331 | (offset + seq_len) <= self.cache_max_seq_len
332 | )
333 |
334 | if (
335 | should_cache and \
336 | exists(self.cached_freqs) and \
337 | (offset + seq_len) <= self.cached_freqs_seq_len
338 | ):
339 | return self.cached_freqs[offset:(offset + seq_len)].detach()
340 |
341 | freqs = self.freqs
342 |
343 | freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
344 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
345 |
346 | if should_cache and offset == 0:
347 | self.cached_freqs[:seq_len] = freqs.detach()
348 | self.cached_freqs_seq_len = seq_len
349 |
350 | return freqs
351 |
--------------------------------------------------------------------------------