├── .github ├── FUNDING.yml └── workflows │ ├── python-publish.yml │ └── python-test.yaml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── images ├── all-attention.png ├── attention-on-attention.png ├── cosine-sim-attention.png ├── deepnorm.png ├── dynamic-pos-bias-linear.png ├── dynamic-pos-bias-log.png ├── dynamic-pos-bias-sinusoidal.png ├── dynamic-pos-bias.png ├── enhanced-recurrence.png ├── fcm.png ├── ffglu.png ├── flash-attention.png ├── gate_values.png ├── gating.png ├── length-extrapolation-scale.png ├── macaron-1.png ├── macaron-2.png ├── memory-transformer.png ├── normformer.png ├── pia.png ├── qknorm-analysis.png ├── resi_dual.png ├── residual_attn.png ├── rezero.png ├── rotary.png ├── sandwich-2.png ├── sandwich.png ├── sandwich_norm.png ├── scalenorm.png ├── talking-heads.png ├── topk-attention.png └── xval.png ├── pyproject.toml ├── tests └── test_x_transformers.py ├── train_belief_state.py ├── train_copy.py ├── train_entropy_tokenizer.py ├── train_enwik8.py ├── train_length_extrapolate.py ├── train_parity.py └── x_transformers ├── __init__.py ├── attend.py ├── autoregressive_wrapper.py ├── belief_state_wrapper.py ├── continuous.py ├── dpo.py ├── entropy_based_tokenizer.py ├── multi_input.py ├── neo_mlp.py ├── nonautoregressive_wrapper.py ├── x_transformers.py ├── xl_autoregressive_wrapper.py └── xval.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [lucidrains] 4 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/python-test.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip uv 29 | python -m uv pip install torch==2.1.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu 30 | python -m uv pip install .[test] 31 | - name: Test with pytest 32 | run: | 33 | pytest tests 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/data/enwik8.gz -------------------------------------------------------------------------------- /images/all-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/all-attention.png -------------------------------------------------------------------------------- /images/attention-on-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/attention-on-attention.png -------------------------------------------------------------------------------- /images/cosine-sim-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/cosine-sim-attention.png -------------------------------------------------------------------------------- /images/deepnorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/deepnorm.png -------------------------------------------------------------------------------- /images/dynamic-pos-bias-linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/dynamic-pos-bias-linear.png -------------------------------------------------------------------------------- /images/dynamic-pos-bias-log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/dynamic-pos-bias-log.png -------------------------------------------------------------------------------- /images/dynamic-pos-bias-sinusoidal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/dynamic-pos-bias-sinusoidal.png -------------------------------------------------------------------------------- /images/dynamic-pos-bias.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/dynamic-pos-bias.png -------------------------------------------------------------------------------- /images/enhanced-recurrence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/enhanced-recurrence.png -------------------------------------------------------------------------------- /images/fcm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/fcm.png -------------------------------------------------------------------------------- /images/ffglu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/ffglu.png -------------------------------------------------------------------------------- /images/flash-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/flash-attention.png -------------------------------------------------------------------------------- /images/gate_values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/gate_values.png -------------------------------------------------------------------------------- /images/gating.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/gating.png -------------------------------------------------------------------------------- /images/length-extrapolation-scale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/length-extrapolation-scale.png -------------------------------------------------------------------------------- /images/macaron-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/macaron-1.png -------------------------------------------------------------------------------- /images/macaron-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/macaron-2.png -------------------------------------------------------------------------------- /images/memory-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/memory-transformer.png -------------------------------------------------------------------------------- /images/normformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/normformer.png -------------------------------------------------------------------------------- /images/pia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/pia.png -------------------------------------------------------------------------------- /images/qknorm-analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/qknorm-analysis.png -------------------------------------------------------------------------------- /images/resi_dual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/resi_dual.png -------------------------------------------------------------------------------- /images/residual_attn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/residual_attn.png -------------------------------------------------------------------------------- /images/rezero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/rezero.png -------------------------------------------------------------------------------- /images/rotary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/rotary.png -------------------------------------------------------------------------------- /images/sandwich-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/sandwich-2.png -------------------------------------------------------------------------------- /images/sandwich.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/sandwich.png -------------------------------------------------------------------------------- /images/sandwich_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/sandwich_norm.png -------------------------------------------------------------------------------- /images/scalenorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/scalenorm.png -------------------------------------------------------------------------------- /images/talking-heads.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/talking-heads.png -------------------------------------------------------------------------------- /images/topk-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/topk-attention.png -------------------------------------------------------------------------------- /images/xval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/x-transformers/2133f4779302a0cefee37a70590ccbb0d51683a3/images/xval.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "x-transformers" 3 | version = "2.3.12" 4 | description = "X-Transformers" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'attention mechanism', 14 | 'transformers' 15 | ] 16 | classifiers=[ 17 | 'Development Status :: 4 - Beta', 18 | 'Intended Audience :: Developers', 19 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 20 | 'License :: OSI Approved :: MIT License', 21 | 'Programming Language :: Python :: 3.6', 22 | ] 23 | 24 | dependencies = [ 25 | 'einx>=0.3.0', 26 | 'einops>=0.8.0', 27 | 'loguru', 28 | 'packaging>=21.0', 29 | 'torch>=2.0', 30 | ] 31 | 32 | [project.urls] 33 | Homepage = "https://pypi.org/project/x-transformers/" 34 | Repository = "https://github.com/lucidrains/x-transformers" 35 | 36 | [project.optional-dependencies] 37 | examples = [ 38 | "lion-pytorch", 39 | "tqdm", 40 | ] 41 | 42 | test = [ 43 | "pytest", 44 | ] 45 | 46 | [build-system] 47 | requires = ["hatchling"] 48 | build-backend = "hatchling.build" 49 | 50 | 51 | [tool.pytest.ini_options] 52 | pythonpath = [ 53 | "." 54 | ] 55 | 56 | [tool.hatch.metadata] 57 | allow-direct-references = true 58 | 59 | [tool.hatch.build.targets.wheel] 60 | packages = ["x_transformers"] 61 | -------------------------------------------------------------------------------- /tests/test_x_transformers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module 6 | 7 | from x_transformers.x_transformers import ( 8 | XTransformer, 9 | TransformerWrapper, 10 | Encoder, 11 | Decoder, 12 | LinearNoBias, 13 | ) 14 | 15 | from x_transformers.neo_mlp import ( 16 | NeoMLP 17 | ) 18 | 19 | from x_transformers.multi_input import MultiInputTransformerWrapper 20 | 21 | def test_readme(): 22 | model = XTransformer( 23 | dim = 512, 24 | enc_num_tokens = 256, 25 | enc_depth = 6, 26 | enc_heads = 8, 27 | enc_max_seq_len = 1024, 28 | dec_num_tokens = 256, 29 | dec_depth = 6, 30 | dec_heads = 8, 31 | dec_max_seq_len = 1024, 32 | tie_token_emb = True 33 | ) 34 | 35 | src = torch.randint(0, 256, (1, 1024)) 36 | src_mask = torch.ones_like(src).bool() 37 | tgt = torch.randint(0, 256, (1, 1024)) 38 | 39 | loss = model(src, tgt, mask = src_mask) 40 | loss.backward() 41 | 42 | def test_kv_cache(): 43 | model = TransformerWrapper( 44 | num_tokens = 20000, 45 | max_seq_len = 1024, 46 | attn_layers = Decoder( 47 | dim = 8, 48 | depth = 2, 49 | heads = 4, 50 | cross_attend = True 51 | ) 52 | ) 53 | 54 | model.eval() 55 | 56 | prompts = torch.zeros((2, 16)) 57 | context = torch.randn(2, 8, 8) 58 | 59 | logits, cache = model( 60 | prompts, 61 | context = context, 62 | return_intermediates = True 63 | ) 64 | 65 | sampled = logits[:, -1].argmax(dim = -1, keepdim = True) 66 | prompts = torch.cat((prompts, sampled), dim = -1) 67 | 68 | next_logits = model(prompts, context = context) 69 | next_logits_with_cache = model(prompts, context = context, cache = cache) 70 | 71 | assert torch.allclose(next_logits[:, -1], next_logits_with_cache[:, -1], atol = 1e-6) 72 | 73 | def test_cope(): 74 | model = TransformerWrapper( 75 | num_tokens = 256, 76 | max_seq_len = 1024, 77 | attn_layers = Decoder( 78 | dim = 8, 79 | depth = 2, 80 | heads = 4, 81 | attn_use_cope = True 82 | ) 83 | ) 84 | 85 | seq = torch.randint(0, 256, (1, 1024)) 86 | logits = model(seq) 87 | 88 | def test_adaptive_layernorm(): 89 | model = TransformerWrapper( 90 | num_tokens = 20000, 91 | max_seq_len = 1024, 92 | attn_layers = Decoder( 93 | dim = 512, 94 | dim_condition = 768, 95 | depth = 12, 96 | heads = 8, 97 | use_adaptive_layernorm = True, 98 | use_adaptive_layerscale = True 99 | ) 100 | ) 101 | 102 | x = torch.randint(0, 256, (2, 1024)) 103 | condition = torch.randn(2, 768) 104 | 105 | model(x, condition = condition) 106 | 107 | def test_adaptive_rmsnorm(): 108 | model = TransformerWrapper( 109 | num_tokens = 20000, 110 | max_seq_len = 1024, 111 | attn_layers = Decoder( 112 | dim = 512, 113 | dim_condition = 768, 114 | depth = 12, 115 | heads = 8, 116 | use_adaptive_rmsnorm = True, 117 | adaptive_condition_mlp = True 118 | ) 119 | ) 120 | 121 | x = torch.randint(0, 256, (2, 1024)) 122 | condition = torch.randn(2, 768) 123 | 124 | model(x, condition = condition) 125 | 126 | def test_attn_softclamp_logits(): 127 | model = TransformerWrapper( 128 | num_tokens = 20000, 129 | max_seq_len = 1024, 130 | attn_layers = Decoder( 131 | dim = 512, 132 | dim_condition = 768, 133 | depth = 12, 134 | heads = 8, 135 | attn_softclamp_logits = True, 136 | ) 137 | ) 138 | 139 | x = torch.randint(0, 256, (1, 1024)) 140 | 141 | model(x) 142 | 143 | def test_multiple_input_embeds(): 144 | model = MultiInputTransformerWrapper( 145 | num_tokens = dict( 146 | note = 20000, 147 | pitch = 32, 148 | tone = 16 149 | ), 150 | max_seq_len = 1024, 151 | return_only_embed = True, 152 | attn_layers = Decoder( 153 | dim = 128, 154 | depth = 6, 155 | heads = 8 156 | ) 157 | ) 158 | 159 | x = dict( 160 | note = torch.randint(0, 20000, (2, 1024)), 161 | pitch = torch.randint(0, 32, (2, 1024)), 162 | tone = torch.randint(0, 16, (2, 1024)) 163 | ) 164 | 165 | embed = model(x) 166 | 167 | assert embed.shape == (2, 1024, 128) 168 | 169 | def test_average_pool_embed(): 170 | model = TransformerWrapper( 171 | num_tokens = 20000, 172 | max_seq_len = 1024, 173 | num_memory_tokens = 2, 174 | average_pool_embed = True, 175 | attn_layers = Encoder( 176 | dim = 128, 177 | depth = 6, 178 | heads = 8 179 | ) 180 | ) 181 | 182 | x = torch.randint(0, 20000, (2, 1024)) 183 | mask = torch.randint(0, 2, (2, 1024)).bool() 184 | 185 | logits = model(x, mask = mask) 186 | 187 | assert logits.shape == (2, 20000) 188 | 189 | @pytest.mark.parametrize('num_cls_tokens', (1, 2)) 190 | def test_cls_token(num_cls_tokens): 191 | model = TransformerWrapper( 192 | num_tokens = 20000, 193 | max_seq_len = 1024, 194 | num_memory_tokens = 2, 195 | use_cls_token = True, 196 | num_cls_tokens=num_cls_tokens, 197 | attn_layers = Encoder( 198 | dim = 128, 199 | depth = 6, 200 | heads = 8 201 | ) 202 | ) 203 | 204 | x = torch.randint(0, 20000, (2, 1024)) 205 | mask = torch.randint(0, 2, (2, 1024)).bool() 206 | 207 | logits = model(x, mask = mask) 208 | 209 | if num_cls_tokens == 1: 210 | expected_shape = (2, 20000) 211 | else: 212 | expected_shape = (2, num_cls_tokens, 20000) 213 | 214 | assert logits.shape == expected_shape 215 | 216 | def test_squeeze_logit_dim_one(): 217 | model = TransformerWrapper( 218 | num_tokens = 20000, 219 | max_seq_len = 1024, 220 | logits_dim = 1, 221 | average_pool_embed = True, 222 | squeeze_out_last_dim = True, 223 | attn_layers = Encoder( 224 | dim = 128, 225 | depth = 6, 226 | heads = 8 227 | ) 228 | ) 229 | 230 | x = torch.randint(0, 20000, (2, 1024)) 231 | mask = torch.randint(0, 2, (2, 1024)).bool() 232 | 233 | logits = model(x, mask = mask) 234 | 235 | assert logits.shape == (2,) 236 | 237 | @pytest.mark.parametrize('depth', (4, 5)) 238 | def test_unet_skip(depth): 239 | 240 | model = TransformerWrapper( 241 | num_tokens = 20000, 242 | max_seq_len = 1024, 243 | attn_layers = Encoder( 244 | dim = 128, 245 | depth = depth, 246 | heads = 8, 247 | unet_skips = True 248 | ) 249 | ) 250 | 251 | x = torch.randint(0, 20000, (2, 1024)) 252 | mask = torch.randint(0, 2, (2, 1024)).bool() 253 | 254 | model(x, mask = mask) 255 | 256 | def test_recycling(): 257 | model = TransformerWrapper( 258 | num_tokens = 20000, 259 | max_seq_len = 1024, 260 | recycling = True, 261 | train_max_recycle_steps = 5, 262 | attn_layers = Decoder( 263 | dim = 128, 264 | depth = 6, 265 | heads = 8 266 | ) 267 | ) 268 | 269 | x = torch.randint(0, 20000, (2, 1024)) 270 | 271 | logits = model(x) 272 | 273 | model.eval() 274 | 275 | eval_logits = model(x, recycle_steps = 3) 276 | 277 | def test_mos(): 278 | model = TransformerWrapper( 279 | num_tokens = 20000, 280 | max_seq_len = 1024, 281 | mixture_of_softmax = True, 282 | attn_layers = Decoder( 283 | dim = 128, 284 | depth = 6, 285 | heads = 8 286 | ) 287 | ) 288 | 289 | x = torch.randint(0, 20000, (2, 1024)) 290 | 291 | logits = model(x) 292 | 293 | model.eval() 294 | 295 | eval_logits = model(x) 296 | 297 | @pytest.mark.parametrize('attn_one_kv_head', (True, False)) 298 | def test_l2_distance(attn_one_kv_head): 299 | 300 | model = TransformerWrapper( 301 | num_tokens = 20000, 302 | max_seq_len = 1024, 303 | attn_layers = Decoder( 304 | dim = 512, 305 | depth = 12, 306 | heads = 8, 307 | attn_l2_distance = True, 308 | attn_one_kv_head = attn_one_kv_head, 309 | ) 310 | ) 311 | 312 | x = torch.randint(0, 256, (1, 1024)) 313 | 314 | model(x) 315 | 316 | def test_reinject_input(): 317 | 318 | model = TransformerWrapper( 319 | num_tokens = 20000, 320 | max_seq_len = 1024, 321 | recycling = True, 322 | attn_layers = Decoder( 323 | dim = 512, 324 | depth = 12, 325 | heads = 8, 326 | reinject_input = True 327 | ) 328 | ) 329 | 330 | x = torch.randint(0, 256, (1, 1024)) 331 | 332 | model(x) # (1, 1024, 20000) 333 | 334 | @pytest.mark.parametrize('learned_value_residual_mix', (False, True)) 335 | def test_value_residual( 336 | learned_value_residual_mix: bool 337 | ): 338 | 339 | model = TransformerWrapper( 340 | num_tokens = 20000, 341 | max_seq_len = 1024, 342 | attn_layers = Decoder( 343 | dim = 128, 344 | depth = 6, 345 | heads = 8, 346 | add_value_residual = True, 347 | learned_value_residual_mix = learned_value_residual_mix 348 | ) 349 | ) 350 | 351 | x = torch.randint(0, 20000, (2, 1024)) 352 | 353 | model(x) 354 | 355 | @pytest.mark.parametrize('has_num_mem_kv', (False, True)) 356 | def test_forgetting_transformer( 357 | has_num_mem_kv: bool 358 | ): 359 | 360 | model = TransformerWrapper( 361 | num_tokens = 20000, 362 | max_seq_len = 1024, 363 | attn_layers = Decoder( 364 | dim = 128, 365 | depth = 6, 366 | heads = 8, 367 | attn_num_mem_kv = 1 if has_num_mem_kv else 0, 368 | attn_data_dependent_alibi = True 369 | ) 370 | ) 371 | 372 | x = torch.randint(0, 20000, (2, 1024)) 373 | 374 | embed = model(x) 375 | 376 | def test_neo_mlp(): 377 | 378 | mlp = NeoMLP( 379 | dim_in = 5, 380 | dim_out = 7, 381 | dim_hidden = 16, 382 | depth = 5, 383 | dim_model = 64, 384 | ) 385 | 386 | x = torch.randn(3, 5) 387 | 388 | out = mlp(x) 389 | assert out.shape == (3, 7) 390 | 391 | @pytest.mark.parametrize('flash', (True, False)) 392 | def test_custom_alibi(flash: bool): 393 | 394 | model = TransformerWrapper( 395 | num_tokens = 20_000, 396 | max_seq_len = 1024, 397 | attn_layers = Decoder( 398 | dim = 512, 399 | depth = 2, 400 | heads = 8, 401 | alibi_pos_bias = True, 402 | attn_flash = flash 403 | ) 404 | ) 405 | 406 | x = torch.randint(0, 20000, (2, 4)) 407 | 408 | pos = torch.tensor([[0, 1, 2, 4], [1, 3, 5, 7]]) 409 | 410 | logits = model(x, pos = pos) 411 | 412 | @pytest.mark.parametrize('rotary_xpos', (True, False)) 413 | def test_custom_rotary_pos_emb(rotary_xpos): 414 | from einops import repeat 415 | 416 | model = TransformerWrapper( 417 | num_tokens = 20_000, 418 | max_seq_len = 1024, 419 | attn_layers = Decoder( 420 | dim = 512, 421 | depth = 2, 422 | heads = 8, 423 | rotary_pos_emb = True, 424 | rotary_xpos = rotary_xpos 425 | ) 426 | ) 427 | 428 | x = torch.randint(0, 20000, (4, 4)) 429 | 430 | pos = repeat(torch.arange(0, 4), "n -> b n", b=4) 431 | 432 | logits1 = model(x, pos = pos) 433 | logits2 = model(x) 434 | assert torch.allclose(logits1, logits2) 435 | 436 | @pytest.mark.parametrize('flash', (True, False)) 437 | def test_custom_alibi_across_heads(flash: bool): 438 | model = Decoder( 439 | dim = 512, 440 | depth = 2, 441 | heads = 2, 442 | alibi_pos_bias = True, 443 | rel_pos_kwargs = dict( 444 | slopes = [1, 1] 445 | ), 446 | attn_flash = flash 447 | ) 448 | 449 | x = torch.randn(2, 4, 512) 450 | 451 | pos = torch.tensor([ 452 | [[0, 1, 2, 4], [1, 3, 5, 7]], 453 | [[2, 3, 4, 5], [6, 8, 9, 10]] 454 | ]) 455 | 456 | embed = model(x, pos = pos) 457 | 458 | @pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom')) 459 | def test_embedder(embedder_type): 460 | num_tokens = 20000 461 | dim = 128 462 | token_emb_kwargs = {} 463 | 464 | if embedder_type == 'embedding': 465 | embedder = nn.Embedding(num_tokens, dim) 466 | elif embedder_type == 'none': 467 | embedder = None 468 | else: 469 | class CustomEmbedder(Module): 470 | """ 471 | Made up embedder that sums two embeddings. Just to check if we can pass additional input to the embedder's 472 | forward pass without breaking the model. 473 | """ 474 | def __init__(self, num_tokens, dim): 475 | super().__init__() 476 | self.embed_x = nn.Embedding(num_tokens, dim) 477 | self.embed_y = nn.Embedding(num_tokens, dim) 478 | 479 | def forward(self, x, y): 480 | return self.embed_x(x) + self.embed_y(y) 481 | 482 | def init_(self): 483 | pass 484 | 485 | embedder = CustomEmbedder(num_tokens, dim) 486 | token_emb_kwargs['y'] = torch.randint(0, num_tokens, (2, 1024)) 487 | 488 | model = TransformerWrapper( 489 | num_tokens = num_tokens, 490 | max_seq_len = 1024, 491 | attn_layers = Decoder( 492 | dim = dim, 493 | depth = 6, 494 | heads = 8, 495 | ), 496 | token_emb = embedder, 497 | ) 498 | 499 | x = torch.randint(0, 20000, (2, 1024)) 500 | 501 | output = model(x, token_emb_kwargs=token_emb_kwargs) 502 | assert output.shape == (2, 1024, 20000) 503 | 504 | 505 | @pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer')) 506 | def test_to_logits(to_logits): 507 | num_tokens = 20000 508 | dim = 128 509 | 510 | to_logits_kwargs = {} 511 | 512 | if to_logits == 'linear': 513 | logit_mapper = LinearNoBias(dim, num_tokens) 514 | elif to_logits == 'none': 515 | logit_mapper = None 516 | else: 517 | class PointerNetworkLogits(Module): 518 | def __init__(self, dim): 519 | super().__init__() 520 | self.proj_to_pointers = nn.Linear(dim, dim) 521 | 522 | def forward(self, model_embeddings, input_embeddings): 523 | pointers = self.proj_to_pointers(model_embeddings) 524 | logits = torch.matmul(pointers, input_embeddings.permute(0, 2, 1)) 525 | return logits 526 | 527 | logit_mapper = PointerNetworkLogits(dim) 528 | to_logits_kwargs['input_embeddings'] = torch.randn(2, 20000, dim) 529 | 530 | model = TransformerWrapper( 531 | num_tokens = num_tokens, 532 | max_seq_len = 1024, 533 | attn_layers = Decoder( 534 | dim = dim, 535 | depth = 6, 536 | heads = 8, 537 | ), 538 | to_logits = logit_mapper, 539 | ) 540 | 541 | x = torch.randint(0, num_tokens, (2, 1024)) 542 | 543 | output = model(x, to_logits_kwargs=to_logits_kwargs) 544 | 545 | assert output.shape == (2, 1024, 20000) 546 | 547 | def test_laser(): 548 | model = TransformerWrapper( 549 | num_tokens = 20000, 550 | max_seq_len = 1024, 551 | attn_layers = Decoder( 552 | dim = 128, 553 | depth = 6, 554 | heads = 8, 555 | attn_laser = True 556 | ) 557 | ) 558 | 559 | x = torch.randint(0, 20000, (2, 1024)) 560 | 561 | model(x) 562 | 563 | @pytest.mark.parametrize('self_attn_custom_pos', (True, False)) 564 | @pytest.mark.parametrize('cross_attn_rotary', (True, False)) 565 | def test_cross_attn_rotary( 566 | self_attn_custom_pos: bool, 567 | cross_attn_rotary: bool 568 | ): 569 | 570 | x = torch.randn((1, 64, 256)) 571 | mask = torch.ones((1, 64)).bool() 572 | context = torch.randn((1, 128, 512)) 573 | context_mask = torch.ones((1, 128)).bool() 574 | 575 | model = Encoder( 576 | dim = 256, 577 | depth = 4, 578 | heads = 4, 579 | rotary_pos_emb = True, 580 | cross_attend = True, 581 | cross_attn_dim_context = 512 582 | ) 583 | 584 | pos = torch.arange(64) if self_attn_custom_pos else None 585 | context_pos = torch.arange(128) if cross_attn_rotary else None 586 | 587 | embed = model( 588 | x = x, 589 | mask = mask, 590 | context = context, 591 | pos = pos, 592 | context_pos = context_pos, 593 | context_mask = context_mask 594 | ) 595 | 596 | @pytest.mark.parametrize('tanh', (True, False)) 597 | def test_hyper_connections(tanh): 598 | 599 | model = TransformerWrapper( 600 | num_tokens = 20000, 601 | max_seq_len = 1024, 602 | attn_layers = Decoder( 603 | dim = 128, 604 | depth = 6, 605 | heads = 8, 606 | num_residual_streams = 8, # 8 dynamic hyper connection residual streams 607 | residual_fn_kwargs = dict( 608 | tanh = tanh 609 | ) 610 | ) 611 | ) 612 | 613 | x = torch.randint(0, 20000, (2, 1024)) 614 | 615 | model(x) 616 | 617 | @pytest.mark.parametrize('hybrid_axial_dim', (1, 4)) 618 | def test_hybrid(hybrid_axial_dim): 619 | from torch.nn import GRU 620 | 621 | dec = TransformerWrapper( 622 | num_tokens = 20000, 623 | max_seq_len = 1024, 624 | attn_layers = Decoder( 625 | dim = 128, 626 | depth = 6, 627 | heads = 8, 628 | attn_dim_head = 64, 629 | attn_hybrid_fold_axial_dim = hybrid_axial_dim, 630 | attn_hybrid_module = GRU(128, 64 * 8, batch_first = True) 631 | ) 632 | ) 633 | 634 | x = torch.randint(0, 20000, (2, 1024)) 635 | 636 | embed = dec(x) 637 | 638 | enc = TransformerWrapper( 639 | num_tokens = 20000, 640 | max_seq_len = 1024, 641 | attn_layers = Encoder( 642 | dim = 128, 643 | depth = 6, 644 | heads = 8, 645 | attn_dim_head = 64, 646 | attn_hybrid_fold_axial_dim = hybrid_axial_dim, 647 | attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True) 648 | ) 649 | ) 650 | 651 | mask = torch.randint(0, 2, (2, 1024)).bool() 652 | embed = enc(x, mask = mask) 653 | 654 | def test_multi_latent_attention(): 655 | model = TransformerWrapper( 656 | num_tokens = 20000, 657 | max_seq_len = 1024, 658 | attn_layers = Decoder( 659 | dim = 128, 660 | depth = 6, 661 | heads = 8, 662 | attn_use_latent_q = True, 663 | attn_dim_latent_q = 128, 664 | attn_use_latent_kv = True, 665 | attn_dim_latent_kv = 128, 666 | attn_latent_rope_subheads = 4, 667 | rotary_pos_emb = False 668 | ) 669 | ) 670 | 671 | x = torch.randint(0, 20000, (2, 1024)) 672 | 673 | model(x) 674 | 675 | @pytest.mark.parametrize('num_residual_streams', (1, 4)) 676 | @pytest.mark.parametrize('integrate_layers', (False, True)) 677 | def test_lime( 678 | num_residual_streams, 679 | integrate_layers 680 | ): 681 | model = TransformerWrapper( 682 | num_tokens = 20000, 683 | max_seq_len = 1024, 684 | attn_layers = Decoder( 685 | dim = 128, 686 | depth = 6, 687 | heads = 8, 688 | num_residual_streams = num_residual_streams, 689 | integrate_layers = integrate_layers 690 | ) 691 | ) 692 | 693 | x = torch.randint(0, 20000, (2, 1024)) 694 | 695 | model(x) 696 | 697 | @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5)) 698 | @pytest.mark.parametrize('goal_suffix', (False, True)) 699 | @pytest.mark.parametrize('pred_distance', (False, True)) 700 | @pytest.mark.parametrize('variable_len', (False, True)) 701 | def test_belief_state_wrapper( 702 | backward_ar_loss_weight, 703 | goal_suffix, 704 | pred_distance, 705 | variable_len 706 | ): 707 | from x_transformers.belief_state_wrapper import BeliefStateWrapper 708 | 709 | forward_model = TransformerWrapper( 710 | num_tokens = 20000, 711 | max_seq_len = 1024, 712 | attn_layers = Decoder( 713 | dim = 512, 714 | depth = 6, 715 | heads = 8, 716 | rotary_pos_emb = True 717 | ) 718 | ) 719 | 720 | backward_model = TransformerWrapper( 721 | num_tokens = 20000, 722 | max_seq_len = 1024, 723 | attn_layers = Decoder( 724 | dim = 512, 725 | depth = 6, 726 | heads = 8, 727 | rotary_pos_emb = True 728 | ) 729 | ) 730 | 731 | model = BeliefStateWrapper( 732 | forward_decoder = forward_model, 733 | backward_decoder = backward_model, 734 | backward_ar_loss_weight = backward_ar_loss_weight, 735 | pred_distance = pred_distance 736 | ) 737 | 738 | seq = torch.randint(0, 20000, (2, 16)) 739 | 740 | lens = None 741 | 742 | if variable_len: 743 | lens = torch.randint(4, 16, (2,)) 744 | 745 | loss = model(seq, lens = lens) # backwards happen automatically 746 | loss.backward() 747 | 748 | suffix = None 749 | if goal_suffix: 750 | suffix = torch.randint(0, 20000, (2, 2)) 751 | 752 | sampled = model.generate_with_suffix_cond(seq[:, :1], 16, suffix = suffix) 753 | assert sampled.shape == (2, 16) 754 | 755 | def test_dynamic_tanh(): 756 | model = TransformerWrapper( 757 | num_tokens = 20000, 758 | max_seq_len = 1024, 759 | attn_layers = Decoder( 760 | dim = 128, 761 | depth = 6, 762 | heads = 8, 763 | use_dynamic_tanh = True, 764 | dynamic_tanh_init_alpha = 1.5 765 | ) 766 | ) 767 | 768 | x = torch.randint(0, 20000, (2, 1024)) 769 | 770 | model(x) 771 | 772 | @pytest.mark.parametrize('var_length', (False, True)) 773 | def test_entropy_based_tokenizer( 774 | var_length 775 | ): 776 | from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer 777 | 778 | model = TransformerWrapper( 779 | num_tokens = 20000, 780 | max_seq_len = 1024, 781 | attn_layers = Decoder( 782 | dim = 128, 783 | depth = 6, 784 | heads = 8, 785 | attn_dim_head = 64, 786 | ) 787 | ) 788 | 789 | tokenizer = EntropyBasedTokenizer(model, entropy_threshold = 9.738) 790 | 791 | seq = torch.randint(0, 20000, (2, 1024)) 792 | 793 | lens = None 794 | if var_length: 795 | lens = torch.randint(512, 768, (2,)) 796 | 797 | segmented_seq = tokenizer(seq, lens, return_segmented_seq = True) 798 | 799 | assert len(segmented_seq) == seq.shape[0] 800 | 801 | tokenizer(seq[0]) # able to handle without batch dim 802 | 803 | def test_entropy_based_tokenizer_max_token_len(): 804 | from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer 805 | 806 | model = TransformerWrapper( 807 | num_tokens = 20000, 808 | max_seq_len = 1024, 809 | attn_layers = Decoder( 810 | dim = 128, 811 | depth = 6, 812 | heads = 8, 813 | attn_dim_head = 64, 814 | ) 815 | ) 816 | 817 | tokenizer = EntropyBasedTokenizer( 818 | model, 819 | entropy_threshold = 100, 820 | max_token_size = 4 821 | ) 822 | 823 | seq = torch.randint(0, 20000, (1, 16,)) 824 | lens = torch.tensor([14]) 825 | 826 | token_lengths = tokenizer(seq, lens = lens) 827 | 828 | assert token_lengths.amax().item() <= 4 829 | assert token_lengths.sum().item() == 14 830 | 831 | def test_custom_ff_activation(): 832 | 833 | model = TransformerWrapper( 834 | num_tokens = 20000, 835 | max_seq_len = 1024, 836 | attn_layers = Decoder( 837 | dim = 128, 838 | depth = 6, 839 | heads = 8, 840 | attn_dim_head = 64, 841 | ff_custom_activation = nn.Sigmoid() 842 | ) 843 | ) 844 | 845 | seq = torch.randint(0, 20000, (2, 1024)) 846 | 847 | logits = model(seq) 848 | 849 | assert logits.shape == (2, 1024, 20000) 850 | 851 | def test_ff_deep_embed(): 852 | 853 | model = TransformerWrapper( 854 | num_tokens = 20000, 855 | max_seq_len = 1024, 856 | ff_deep_embed = True, 857 | attn_layers = Decoder( 858 | dim = 512, 859 | depth = 6, 860 | heads = 8, 861 | rotary_pos_emb = True, 862 | ) 863 | ) 864 | 865 | seq = torch.randint(0, 20000, (2, 1024)) 866 | 867 | logits = model(seq) 868 | 869 | assert logits.shape == (2, 1024, 20000) 870 | 871 | @pytest.mark.parametrize('probabilistic', (False, True)) 872 | @pytest.mark.parametrize('cache_kv', (False, True)) 873 | def test_continuous( 874 | probabilistic, 875 | cache_kv 876 | ): 877 | from x_transformers import ( 878 | ContinuousTransformerWrapper, 879 | Decoder, 880 | ContinuousAutoregressiveWrapper 881 | ) 882 | 883 | model = ContinuousTransformerWrapper( 884 | dim_in = 777, 885 | dim_out = 777, 886 | max_seq_len = 1024, 887 | probabilistic = probabilistic, 888 | attn_layers = Decoder( 889 | dim = 512, 890 | depth = 12, 891 | heads = 8 892 | ) 893 | ) 894 | 895 | # wrap it with the continuous autoregressive wrapper 896 | 897 | model = ContinuousAutoregressiveWrapper(model) 898 | 899 | # mock data 900 | 901 | x = torch.randn((1, 1024, 777)) 902 | mask = torch.ones(1, 1024).bool() 903 | 904 | # train on a lot of data above 905 | 906 | loss = model(x, mask = mask) 907 | loss.backward() 908 | 909 | # then generate 910 | 911 | start_emb = torch.randn(1, 777) 912 | generated = model.generate(start_emb, 17, cache_kv = cache_kv) # (17, 777) 913 | assert generated.shape == (17, 777) 914 | -------------------------------------------------------------------------------- /train_belief_state.py: -------------------------------------------------------------------------------- 1 | from x_transformers import TransformerWrapper, Decoder, BeliefStateWrapper 2 | from x_transformers.autoregressive_wrapper import AutoregressiveWrapper 3 | 4 | import random 5 | import tqdm 6 | import gzip 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | # constants 14 | 15 | NUM_BATCHES = int(1e5) 16 | BATCH_SIZE = 2 17 | GRADIENT_ACCUMULATE_EVERY = 8 18 | LEARNING_RATE = 1e-4 19 | VALIDATE_EVERY = 100 20 | GENERATE_EVERY = 500 21 | GENERATE_LENGTH = 256 22 | SEQ_LEN = 256 23 | 24 | FORWARD_BACKWARD_SAME_MODEL = True 25 | 26 | # helpers 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | def decode_token(token): 34 | return str(chr(max(32, token))) 35 | 36 | def decode_tokens(tokens): 37 | return ''.join(list(map(decode_token, tokens))) 38 | 39 | # instantiate GPT-like decoder model for forward and backwards 40 | 41 | forward_model = TransformerWrapper( 42 | num_tokens = 256, 43 | max_seq_len = SEQ_LEN, 44 | attn_layers = Decoder( 45 | dim = 512, 46 | depth = 6, 47 | heads = 8, 48 | rotary_pos_emb = True 49 | ) 50 | ) 51 | 52 | backward_model = None 53 | 54 | if not FORWARD_BACKWARD_SAME_MODEL: 55 | backward_model = TransformerWrapper( 56 | num_tokens = 256, 57 | max_seq_len = SEQ_LEN, 58 | attn_layers = Decoder( 59 | dim = 512, 60 | depth = 4, # do a smaller backwards 61 | heads = 8, 62 | rotary_pos_emb = True 63 | ) 64 | ) 65 | 66 | model = BeliefStateWrapper( 67 | forward_decoder = forward_model, 68 | backward_decoder = backward_model 69 | ) 70 | 71 | model.cuda() 72 | 73 | # prepare enwik8 data 74 | 75 | with gzip.open('./data/enwik8.gz') as file: 76 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 77 | train_x, valid_x = np.split(data, [int(90e6)]) 78 | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) 79 | 80 | class TextSamplerDataset(Dataset): 81 | def __init__(self, data, seq_len): 82 | super().__init__() 83 | self.data = data 84 | self.seq_len = seq_len 85 | 86 | def __getitem__(self, index): 87 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) 88 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 89 | return full_seq.cuda() 90 | 91 | def __len__(self): 92 | return self.data.size(0) // self.seq_len 93 | 94 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 95 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 96 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) 97 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)) 98 | 99 | # optimizer 100 | 101 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 102 | 103 | # training 104 | 105 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'): 106 | model.train() 107 | 108 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 109 | loss = model(next(train_loader)) 110 | (loss / GRADIENT_ACCUMULATE_EVERY).backward() 111 | 112 | print(f'training loss: {loss.item()}') 113 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 114 | optim.step() 115 | optim.zero_grad() 116 | 117 | if i % VALIDATE_EVERY == 0: 118 | model.eval() 119 | with torch.no_grad(): 120 | loss = model(next(val_loader)) 121 | print(f'validation loss: {loss.item()}') 122 | 123 | if i % GENERATE_EVERY == 0: 124 | model.eval() 125 | inp = random.choice(val_dataset)[:-1] 126 | prime = decode_tokens(inp) 127 | 128 | print(f'%s \n\n %s', (prime, '*' * 100)) 129 | 130 | print('forwards:\n') 131 | 132 | sample = model.generate_with_suffix_cond( 133 | prompts = inp, 134 | seq_len = GENERATE_LENGTH, 135 | cache_kv = True 136 | ) 137 | 138 | output_str = decode_tokens(sample) 139 | print(output_str) 140 | 141 | print('\nbackwards:\n') 142 | 143 | sample = model.generate_with_suffix_cond( 144 | prompts = inp, 145 | seq_len = GENERATE_LENGTH, 146 | cache_kv = True, 147 | decode_backwards = True 148 | ) 149 | 150 | output_str = decode_tokens(sample.flip(0)) 151 | print(output_str) 152 | -------------------------------------------------------------------------------- /train_copy.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | import torch.optim as optim 4 | from x_transformers import XTransformer 5 | 6 | # constants 7 | 8 | NUM_BATCHES = int(1e5) 9 | BATCH_SIZE = 32 10 | LEARNING_RATE = 3e-4 11 | GENERATE_EVERY = 100 12 | NUM_TOKENS = 16 + 2 13 | ENC_SEQ_LEN = 32 14 | DEC_SEQ_LEN = 64 + 1 15 | 16 | # helpers 17 | 18 | def cycle(): 19 | while True: 20 | prefix = torch.ones((BATCH_SIZE, 1)).long().cuda() 21 | src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda() 22 | tgt = torch.cat((prefix, src, src), 1) 23 | src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda() 24 | yield (src, tgt, src_mask) 25 | 26 | # instantiate model 27 | 28 | model = XTransformer( 29 | dim = 512, 30 | tie_token_emb = True, 31 | return_tgt_loss = True, 32 | enc_num_tokens=NUM_TOKENS, 33 | enc_depth = 3, 34 | enc_heads = 8, 35 | enc_max_seq_len = ENC_SEQ_LEN, 36 | dec_num_tokens = NUM_TOKENS, 37 | dec_depth = 3, 38 | dec_heads = 8, 39 | dec_max_seq_len = DEC_SEQ_LEN 40 | ).cuda() 41 | 42 | # optimizer 43 | 44 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 45 | 46 | # training 47 | 48 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 49 | model.train() 50 | 51 | src, tgt, src_mask = next(cycle()) 52 | 53 | loss = model(src, tgt, mask=src_mask) 54 | loss.backward() 55 | print(f'{i}: {loss.item()}') 56 | 57 | optim.step() 58 | optim.zero_grad() 59 | 60 | if i != 0 and i % GENERATE_EVERY == 0: 61 | model.eval() 62 | src, _, src_mask = next(cycle()) 63 | src, src_mask = src[:1], src_mask[:1] 64 | start_tokens = (torch.ones((1, 1)) * 1).long().cuda() 65 | 66 | sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask) 67 | incorrects = (src != sample).abs().sum() 68 | 69 | print(f"input: ", src) 70 | print(f"predicted output: ", sample) 71 | print(f"incorrects: {incorrects}") 72 | -------------------------------------------------------------------------------- /train_entropy_tokenizer.py: -------------------------------------------------------------------------------- 1 | from x_transformers import TransformerWrapper, Decoder 2 | from x_transformers.autoregressive_wrapper import AutoregressiveWrapper 3 | from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer 4 | 5 | import random 6 | import tqdm 7 | import gzip 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | from torch.nn import functional as F 12 | from torch.utils.data import DataLoader, Dataset 13 | 14 | # constants 15 | 16 | NUM_BATCHES = int(1e5) 17 | BATCH_SIZE = 4 18 | GRADIENT_ACCUMULATE_EVERY = 4 19 | LEARNING_RATE = 1e-4 20 | VALIDATE_EVERY = 100 21 | GENERATE_EVERY = 100 22 | GENERATE_LENGTH = 1024 23 | SEQ_LEN = 1024 24 | 25 | # helpers 26 | 27 | def cycle(loader): 28 | while True: 29 | for data in loader: 30 | yield data 31 | 32 | def decode_token(token): 33 | return str(chr(max(32, token))) 34 | 35 | def decode_tokens(tokens): 36 | return ''.join(list(map(decode_token, tokens))) 37 | 38 | # instantiate GPT-like decoder model 39 | 40 | model = TransformerWrapper( 41 | num_tokens = 256, 42 | max_seq_len = SEQ_LEN, 43 | attn_layers = Decoder( 44 | dim = 512, 45 | depth = 6, 46 | heads = 8, 47 | rotary_pos_emb = True 48 | ) 49 | ) 50 | 51 | tokenizer = EntropyBasedTokenizer( 52 | model, 53 | entropy_threshold = 2.5 54 | ) 55 | 56 | model = AutoregressiveWrapper(model) 57 | model.cuda() 58 | 59 | # prepare enwik8 data 60 | 61 | with gzip.open('./data/enwik8.gz') as file: 62 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 63 | train_x, valid_x = np.split(data, [int(90e6)]) 64 | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) 65 | 66 | class TextSamplerDataset(Dataset): 67 | def __init__(self, data, seq_len): 68 | super().__init__() 69 | self.data = data 70 | self.seq_len = seq_len 71 | 72 | def __getitem__(self, index): 73 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) 74 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 75 | return full_seq.cuda() 76 | 77 | def __len__(self): 78 | return self.data.size(0) // self.seq_len 79 | 80 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 81 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 82 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) 83 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)) 84 | 85 | # optimizer 86 | 87 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 88 | 89 | # training 90 | 91 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 92 | model.train() 93 | 94 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 95 | loss = model(next(train_loader)) 96 | (loss / GRADIENT_ACCUMULATE_EVERY).backward() 97 | 98 | print(f'training loss: {loss.item()}') 99 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 100 | optim.step() 101 | optim.zero_grad() 102 | 103 | if i % VALIDATE_EVERY == 0: 104 | model.eval() 105 | with torch.no_grad(): 106 | loss = model(next(val_loader)) 107 | print(f'validation loss: {loss.item()}') 108 | 109 | if i % GENERATE_EVERY == 0: 110 | model.eval() 111 | inp = random.choice(val_dataset)[:-1] 112 | 113 | tokens = tokenizer(inp, return_segmented_seq = True) 114 | 115 | delimiter = " \u275A " 116 | output_str = delimiter.join([decode_tokens(token) for token in tokens]) 117 | 118 | print(f"{output_str}\n\n") 119 | -------------------------------------------------------------------------------- /train_enwik8.py: -------------------------------------------------------------------------------- 1 | from x_transformers import TransformerWrapper, Decoder 2 | from x_transformers.autoregressive_wrapper import AutoregressiveWrapper 3 | 4 | import random 5 | import tqdm 6 | import gzip 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | # constants 14 | 15 | NUM_BATCHES = int(1e5) 16 | BATCH_SIZE = 4 17 | GRADIENT_ACCUMULATE_EVERY = 4 18 | LEARNING_RATE = 1e-4 19 | VALIDATE_EVERY = 100 20 | GENERATE_EVERY = 500 21 | GENERATE_LENGTH = 1024 22 | SEQ_LEN = 1024 23 | 24 | # helpers 25 | 26 | def cycle(loader): 27 | while True: 28 | for data in loader: 29 | yield data 30 | 31 | def decode_token(token): 32 | return str(chr(max(32, token))) 33 | 34 | def decode_tokens(tokens): 35 | return ''.join(list(map(decode_token, tokens))) 36 | 37 | # instantiate GPT-like decoder model 38 | 39 | model = TransformerWrapper( 40 | num_tokens = 256, 41 | max_seq_len = SEQ_LEN, 42 | attn_layers = Decoder( 43 | dim = 512, 44 | depth = 6, 45 | heads = 8, 46 | rotary_pos_emb = True 47 | ) 48 | ) 49 | 50 | model = AutoregressiveWrapper(model) 51 | model.cuda() 52 | 53 | # prepare enwik8 data 54 | 55 | with gzip.open('./data/enwik8.gz') as file: 56 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 57 | train_x, valid_x = np.split(data, [int(90e6)]) 58 | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) 59 | 60 | class TextSamplerDataset(Dataset): 61 | def __init__(self, data, seq_len): 62 | super().__init__() 63 | self.data = data 64 | self.seq_len = seq_len 65 | 66 | def __getitem__(self, index): 67 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) 68 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 69 | return full_seq.cuda() 70 | 71 | def __len__(self): 72 | return self.data.size(0) // self.seq_len 73 | 74 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 75 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 76 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) 77 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)) 78 | 79 | # optimizer 80 | 81 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 82 | 83 | # training 84 | 85 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 86 | model.train() 87 | 88 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 89 | loss = model(next(train_loader)) 90 | (loss / GRADIENT_ACCUMULATE_EVERY).backward() 91 | 92 | print(f'training loss: {loss.item()}') 93 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 94 | optim.step() 95 | optim.zero_grad() 96 | 97 | if i % VALIDATE_EVERY == 0: 98 | model.eval() 99 | with torch.no_grad(): 100 | loss = model(next(val_loader)) 101 | print(f'validation loss: {loss.item()}') 102 | 103 | if i % GENERATE_EVERY == 0: 104 | model.eval() 105 | inp = random.choice(val_dataset)[:-1] 106 | prime = decode_tokens(inp) 107 | print(f'%s \n\n %s', (prime, '*' * 100)) 108 | 109 | sample = model.generate( 110 | prompts = inp, 111 | seq_len = GENERATE_LENGTH, 112 | cache_kv = True 113 | ) 114 | 115 | output_str = decode_tokens(sample) 116 | print(output_str) 117 | -------------------------------------------------------------------------------- /train_length_extrapolate.py: -------------------------------------------------------------------------------- 1 | from x_transformers import TransformerWrapper, Decoder 2 | from x_transformers.autoregressive_wrapper import AutoregressiveWrapper 3 | 4 | import random 5 | import tqdm 6 | import gzip 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | # constants 14 | 15 | NUM_BATCHES = int(1e5) 16 | BATCH_SIZE = 4 17 | GRADIENT_ACCUMULATE_EVERY = 4 18 | LEARNING_RATE = 1e-4 19 | GENERATE_EVERY = 500 20 | GENERATE_LENGTH = 256 21 | SEQ_LEN = 256 22 | 23 | VALIDATE_EVERY = 100 24 | VALIDATE_SEQ_LENS = (256, 512, 1024, 2048, 4096) 25 | 26 | # helpers 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | def decode_token(token): 34 | return str(chr(max(32, token))) 35 | 36 | def decode_tokens(tokens): 37 | return ''.join(list(map(decode_token, tokens))) 38 | 39 | # instantiate GPT-like decoder model 40 | 41 | model = TransformerWrapper( 42 | num_tokens = 256, 43 | max_seq_len = SEQ_LEN, 44 | use_abs_pos_emb = False, 45 | attn_layers = Decoder( 46 | dim = 512, 47 | depth = 6, 48 | heads = 8, 49 | dynamic_pos_bias = True, 50 | ) 51 | ) 52 | 53 | model = AutoregressiveWrapper(model) 54 | model.cuda() 55 | 56 | # prepare enwik8 data 57 | 58 | with gzip.open('./data/enwik8.gz') as file: 59 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 60 | train_x, valid_x = np.split(data, [int(90e6)]) 61 | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) 62 | 63 | class TextSamplerDataset(Dataset): 64 | def __init__(self, data, seq_len): 65 | super().__init__() 66 | self.data = data 67 | self.seq_len = seq_len 68 | 69 | def __getitem__(self, index): 70 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) 71 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 72 | return full_seq.cuda() 73 | 74 | def __len__(self): 75 | return self.data.size(0) // self.seq_len 76 | 77 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 78 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) 79 | 80 | val_dataset_generate = TextSamplerDataset(data_val, SEQ_LEN) 81 | 82 | # validation loaders with different sequence lengths 83 | 84 | val_loaders = dict() 85 | 86 | for valid_seq_len in VALIDATE_SEQ_LENS: 87 | val_dataset = TextSamplerDataset(data_val, valid_seq_len) 88 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)) 89 | 90 | val_loaders[valid_seq_len] = val_loader 91 | 92 | # optimizer 93 | 94 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 95 | 96 | # training 97 | 98 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 99 | model.train() 100 | 101 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 102 | loss = model(next(train_loader)) 103 | (loss / GRADIENT_ACCUMULATE_EVERY).backward() 104 | 105 | print(f'training loss: {loss.item()}') 106 | 107 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 108 | optim.step() 109 | optim.zero_grad() 110 | 111 | if i % VALIDATE_EVERY == 0: 112 | print(f'validation losses:\n') 113 | 114 | model.eval() 115 | with torch.no_grad(): 116 | for valid_seq_len in VALIDATE_SEQ_LENS: 117 | val_loader = val_loaders[valid_seq_len] 118 | 119 | loss = model(next(val_loader)) 120 | print(f'[{valid_seq_len}]:\t {loss.item()}') 121 | 122 | print('\n') 123 | 124 | if i % GENERATE_EVERY == 0: 125 | model.eval() 126 | inp = random.choice(val_dataset_generate)[:-1] 127 | prime = decode_tokens(inp) 128 | print(f'%s \n\n %s', (prime, '*' * 100)) 129 | 130 | sample = model.generate( 131 | prompts = inp, 132 | seq_len = GENERATE_LENGTH, 133 | cache_kv = True 134 | ) 135 | 136 | output_str = decode_tokens(sample) 137 | print(f'{output_str}\n\n') 138 | -------------------------------------------------------------------------------- /train_parity.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | from x_transformers import TransformerWrapper, Decoder 7 | 8 | # constants 9 | 10 | BATCH_SIZE = 256 11 | LEARNING_RATE = 3e-4 12 | EVAL_EVERY = 500 13 | 14 | EVAL_LENGTHS = (16, 32, 64, 128, 256, 512) 15 | TRAIN_MAX_LENGTH = EVAL_LENGTHS[-2] 16 | 17 | LOSS_THRES_INCREASE_LEN = 1e-3 18 | MEET_CRITERIA_THRES_INCREASE_LEN = 10 19 | 20 | HYBRIDIZE_WITH_RNN = True 21 | 22 | # rnn for fully resolving state tracking by hybridization 23 | # but will also look into gated delta net + negative eigenvalues (Songlin Yang et al) as a parallel solution 24 | 25 | dim = 64 26 | heads = 4 27 | dim_head = 32 28 | decoder_kwargs = dict() 29 | 30 | if HYBRIDIZE_WITH_RNN: 31 | from torch.nn import GRU 32 | 33 | decoder_kwargs = dict( 34 | attn_hybrid_fold_axial_dim = 4, # even if recurrence is every 4 tokens, can generalize for parity 35 | attn_hybrid_learned_mix = True, 36 | attn_hybrid_module = GRU(dim, dim_head * heads, batch_first = True) 37 | ) 38 | 39 | # instantiate model 40 | 41 | model = TransformerWrapper( 42 | num_tokens = 2, 43 | max_seq_len = 0, 44 | attn_layers = Decoder( 45 | dim = dim, 46 | depth = 3, 47 | heads = heads, 48 | attn_dim_head = dim_head, 49 | shift_tokens = 1, # helps a lot with parity training, but not able to generalize on its own 50 | **decoder_kwargs 51 | ) 52 | ).cuda() 53 | 54 | # optimizer 55 | 56 | from lion_pytorch.cautious_lion import Lion 57 | 58 | optimizer = Lion(model.parameters(), lr = LEARNING_RATE, cautious_factor = 0.1) 59 | 60 | # data generator 61 | 62 | def cycle(length): 63 | while True: 64 | seq = torch.randint(0, 2, (BATCH_SIZE, length)).cuda() 65 | labels = (seq.cumsum(dim = -1) % 2) 66 | yield (seq, labels) 67 | 68 | # dataloaders 69 | 70 | train_dl = cycle(TRAIN_MAX_LENGTH) 71 | 72 | eval_dls = {eval_length: cycle(eval_length) for eval_length in EVAL_LENGTHS} 73 | 74 | print(f'training at max length: {TRAIN_MAX_LENGTH}') 75 | 76 | # training 77 | 78 | i = 0 79 | meet_criteria = 0 80 | train_seq_len = 1 81 | stop_length = EVAL_LENGTHS[-2] 82 | 83 | with tqdm.tqdm(mininterval = 10., desc = 'training') as pbar: 84 | 85 | while train_seq_len < stop_length: 86 | model.train() 87 | 88 | seq, labels = next(train_dl) 89 | 90 | # length curriculum learning 91 | 92 | seq = seq[:, :train_seq_len] 93 | labels = labels[:, :train_seq_len] 94 | 95 | logits = model(seq) 96 | 97 | loss = F.cross_entropy(logits.transpose(-1, -2), labels, reduction = 'none') 98 | last_loss = loss[:, -1].mean() 99 | loss.mean().backward() 100 | 101 | if last_loss.item() < LOSS_THRES_INCREASE_LEN: 102 | meet_criteria += 1 103 | else: 104 | meet_criteria = 0 105 | 106 | if meet_criteria >= MEET_CRITERIA_THRES_INCREASE_LEN: 107 | meet_criteria = 0 108 | train_seq_len += 1 109 | print(f'criteria met, incrementing to {train_seq_len}') 110 | 111 | print(f'({train_seq_len})| {i}: {last_loss.item()}') 112 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 113 | 114 | optimizer.step() 115 | optimizer.zero_grad() 116 | 117 | last_step = train_seq_len == stop_length 118 | 119 | if last_step: 120 | print(f'made it to training length {train_seq_len}. running final eval to check for generalization') 121 | 122 | if last_step or (i + 1) % EVAL_EVERY == 0: 123 | 124 | model.eval() 125 | print('\n') 126 | 127 | for eval_length, eval_dl in eval_dls.items(): 128 | incorrects = 0 129 | 130 | seq, labels = next(eval_dl) 131 | 132 | logits = model(seq) 133 | pred = logits[:, -1].argmax(dim = -1) 134 | incorrects = (pred != labels[:, -1]).abs().sum().item() 135 | 136 | frac_incorrect = incorrects * 100 / BATCH_SIZE 137 | 138 | print(f"{eval_length}\t - frac incorrect:\t {frac_incorrect:.1f}%") 139 | 140 | print('\n') 141 | 142 | i += 1 143 | pbar.update(1) 144 | -------------------------------------------------------------------------------- /x_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from x_transformers.x_transformers import ( 2 | XTransformer, 3 | Encoder, 4 | Decoder, 5 | PrefixDecoder, 6 | CrossAttender, 7 | Attention, 8 | FeedForward, 9 | RMSNorm, 10 | AdaptiveRMSNorm, 11 | TransformerWrapper, 12 | ViTransformerWrapper 13 | ) 14 | 15 | from x_transformers.autoregressive_wrapper import AutoregressiveWrapper 16 | from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper 17 | from x_transformers.belief_state_wrapper import BeliefStateWrapper 18 | 19 | from x_transformers.continuous import ( 20 | ContinuousTransformerWrapper, 21 | ContinuousAutoregressiveWrapper 22 | ) 23 | 24 | from x_transformers.multi_input import MultiInputTransformerWrapper 25 | 26 | from x_transformers.xval import ( 27 | XValTransformerWrapper, 28 | XValAutoregressiveWrapper 29 | ) 30 | 31 | from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper 32 | 33 | from x_transformers.dpo import ( 34 | DPO 35 | ) 36 | 37 | from x_transformers.neo_mlp import ( 38 | NeoMLP 39 | ) 40 | 41 | from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer 42 | -------------------------------------------------------------------------------- /x_transformers/attend.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from typing import Tuple, Callable 5 | 6 | import torch 7 | from torch.nn import Module 8 | from torch import nn, einsum, Tensor 9 | import torch.nn.functional as F 10 | 11 | from collections import namedtuple 12 | from functools import wraps 13 | from packaging import version 14 | from dataclasses import dataclass 15 | 16 | from einops import rearrange, repeat, pack, unpack 17 | 18 | # constants 19 | 20 | @dataclass 21 | class Intermediates: 22 | qk_similarities: Tensor | None = None 23 | pre_softmax_attn: Tensor | None = None 24 | post_softmax_attn: Tensor | None = None 25 | values: Tensor | None = None 26 | cached_kv: Tuple[Tensor, Tensor] | None = None 27 | layer_type: str | None = None 28 | 29 | def to_tuple(self): 30 | return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn) 31 | 32 | # helpers 33 | 34 | def exists(val): 35 | return val is not None 36 | 37 | def default(val, d): 38 | return val if exists(val) else d 39 | 40 | def at_most_one_of(*bools): 41 | return sum([*map(int, bools)]) <= 1 42 | 43 | def compact(arr): 44 | return [*filter(exists, arr)] 45 | 46 | @torch.jit.script 47 | def softclamp(t: Tensor, value: float): 48 | return (t / value).tanh() * value 49 | 50 | def pack_one(t, pattern): 51 | return pack([t], pattern) 52 | 53 | def unpack_one(t, ps, pattern): 54 | return unpack(t, ps, pattern)[0] 55 | 56 | def once(fn): 57 | called = False 58 | @wraps(fn) 59 | def inner(x): 60 | nonlocal called 61 | if called: 62 | return 63 | called = True 64 | return fn(x) 65 | return inner 66 | 67 | print_once = once(print) 68 | 69 | # selective attention 70 | # https://arxiv.org/abs/2410.02703 - section 3.3 71 | # it is a technique to allow each token to prevent itself from being attended to by future tokens 72 | # if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework) 73 | 74 | def selective_attn( 75 | sim, 76 | sim_head_gate = None, 77 | no_mask_sos = True 78 | ): 79 | i, j, device = *sim.shape[-2:], sim.device 80 | sim_head_gate = default(sim_head_gate, sim[:, 0]) 81 | 82 | gate = F.relu(sim_head_gate) # only positive 83 | 84 | if no_mask_sos: 85 | gate = gate.clone() 86 | gate[..., -i] = 0. 87 | 88 | eye = torch.eye(i, device = device) 89 | 90 | if j > i: 91 | eye = F.pad(eye, (j - i, 0), value = 1.) 92 | 93 | gate = (1. - eye) * gate 94 | gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future 95 | gate = gate.cumsum(dim = -2) 96 | 97 | return sim - rearrange(gate, 'b i j -> b 1 i j') 98 | 99 | # alternative distance functions 100 | 101 | def qk_l2_dist_squared(q, k): 102 | if k.ndim == 3: 103 | k = repeat(k, 'b j d -> b h j d', h = q.shape[1]) 104 | 105 | q, packed_shape = pack_one(q, '* i d') 106 | k, _ = pack_one(k, '* j d') 107 | 108 | l2_dist_squared = torch.cdist(q, k) ** 2 109 | return unpack_one(l2_dist_squared, packed_shape, '* i j') 110 | 111 | # one-hot straight through softmax 112 | 113 | def one_hot_straight_through(logits, temperature = 1.): 114 | one_hot_indices = logits.argmax(dim = -1, keepdim = True) 115 | one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.) 116 | 117 | soft_attn = (logits / temperature).softmax(dim = -1) 118 | return one_hot + soft_attn - soft_attn.detach() 119 | 120 | # sparse topk attention - only keep topk attn logits for softmax 121 | # optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True` 122 | 123 | def sparse_topk_attn( 124 | logits, 125 | sparse_topk, 126 | temperature = 1., 127 | straight_through = False 128 | ): 129 | orig_logits = logits 130 | 131 | mask_value = -torch.finfo(logits.dtype).max 132 | top_values, _ = logits.topk(sparse_topk, dim = -1) 133 | sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value) 134 | logits = logits.masked_fill(~sparse_topk_mask, mask_value) 135 | topk_attn = logits.softmax(dim = -1) 136 | 137 | if not straight_through: 138 | return topk_attn 139 | 140 | soft_attn = (orig_logits / temperature).softmax(dim = -1) 141 | return topk_attn.detach() + soft_attn - soft_attn.detach() 142 | 143 | # functions for creating causal mask 144 | # need a special one for onnx cpu (no support for .triu) 145 | 146 | def create_causal_mask(i, j, device): 147 | return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) 148 | 149 | def onnx_create_causal_mask(i, j, device): 150 | r = torch.arange(i, device = device) 151 | causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j') 152 | causal_mask = F.pad(causal_mask, (j - i, 0), value = False) 153 | return causal_mask 154 | 155 | # main class 156 | 157 | class Attend(Module): 158 | def __init__( 159 | self, 160 | *, 161 | dropout = 0., 162 | causal = False, 163 | heads = None, 164 | pre_talking_heads = False, 165 | post_talking_heads = False, 166 | pre_scale_post_talking_heads = False, 167 | sparse_topk = None, 168 | sparse_topk_straight_through = False, 169 | scale = None, 170 | qk_norm = False, 171 | l2_distance = False, 172 | sigmoid = False, 173 | custom_attn_fn: Callable | None = None, 174 | flash = False, 175 | softclamp_logits = False, 176 | logit_softclamp_value = 50., 177 | add_zero_kv = False, 178 | selective = False, 179 | hard = False, 180 | cope = None, 181 | onnxable = False, 182 | sdp_kwargs: dict = dict( 183 | enable_flash = True, 184 | enable_math = True, 185 | enable_mem_efficient = True 186 | ) 187 | ): 188 | super().__init__() 189 | self.scale = scale 190 | 191 | # causal related 192 | 193 | self.causal = causal 194 | self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask 195 | 196 | # attention type 197 | 198 | is_sparse_topk_attn = exists(sparse_topk) 199 | 200 | assert not (flash and sigmoid), 'sigmoid attention not available for flash' 201 | assert not (flash and hard), 'hard attention not available for flash' 202 | assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash' 203 | 204 | assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn) 205 | 206 | if exists(custom_attn_fn): 207 | self.attn_fn = custom_attn_fn 208 | elif sigmoid: 209 | self.attn_fn = F.sigmoid 210 | elif hard: 211 | self.attn_fn = one_hot_straight_through 212 | elif is_sparse_topk_attn: 213 | self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through) 214 | else: 215 | softmax_fn = partial(F.softmax, dim = -1) 216 | self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn 217 | 218 | # dropouts 219 | 220 | self.dropout = dropout 221 | self.attn_dropout = nn.Dropout(dropout) 222 | 223 | # talking heads 224 | 225 | assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention' 226 | 227 | self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None 228 | self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None 229 | self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None 230 | 231 | if exists(self.pre_softmax_talking_heads): 232 | nn.init.dirac_(self.pre_softmax_talking_heads.weight) 233 | 234 | if exists(self.post_softmax_talking_heads): 235 | nn.init.dirac_(self.post_softmax_talking_heads.weight) 236 | 237 | if exists(self.pre_scale_post_talking_heads): 238 | # an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention 239 | nn.init.dirac_(self.pre_scale_post_talking_heads.weight) 240 | 241 | # selective attention 242 | 243 | assert not (flash and selective), 'selective attention cannot work on flash attention' 244 | assert not (selective and not causal), 'selective attention is designed for autoregressive' 245 | self.selective = selective 246 | 247 | # l2 distance attention 248 | 249 | self.l2_distance = l2_distance 250 | 251 | # add a key / value token composed of zeros 252 | # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html 253 | 254 | self.add_zero_kv = add_zero_kv 255 | 256 | # soft clamp attention logit value 257 | 258 | if softclamp_logits: 259 | assert not flash, 'flash attention not compatible with logit softclamp value yet' 260 | assert logit_softclamp_value > 0. 261 | 262 | self.softclamp_logits = softclamp_logits 263 | self.logit_softclamp_value = logit_softclamp_value 264 | 265 | # contextual positional encoding 266 | 267 | self.cope = cope 268 | 269 | # flash attention 270 | 271 | self.flash = flash 272 | 273 | torch_version = version.parse(torch.__version__) 274 | assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 275 | 276 | # torch 2.3 uses new backend and context manager 277 | 278 | if torch_version >= version.parse('2.3'): 279 | from torch.nn.attention import SDPBackend 280 | 281 | str_to_backend = dict( 282 | enable_flash = SDPBackend.FLASH_ATTENTION, 283 | enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION, 284 | enable_math = SDPBackend.MATH, 285 | enable_cudnn = SDPBackend.CUDNN_ATTENTION 286 | ) 287 | 288 | sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable] 289 | 290 | self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends) 291 | else: 292 | self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs) 293 | 294 | def flash_attn( 295 | self, 296 | q, k, v, 297 | mask = None, 298 | attn_bias = None 299 | ): 300 | batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 301 | 302 | # Recommended for multi-query single-key-value attention by Tri Dao 303 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) 304 | 305 | if k.ndim == 3: 306 | k = repeat(k, 'b ... -> b h ...', h = q.shape[1]) 307 | 308 | if v.ndim == 3: 309 | v = repeat(v, 'b ... -> b h ...', h = q.shape[1]) 310 | 311 | # handle maybe l2 distance 312 | 313 | if self.l2_distance: 314 | k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2 315 | k = F.pad(k, (0, 1), value = -1.) 316 | k = torch.cat((k, k_norm_sq), dim = -1) 317 | 318 | q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2 319 | q = torch.cat((2 * q, q_norm_sq), dim = -1) 320 | q = F.pad(q, (0, 1), value = -1.) 321 | 322 | # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention 323 | 324 | if exists(self.scale): 325 | default_scale = q.shape[-1] ** -0.5 326 | q = q * (self.scale / default_scale) 327 | 328 | # Check if mask exists and expand to compatible shape 329 | # The mask is B L, so it would have to be expanded to B H N L 330 | 331 | causal = self.causal 332 | 333 | # in the case of kv caching with one token (q_len == 1), just turn off causal masking 334 | # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there 335 | 336 | if q_len == 1 and causal: 337 | causal = False 338 | 339 | # expand key padding mask 340 | 341 | if exists(mask): 342 | assert mask.ndim == 4 343 | mask = mask.expand(batch, heads, q_len, k_len) 344 | 345 | # handle kv cache - this should be bypassable in updated flash attention 2 346 | 347 | if k_len > q_len and causal: 348 | causal_mask = self.create_causal_mask(q_len, k_len, device = device) 349 | if not exists(mask): 350 | mask = ~causal_mask 351 | else: 352 | mask = mask & ~causal_mask 353 | causal = False 354 | 355 | # manually handle causal mask, if another mask was given 356 | 357 | if exists(mask) and causal: 358 | causal_mask = self.create_causal_mask(q_len, k_len, device = device) 359 | mask = mask & ~causal_mask 360 | causal = False 361 | 362 | # protect against an entire row being masked out 363 | 364 | row_is_entirely_masked = None 365 | 366 | if exists(mask): 367 | row_is_entirely_masked = ~mask.any(dim = -1) 368 | 369 | # handle alibi positional bias 370 | # convert from bool to float 371 | 372 | if exists(attn_bias): 373 | attn_bias = attn_bias.expand(batch, heads, -1, -1) 374 | 375 | # if mask given, the mask would already contain the causal mask from above logic 376 | # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number 377 | 378 | mask_value = -torch.finfo(q.dtype).max 379 | 380 | if exists(mask): 381 | attn_bias = attn_bias.masked_fill(~mask, mask_value // 2) 382 | elif causal: 383 | causal_mask = self.create_causal_mask(q_len, k_len, device = device) 384 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2) 385 | causal = False 386 | 387 | # scaled_dot_product_attention handles attn_mask either as bool or additive bias 388 | # make it an additive bias here 389 | 390 | mask = attn_bias 391 | 392 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 393 | 394 | with self.sdp_context_manager(): 395 | out = F.scaled_dot_product_attention( 396 | q, k, v, 397 | attn_mask = mask, 398 | dropout_p = self.dropout if self.training else 0., 399 | is_causal = causal 400 | ) 401 | 402 | # for a row that is entirely masked out, should zero out the output of that row token 403 | 404 | if exists(row_is_entirely_masked) and row_is_entirely_masked.any(): 405 | out = out.masked_fill(row_is_entirely_masked[..., None], 0.) 406 | 407 | return out, Intermediates() 408 | 409 | def forward( 410 | self, 411 | q, k, v, 412 | mask = None, 413 | attn_bias = None, 414 | prev_attn = None 415 | ): 416 | """ 417 | einstein notation 418 | b - batch 419 | h - heads 420 | n, i, j - sequence length (base sequence length, source, target) 421 | d - feature dimension 422 | """ 423 | 424 | n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device 425 | 426 | scale = default(self.scale, q.shape[-1] ** -0.5) 427 | 428 | causal = self.causal 429 | 430 | # handle key padding mask 431 | 432 | if exists(mask) and mask.ndim == 2: 433 | mask = rearrange(mask, 'b j -> b 1 1 j') 434 | 435 | # handle kv cached decoding 436 | 437 | if n == 1 and causal: 438 | causal = False 439 | 440 | # handle grouped multi-query attention 441 | 442 | if kv_heads == 1: 443 | k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v)) 444 | elif kv_heads < heads: 445 | k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v)) 446 | 447 | # handle zero kv, as means for allowing network to attend to nothing 448 | 449 | if self.add_zero_kv: 450 | k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v)) 451 | 452 | if exists(mask): 453 | mask = F.pad(mask, (1, 0), value = True) 454 | 455 | if exists(attn_bias): 456 | attn_bias = F.pad(attn_bias, (1, 0), value = 0.) 457 | 458 | if self.flash: 459 | assert not exists(prev_attn), 'residual attention not compatible with flash attention' 460 | return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) 461 | 462 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' 463 | 464 | if not self.l2_distance: 465 | sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) 466 | else: 467 | sim = -qk_l2_dist_squared(q, k) 468 | 469 | sim = sim * scale 470 | 471 | if exists(prev_attn): 472 | sim = sim + prev_attn 473 | 474 | qk_similarities = sim.clone() 475 | 476 | if exists(self.pre_scale_post_talking_heads): 477 | pre_to_post_scale = self.pre_scale_post_talking_heads(sim) 478 | 479 | if exists(self.pre_softmax_talking_heads): 480 | sim = sim + self.pre_softmax_talking_heads(sim) 481 | 482 | if exists(attn_bias): 483 | sim = sim + attn_bias 484 | 485 | if self.softclamp_logits: 486 | sim = softclamp(sim, self.logit_softclamp_value) 487 | 488 | i, j, dtype = *sim.shape[-2:], sim.dtype 489 | 490 | mask_value = -torch.finfo(sim.dtype).max 491 | 492 | if exists(mask): 493 | sim = sim.masked_fill(~mask, mask_value) 494 | 495 | if causal: 496 | causal_mask = self.create_causal_mask(i, j, device = device) 497 | sim = sim.masked_fill(causal_mask, mask_value) 498 | 499 | row_is_entirely_masked = None 500 | 501 | if exists(mask): 502 | row_is_entirely_masked = ~mask.any(dim = -1) 503 | 504 | if exists(self.cope): 505 | sim = sim + self.cope(q, sim) 506 | 507 | if self.selective: 508 | sim = selective_attn(sim) 509 | 510 | pre_softmax_attn = sim 511 | 512 | attn = self.attn_fn(sim) 513 | 514 | attn = attn.type(dtype) 515 | 516 | post_softmax_attn = attn 517 | 518 | attn = self.attn_dropout(attn) 519 | 520 | if exists(self.post_softmax_talking_heads): 521 | attn = self.post_softmax_talking_heads(attn) 522 | 523 | if exists(self.pre_scale_post_talking_heads): 524 | attn = attn * pre_to_post_scale 525 | 526 | out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) 527 | 528 | intermediates = Intermediates( 529 | qk_similarities = qk_similarities, 530 | pre_softmax_attn = pre_softmax_attn, 531 | post_softmax_attn = post_softmax_attn 532 | ) 533 | 534 | if exists(row_is_entirely_masked) and row_is_entirely_masked.any(): 535 | out = out.masked_fill(row_is_entirely_masked[..., None], 0.) 536 | 537 | return out, intermediates 538 | -------------------------------------------------------------------------------- /x_transformers/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from math import ceil, log 4 | from typing import Tuple, Callable 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import Module 9 | import torch.nn.functional as F 10 | 11 | from einops import rearrange, pack, unpack 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def default(val, d): 17 | return val if exists(val) else d 18 | 19 | def identity(t, *args, **kwargs): 20 | return t 21 | 22 | def join(arr, delimiter = ', '): 23 | return delimiter.join(arr) 24 | 25 | def cast_tuple(t, length = 1): 26 | return t if isinstance(t, tuple) else (t,) * length 27 | 28 | def eval_decorator(fn): 29 | def inner(self, *args, **kwargs): 30 | was_training = self.training 31 | self.eval() 32 | out = fn(self, *args, **kwargs) 33 | self.train(was_training) 34 | return out 35 | return inner 36 | 37 | # for variable lengthed prefixes 38 | 39 | def align_right(t, lens, pad_id = 0): 40 | batch, seq_len, device, dtype = *t.shape, t.device, t.dtype 41 | 42 | assert lens.ndim == 1 and lens.shape[0] == batch 43 | assert lens.amax() <= seq_len 44 | 45 | pad_lens = seq_len - lens 46 | max_pad_len = pad_lens.amax() 47 | 48 | batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None] 49 | prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long) 50 | 51 | t = F.pad(t, (max_pad_len, 0), value = pad_id) 52 | offset = max_pad_len - pad_lens 53 | 54 | aligned = t[batch_arange, prompt_len_arange + offset[..., None]] 55 | return aligned 56 | 57 | # nucleus 58 | 59 | def top_p(logits, thres = 0.9): 60 | sorted_logits, sorted_indices = torch.sort(logits, descending = True) 61 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1) 62 | 63 | sorted_indices_to_remove = cum_probs > thres 64 | sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False) 65 | 66 | sorted_logits[sorted_indices_to_remove] = float('-inf') 67 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 68 | 69 | # topk 70 | 71 | def top_k(logits, frac_num_tokens = 0.1, k = None): 72 | num_tokens = logits.shape[-1] 73 | 74 | k = default(k, ceil(frac_num_tokens * num_tokens)) 75 | k = min(k, num_tokens) 76 | 77 | val, ind = torch.topk(logits, k) 78 | probs = torch.full_like(logits, float('-inf')) 79 | probs.scatter_(1, ind, val) 80 | return probs 81 | 82 | # top_a 83 | 84 | def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02): 85 | probs = logits.softmax(dim = -1) 86 | max_probs = probs.amax(dim = -1, keepdim = True) 87 | limit = torch.pow(max_probs, min_p_pow) * min_p_ratio 88 | return torch.where(probs < limit, float('-inf'), logits) 89 | 90 | # min_p 91 | # https://arxiv.org/abs/2407.01082 92 | 93 | def min_p(logits, min_p = 0.1): 94 | probs = logits.softmax(dim = -1) 95 | max_probs = probs.amax(dim = -1, keepdim = True) 96 | limit = min_p * max_probs 97 | return torch.where(probs < limit, float('-inf'), logits) 98 | 99 | # filter logits functions dict[str -> Callable] 100 | 101 | FILTER_LOGITS_FN = dict( 102 | top_p = top_p, 103 | top_k = top_k, 104 | top_a = top_a, 105 | min_p = min_p 106 | ) 107 | 108 | # contrastive decoding function 109 | 110 | def contrastive_decode_fn( 111 | expert_logits, 112 | amateur_logits, 113 | alpha = 0.1, 114 | beta = 0.5 115 | ): 116 | """ 117 | Appendix A Algorithm 2 118 | https://arxiv.org/abs/2309.09117 119 | """ 120 | 121 | cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True) 122 | diffs = (1 + beta) * expert_logits - beta * amateur_logits 123 | contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max) 124 | return contrastive_decode_logits 125 | 126 | # autoregressive wrapper class 127 | 128 | class AutoregressiveWrapper(Module): 129 | def __init__( 130 | self, 131 | net, 132 | ignore_index = -100, 133 | pad_value = 0, 134 | mask_prob = 0., 135 | add_attn_z_loss = False 136 | ): 137 | super().__init__() 138 | self.pad_value = pad_value 139 | self.ignore_index = ignore_index 140 | 141 | self.net = net 142 | self.max_seq_len = net.max_seq_len 143 | 144 | # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432 145 | assert mask_prob < 1. 146 | self.mask_prob = mask_prob 147 | 148 | # whether to add router z-loss 149 | self.add_attn_z_loss = add_attn_z_loss 150 | 151 | @torch.no_grad() 152 | @eval_decorator 153 | def generate( 154 | self, 155 | prompts, 156 | seq_len, 157 | eos_token = None, 158 | temperature = 1., 159 | prompt_lens: Tensor | None = None, 160 | filter_logits_fn: str | Callable = top_k, 161 | restrict_to_max_seq_len = True, 162 | amateur_model: Module | Tuple[Module] | None = None, 163 | filter_kwargs: dict = dict(), 164 | contrastive_decode_kwargs: dict | Tuple[dict] = dict( 165 | beta = 0.5, 166 | alpha = 0.1 167 | ), 168 | cache_kv = True, 169 | **kwargs 170 | ): 171 | max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device 172 | 173 | prompts, ps = pack([prompts], '* n') 174 | 175 | b, t = prompts.shape 176 | 177 | # handle filter logits fn given as string 178 | 179 | if isinstance(filter_logits_fn, str): 180 | assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" 181 | 182 | filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] 183 | 184 | # handle variable lengthed prompts (prefixes) 185 | 186 | seq_start_pos = None 187 | if exists(prompt_lens): 188 | prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) 189 | seq_start_pos = t - prompt_lens 190 | 191 | # output from which sampled tokens appended to 192 | 193 | out = prompts 194 | 195 | # kv caches 196 | 197 | cache = None 198 | 199 | # if doing contrastive decoding, turn off filter automatically 200 | 201 | if exists(amateur_model): 202 | amateur_model = cast_tuple(amateur_model) 203 | contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs) 204 | 205 | assert len(amateur_model) == len(contrastive_decode_kwargs) 206 | 207 | amateur_caches = [None] * len(amateur_model) 208 | filter_logits_fn = identity 209 | 210 | for i, module in enumerate(amateur_model): 211 | if isinstance(module, AutoregressiveWrapper): 212 | amateur_model[i] = module.net 213 | 214 | module.eval() 215 | 216 | # sampling up to seq_len 217 | 218 | for _ in range(seq_len): 219 | 220 | if restrict_to_max_seq_len: 221 | max_len_exceeded = out.shape[-1] > max_seq_len 222 | 223 | assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue' 224 | 225 | x = out[:, -max_seq_len:] 226 | 227 | if exists(cache): 228 | for inter in cache.attn_intermediates: 229 | if inter.layer_type == 'a': 230 | inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv] 231 | 232 | logits, new_cache = self.net( 233 | x, 234 | return_intermediates = True, 235 | cache = cache, 236 | seq_start_pos = seq_start_pos, 237 | **kwargs 238 | ) 239 | 240 | if cache_kv and self.net.can_cache_kv: 241 | cache = new_cache 242 | 243 | logits = logits[:, -1] 244 | 245 | # handle contrastive decoding, Li et al. 246 | # https://arxiv.org/abs/2210.15097 247 | 248 | if exists(amateur_model): 249 | for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)): 250 | amateur_logits, next_amateur_cache = amateur( 251 | x, 252 | return_intermediates = True, 253 | cache = amateur_cache, 254 | seq_start_pos = seq_start_pos, 255 | **kwargs 256 | ) 257 | 258 | amateur_logits = amateur_logits[:, -1] 259 | 260 | assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model' 261 | logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs) 262 | 263 | if cache_kv and amateur.can_cache_kv: 264 | amateur_caches[i] = next_amateur_cache 265 | 266 | # filter by top_k, top_p (nucleus), top_a, or custom 267 | 268 | if greedy: 269 | sample = logits.argmax(dim = -1, keepdim = True) 270 | else: 271 | filtered_logits = filter_logits_fn(logits, **filter_kwargs) 272 | probs = F.softmax(filtered_logits / temperature, dim=-1) 273 | sample = torch.multinomial(probs, 1) 274 | 275 | # concat sample 276 | 277 | out = torch.cat((out, sample), dim=-1) 278 | 279 | if not exists(eos_token): 280 | continue 281 | 282 | is_eos_tokens = (out == eos_token) 283 | 284 | if is_eos_tokens.any(dim = -1).all(): 285 | break 286 | 287 | if exists(eos_token): 288 | # mask out everything after the eos tokens 289 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 290 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 291 | out = out.masked_fill(mask, self.pad_value) 292 | 293 | out = out[:, t:] 294 | 295 | out, = unpack(out, ps, '* n') 296 | 297 | return out 298 | 299 | def forward(self, x, return_outputs = False, **kwargs): 300 | seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss 301 | 302 | inp, target = x[:, :-1], x[:, 1:] 303 | inp = torch.where(inp == ignore_index, self.pad_value, inp) 304 | 305 | if self.mask_prob > 0.: 306 | rand = torch.randn(inp.shape, device = x.device) 307 | rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out 308 | num_mask = min(int(seq * self.mask_prob), seq - 1) 309 | indices = rand.topk(num_mask, dim = -1).indices 310 | mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() 311 | kwargs.update(self_attn_kv_mask = mask) 312 | 313 | logits, cache = self.net( 314 | inp, 315 | return_intermediates = True, 316 | return_attn_z_loss = add_attn_z_loss, 317 | **kwargs 318 | ) 319 | 320 | loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss 321 | 322 | loss = loss_fn( 323 | rearrange(logits, 'b n c -> b c n'), 324 | target, 325 | ignore_index = ignore_index 326 | ) 327 | 328 | if add_attn_z_loss: 329 | loss = loss + cache.attn_z_loss 330 | 331 | if not return_outputs: 332 | return loss 333 | 334 | return loss, (logits, cache) 335 | -------------------------------------------------------------------------------- /x_transformers/belief_state_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | # Belief State Transformer 3 | 4 | # Hu et al. https://arxiv.org/abs/2410.23506 5 | # https://www.youtube.com/watch?v=aqhbRtB2Fyg 6 | 7 | from __future__ import annotations 8 | from random import random 9 | 10 | import torch 11 | from torch.autograd import Function 12 | from torch.nn import Module, ModuleList 13 | from torch import nn, cat, stack, tensor, Tensor, arange, cartesian_prod 14 | import torch.nn.functional as F 15 | 16 | from x_transformers.autoregressive_wrapper import ( 17 | eval_decorator, 18 | min_p, 19 | ) 20 | 21 | from x_transformers.x_transformers import ( 22 | Decoder, 23 | TransformerWrapper 24 | ) 25 | 26 | import einx 27 | from einops import rearrange, repeat, pack, unpack 28 | from einops.layers.torch import Rearrange 29 | 30 | # helper functions 31 | 32 | def exists(v): 33 | return v is not None 34 | 35 | def default(v, d): 36 | return v if exists(v) else d 37 | 38 | # a custom flip that can handle variable lengths across batch 39 | 40 | def flip(x, dim = 1, lens = None): 41 | if not exists(lens): 42 | return x.flip(dim) 43 | 44 | batch, seq_len, device = *x.shape[:2], x.device 45 | seq = arange(seq_len, device = device) 46 | 47 | mask = einx.less('j, i -> i j', seq, lens) 48 | masked_seq = einx.where('i j, j,', mask, seq, -1) 49 | 50 | flip_indices = masked_seq.argsort(dim = -1, descending = True) 51 | 52 | if x.ndim == 3: 53 | flip_indices = repeat(flip_indices, '... -> ... d', d = x.shape[-1]) 54 | 55 | return x.gather(dim, flip_indices) 56 | 57 | # detach multiple tensors and backward the gradients once 58 | 59 | class DetachMultiple(Function): 60 | 61 | @classmethod 62 | def forward(self, ctx, *tensors): 63 | detached_tensors = tuple(t.detach() for t in tensors) 64 | 65 | for detached_tensor in detached_tensors: 66 | detached_tensor.requires_grad_() 67 | 68 | return detached_tensors 69 | 70 | @classmethod 71 | def backward(self, ctx, *grads): 72 | 73 | return grads 74 | 75 | detach_multiple = DetachMultiple.apply 76 | 77 | # wrappers 78 | 79 | class BeliefStateWrapper(Module): 80 | """ 81 | Figure 13. in https://arxiv.org/abs/2410.23506 82 | """ 83 | 84 | def __init__( 85 | self, 86 | forward_decoder: TransformerWrapper, 87 | backward_decoder: TransformerWrapper | None = None, 88 | train_frac_forward_backward_pairs: float = 1., 89 | text_head: Module | None = None, 90 | backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc 91 | pred_distance = False, 92 | pred_distance_loss_weight: float = 1., 93 | cond_on_distance = False, 94 | cond_on_distance_prob = 0.5, 95 | max_pred_distance = None 96 | ): 97 | super().__init__() 98 | backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token 99 | 100 | assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension' 101 | assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens' 102 | 103 | dim = forward_decoder.emb_dim 104 | num_tokens = forward_decoder.num_tokens 105 | max_seq_len = forward_decoder.max_seq_len 106 | 107 | self.num_tokens = num_tokens 108 | 109 | # the suffix token 110 | 111 | self.suffix_token = nn.Parameter(torch.zeros(dim)) 112 | nn.init.normal_(self.suffix_token, std = 0.02) 113 | 114 | # the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences 115 | 116 | if not exists(text_head): 117 | text_head = nn.Sequential( 118 | nn.Linear(dim * 2, dim), 119 | nn.LeakyReLU(), 120 | nn.Linear(dim, num_tokens * 2), 121 | ) 122 | 123 | self.text_head = text_head 124 | 125 | # predicting terminal state (when suffix and prefix predict the same token) 126 | 127 | self.max_pred_distance = default(max_pred_distance, max_seq_len) 128 | 129 | self.to_distance_logits = nn.Sequential( 130 | nn.Linear(dim * 2, dim), 131 | nn.LeakyReLU(), 132 | nn.Linear(dim, self.max_pred_distance), 133 | ) if pred_distance else None 134 | 135 | self.pred_distance_loss_weight = pred_distance_loss_weight 136 | 137 | # conditioning on distance 138 | 139 | assert 0. < cond_on_distance_prob < 1. 140 | 141 | self.cond_on_distance = cond_on_distance 142 | self.cond_on_distance_prob = cond_on_distance_prob 143 | 144 | if cond_on_distance: 145 | self.to_distance_cond = nn.Sequential( 146 | Rearrange('... -> ... 1'), 147 | nn.Linear(1, dim), 148 | nn.LeakyReLU(), 149 | nn.Linear(dim, dim * 2), 150 | ) 151 | 152 | # the two decoders, one which is causal forward, the other causal backwards 153 | 154 | self.forward_decoder = forward_decoder 155 | self.backward_decoder = backward_decoder 156 | 157 | # what fraction of forward backward pairs to train on 158 | # for further memory efficiency 159 | 160 | assert 0 < train_frac_forward_backward_pairs <= 1. 161 | self.train_frac_fb_pairs = train_frac_forward_backward_pairs 162 | self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1. 163 | 164 | # loss weighting 165 | 166 | self.backward_ar_loss_weight = backward_ar_loss_weight 167 | self.needs_loss_weight = backward_ar_loss_weight != 1. 168 | 169 | self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight])) 170 | 171 | # sampling 172 | 173 | self.max_seq_len = self.forward_decoder.max_seq_len 174 | 175 | @torch.no_grad() 176 | @eval_decorator 177 | def generate_with_suffix_cond( 178 | self, 179 | prompts, 180 | seq_len, 181 | temperature = 1.25, 182 | cache_kv = False, 183 | suffix: Tensor | None = None, # the goal conditioning 184 | filter_logits_fn = min_p, 185 | filter_kwargs = dict( 186 | min_p = 0.1 187 | ), 188 | decode_backwards = False, 189 | **kwargs 190 | ): 191 | max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device 192 | 193 | prompts, batch_ps = pack([prompts], '* d') 194 | 195 | batch, orig_seq_len = prompts.shape 196 | 197 | # allow for decoding backwards, to make sure it is working 198 | 199 | main_decoder = self.forward_decoder 200 | 201 | if decode_backwards: 202 | prompts = prompts.flip(1) 203 | main_decoder = self.backward_decoder 204 | 205 | out = prompts 206 | 207 | # kv caches 208 | 209 | cache = None 210 | 211 | # get the encoded suffix token once 212 | 213 | suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d') 214 | 215 | suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch) 216 | 217 | if not decode_backwards: 218 | if exists(suffix): 219 | if suffix.ndim == 1: 220 | suffix = repeat(suffix, 'n -> b n', b = batch) 221 | 222 | suffix = suffix.flip(1) # reverse autoregressive 223 | 224 | suffix_embed = self.backward_decoder( 225 | suffix, 226 | prepend_embeds = suffix_sos_tokens, 227 | return_embeddings = True 228 | ) 229 | 230 | # pick out the last embedding for fill in the middle 231 | 232 | suffix_embed = suffix_embed[:, -1:] 233 | 234 | else: 235 | # just grab a random token for now for prefix 236 | 237 | prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device) 238 | 239 | prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True) 240 | 241 | # sampling up to seq_len 242 | 243 | for _ in range(seq_len): 244 | 245 | embeds, new_cache = main_decoder( 246 | out, 247 | prepend_embeds = suffix_sos_tokens if decode_backwards else None, 248 | return_intermediates = True, 249 | return_embeddings = True, 250 | cache = cache, 251 | **kwargs 252 | ) 253 | 254 | last_embeds = embeds[:, -1:] 255 | 256 | if not decode_backwards: 257 | embeds = cat((last_embeds, suffix_embed), dim = -1) 258 | else: 259 | embeds = cat((prefix_embed, last_embeds), dim = -1) 260 | 261 | if cache_kv and self.forward_decoder.can_cache_kv: 262 | cache = new_cache 263 | 264 | forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1) 265 | 266 | logits = forward_logits if not decode_backwards else backward_logits 267 | 268 | logits = logits[:, -1] 269 | 270 | if greedy: 271 | sample = logits.argmax(dim = -1, keepdim = True) 272 | else: 273 | filtered_logits = filter_logits_fn(logits, **filter_kwargs) 274 | probs = F.softmax(filtered_logits / temperature, dim = -1) 275 | sample = torch.multinomial(probs, 1) 276 | 277 | # concat sample 278 | 279 | out = torch.cat((out, sample), dim = -1) 280 | 281 | out = out[:, orig_seq_len:] 282 | 283 | out, = unpack(out, batch_ps, '* n') 284 | 285 | return out 286 | 287 | def forward( 288 | self, 289 | seq, 290 | lens: Tensor | None = None, # Int['b'] 291 | loss_weight_by_fb_indices: callable | None = None 292 | ): 293 | batch, seq_len, device = *seq.shape, seq.device 294 | 295 | # handle variable length sequences 296 | 297 | seq_for_labels = seq 298 | 299 | if exists(lens): 300 | mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens) 301 | seq_for_labels = torch.where(mask, seq, -1) 302 | 303 | # forward autoregressive 304 | 305 | forward_embeds = self.forward_decoder(seq, return_embeddings = True) 306 | 307 | # backward autoregressive 308 | 309 | backward_seq = flip(seq, lens = lens) 310 | 311 | suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch) 312 | 313 | backward_embeds = self.backward_decoder( 314 | backward_seq, 315 | prepend_embeds = suffix_tokens, 316 | return_embeddings = True 317 | ) 318 | 319 | backward_embeds = flip(backward_embeds, lens = lens) 320 | 321 | # trick to reduce memory on backwards pass 322 | 323 | forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds) 324 | 325 | # belief state objective 326 | 327 | seq_arange = arange(seq_len, device = device) 328 | 329 | fb_pairs = cartesian_prod(seq_arange, seq_arange + 1) # plus one for suffix token 330 | 331 | # filter down to valid pairs, as in figure 11 332 | # f - forward, b - backward, i - indices 333 | 334 | fi, bi = fb_pairs.unbind(dim = -1) 335 | 336 | valid_mask = (bi - fi) >= 2 337 | 338 | fb_pairs = fb_pairs[valid_mask] 339 | 340 | # maybe subsample fb pairs 341 | 342 | if self.needs_subsample_fb_pairs: 343 | num_pairs = fb_pairs.shape[0] 344 | 345 | num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1) 346 | 347 | rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled] 348 | 349 | fb_pairs = fb_pairs[rand_subsampled_indices] 350 | 351 | # get labels for both 352 | 353 | fi, bi = fb_pairs.unbind(dim = -1) 354 | 355 | labels_fi, labels_bi = (fi + 1), (bi - 1) 356 | 357 | forward_labels, backward_labels = seq_for_labels[:, labels_fi], seq_for_labels[:, labels_bi] 358 | 359 | labels = cat((forward_labels, backward_labels), dim = -1) 360 | 361 | # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions 362 | 363 | fb_embeds = cat(( 364 | forward_embeds[:, fi], 365 | backward_embeds[:, bi] 366 | ), dim = -1) 367 | 368 | logits = self.text_head(fb_embeds) 369 | 370 | # cross entropy loss 371 | 372 | loss = F.cross_entropy( 373 | rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2), 374 | labels, 375 | reduction = 'none' if self.needs_loss_weight else 'mean', 376 | ignore_index = -1 377 | ) 378 | 379 | # maybe condition on distance 380 | 381 | cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob) 382 | 383 | if cond_on_distance: 384 | distance = (bi - fi).float() 385 | distance_cond = self.to_distance_cond(distance) 386 | 387 | fb_embeds = fb_embeds * distance_cond 388 | 389 | # maybe predict distance 390 | 391 | if exists(self.to_distance_logits) and not cond_on_distance: 392 | distance_logits = self.to_distance_logits(fb_embeds) 393 | 394 | distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1) 395 | distance_labels = repeat(distance_labels, 'n -> b n', b = batch) 396 | 397 | pred_dist_loss = F.cross_entropy( 398 | rearrange(distance_logits, 'b n l -> b l n'), 399 | distance_labels 400 | ) 401 | 402 | loss = ( 403 | loss + 404 | pred_dist_loss * self.pred_distance_loss_weight 405 | ) 406 | 407 | # maybe loss weighting 408 | 409 | needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices)) 410 | 411 | if needs_loss_weight: 412 | loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2) 413 | 414 | if self.needs_loss_weight: 415 | loss = einx.multiply('b fb n, fb', loss, self.loss_weights) 416 | 417 | # allow researcher to pass in a function that acts on the the forward backward indices Int['n fb'] 418 | # the reason this may be needed is because the earlier tokens will have more eligible pairs for training, and perhaps this could be normalized 419 | 420 | if exists(loss_weight_by_fb_indices): 421 | loss_weight = loss_weight_by_fb_indices(fb_pairs) 422 | 423 | if loss_weight.ndim == 1: 424 | loss = einx.multiply('b fb n, n', loss, loss_weight) 425 | elif loss_weight.ndim == 2: 426 | loss = einx.multiply('b fb n, n fb', loss, loss_weight) 427 | else: 428 | raise ValueError('invalid loss weight dims') 429 | 430 | loss = loss.mean() 431 | 432 | return loss 433 | -------------------------------------------------------------------------------- /x_transformers/continuous.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import nn, cat, stack 5 | from torch.nn import Module 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal 8 | 9 | import einx 10 | from einops import rearrange, reduce, pack, repeat, unpack 11 | 12 | from x_transformers.x_transformers import ( 13 | Attention, 14 | AttentionLayers, 15 | ScaledSinusoidalEmbedding, 16 | AbsolutePositionalEmbedding, 17 | LayerNorm, 18 | masked_mean, 19 | always, 20 | pad_at_dim 21 | ) 22 | 23 | # helper functions 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def default(val, d): 29 | if exists(val): 30 | return val 31 | return d() if not isinstance(d, Module) and callable(d) else d 32 | 33 | def masked_mean(t, mask): 34 | t = einx.where('b n, b n d, -> b n d', mask, t, 0.) 35 | 36 | num = reduce(t, 'b n d -> b', 'sum') 37 | den = mask.sum(dim = -1) 38 | 39 | masked_average = num / den.clamp(min = 1.) 40 | return masked_average 41 | 42 | # probabilistic loss fn 43 | 44 | class GaussianNLL(Module): 45 | def forward(self, pred, target): 46 | mean, var = pred 47 | return F.gaussian_nll_loss(mean, target, var, reduction = 'none') 48 | 49 | # main classes 50 | 51 | class ContinuousTransformerWrapper(Module): 52 | def __init__( 53 | self, 54 | *, 55 | max_seq_len, 56 | attn_layers: AttentionLayers, 57 | dim_in = None, 58 | dim_out = None, 59 | emb_dim = None, 60 | max_mem_len = 0, 61 | num_memory_tokens = None, 62 | post_emb_norm = False, 63 | emb_dropout = 0., 64 | use_abs_pos_emb = True, 65 | scaled_sinu_pos_emb = False, 66 | average_pool_embed = False, 67 | probabilistic = False 68 | ): 69 | super().__init__() 70 | dim = attn_layers.dim 71 | 72 | self.max_seq_len = max_seq_len 73 | 74 | self.max_mem_len = max_mem_len 75 | 76 | no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb) 77 | 78 | if no_abs_pos_emb: 79 | self.pos_emb = always(0) 80 | elif scaled_sinu_pos_emb: 81 | self.pos_emb = ScaledSinusoidalEmbedding(dim) 82 | else: 83 | self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) 84 | 85 | self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity() 86 | self.emb_dropout = nn.Dropout(emb_dropout) 87 | 88 | # memory tokens 89 | 90 | num_memory_tokens = default(num_memory_tokens, 0) 91 | self.has_memory_tokens = num_memory_tokens > 0 92 | 93 | if num_memory_tokens > 0: 94 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 95 | 96 | # attention layers 97 | 98 | self.attn_layers = attn_layers 99 | 100 | # average pool 101 | 102 | self.average_pool_embed = average_pool_embed 103 | 104 | # project in and out 105 | 106 | self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity() 107 | 108 | # output is multipled by 2 for outputting mean and log variance 109 | 110 | self.probabilistic = probabilistic 111 | 112 | self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity() 113 | 114 | # can cache kv 115 | 116 | self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)]) 117 | 118 | def forward( 119 | self, 120 | x, 121 | return_embeddings = False, 122 | return_intermediates = False, 123 | return_mems = False, 124 | mask = None, 125 | lens = None, 126 | return_attn = False, 127 | mems = None, 128 | mem_masks = None, 129 | pos = None, 130 | sum_embeds = None, 131 | prepend_embeds = None, 132 | prepend_mask = None, 133 | **kwargs 134 | ): 135 | batch, seq, orig_mask, device = *x.shape[:2], mask, x.device 136 | 137 | # maybe seq lengths passed in 138 | 139 | if exists(lens): 140 | assert not exists(mask), 'either `mask` or `lens` passed in, but not both' 141 | seq_arange = torch.arange(seq, device = device) 142 | 143 | mask = einx.less('j, i -> i j', seq_arange, lens) 144 | 145 | # project in + positional embedding 146 | 147 | x = self.project_in(x) 148 | x = x + self.pos_emb(x, pos = pos) 149 | 150 | if exists(sum_embeds): 151 | x = x + sum_embeds 152 | 153 | x = self.post_emb_norm(x) 154 | 155 | # memory tokens 156 | 157 | if self.has_memory_tokens: 158 | m = repeat(self.memory_tokens, 'm d -> b m d', b = batch) 159 | x, mem_ps = pack([m, x], 'b * d') 160 | 161 | if exists(mask): 162 | num_mems = m.shape[-2] 163 | mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True) 164 | 165 | # whether to append embeds, as in PaLI, for image embeddings 166 | 167 | if exists(prepend_embeds): 168 | prepend_seq, prepend_dim = prepend_embeds.shape[1:] 169 | 170 | assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions' 171 | 172 | x = cat((prepend_embeds, x), dim = -2) 173 | 174 | if exists(prepend_mask) or exists(mask): 175 | mask = default(mask, lambda: torch.ones((batch, seq), device = device, dtype = torch.bool)) 176 | prepend_mask = default(prepend_mask, lambda: torch.ones((batch, prepend_seq), device = device, dtype = torch.bool)) 177 | 178 | mask = cat((prepend_mask, mask), dim = -1) 179 | 180 | x = self.emb_dropout(x) 181 | 182 | # attention layers 183 | 184 | x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs) 185 | 186 | # splice out memory tokens 187 | 188 | if self.has_memory_tokens: 189 | m, x = unpack(x, mem_ps, 'b * d') 190 | intermediates.memory_tokens = m 191 | 192 | if self.average_pool_embed: 193 | x = masked_mean(x, mask = orig_mask) 194 | 195 | # maybe linear project out 196 | 197 | out = self.project_out(x) if not return_embeddings else x 198 | 199 | if not return_embeddings and self.probabilistic: 200 | mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2) 201 | variance = log_var.exp() 202 | out = stack((mean, variance)) 203 | 204 | if return_intermediates: 205 | return out, intermediates 206 | 207 | if return_mems: 208 | hiddens = intermediates.hiddens 209 | new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens) 210 | return out, new_mems 211 | 212 | if return_attn: 213 | attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates) 214 | return out, attn_maps 215 | 216 | return out 217 | 218 | class ContinuousAutoregressiveWrapper(Module): 219 | def __init__( 220 | self, 221 | net: ContinuousTransformerWrapper, 222 | loss_fn: Module | None = None, 223 | equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token) 224 | ): 225 | super().__init__() 226 | self.net = net 227 | self.max_seq_len = net.max_seq_len 228 | 229 | probabilistic = net.probabilistic 230 | self.probabilistic = probabilistic 231 | 232 | loss_fn = default(loss_fn, nn.MSELoss(reduction = 'none') if not probabilistic else GaussianNLL()) 233 | 234 | self.loss_fn = loss_fn 235 | self.equal_loss_weight_batch = equal_loss_weight_batch 236 | 237 | @torch.no_grad() 238 | def generate( 239 | self, 240 | start_tokens, 241 | seq_len, 242 | temperature = 1., 243 | cache_kv = True, 244 | **kwargs 245 | ): 246 | should_cache_kv = cache_kv and self.net.can_cache_kv 247 | device = start_tokens.device 248 | 249 | was_training = self.net.training 250 | num_dims = len(start_tokens.shape) 251 | 252 | assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2' 253 | 254 | if num_dims == 2: 255 | start_tokens = start_tokens[None, :] 256 | 257 | b, t, _, device = *start_tokens.shape, start_tokens.device 258 | 259 | self.net.eval() 260 | out = start_tokens 261 | 262 | cache = None 263 | 264 | for _ in range(seq_len): 265 | x = out[:, -self.max_seq_len:] 266 | 267 | net_out, new_cache = self.net(x, cache = cache, return_intermediates = True, **kwargs) 268 | 269 | last_output = net_out[..., -1:, :] 270 | 271 | if self.probabilistic: 272 | mean, var = last_output 273 | stddev = var.clamp(min = 1e-5).sqrt() 274 | 275 | last_output = torch.normal(mean, stddev * temperature) 276 | 277 | out = cat((out, last_output), dim = -2) 278 | 279 | if should_cache_kv: 280 | cache = new_cache 281 | 282 | out = out[:, t:] 283 | 284 | if num_dims == 2: 285 | out = out.squeeze(0) 286 | 287 | self.net.train(was_training) 288 | return out 289 | 290 | def forward( 291 | self, 292 | x, 293 | **kwargs 294 | ): 295 | inp, target = x[:, :-1], x[:, 1:] 296 | 297 | assert 'prepend_embeds' not in kwargs 298 | 299 | # lens 300 | 301 | lens = kwargs.pop('lens', None) 302 | 303 | if exists(lens): 304 | assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both' 305 | seq_len, device = inp.shape[1], inp.device 306 | seq_arange = torch.arange(seq_len, device = device) 307 | mask = einx.less('j, i -> i j', seq_arange, lens) 308 | 309 | kwargs['mask'] = mask 310 | 311 | # mask 312 | 313 | mask = kwargs.get('mask', None) 314 | 315 | if exists(mask) and mask.shape[1] == x.shape[1]: 316 | mask = mask[:, :-1] 317 | kwargs['mask'] = mask 318 | 319 | out = self.net(inp, **kwargs) 320 | 321 | loss = self.loss_fn(out, target) 322 | 323 | if exists(mask): 324 | assert loss.ndim > 1, 'loss should not be reduced if mask is passed in' 325 | 326 | if self.equal_loss_weight_batch: 327 | loss = masked_mean(loss, mask) 328 | else: 329 | loss = loss[mask] 330 | 331 | return loss.mean() 332 | -------------------------------------------------------------------------------- /x_transformers/dpo.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | from torch.nn import Module 5 | import torch.nn.functional as F 6 | from x_transformers.x_transformers import TransformerWrapper 7 | 8 | import einx 9 | from einops import rearrange 10 | 11 | # helper functions 12 | 13 | def exists(v): 14 | return v is not None 15 | 16 | def freeze_all_layers_(module): 17 | for param in module.parameters(): 18 | param.requires_grad = False 19 | 20 | def log_prob_from_model_and_seq(model, seq): 21 | src_seq, tgt_seq = seq[:, :-1], seq[:, 1:] 22 | logits = model(src_seq) 23 | log_prob = logits.log_softmax(dim = -1) 24 | return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq) 25 | 26 | def masked_mean(log_probs, mask = None): 27 | if not exists(mask): 28 | return log_probs.mean(dim = -1) 29 | 30 | if mask.shape[-1] == (log_probs.shape[-1] + 1): 31 | mask = mask[:, :-1] 32 | 33 | log_probs = log_probs.masked_fill(~mask, 0.) 34 | num = log_probs.sum(dim = -1) 35 | den = mask.sum(dim = -1) 36 | return num / den.clamp(min = 1e-5) 37 | 38 | def maybe_and_mask(*masks): 39 | masks = [*filter(exists, masks)] 40 | if len(masks) == 0: 41 | return None 42 | 43 | mask, *rest_masks = masks 44 | for rest_mask in rest_masks: 45 | mask = mask & rest_mask 46 | 47 | return mask 48 | 49 | # main class 50 | 51 | class DPO(Module): 52 | def __init__( 53 | self, 54 | model: TransformerWrapper, 55 | *, 56 | beta = 0.1, 57 | pad_id = None 58 | ): 59 | super().__init__() 60 | self.policy_model = model 61 | 62 | self.ref_model = deepcopy(model) 63 | freeze_all_layers_(self.ref_model) 64 | 65 | self.beta = beta 66 | self.pad_id = pad_id 67 | 68 | def parameters(self): 69 | return self.policy_model.parameters() 70 | 71 | def forward( 72 | self, 73 | preferred_seq, 74 | unpreferred_seq, 75 | *, 76 | prompt_mask, 77 | preferred_seq_mask = None, 78 | unpreferred_seq_mask = None, 79 | ): 80 | assert preferred_seq.ndim == 2 81 | assert preferred_seq.shape == unpreferred_seq.shape 82 | 83 | if exists(self.pad_id): 84 | if not exists(preferred_seq_mask): 85 | preferred_seq_mask = preferred_seq != self.pad_id 86 | 87 | if not exists(unpreferred_seq_mask): 88 | unpreferred_seq_mask = unpreferred_seq != self.pad_id 89 | 90 | """ 91 | Following Appendix B in https://arxiv.org/abs/2305.18290 92 | """ 93 | 94 | with torch.no_grad(): 95 | self.ref_model.eval() 96 | ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq) 97 | ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq) 98 | 99 | policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq) 100 | policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq) 101 | 102 | # masked mean of log probs 103 | 104 | preferred_seq_mask = maybe_and_mask(~prompt_mask, preferred_seq_mask) 105 | unpreferred_seq_mask = maybe_and_mask(~prompt_mask, unpreferred_seq_mask) 106 | 107 | ref_preferred_logprob, policy_preferred_logprob = map(lambda t: masked_mean(t, preferred_seq_mask), (ref_preferred_logprob, policy_preferred_logprob)) 108 | ref_unpreferred_logprob, policy_unpreferred_logprob = map(lambda t: masked_mean(t, unpreferred_seq_mask), (ref_unpreferred_logprob, policy_unpreferred_logprob)) 109 | 110 | # main dpo formula 111 | 112 | policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob 113 | ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob 114 | 115 | losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios)) 116 | 117 | return losses.mean() 118 | -------------------------------------------------------------------------------- /x_transformers/entropy_based_tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from itertools import zip_longest 3 | 4 | import torch 5 | from torch import tensor 6 | import torch.nn.functional as F 7 | from torch.nn import Module 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | import einx 11 | from einops import repeat, rearrange, pack, unpack 12 | 13 | # helper functions 14 | 15 | def exists(v): 16 | return v is not None 17 | 18 | def default(v, d): 19 | return v if exists(v) else d 20 | 21 | def log(t, eps = 1e-20): 22 | return t.clamp(min = eps).log() 23 | 24 | def calc_entropy_from_logits(logits): 25 | prob = logits.softmax(dim = -1) 26 | return -(prob * log(prob)).sum(dim = -1) 27 | 28 | # entropy based tokenizer applied in byte-latent transformer paper 29 | # they use a simple entropy threshold for segmenting a string into variable sized tokens 30 | 31 | # https://arxiv.org/abs/2412.09871 32 | 33 | class EntropyBasedTokenizer(Module): 34 | def __init__( 35 | self, 36 | decoder: Module, 37 | entropy_threshold: float, 38 | max_token_size: int | None = None 39 | ): 40 | super().__init__() 41 | self.decoder = decoder 42 | self.entropy_threshold = entropy_threshold 43 | 44 | self.max_token_size = max_token_size 45 | 46 | @torch.no_grad() 47 | def forward( 48 | self, 49 | seq, # Float['b n'] | Float['n'] 50 | lens = None, # Int['b'] 51 | return_segmented_seq = False, 52 | decoder_forward_kwargs: dict = dict() 53 | ): 54 | no_batch_dim = seq.ndim == 1 55 | seq, maybe_batch_ps = pack((seq,), '* n') 56 | 57 | self.decoder.eval() 58 | 59 | is_var_length = exists(lens) 60 | batch, seq_len, device, max_token_size = *seq.shape, seq.device, self.max_token_size 61 | 62 | arange = torch.arange(seq_len, device = device) 63 | 64 | # forward through a small trained decoder and get the entropies of the logits 65 | 66 | logits = self.decoder(seq, **decoder_forward_kwargs) 67 | 68 | entropies = calc_entropy_from_logits(logits) 69 | 70 | # get length mask for boundaries 71 | 72 | mask = tensor(True, device = device) 73 | 74 | if is_var_length: 75 | mask = einx.less('n, b -> b n', arange, lens) 76 | 77 | # the mask for tokens that were of a sufficient surprise level 78 | 79 | over_thres_mask = (entropies >= self.entropy_threshold) & mask 80 | 81 | # needed for selecting out indices at entropy threshold mask 82 | 83 | arange_plus_one = arange + 1 84 | arange_plus_one = repeat(arange_plus_one, 'n -> b n', b = batch) 85 | 86 | # get a tensor of Int['b num_tokens'] with the token lengths, zero padded 87 | 88 | boundaries = over_thres_mask.clone() 89 | 90 | # set the boundary of the last token 91 | 92 | # if `lens` not given, assume always last token 93 | # but if `lens` were given, then properly set the index 94 | 95 | if not is_var_length: 96 | boundaries[..., -1] = True 97 | else: 98 | scatter_indices = rearrange(lens - 1, 'b -> b 1') 99 | boundaries.scatter_(-1, scatter_indices, True) 100 | 101 | # handle max token size - technique has the flaw that repeating subsequences are grouped into one large token 102 | 103 | if exists(max_token_size): 104 | token_ids = boundaries.cumsum(dim = -1) 105 | token_ids = F.pad(token_ids, (1, -1), value = 0) 106 | 107 | max_num_tokens = boundaries.sum(dim = -1).amax().item() 108 | token_ids_seq = torch.arange(max_num_tokens, device = device) 109 | 110 | token_mask = einx.equal('j, b i -> b j i', token_ids_seq, token_ids) 111 | 112 | token_sub_seq_arange = token_mask.cumsum(dim = -1) 113 | 114 | sub_seq_boundaries = (token_sub_seq_arange % max_token_size == 0) 115 | sub_seq_boundaries = (sub_seq_boundaries & token_mask).any(dim = 1) 116 | 117 | boundaries = boundaries | sub_seq_boundaries 118 | 119 | if exists(mask): 120 | boundaries = boundaries & mask 121 | 122 | # number of tokens 123 | 124 | num_tokens = boundaries.sum(dim = -1) 125 | 126 | # get number of tokens as well as derived indices 127 | 128 | indices = arange_plus_one[boundaries].split(num_tokens.tolist()) 129 | 130 | # get the token lengths 131 | 132 | token_lengths = [] 133 | 134 | for one_indices in indices: 135 | padded_indices = F.pad(one_indices, (1, 0), value = 0.) 136 | one_token_lengths = padded_indices[1:] - padded_indices[:-1] 137 | 138 | token_lengths.append(one_token_lengths) 139 | 140 | token_lengths = pad_sequence(token_lengths, batch_first = True) 141 | 142 | # early return 143 | 144 | if not return_segmented_seq: 145 | token_lengths, = unpack(token_lengths, maybe_batch_ps, '* num_tokens') 146 | 147 | return token_lengths 148 | 149 | # segment the sequence based on the token lengths 150 | 151 | lens = default(lens, (None,)) 152 | segmented_seq = [] 153 | 154 | for one_seq, one_len, one_token_length in zip_longest(seq, lens, token_lengths): 155 | 156 | if exists(one_len): 157 | one_seq = one_seq[:one_len] 158 | 159 | one_token_length = one_token_length[one_token_length > 0] 160 | 161 | splitted_seq = one_seq.split(one_token_length.tolist()) 162 | segmented_seq.append(splitted_seq) 163 | 164 | if no_batch_dim: 165 | segmented_seq = segmented_seq[0] 166 | 167 | return segmented_seq 168 | -------------------------------------------------------------------------------- /x_transformers/multi_input.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torch.nn import Module, ModuleDict 6 | import torch.nn.functional as F 7 | 8 | from typing import Dict 9 | 10 | from einops import pack, repeat, unpack 11 | 12 | from x_transformers.x_transformers import ( 13 | AttentionLayers, 14 | ScaledSinusoidalEmbedding, 15 | AbsolutePositionalEmbedding, 16 | LayerIntermediates, 17 | LayerNorm, 18 | always, 19 | pad_at_dim, 20 | is_empty, 21 | ) 22 | 23 | # helper functions 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def default(val, d): 29 | if exists(val): 30 | return val 31 | return d() if callable(d) else d 32 | 33 | 34 | class MultiInputTransformerWrapper(Module): 35 | def __init__( 36 | self, 37 | *, 38 | num_tokens: Dict[str, int] = dict(), 39 | max_seq_len, 40 | attn_layers: AttentionLayers, 41 | emb_dim = None, 42 | max_mem_len = 0, 43 | shift_mem_down = 0, 44 | emb_dropout = 0., 45 | post_emb_norm = False, 46 | num_memory_tokens = None, 47 | memory_tokens_interspersed_every = None, 48 | return_only_embed = False, 49 | use_abs_pos_emb = True, 50 | scaled_sinu_pos_emb = False, 51 | emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1 52 | attn_z_loss_weight = 1e-4, 53 | ): 54 | super().__init__() 55 | 56 | dim = attn_layers.dim 57 | emb_dim = default(emb_dim, dim) 58 | self.emb_dim = emb_dim 59 | 60 | self.max_seq_len = max_seq_len 61 | self.max_mem_len = max_mem_len 62 | self.shift_mem_down = shift_mem_down 63 | 64 | no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb) 65 | 66 | if no_abs_pos_emb: 67 | self.pos_emb = always(0) 68 | elif scaled_sinu_pos_emb: 69 | self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) 70 | else: 71 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) 72 | 73 | # additional embeddings - say type embedding from BERT 74 | 75 | self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()}) 76 | 77 | # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 78 | 79 | self.emb_frac_gradient = emb_frac_gradient 80 | 81 | self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity() 82 | self.emb_dropout = nn.Dropout(emb_dropout) 83 | 84 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 85 | self.attn_layers = attn_layers 86 | 87 | # output head, usually to logits of num_tokens 88 | 89 | if return_only_embed: 90 | self.to_logits = None 91 | else: 92 | self.to_logits = ModuleDict({name: nn.Linear(dim, logits_dim, bias = False) for name, logits_dim in num_tokens.items()}) 93 | 94 | # memory tokens (like [cls]) from Memory Transformers paper 95 | 96 | num_memory_tokens = default(num_memory_tokens, 0) 97 | self.num_memory_tokens = num_memory_tokens 98 | if num_memory_tokens > 0: 99 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 100 | 101 | self.memory_tokens_interspersed_every = memory_tokens_interspersed_every 102 | 103 | # whether can do cached kv decoding 104 | 105 | self.can_cache_kv = self.num_memory_tokens == 0 106 | self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb 107 | 108 | def forward( 109 | self, 110 | x: Dict[str, Tensor], 111 | return_embeddings = False, 112 | return_logits_and_embeddings = False, 113 | return_intermediates = False, 114 | mask = None, 115 | return_mems = False, 116 | return_attn = False, 117 | mems = None, 118 | mem_masks = None, 119 | pos = None, 120 | prepend_embeds = None, 121 | prepend_mask = None, 122 | sum_embeds = None, 123 | return_attn_z_loss = False, 124 | attn_z_loss_weight = 1e-4, 125 | seq_start_pos = None, 126 | cache: LayerIntermediates | None = None, 127 | **kwargs 128 | ): 129 | assert not is_empty(x) 130 | first_input = list(x.values())[0] 131 | 132 | b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient 133 | 134 | return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss 135 | return_embeddings = return_embeddings | (not exists(self.to_logits)) 136 | 137 | # token embedding 138 | 139 | assert len(x) == len(self.embeds) 140 | 141 | token_emb = 0. 142 | 143 | for name, embed_id in x.items(): 144 | embed_key = f'{name}_embed' 145 | 146 | assert embed_key in self.embeds 147 | embed = self.embeds[embed_key](embed_id) 148 | 149 | token_emb = token_emb + embed 150 | 151 | # absolute positional embedding 152 | 153 | external_pos_emb = exists(pos) and pos.dtype != torch.long 154 | pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos 155 | 156 | token_emb = token_emb + pos_emb 157 | 158 | # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training 159 | 160 | if exists(sum_embeds): 161 | token_emb = token_emb + sum_embeds 162 | 163 | # set back to `x` 164 | 165 | x = token_emb 166 | 167 | # post embedding norm, purportedly leads to greater stabilization 168 | 169 | x = self.post_emb_norm(x) 170 | 171 | # whether to append embeds, as in PaLI, for image embeddings 172 | 173 | if exists(prepend_embeds): 174 | prepend_seq, prepend_dim = prepend_embeds.shape[1:] 175 | assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions' 176 | 177 | x = torch.cat((prepend_embeds, x), dim = -2) 178 | 179 | if exists(prepend_mask) or exists(mask): 180 | mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool)) 181 | prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool)) 182 | 183 | mask = torch.cat((prepend_mask, mask), dim = -1) 184 | 185 | # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model 186 | 187 | if emb_frac_gradient < 1: 188 | assert emb_frac_gradient > 0 189 | x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient) 190 | 191 | # embedding dropout 192 | 193 | x = self.emb_dropout(x) 194 | 195 | x = self.project_emb(x) 196 | 197 | if has_memory_tokens: 198 | mem_every = self.memory_tokens_interspersed_every 199 | 200 | if exists(mem_every): 201 | assert mem_every > 0 202 | assert isinstance(self.attn_layers, Decoder), 'only for decoder' 203 | next_seq_len = math.ceil(n / mem_every) * mem_every 204 | 205 | x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.) 206 | x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every) 207 | 208 | mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0]) 209 | x, mem_packed_shape = pack((mem, x), 'b * d') 210 | 211 | # auto-handle masking after appending memory tokens 212 | if not exists(mem_every) and exists(mask): 213 | mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True) 214 | 215 | if exists(mem_every): 216 | x = rearrange(x, '(b n) m d -> b (n m) d', b = b) 217 | 218 | if self.shift_mem_down and exists(mems): 219 | mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] 220 | mems = [*mems_r, *mems_l] 221 | 222 | x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs) 223 | 224 | # handle memories post-attention 225 | 226 | if has_memory_tokens: 227 | if exists(mem_every): 228 | x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems)) 229 | 230 | mem, x = unpack(x, mem_packed_shape, 'b * d') 231 | 232 | intermediates.memory_tokens = mem 233 | 234 | if exists(mem_every): 235 | x = rearrange(x, '(b n) m d -> b (n m) d', b = b) 236 | 237 | x = x[:, :n] 238 | 239 | # projecting to logits 240 | 241 | if not return_embeddings: 242 | logits = {name: fn(x) for name, fn in self.to_logits.items()} 243 | 244 | # different returns 245 | 246 | if return_logits_and_embeddings: 247 | out = (logits, x) 248 | elif return_embeddings: 249 | out = x 250 | else: 251 | out = logits 252 | 253 | # aux loss 254 | 255 | if return_attn_z_loss: 256 | pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates] 257 | intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight) 258 | return_intermediates = True 259 | 260 | if return_mems: 261 | hiddens = intermediates.hiddens 262 | new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens 263 | new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems] 264 | 265 | if not return_intermediates: 266 | return out, new_mems 267 | 268 | intermediates.mems = new_mems 269 | 270 | if return_intermediates: 271 | return out, intermediates 272 | 273 | if return_attn: 274 | attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] 275 | return out, attn_maps 276 | 277 | return out 278 | -------------------------------------------------------------------------------- /x_transformers/neo_mlp.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torch import nn, tensor, pi, is_tensor 5 | import torch.nn.functional as F 6 | from torch.nn import Module, ModuleList 7 | 8 | from einops import rearrange, repeat, einsum, pack, unpack 9 | 10 | from x_transformers.x_transformers import ( 11 | Encoder 12 | ) 13 | 14 | # helpers 15 | 16 | def exists(v): 17 | return v is not None 18 | 19 | def default(v, d): 20 | return v if exists(v) else d 21 | 22 | # random fourier 23 | 24 | class RandomFourierEmbed(Module): 25 | 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.proj = nn.Linear(1, dim) 29 | self.proj.requires_grad_(False) 30 | 31 | def forward( 32 | self, 33 | times, 34 | ): 35 | 36 | times = rearrange(times, '... -> ... 1') 37 | rand_proj = self.proj(times) 38 | return torch.cos(2 * pi * rand_proj) 39 | 40 | # class 41 | 42 | class NeoMLP(Module): 43 | """ https://openreview.net/forum?id=A8Vuf2e8y6 """ 44 | """ https://haian-jin.github.io/projects/LVSM/ """ 45 | 46 | def __init__( 47 | self, 48 | *, 49 | dim_in, 50 | dim_hidden, 51 | dim_out, 52 | dim_model, 53 | depth, 54 | encoder_kwargs: dict = dict( 55 | attn_dim_head = 16, 56 | heads = 4 57 | ) 58 | ): 59 | super().__init__() 60 | 61 | # input and output embeddings 62 | 63 | self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model)) 64 | self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model)) 65 | self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model)) 66 | 67 | nn.init.normal_(self.input_embed, std = 0.02) 68 | nn.init.normal_(self.hidden_embed, std = 0.02) 69 | nn.init.normal_(self.output_embed, std = 0.02) 70 | 71 | # they use random fourier for continuous features 72 | 73 | self.random_fourier = nn.Sequential( 74 | RandomFourierEmbed(dim_model), 75 | nn.Linear(dim_model, dim_model) 76 | ) 77 | 78 | # hidden dimensions of mlp replaced with nodes with message passing 79 | # which comes back to self attention as a fully connected graph. 80 | 81 | self.transformer = Encoder( 82 | dim = dim_model, 83 | depth = depth, 84 | **encoder_kwargs 85 | ) 86 | 87 | # output 88 | 89 | self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model)) 90 | self.to_output_bias = nn.Parameter(torch.zeros(dim_out)) 91 | 92 | def forward( 93 | self, 94 | x, 95 | return_embeds = False 96 | ): 97 | no_batch = x.ndim == 1 98 | 99 | if no_batch: 100 | x = rearrange(x, '... -> 1 ...') 101 | 102 | batch = x.shape[0] 103 | 104 | fouriered_input = self.random_fourier(x) 105 | 106 | # add fouriered input to the input embedding 107 | 108 | input_embed = fouriered_input + self.input_embed 109 | 110 | hidden_embed, output_embed = tuple(repeat(t, '... -> b ...', b = batch) for t in (self.hidden_embed, self.output_embed)) 111 | 112 | # pack all the inputs into one string of tokens for self attention 113 | 114 | embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d') 115 | 116 | # attention is all you need 117 | 118 | embed = self.transformer(embed) 119 | 120 | # unpack 121 | 122 | input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d') 123 | 124 | # project for output 125 | 126 | output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n') 127 | output = output + self.to_output_bias 128 | 129 | if no_batch: 130 | output = rearrange(output, '1 ... -> ...') 131 | 132 | if not return_embeds: 133 | return output 134 | 135 | return output, (input_embed, hidden_embed, output_embed) 136 | -------------------------------------------------------------------------------- /x_transformers/nonautoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | from random import random 3 | from contextlib import nullcontext 4 | from collections import namedtuple 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from einops import rearrange, repeat, pack, unpack 11 | 12 | from x_transformers.x_transformers import TransformerWrapper 13 | from typing import Optional 14 | 15 | # constants 16 | 17 | Losses = namedtuple('Losses', ['loss', 'generator_loss', 'critic_loss']) 18 | 19 | # helper functions 20 | 21 | def exists(val): 22 | return val is not None 23 | 24 | def default(val, d): 25 | return val if exists(val) else d 26 | 27 | # sampling helpers 28 | 29 | def top_k(logits, thres = 0.9): 30 | k = math.ceil((1 - thres) * logits.shape[-1]) 31 | val, ind = logits.topk(k, dim = -1) 32 | probs = torch.full_like(logits, float('-inf')) 33 | probs.scatter_(2, ind, val) 34 | return probs 35 | 36 | def log(t, eps = 1e-10): 37 | return torch.log(t + eps) 38 | 39 | def gumbel_noise(t): 40 | noise = torch.zeros_like(t).uniform_(0, 1) 41 | return -log(-log(noise)) 42 | 43 | def gumbel_sample(t, temperature = 1., dim = -1): 44 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) 45 | 46 | # prob helpers 47 | 48 | def sample_prob(prob): 49 | return random() < prob 50 | 51 | def coin_flip(): 52 | return sample_prob(0.5) 53 | 54 | # tensor helpers 55 | 56 | def get_mask_subset_prob(mask, prob, min_mask = 0): 57 | batch, seq, device = *mask.shape, mask.device 58 | num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask) 59 | logits = torch.rand((batch, seq), device = device) 60 | logits = logits.masked_fill(~mask, -1) 61 | 62 | randperm = logits.argsort(dim = -1).argsort(dim = -1).float() 63 | 64 | num_padding = (~mask).sum(dim = -1, keepdim = True) 65 | randperm -= num_padding 66 | 67 | subset_mask = randperm < num_to_mask 68 | subset_mask.masked_fill_(~mask, False) 69 | return subset_mask 70 | 71 | # schedules 72 | 73 | def linear_schedule(t): 74 | return 1 - t 75 | 76 | def cosine_schedule(t): 77 | """ https://arxiv.org/abs/2202.04200 """ 78 | return torch.cos(t * math.pi / 2) 79 | 80 | # self token critic 81 | # inspired by Nijkamp et al. - https://aclanthology.org/2021.naacl-main.409/ 82 | 83 | class SelfCritic(nn.Module): 84 | def __init__(self, net): 85 | super().__init__() 86 | self.net = net 87 | 88 | dim = net.attn_layers.dim 89 | self.to_logits = nn.Linear(dim, 1) 90 | 91 | def forward(self, x): 92 | embed = self.net(x, return_embeddings = True) 93 | return self.to_logits(embed) 94 | 95 | class NonAutoregressiveWrapper(nn.Module): 96 | """ 97 | https://arxiv.org/abs/1904.09324 98 | https://arxiv.org/abs/2202.04200 99 | """ 100 | 101 | def __init__( 102 | self, 103 | net, 104 | *, 105 | mask_id, 106 | steps = 18, 107 | self_cond = False, 108 | self_cond_train_prob = 0.75, 109 | no_replace_prob = 0.15, # which percentage of the tokens masked will stay the same, done in original MLM paper 110 | random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper 111 | schedule = 'linear', 112 | can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked 113 | token_critic: Optional[TransformerWrapper] = None, 114 | self_token_critic = False, 115 | critic_loss_weight = 1. 116 | ): 117 | super().__init__() 118 | assert not (self_token_critic and exists(token_critic)) 119 | 120 | self.net = net 121 | 122 | dim = net.emb_dim 123 | self.dim = dim 124 | self.num_tokens = net.num_tokens 125 | 126 | self.mask_id = mask_id 127 | 128 | # afaict, maskgit paper did not do this 129 | # but may help for self conditioning, as used successfully in original BERT 130 | 131 | self.no_replace_prob = no_replace_prob 132 | self.random_token_prob = random_token_prob 133 | 134 | self.max_seq_len = net.max_seq_len 135 | self.steps = steps 136 | 137 | if callable(schedule): 138 | self.schedule_fn = schedule 139 | if schedule == 'linear': 140 | self.schedule_fn = linear_schedule 141 | elif schedule == 'cosine': 142 | self.schedule_fn = cosine_schedule 143 | else: 144 | raise ValueError(f'invalid schedule {schedule}') 145 | 146 | self.can_mask_prev_unmasked = can_mask_prev_unmasked 147 | 148 | # self conditioning 149 | 150 | self.self_cond = self_cond 151 | 152 | if self_cond: 153 | self.null_embed = nn.Parameter(torch.randn(dim)) 154 | self.to_self_cond = nn.Linear(dim, dim, bias = False) if self_cond else None 155 | self.self_cond_train_prob = self_cond_train_prob 156 | 157 | # token critic 158 | 159 | self.token_critic = token_critic 160 | 161 | if self_token_critic: 162 | self.token_critic = SelfCritic(net) 163 | 164 | self.critic_loss_weight = critic_loss_weight 165 | 166 | @torch.no_grad() 167 | def generate( 168 | self, 169 | batch_size = None, 170 | start_temperature = 1., 171 | filter_thres = 0.7, 172 | noise_level_scale = 1., 173 | **kwargs 174 | ): 175 | sample_one = not exists(batch_size) 176 | batch_size = default(batch_size, 1) 177 | 178 | device = next(self.net.parameters()).device 179 | 180 | was_training = self.training 181 | self.eval() 182 | 183 | times = torch.linspace(0., 1., self.steps + 1) 184 | 185 | # sequence starts off as all masked 186 | 187 | shape = (batch_size, self.max_seq_len) 188 | 189 | seq = torch.full(shape, self.mask_id, device = device) 190 | mask = torch.full(shape, True, device = device) 191 | 192 | # slowly demask 193 | 194 | all_mask_num_tokens = (self.schedule_fn(times[1:]) * self.max_seq_len).long() 195 | 196 | # self conditioning 197 | 198 | has_self_cond = self.self_cond 199 | last_embed = self.null_embed if has_self_cond else None 200 | 201 | for mask_num_tokens, steps_until_x0 in zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))): 202 | 203 | self_cond = self.to_self_cond(last_embed) if has_self_cond else None 204 | 205 | logits, embeds = self.net( 206 | seq, 207 | sum_embeds = self_cond, 208 | return_logits_and_embeddings = True, 209 | **kwargs 210 | ) 211 | 212 | if has_self_cond: 213 | last_embed = embeds 214 | 215 | if exists(filter_thres): 216 | logits = top_k(logits, filter_thres) 217 | 218 | annealing_scale = steps_until_x0 / self.steps 219 | temperature = start_temperature * annealing_scale 220 | 221 | probs = (logits / max(temperature, 1e-3)).softmax(dim = -1) 222 | 223 | sampled_ids = gumbel_sample(logits, temperature = max(temperature, 1e-3)) 224 | 225 | seq = torch.where(mask, sampled_ids, seq) 226 | 227 | if exists(self.token_critic): 228 | scores = self.token_critic(seq) 229 | scores = rearrange(scores, 'b n 1 -> b n') 230 | scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale 231 | else: 232 | scores = 1 - logits.softmax(dim = -1) 233 | scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1')) 234 | scores = rearrange(scores, 'b n 1 -> b n') 235 | 236 | if mask_num_tokens == 0: 237 | pass 238 | 239 | if not self.can_mask_prev_unmasked: 240 | scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max) 241 | 242 | mask_indices = scores.topk(mask_num_tokens, dim = -1).indices 243 | mask = torch.zeros_like(scores, dtype = torch.bool).scatter(1, mask_indices, True) 244 | seq = seq.masked_fill(mask, self.mask_id) 245 | 246 | self.train(was_training) 247 | 248 | if sample_one: 249 | seq = rearrange(seq, '1 n -> n') 250 | 251 | return seq 252 | 253 | def forward( 254 | self, 255 | x, 256 | only_train_generator = False, 257 | only_train_critic = False, 258 | generator_sample_temperature = None, 259 | **kwargs 260 | ): 261 | b, n, device = *x.shape, x.device 262 | assert n == self.max_seq_len 263 | 264 | orig_seq = x.clone() 265 | 266 | rand_times = torch.empty(b, device = device).uniform_(0, 1) 267 | batched_randperm = torch.rand((b, n), device = device).argsort(dim = -1).float() 268 | 269 | rand_probs = self.schedule_fn(rand_times) 270 | num_tokens_mask = (rand_probs * n).clamp(min = 1.) 271 | mask = batched_randperm < rearrange(num_tokens_mask, 'b -> b 1') 272 | 273 | # to ensure all tokens produce embeddings, instead of just the ones with [mask] input, as done in seminal BERT MLM paper 274 | # potentially needed for self-conditioning (on embedding) to work well 275 | 276 | replace_mask_id_mask = mask.clone() 277 | frac_seq_left = 1. 278 | 279 | if self.no_replace_prob > 0. and coin_flip(): 280 | frac_seq_left -= self.no_replace_prob 281 | 282 | no_replace_prob_mask = get_mask_subset_prob(mask, self.no_replace_prob) 283 | replace_mask_id_mask &= ~no_replace_prob_mask 284 | 285 | if self.random_token_prob > 0. and coin_flip(): 286 | random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left) 287 | random_tokens = torch.randint(0, self.num_tokens, (b, n), device = device) 288 | 289 | x = torch.where(random_token_prob_mask, random_tokens, x) 290 | replace_mask_id_mask &= ~random_token_prob_mask 291 | 292 | masked = torch.where(replace_mask_id_mask, self.mask_id, x) 293 | 294 | # self conditioning 295 | 296 | if self.self_cond: 297 | self_cond = self.null_embed 298 | 299 | if sample_prob(self.self_cond_train_prob): 300 | with torch.no_grad(): 301 | self_cond = self.net(masked, return_embeddings = True, **kwargs).detach() 302 | 303 | kwargs.update(sum_embeds = self.to_self_cond(self_cond)) 304 | 305 | # logits 306 | 307 | context = torch.no_grad if only_train_critic else nullcontext 308 | 309 | with context(): 310 | logits = self.net(masked, **kwargs) 311 | 312 | loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss 313 | 314 | # cross entropy loss 315 | 316 | loss = loss_fn( 317 | logits[mask], 318 | orig_seq[mask] 319 | ) 320 | 321 | if not exists(self.token_critic) or only_train_generator: 322 | return Losses(loss, loss, None) 323 | 324 | sampled_ids = gumbel_sample(logits, temperature = default(generator_sample_temperature, random())) 325 | generated = torch.where(mask, sampled_ids, orig_seq) 326 | 327 | critic_logits = self.token_critic(generated) 328 | critic_labels = (sampled_ids != orig_seq).float() 329 | 330 | critic_loss = F.binary_cross_entropy_with_logits( 331 | rearrange(critic_logits, '... 1 -> ...'), 332 | critic_labels 333 | ) 334 | 335 | # determine losses to be returned based on what researcher wants to train 336 | 337 | if only_train_critic: 338 | total_loss = critic_loss 339 | loss = None 340 | else: 341 | total_loss = loss + critic_loss * self.critic_loss_weight 342 | 343 | return Losses(total_loss, loss, critic_loss) 344 | -------------------------------------------------------------------------------- /x_transformers/xl_autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, pack, unpack 8 | from x_transformers.autoregressive_wrapper import top_p, top_k, eval_decorator 9 | 10 | # helper functions 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def divisible_by(numer, denom): 16 | return (numer % denom) == 0 17 | 18 | # xl autoregressive wrapper class 19 | 20 | class XLAutoregressiveWrapper(nn.Module): 21 | def __init__( 22 | self, 23 | net, 24 | ignore_index = -100, 25 | pad_value = 0 26 | ): 27 | super().__init__() 28 | self.pad_value = pad_value 29 | self.ignore_index = ignore_index 30 | 31 | self.net = net 32 | self.max_seq_len = net.max_seq_len 33 | 34 | @torch.no_grad() 35 | @eval_decorator 36 | def generate( 37 | self, 38 | start_tokens, 39 | seq_len, 40 | eos_token = None, 41 | temperature = 1., 42 | filter_logits_fn = top_k, 43 | filter_kwargs: dict = dict(), 44 | mems = None, 45 | **kwargs 46 | ): 47 | device, max_seq_len = start_tokens.device, self.max_seq_len 48 | 49 | start_tokens, ps = pack([start_tokens], '* n') 50 | 51 | b, t = start_tokens.shape 52 | 53 | *all_leading_tokens, _ = start_tokens.split(max_seq_len, dim = -1) 54 | 55 | # catch the memory up to the current segment 56 | 57 | for leading_tokens in all_leading_tokens: 58 | _, mems = self.net( 59 | leading_tokens, 60 | mems = mems, 61 | return_mems = True, 62 | **kwargs 63 | ) 64 | 65 | # now start sampling from the current segment 66 | 67 | curr_pos = len(all_leading_tokens) * max_seq_len 68 | curr_mems = mems 69 | 70 | cache = None 71 | out = start_tokens 72 | 73 | for _ in range(seq_len): 74 | curr_segment_len = out.shape[-1] 75 | is_last_segment_tokens = divisible_by(curr_segment_len, max_seq_len) 76 | 77 | x = out[:, curr_pos:] 78 | 79 | logits, cache = self.net( 80 | x, 81 | mems = curr_mems, 82 | cache = cache, 83 | return_mems = True, 84 | return_intermediates = True, 85 | **kwargs 86 | ) 87 | 88 | mems = cache.mems 89 | 90 | logits = logits[:, -1] 91 | filtered_logits = filter_logits_fn(logits, **filter_kwargs) 92 | probs = F.softmax(filtered_logits / temperature, dim=-1) 93 | 94 | sample = torch.multinomial(probs, 1) 95 | 96 | if is_last_segment_tokens: 97 | curr_pos = curr_segment_len 98 | curr_mems = mems 99 | 100 | out = torch.cat((out, sample), dim=-1) 101 | 102 | if exists(eos_token): 103 | is_eos_tokens = (out == eos_token) 104 | 105 | if is_eos_tokens.any(dim = -1).all(): 106 | # mask out everything after the eos tokens 107 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 108 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 109 | out = out.masked_fill(mask, self.pad_value) 110 | break 111 | 112 | out = out[:, t:] 113 | 114 | out, = unpack(out, ps, '* n') 115 | 116 | return out 117 | 118 | def forward( 119 | self, 120 | x, 121 | mems = None, 122 | **kwargs 123 | ): 124 | ignore_index, max_seq_len = self.ignore_index, self.max_seq_len 125 | 126 | x, labels = x[:, :-1], x[:, 1:] 127 | 128 | seq_len = x.shape[1] 129 | 130 | # prepare chunks 131 | 132 | split_x = x.split(max_seq_len, dim = -1) 133 | split_labels = labels.split(max_seq_len, dim = -1) 134 | loss_weights = tuple((t.shape[-1] / seq_len) for t in split_x) 135 | 136 | loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss 137 | 138 | # go through each chunk and derive weighted losses 139 | 140 | total_loss = 0. 141 | 142 | for chunk, chunk_labels, loss_weight in zip(split_x, split_labels, loss_weights): 143 | 144 | logits, mems = self.net( 145 | chunk, 146 | mems = mems, 147 | return_mems = True, 148 | **kwargs 149 | ) 150 | 151 | loss = loss_fn( 152 | rearrange(logits, 'b n c -> b c n'), 153 | chunk_labels, 154 | ignore_index = ignore_index 155 | ) 156 | 157 | total_loss = total_loss + loss * loss_weight 158 | 159 | return total_loss 160 | -------------------------------------------------------------------------------- /x_transformers/xval.py: -------------------------------------------------------------------------------- 1 | """ 2 | regular transformer with discrete tokens, but continuous for number 3 | generalizes better for arithmetic 4 | https://arxiv.org/abs/2310.02989 5 | """ 6 | 7 | import torch 8 | from torch import nn, Tensor 9 | import torch.nn.functional as F 10 | 11 | from typing import Callable 12 | from collections import namedtuple 13 | 14 | from einops import rearrange 15 | from einops.layers.torch import Rearrange 16 | 17 | from x_transformers.x_transformers import ( 18 | AttentionLayers, 19 | TokenEmbedding, 20 | ScaledSinusoidalEmbedding, 21 | AbsolutePositionalEmbedding, 22 | always 23 | ) 24 | 25 | from x_transformers.autoregressive_wrapper import ( 26 | top_k, 27 | top_p 28 | ) 29 | 30 | # constants 31 | 32 | LossBreakdown = namedtuple('LossBreakdown', ['cross_entropy_loss', 'numerical_mse_loss']) 33 | 34 | GenerateReturn = namedtuple('GenerateReturn', ['sampled_token_ids', 'sampled_numbers', 'is_number_mask']) 35 | 36 | # helper functions 37 | 38 | def exists(val): 39 | return val is not None 40 | 41 | def default(val, d): 42 | if exists(val): 43 | return val 44 | return d() if callable(d) else d 45 | 46 | # main classes 47 | 48 | class XValTransformerWrapper(nn.Module): 49 | def __init__( 50 | self, 51 | *, 52 | num_tokens, 53 | max_seq_len, 54 | numerical_token_id, 55 | attn_layers: AttentionLayers, 56 | emb_dim = None, 57 | logits_dim = None, 58 | tie_embedding = False, 59 | max_mem_len = 0, 60 | num_memory_tokens = None, 61 | emb_dropout = 0., 62 | use_abs_pos_emb = True, 63 | scaled_sinu_pos_emb = False 64 | ): 65 | super().__init__() 66 | dim = attn_layers.dim 67 | emb_dim = default(emb_dim, dim) 68 | 69 | self.emb_dim = emb_dim 70 | self.token_emb = TokenEmbedding(emb_dim, num_tokens) 71 | 72 | self.numerical_token_id = numerical_token_id 73 | 74 | self.max_seq_len = max_seq_len 75 | 76 | self.max_mem_len = max_mem_len 77 | 78 | if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb): 79 | self.pos_emb = always(0) 80 | elif scaled_sinu_pos_emb: 81 | self.pos_emb = ScaledSinusoidalEmbedding(dim) 82 | else: 83 | self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) 84 | 85 | self.emb_dropout = nn.Dropout(emb_dropout) 86 | 87 | # memory tokens 88 | 89 | num_memory_tokens = default(num_memory_tokens, 0) 90 | self.has_memory_tokens = num_memory_tokens > 0 91 | 92 | if num_memory_tokens > 0: 93 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 94 | 95 | # attention layers 96 | 97 | self.attn_layers = attn_layers 98 | 99 | # to logits 100 | 101 | logits_dim = default(logits_dim, num_tokens) 102 | self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t() 103 | 104 | self.to_numerical_output = nn.Sequential( 105 | nn.Linear(dim, 1), 106 | Rearrange('... 1 -> ...') 107 | ) 108 | 109 | def forward( 110 | self, 111 | x: Tensor, 112 | x_num: Tensor, 113 | return_embeddings = False, 114 | return_intermediates = False, 115 | return_mems = False, 116 | mask = None, 117 | return_attn = False, 118 | mems = None, 119 | pos = None, 120 | prepend_embeds = None, 121 | **kwargs 122 | ): 123 | assert x.shape == x_num.shape 124 | 125 | batch = x.shape[0] 126 | 127 | is_number_mask = x == self.numerical_token_id 128 | 129 | x = self.token_emb(x) 130 | 131 | scale = torch.where(is_number_mask, x_num, 1.) 132 | scale = rearrange(scale, '... -> ... 1') 133 | 134 | x = x * scale 135 | 136 | x = x + self.pos_emb(x, pos = pos) 137 | 138 | # memory tokens 139 | 140 | if self.has_memory_tokens: 141 | m = repeat(self.memory_tokens, 'm d -> b m d', b = batch) 142 | x, mem_ps = pack([m, x], 'b * d') 143 | 144 | if exists(mask): 145 | num_mems = m.shape[-2] 146 | mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True) 147 | 148 | # whether to append embeds, as in PaLI, for image embeddings 149 | 150 | if exists(prepend_embeds): 151 | _, prepend_dim = prepend_embeds.shape[1:] 152 | assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions' 153 | 154 | x = torch.cat((prepend_embeds, x), dim = -2) 155 | 156 | x = self.emb_dropout(x) 157 | 158 | # attention layers 159 | 160 | x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) 161 | 162 | # splice out memory tokens 163 | 164 | if self.has_memory_tokens: 165 | m, x = unpack(x, mem_ps, 'b * d') 166 | intermediates.memory_tokens = m 167 | 168 | if not return_embeddings: 169 | logits = self.to_logits(x) 170 | numerical_pred = self.to_numerical_output(x) 171 | out = (logits, numerical_pred) 172 | else: 173 | out = x 174 | 175 | if return_intermediates: 176 | return out, intermediates 177 | 178 | if return_mems: 179 | hiddens = intermediates.hiddens 180 | new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens) 181 | return out, new_mems 182 | 183 | if return_attn: 184 | attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates) 185 | return out, attn_maps 186 | 187 | return out 188 | 189 | class XValAutoregressiveWrapper(nn.Module): 190 | def __init__( 191 | self, 192 | net: XValTransformerWrapper, 193 | ignore_index = -100, 194 | pad_value = 0, 195 | numerical_loss_weight = 1. 196 | ): 197 | super().__init__() 198 | self.net = net 199 | self.max_seq_len = net.max_seq_len 200 | self.numerical_loss_weight = numerical_loss_weight 201 | self.ignore_index = ignore_index 202 | 203 | @torch.no_grad() 204 | def generate( 205 | self, 206 | start_tokens: Tensor, 207 | start_numbers: Tensor, 208 | seq_len, 209 | filter_logits_fn: Callable = top_k, 210 | filter_kwargs: dict = dict(), 211 | temperature = 1., 212 | **kwargs 213 | ): 214 | device = start_tokens.device 215 | was_training = self.net.training 216 | num_dims = len(start_tokens.shape) 217 | 218 | assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2' 219 | assert start_tokens.shape == start_numbers.shape 220 | 221 | b, t, device = *start_tokens.shape, start_tokens.device 222 | 223 | self.net.eval() 224 | out = start_tokens 225 | num_out = start_numbers 226 | 227 | for _ in range(seq_len): 228 | x = out[:, -self.max_seq_len:] 229 | x_num = num_out[:, -self.max_seq_len:] 230 | 231 | logits, numerical_pred = self.net(x, x_num, **kwargs) 232 | 233 | last_logits = logits[:, -1] 234 | last_num_pred = numerical_pred[:, -1:] 235 | 236 | filtered_logits = filter_logits_fn(last_logits, **filter_kwargs) 237 | 238 | probs = F.softmax(filtered_logits / temperature, dim=-1) 239 | 240 | sample = torch.multinomial(probs, 1) 241 | 242 | out = torch.cat((out, sample), dim = -1) 243 | num_out = torch.cat((num_out, last_num_pred), dim = -1) 244 | 245 | out = out[:, t:] 246 | num_out = num_out[:, t:] 247 | 248 | is_number = out == self.net.numerical_token_id 249 | num_out = torch.where(is_number, num_out, float('nan')) 250 | 251 | self.net.train(was_training) 252 | return GenerateReturn(out, num_out, is_number) 253 | 254 | def forward( 255 | self, 256 | x: Tensor, 257 | x_num: Tensor, 258 | return_loss_breakdown = False, 259 | **kwargs 260 | ): 261 | inp, target = x[:, :-1], x[:, 1:] 262 | x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:] 263 | 264 | # ignore index 265 | 266 | target_mask = target != self.ignore_index 267 | 268 | # key padding mask 269 | 270 | mask = kwargs.get('mask', None) 271 | if exists(mask): 272 | target_mask &= mask 273 | 274 | if mask.shape[1] == x.shape[1]: 275 | mask = mask[:, :-1] 276 | kwargs['mask'] = mask 277 | 278 | logits, numerical_pred = self.net(inp, x_num_inp, **kwargs) 279 | 280 | logits = rearrange(logits, 'b n c -> b c n') 281 | 282 | cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index) 283 | 284 | # protect against nan in `x_num` input tensor 285 | 286 | target_is_number_mask = target == self.net.numerical_token_id 287 | x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.) 288 | 289 | # numerical mse loss 290 | 291 | numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none') 292 | 293 | numerical_mse_loss = numerical_mse_loss * target_mask 294 | numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.) 295 | 296 | # combine losses 297 | 298 | loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight 299 | 300 | loss = loss[target_mask] 301 | loss = loss.mean() 302 | 303 | if not return_loss_breakdown: 304 | return loss 305 | 306 | return loss, LossBreakdown(cross_entropy_loss, numerical_mse_loss) 307 | --------------------------------------------------------------------------------