├── rope.png ├── rotary_embedding_torch ├── __init__.py └── rotary_embedding_torch.py ├── setup.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── .gitignore └── README.md /rope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/rotary-embedding-torch/HEAD/rope.png -------------------------------------------------------------------------------- /rotary_embedding_torch/__init__.py: -------------------------------------------------------------------------------- 1 | from rotary_embedding_torch.rotary_embedding_torch import ( 2 | apply_rotary_emb, 3 | RotaryEmbedding, 4 | apply_learned_rotations, 5 | broadcat 6 | ) 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'rotary-embedding-torch', 5 | packages = find_packages(), 6 | version = '0.8.9', 7 | license='MIT', 8 | description = 'Rotary Embedding - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/rotary-embedding-torch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'positional embedding' 17 | ], 18 | install_requires=[ 19 | 'einops>=0.7', 20 | 'torch>=2.0' 21 | ], 22 | classifiers=[ 23 | 'Development Status :: 4 - Beta', 24 | 'Intended Audience :: Developers', 25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3.6', 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Rotary Embeddings - Pytorch 4 | 5 | A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs. 6 | 7 | My gut also tells me there is something more to rotations that can be exploited in artificial neural networks. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install rotary-embedding-torch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from rotary_embedding_torch import RotaryEmbedding 20 | 21 | # instantiate the positional embedding in your transformer and pass to all your attention layers 22 | 23 | rotary_emb = RotaryEmbedding(dim = 32) 24 | 25 | # mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc) 26 | 27 | q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head) 28 | k = torch.randn(1, 8, 1024, 64) # keys 29 | 30 | # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention) 31 | 32 | q = rotary_emb.rotate_queries_or_keys(q) 33 | k = rotary_emb.rotate_queries_or_keys(k) 34 | 35 | # then do your attention with your queries (q) and keys (k) as usual 36 | ``` 37 | 38 | If you do all the steps above correctly, you should see a dramatic improvement during training 39 | 40 | ## Inference Key-Value Cache 41 | 42 | When dealing with key / value caches at inference, the query position needs to be offset with the `key_value_seq_length - query_seq_length` 43 | 44 | To make this easy, use the `rotate_queries_with_cached_keys` method 45 | 46 | ```python 47 | q = torch.randn(1, 8, 1, 64) # only one query at a time 48 | k = torch.randn(1, 8, 1024, 64) # key / values with cache concatted 49 | 50 | q, k = rotary_emb.rotate_queries_with_cached_keys(q, k) 51 | ``` 52 | 53 | You can also do this manually like so 54 | 55 | ```python 56 | q = rotary_emb.rotate_queries_or_keys(q, offset = k.shape[-2] - q.shape[-2]) 57 | ``` 58 | 59 | ## Axial Rotary Embeddings 60 | 61 | For easy use of n-dimensional axial relative positional embedding, ie. video transformers 62 | 63 | ```python 64 | import torch 65 | 66 | from rotary_embedding_torch import ( 67 | RotaryEmbedding, 68 | apply_rotary_emb 69 | ) 70 | 71 | pos_emb = RotaryEmbedding( 72 | dim = 16, 73 | freqs_for = 'pixel', 74 | max_freq = 256 75 | ) 76 | 77 | # queries and keys for frequencies to be rotated into 78 | # say for a video with 8 frames, and rectangular image (feature dimension comes last) 79 | 80 | q = torch.randn(1, 8, 64, 32, 64) 81 | k = torch.randn(1, 8, 64, 32, 64) 82 | 83 | # get axial frequencies - (8, 64, 32, 16 * 3 = 48) 84 | # will automatically do partial rotary 85 | 86 | freqs = pos_emb.get_axial_freqs(8, 64, 32) 87 | 88 | # rotate in frequencies 89 | 90 | q = apply_rotary_emb(freqs, q) 91 | k = apply_rotary_emb(freqs, k) 92 | ``` 93 | 94 | ## Length Extrapolatable Rotary Embeddings 95 | 96 | In this paper, they were able to fix length extrapolation issue with rotary embeddings by giving it a decay similar to ALiBi. They named this technique XPos, and you can use it by setting `use_xpos = True` on initialization. 97 | 98 | This can only be used for autoregressive transformers 99 | 100 | ```python 101 | import torch 102 | from rotary_embedding_torch import RotaryEmbedding 103 | 104 | # instantiate the positional embedding in your transformer and pass to all your attention layers 105 | 106 | rotary_emb = RotaryEmbedding( 107 | dim = 32, 108 | use_xpos = True # set this to True to make rotary embeddings extrapolate better to sequence lengths greater than the one used at training time 109 | ) 110 | 111 | # mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc) 112 | 113 | q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head) 114 | k = torch.randn(1, 8, 1024, 64) # keys 115 | 116 | # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention) 117 | 118 | # instead of using `rotate_queries_or_keys`, you will use `rotate_queries_and_keys`, the rest is taken care of 119 | 120 | q, k = rotary_emb.rotate_queries_and_keys(q, k) 121 | ``` 122 | 123 | ## Interpolating Sequence Positions 124 | 125 | This MetaAI paper proposes simply fine-tuning on interpolations of the sequence positions for extending to longer context length for pretrained models. They show this performs much better than simply fine-tuning on the same sequence positions but extended further. 126 | 127 | You can use this by setting the `interpolate_factor` on initialization to a value greater than `1.` (ex. if pretrained model was trained on 2048, setting `interpolate_factor = 2.` would allow fine-tuning to `2048 x 2. = 4096`) 128 | 129 | Update: someone in the community has reported that it does not work well. please email me if you see either a positive or negative result 130 | 131 | ```python 132 | import torch 133 | from rotary_embedding_torch import RotaryEmbedding 134 | 135 | rotary_emb = RotaryEmbedding( 136 | dim = 32, 137 | interpolate_factor = 2. # add this line of code to pretrained model and fine-tune for ~1000 steps, as shown in paper 138 | ) 139 | ``` 140 | 141 | ## Citations 142 | 143 | ```bibtex 144 | @misc{su2021roformer, 145 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 146 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 147 | year = {2021}, 148 | eprint = {2104.09864}, 149 | archivePrefix = {arXiv}, 150 | primaryClass = {cs.CL} 151 | } 152 | ``` 153 | 154 | ```bibtex 155 | @inproceedings{Sun2022ALT, 156 | title = {A Length-Extrapolatable Transformer}, 157 | author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei}, 158 | year = {2022} 159 | } 160 | ``` 161 | 162 | ```bibtex 163 | @inproceedings{Chen2023ExtendingCW, 164 | title = {Extending Context Window of Large Language Models via Positional Interpolation}, 165 | author = {Shouyuan Chen and Sherman Wong and Liangjian Chen and Yuandong Tian}, 166 | year = {2023} 167 | } 168 | ``` 169 | 170 | ```bibtex 171 | @misc{bloc97-2023 172 | title = {NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.}, 173 | author = {/u/bloc97}, 174 | url = {https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/} 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /rotary_embedding_torch/rotary_embedding_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from math import pi, log 3 | 4 | import torch 5 | from torch.amp import autocast 6 | from torch.nn import Module, ModuleList 7 | from torch import nn, einsum, broadcast_tensors, is_tensor, tensor, Tensor 8 | 9 | from einops import rearrange, repeat 10 | 11 | from typing import Literal 12 | 13 | # helper functions 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | # broadcat, as tortoise-tts was using it 22 | 23 | def broadcat(tensors, dim = -1): 24 | broadcasted_tensors = broadcast_tensors(*tensors) 25 | return torch.cat(broadcasted_tensors, dim = dim) 26 | 27 | def slice_at_dim(t, dim_slice: slice, *, dim): 28 | dim += (t.ndim if dim < 0 else 0) 29 | colons = [slice(None)] * t.ndim 30 | colons[dim] = dim_slice 31 | return t[tuple(colons)] 32 | 33 | # rotary embedding helper functions 34 | 35 | def rotate_half(x): 36 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 37 | x1, x2 = x.unbind(dim = -1) 38 | x = torch.stack((-x2, x1), dim = -1) 39 | return rearrange(x, '... d r -> ... (d r)') 40 | 41 | @autocast('cuda', enabled = False) 42 | def apply_rotary_emb( 43 | freqs, 44 | t, 45 | start_index = 0, 46 | scale = 1., 47 | seq_dim = -2, 48 | freqs_seq_dim = None 49 | ): 50 | dtype = t.dtype 51 | 52 | if not exists(freqs_seq_dim): 53 | if freqs.ndim == 2 or t.ndim == 3: 54 | freqs_seq_dim = 0 55 | 56 | if t.ndim == 3 or exists(freqs_seq_dim): 57 | seq_len = t.shape[seq_dim] 58 | freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) 59 | 60 | rot_dim = freqs.shape[-1] 61 | end_index = start_index + rot_dim 62 | 63 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 64 | 65 | # Split t into three parts: left, middle (to be transformed), and right 66 | t_left = t[..., :start_index] 67 | t_middle = t[..., start_index:end_index] 68 | t_right = t[..., end_index:] 69 | 70 | # Apply rotary embeddings without modifying t in place 71 | t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) 72 | 73 | out = torch.cat((t_left, t_transformed, t_right), dim=-1) 74 | 75 | return out.type(dtype) 76 | 77 | # learned rotation helpers 78 | 79 | def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): 80 | if exists(freq_ranges): 81 | rotations = einsum('..., f -> ... f', rotations, freq_ranges) 82 | rotations = rearrange(rotations, '... r f -> ... (r f)') 83 | 84 | rotations = repeat(rotations, '... n -> ... (n r)', r = 2) 85 | return apply_rotary_emb(rotations, t, start_index = start_index) 86 | 87 | # classes 88 | 89 | class RotaryEmbedding(Module): 90 | def __init__( 91 | self, 92 | dim, 93 | custom_freqs: Tensor | None = None, 94 | freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', 95 | theta = 10000, 96 | max_freq = 10, 97 | num_freqs = 1, 98 | learned_freq = False, 99 | use_xpos = False, 100 | xpos_scale_base = 512, 101 | interpolate_factor = 1., 102 | theta_rescale_factor = 1., 103 | seq_before_head_dim = False, 104 | cache_if_possible = True, 105 | cache_max_seq_len = 8192 106 | ): 107 | super().__init__() 108 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 109 | # has some connection to NTK literature 110 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 111 | 112 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 113 | 114 | self.freqs_for = freqs_for 115 | 116 | if exists(custom_freqs): 117 | freqs = custom_freqs 118 | elif freqs_for == 'lang': 119 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 120 | elif freqs_for == 'pixel': 121 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 122 | elif freqs_for == 'constant': 123 | freqs = torch.ones(num_freqs).float() 124 | 125 | self.cache_if_possible = cache_if_possible 126 | self.cache_max_seq_len = cache_max_seq_len 127 | 128 | self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) 129 | self.cached_freqs_seq_len = 0 130 | 131 | self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) 132 | 133 | self.learned_freq = learned_freq 134 | 135 | # dummy for device 136 | 137 | self.register_buffer('dummy', torch.tensor(0), persistent = False) 138 | 139 | # default sequence dimension 140 | 141 | self.seq_before_head_dim = seq_before_head_dim 142 | self.default_seq_dim = -3 if seq_before_head_dim else -2 143 | 144 | # interpolation factors 145 | 146 | assert interpolate_factor >= 1. 147 | self.interpolate_factor = interpolate_factor 148 | 149 | # xpos 150 | 151 | self.use_xpos = use_xpos 152 | 153 | if not use_xpos: 154 | return 155 | 156 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) 157 | self.scale_base = xpos_scale_base 158 | 159 | self.register_buffer('scale', scale, persistent = False) 160 | self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) 161 | self.cached_scales_seq_len = 0 162 | 163 | # add apply_rotary_emb as static method 164 | 165 | self.apply_rotary_emb = staticmethod(apply_rotary_emb) 166 | 167 | @property 168 | def device(self): 169 | return self.dummy.device 170 | 171 | def get_seq_pos(self, seq_len, device = None, dtype = None, offset = 0): 172 | device = default(device, self.device) 173 | dtype = default(dtype, self.cached_freqs.dtype) 174 | 175 | return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor 176 | 177 | def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): 178 | seq_dim = default(seq_dim, self.default_seq_dim) 179 | 180 | assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' 181 | 182 | device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] 183 | 184 | seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) 185 | 186 | freqs = self.forward(seq, seq_len = seq_len, offset = offset) 187 | 188 | if seq_dim == -3: 189 | freqs = rearrange(freqs, 'n d -> n 1 d') 190 | 191 | return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) 192 | 193 | def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): 194 | dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) 195 | 196 | q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] 197 | assert q_len <= k_len 198 | 199 | q_scale = k_scale = 1. 200 | 201 | if self.use_xpos: 202 | seq = self.get_seq_pos(k_len, dtype = dtype, device = device) 203 | 204 | q_scale = self.get_scale(seq[-q_len:]).type(dtype) 205 | k_scale = self.get_scale(seq).type(dtype) 206 | 207 | rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) 208 | rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) 209 | 210 | rotated_q = rotated_q.type(q.dtype) 211 | rotated_k = rotated_k.type(k.dtype) 212 | 213 | return rotated_q, rotated_k 214 | 215 | def rotate_queries_and_keys(self, q, k, seq_dim = None): 216 | seq_dim = default(seq_dim, self.default_seq_dim) 217 | 218 | assert self.use_xpos 219 | device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] 220 | 221 | seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) 222 | 223 | freqs = self.forward(seq, seq_len = seq_len) 224 | scale = self.get_scale(seq, seq_len = seq_len).to(dtype) 225 | 226 | if seq_dim == -3: 227 | freqs = rearrange(freqs, 'n d -> n 1 d') 228 | scale = rearrange(scale, 'n d -> n 1 d') 229 | 230 | rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) 231 | rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) 232 | 233 | rotated_q = rotated_q.type(q.dtype) 234 | rotated_k = rotated_k.type(k.dtype) 235 | 236 | return rotated_q, rotated_k 237 | 238 | def get_scale( 239 | self, 240 | t: Tensor, 241 | seq_len: int | None = None, 242 | offset = 0 243 | ): 244 | assert self.use_xpos 245 | 246 | should_cache = ( 247 | self.cache_if_possible and 248 | exists(seq_len) and 249 | (offset + seq_len) <= self.cache_max_seq_len 250 | ) 251 | 252 | if ( 253 | should_cache and \ 254 | exists(self.cached_scales) and \ 255 | (seq_len + offset) <= self.cached_scales_seq_len 256 | ): 257 | return self.cached_scales[offset:(offset + seq_len)] 258 | 259 | scale = 1. 260 | if self.use_xpos: 261 | power = (t - len(t) // 2) / self.scale_base 262 | scale = self.scale ** rearrange(power, 'n -> n 1') 263 | scale = repeat(scale, 'n d -> n (d r)', r = 2) 264 | 265 | if should_cache and offset == 0: 266 | self.cached_scales[:seq_len] = scale.detach() 267 | self.cached_scales_seq_len = seq_len 268 | 269 | return scale 270 | 271 | def get_axial_freqs( 272 | self, 273 | *dims, 274 | offsets: ( 275 | tuple[int | float, ...] | 276 | Tensor | 277 | None 278 | ) = None 279 | ): 280 | Colon = slice(None) 281 | all_freqs = [] 282 | 283 | # handle offset 284 | 285 | if exists(offsets): 286 | if not is_tensor(offsets): 287 | offsets = tensor(offsets) 288 | 289 | assert len(offsets) == len(dims) 290 | 291 | # get frequencies for each axis 292 | 293 | for ind, dim in enumerate(dims): 294 | 295 | offset = 0 296 | if exists(offsets): 297 | offset = offsets[ind] 298 | 299 | if self.freqs_for == 'pixel': 300 | pos = torch.linspace(-1, 1, steps = dim, device = self.device) 301 | else: 302 | pos = torch.arange(dim, device = self.device) 303 | 304 | pos = pos + offset 305 | 306 | freqs = self.forward(pos, seq_len = dim) 307 | 308 | all_axis = [None] * len(dims) 309 | all_axis[ind] = Colon 310 | 311 | new_axis_slice = (Ellipsis, *all_axis, Colon) 312 | all_freqs.append(freqs[new_axis_slice]) 313 | 314 | # concat all freqs 315 | 316 | all_freqs = broadcast_tensors(*all_freqs) 317 | return torch.cat(all_freqs, dim = -1) 318 | 319 | @autocast('cuda', enabled = False) 320 | def forward( 321 | self, 322 | t: Tensor, 323 | seq_len: int | None = None, 324 | offset = 0 325 | ): 326 | should_cache = ( 327 | self.cache_if_possible and 328 | not self.learned_freq and 329 | exists(seq_len) and 330 | self.freqs_for != 'pixel' and 331 | (offset + seq_len) <= self.cache_max_seq_len 332 | ) 333 | 334 | if ( 335 | should_cache and \ 336 | exists(self.cached_freqs) and \ 337 | (offset + seq_len) <= self.cached_freqs_seq_len 338 | ): 339 | return self.cached_freqs[offset:(offset + seq_len)].detach() 340 | 341 | freqs = self.freqs 342 | 343 | freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) 344 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 345 | 346 | if should_cache and offset == 0: 347 | self.cached_freqs[:seq_len] = freqs.detach() 348 | self.cached_freqs_seq_len = seq_len 349 | 350 | return freqs 351 | --------------------------------------------------------------------------------