├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── enwik8 │ └── enwik8.gz └── flowers │ └── labels.txt ├── pyproject.toml ├── tests └── test_transfusion.py ├── train_image_only.py ├── train_image_only_with_unet.py ├── train_latent_only.py ├── train_latent_with_text.py ├── train_mnist.py ├── train_mnist_vae.py ├── train_mnist_with_unet.py ├── train_text_only.py ├── train_toy.py ├── transfusion.png └── transfusion_pytorch ├── __init__.py └── transfusion.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests the examples in README 2 | on: [push, pull_request] 3 | 4 | env: 5 | TYPECHECK: True 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Install Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.11" 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install uv 19 | python -m uv pip install --upgrade pip 20 | python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu 21 | python -m uv pip install -e .[test] 22 | - name: Test with pytest 23 | run: | 24 | python -m pytest tests/ 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/mnist/ 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Transfusion - Pytorch 4 | 5 | Pytorch implementation of [Transfusion](https://www.arxiv.org/abs/2408.11039), "Predict the Next Token and Diffuse Images with One Multi-Modal Model", from MetaAI. 6 | 7 | In this repo, we will substitute diffusion with flow matching given the success of Flux from Black Forest Labs (but will keep the original paper title given Transflow does not have the same ring). This repository will also attempt to extend to any number of modalities. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install transfusion-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | One modality, say images 18 | 19 | ```python 20 | from torch import randint, randn 21 | from transfusion_pytorch import Transfusion 22 | 23 | model = Transfusion( 24 | num_text_tokens = 256, 25 | dim_latent = 384, 26 | modality_default_shape = (4,), # fallback, in the case the language model did not produce a valid modality shape 27 | transformer = dict( 28 | dim = 512, 29 | depth = 8 30 | ) 31 | ) 32 | 33 | # any torch.long is text, torch.float is modalities 34 | 35 | text_and_images = [ 36 | [randint(0, 256, (16,)), randn(4, 384), randint(0, 256, (8,)), randn(6, 384)], 37 | [randint(0, 256, (16,)), randn(7, 384), randint(0, 256, (5,)), randn(2, 384), randint(0, 256, (9,))] 38 | ] 39 | 40 | loss = model(text_and_images) 41 | 42 | loss.backward() 43 | 44 | # after much training 45 | 46 | one_multimodal_sample = model.sample() 47 | ``` 48 | 49 | Multiple different modalities 50 | 51 | ```python 52 | from torch import randint, randn 53 | from transfusion_pytorch import Transfusion 54 | 55 | model = Transfusion( 56 | num_text_tokens = 256, 57 | dim_latent = (384, 192), # specify multiple latent dimensions 58 | modality_default_shape = ((4,), (2,)), # default shapes for first and second modality 59 | transformer = dict( 60 | dim = 512, 61 | depth = 8 62 | ) 63 | ) 64 | 65 | # then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position 66 | 67 | # any torch.long is text, torch.float is modalities 68 | 69 | text_images_and_audio = [ 70 | [randint(0, 256, (16,)), (0, randn(4, 384)), randint(0, 256, (8,)), (1, randn(6, 192))], 71 | [randint(0, 256, (16,)), randn(7, 384), randint(0, 256, (5,)), (1, randn(2, 192)), randint(0, 256, (9,))] 72 | ] 73 | 74 | loss = model(text_images_and_audio) 75 | 76 | loss.backward() 77 | 78 | # after much training 79 | 80 | one_multimodal_sample = model.sample() 81 | ``` 82 | 83 | Automatically taking care of encoding and decoding of images 84 | 85 | ```python 86 | import torch 87 | from torch import nn, randint, randn 88 | from transfusion_pytorch import Transfusion, print_modality_sample 89 | 90 | mock_encoder = nn.Conv2d(3, 384, 3, padding = 1) 91 | mock_decoder = nn.Conv2d(384, 3, 3, padding = 1) 92 | 93 | model = Transfusion( 94 | num_text_tokens = 12, 95 | dim_latent = 384, 96 | channel_first_latent = True, 97 | modality_default_shape = (4, 4), 98 | modality_encoder = mock_encoder, 99 | modality_decoder = mock_decoder, 100 | transformer = dict( 101 | dim = 512, 102 | depth = 8 103 | ) 104 | ) 105 | 106 | text_and_images = [ 107 | [ 108 | randint(0, 12, (16,)), # 16 text tokens 109 | randn(3, 8, 8), # (8 x 8) 3 channeled image 110 | randint(0, 12, (8,)), # 8 text tokens 111 | randn(3, 7, 7) # (7 x 7) 3 channeled image 112 | ], 113 | [ 114 | randint(0, 12, (16,)), # 16 text tokens 115 | randn(3, 8, 5), # (8 x 5) 3 channeled image 116 | randint(0, 12, (5,)), # 5 text tokens 117 | randn(3, 2, 16), # (2 x 16) 3 channeled image 118 | randint(0, 12, (9,)) # 9 text tokens 119 | ] 120 | ] 121 | 122 | loss = model(text_and_images) 123 | 124 | loss.backward() 125 | 126 | # after much training 127 | 128 | one_multimodal_sample = model.sample() 129 | 130 | print_modality_sample(one_multimodal_sample) 131 | ``` 132 | 133 | To pretrain on language first, just pass in your text as type `Int['batch seq']` 134 | 135 | ```python 136 | import torch 137 | from transfusion_pytorch import Transfusion 138 | 139 | model = Transfusion( 140 | num_text_tokens = 256, 141 | dim_latent = 384, 142 | transformer = dict( 143 | dim = 512, 144 | depth = 8, 145 | ) 146 | ).cuda() 147 | 148 | text = torch.randint(0, 256, (2, 1024)).cuda() 149 | 150 | loss = model(text) 151 | loss.backward() 152 | 153 | # after much training 154 | 155 | sampled = model.generate_text_only(text[:, :1], 1024) 156 | ``` 157 | 158 | ## Examples 159 | 160 | To run any of the examples `train_{example_name}.py` in the project root, simply install dependencies first as so 161 | 162 | ```bash 163 | $ pip install .[examples] 164 | ``` 165 | 166 | If you run into some weird error with `safetensors`, run this too 167 | 168 | ```bash 169 | $ pip install -U diffusers transformers accelerate scipy ftfy safetensors 170 | ``` 171 | 172 | ## Citations 173 | 174 | ```bibtex 175 | @inproceedings{Zhou2024TransfusionPT, 176 | title = {Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model}, 177 | author = {Chunting Zhou and Lili Yu and Arun Babu and Kushal Tirumala and Michihiro Yasunaga and Leonid Shamis and Jacob Kahn and Xuezhe Ma and Luke Zettlemoyer and Omer Levy}, 178 | year = {2024}, 179 | url = {https://api.semanticscholar.org/CorpusID:271909855} 180 | } 181 | ``` 182 | 183 | ```bibtex 184 | @misc{Rubin2024, 185 | author = {Ohad Rubin}, 186 | url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950} 187 | } 188 | ``` 189 | 190 | ```bibtex 191 | @article{Nguyen2024MinPS, 192 | title = {Min P Sampling: Balancing Creativity and Coherence at High Temperature}, 193 | author = {Minh Nguyen and Andrew Baker and Andreas Kirsch and Clement Neo}, 194 | journal = {ArXiv}, 195 | year = {2024}, 196 | volume = {abs/2407.01082}, 197 | url = {https://api.semanticscholar.org/CorpusID:270870613} 198 | } 199 | ``` 200 | 201 | ```bibtex 202 | @article{Bao2022AllAW, 203 | title = {All are Worth Words: A ViT Backbone for Diffusion Models}, 204 | author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu}, 205 | journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 206 | year = {2022}, 207 | pages = {22669-22679}, 208 | url = {https://api.semanticscholar.org/CorpusID:253581703} 209 | } 210 | ``` 211 | 212 | ```bibtex 213 | @inproceedings{Zhao2024MonoFormerOT, 214 | title = {MonoFormer: One Transformer for Both Diffusion and Autoregression}, 215 | author = {Chuyang Zhao and Yuxing Song and Wenhao Wang and Haocheng Feng and Errui Ding and Yifan Sun and Xinyan Xiao and Jingdong Wang}, 216 | year = {2024}, 217 | url = {https://api.semanticscholar.org/CorpusID:272832492} 218 | } 219 | ``` 220 | 221 | ```bibtex 222 | @article{Yang2024ConsistencyFM, 223 | title = {Consistency Flow Matching: Defining Straight Flows with Velocity Consistency}, 224 | author = {Ling Yang and Zixiang Zhang and Zhilong Zhang and Xingchao Liu and Minkai Xu and Wentao Zhang and Chenlin Meng and Stefano Ermon and Bin Cui}, 225 | journal = {ArXiv}, 226 | year = {2024}, 227 | volume = {abs/2407.02398}, 228 | url = {https://api.semanticscholar.org/CorpusID:270878436} 229 | } 230 | ``` 231 | 232 | ```bibtex 233 | @inproceedings{Zhou2024ValueRL, 234 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 235 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 236 | year = {2024}, 237 | url = {https://api.semanticscholar.org/CorpusID:273532030} 238 | } 239 | ``` 240 | 241 | ```bibtex 242 | @inproceedings{Duvvuri2024LASERAW, 243 | title = {LASER: Attention with Exponential Transformation}, 244 | author = {Sai Surya Duvvuri and Inderjit S. Dhillon}, 245 | year = {2024}, 246 | url = {https://api.semanticscholar.org/CorpusID:273849947} 247 | } 248 | ``` 249 | 250 | ```bibtex 251 | @inproceedings{Dong2024HymbaAH, 252 | title = {Hymba: A Hybrid-head Architecture for Small Language Models}, 253 | author = {Xin Dong and Y. Fu and Shizhe Diao and Wonmin Byeon and Zijia Chen and Ameya Mahabaleshwarkar and Shih-Yang Liu and Matthijs Van Keirsbilck and Min-Hung Chen and Yoshi Suhara and Yingyan Lin and Jan Kautz and Pavlo Molchanov}, 254 | year = {2024}, 255 | url = {https://api.semanticscholar.org/CorpusID:274166163} 256 | ``` 257 | 258 | ```bibtex 259 | @article{Zhu2024HyperConnections, 260 | title = {Hyper-Connections}, 261 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 262 | journal = {ArXiv}, 263 | year = {2024}, 264 | volume = {abs/2409.19606}, 265 | url = {https://api.semanticscholar.org/CorpusID:272987528} 266 | } 267 | ``` 268 | 269 | ```bibtex 270 | @article{Zhu2025FracConnectionsFE, 271 | title = {Frac-Connections: Fractional Extension of Hyper-Connections}, 272 | author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou}, 273 | journal = {ArXiv}, 274 | year = {2025}, 275 | volume = {abs/2503.14125}, 276 | url = {https://api.semanticscholar.org/CorpusID:277104144} 277 | } 278 | ``` 279 | -------------------------------------------------------------------------------- /data/enwik8/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/transfusion-pytorch/ea045706cc45e2cf5cd1ce4523e5af6a49eff54a/data/enwik8/enwik8.gz -------------------------------------------------------------------------------- /data/flowers/labels.txt: -------------------------------------------------------------------------------- 1 | pink primrose 2 | hard-leaved pocket orchid 3 | canterbury bells 4 | sweet pea 5 | english marigold 6 | tiger lily 7 | moon orchid 8 | bird of paradise 9 | monkshood 10 | globe thistle 11 | snapdragon 12 | colt's foot 13 | king protea 14 | spear thistle 15 | yellow iris 16 | globe-flower 17 | purple coneflower 18 | peruvian lily 19 | balloon flower 20 | giant white arum lily 21 | fire lily 22 | pincushion flower 23 | fritillary 24 | red ginger 25 | grape hyacinth 26 | corn poppy 27 | prince of wales feathers 28 | stemless gentian 29 | artichoke 30 | sweet william 31 | carnation 32 | garden phlox 33 | love in the mist 34 | mexican aster 35 | alpine sea holly 36 | ruby-lipped cattleya 37 | cape flower 38 | great masterwort 39 | siam tulip 40 | lenten rose 41 | barbeton daisy 42 | daffodil 43 | sword lily 44 | poinsettia 45 | bolero deep blue 46 | wallflower 47 | marigold 48 | buttercup 49 | oxeye daisy 50 | common dandelion 51 | petunia 52 | wild pansy 53 | primula 54 | sunflower 55 | pelargonium 56 | bishop of llandaff 57 | gaura 58 | geranium 59 | orange dahlia 60 | pink-yellow dahlia? 61 | cautleya spicata 62 | japanese anemone 63 | black-eyed susan 64 | silverbush 65 | californian poppy 66 | osteospermum 67 | spring crocus 68 | bearded iris 69 | windflower 70 | tree poppy 71 | gazania 72 | azalea 73 | water lily 74 | rose 75 | thorn apple 76 | morning glory 77 | passion flower 78 | lotus 79 | toad lily 80 | anthurium 81 | frangipani 82 | clematis 83 | hibiscus 84 | columbine 85 | desert-rose 86 | tree mallow 87 | magnolia 88 | cyclamen 89 | watercress 90 | canna lily 91 | hippeastrum 92 | bee balm 93 | ball moss 94 | foxglove 95 | bougainvillea 96 | camellia 97 | mallow 98 | mexican petunia 99 | bromelia 100 | blanket flower 101 | trumpet creeper 102 | blackberry lily -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "transfusion-pytorch" 3 | version = "0.11.0" 4 | description = "Transfusion in Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'attention mechanism', 16 | 'rectified flow', 17 | ] 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Developers', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3.8', 24 | ] 25 | 26 | dependencies = [ 27 | 'axial-positional-embedding>=0.3.5', 28 | 'beartype', 29 | 'einx>=0.3.0', 30 | 'einops>=0.8.0', 31 | 'ema-pytorch', 32 | 'hyper-connections>=0.0.10', 33 | 'jaxtyping', 34 | 'loguru', 35 | 'rotary_embedding_torch>=0.8.4', 36 | 'torchdiffeq', 37 | 'torch>=2.0', 38 | 'tqdm' 39 | ] 40 | 41 | [project.urls] 42 | Homepage = "https://pypi.org/project/transfusion-pytorch/" 43 | Repository = "https://github.com/lucidrains/transfusion-pytorch" 44 | 45 | [build-system] 46 | requires = ["hatchling"] 47 | build-backend = "hatchling.build" 48 | 49 | [project.optional-dependencies] 50 | 51 | examples = [ 52 | "datasets", 53 | "diffusers" 54 | ] 55 | test = [ 56 | "pytest", 57 | ] 58 | 59 | [tool.ruff] 60 | line-length = 1000 61 | 62 | lint.ignore = [ 63 | "F722", # for jaxtyping shape annotation 64 | "F401", 65 | "F821", 66 | "E402" 67 | ] 68 | 69 | [tool.pytest.ini_options] 70 | pythonpath = [ 71 | "." 72 | ] 73 | 74 | [tool.hatch.metadata] 75 | allow-direct-references = true 76 | 77 | [tool.hatch.build.targets.wheel] 78 | packages = ["transfusion_pytorch"] 79 | -------------------------------------------------------------------------------- /tests/test_transfusion.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from functools import partial 3 | from copy import deepcopy 4 | 5 | import torch 6 | from torch import nn, randint, randn, tensor, cuda 7 | 8 | from einops import rearrange 9 | 10 | import torch._dynamo 11 | torch._dynamo.config.suppress_errors = True 12 | 13 | cuda_available = cuda.is_available() 14 | 15 | from transfusion_pytorch.transfusion import ( 16 | Transfusion, 17 | flex_attention, 18 | exists, 19 | stack_same_shape_tensors_with_inverse, 20 | filter_with_inverse, 21 | apply_fn_modality_type 22 | ) 23 | 24 | @pytest.mark.parametrize('cache_kv', (False, True)) 25 | @pytest.mark.parametrize('use_flex_attn', (False, True)) 26 | @pytest.mark.parametrize('num_residual_streams', (1, 4)) 27 | @pytest.mark.parametrize('reconstruction_loss_weight', (0., 0.1)) 28 | def test_transfusion( 29 | cache_kv: bool, 30 | use_flex_attn: bool, 31 | num_residual_streams: int, 32 | reconstruction_loss_weight: float 33 | ): 34 | 35 | if use_flex_attn and (not exists(flex_attention) or not cuda_available): 36 | return pytest.skip() 37 | 38 | text_tokens = 8 39 | randint_ = partial(randint, 0, text_tokens) 40 | 41 | model = Transfusion( 42 | num_text_tokens = text_tokens, 43 | dim_latent = (384, 192), 44 | modality_default_shape = ((32,), (64,)), 45 | reconstruction_loss_weight = reconstruction_loss_weight, 46 | transformer = dict( 47 | dim = 64, 48 | depth = 2, 49 | use_flex_attn = use_flex_attn, 50 | num_residual_streams = num_residual_streams 51 | ) 52 | ) 53 | 54 | if use_flex_attn: 55 | model = model.cuda() 56 | 57 | # then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position 58 | 59 | text_images_and_audio = [ 60 | [randint_((16,)), (0, randn(4, 384)), randint_((8,)), (1, randn(6, 192))], 61 | [randint_((16,)), randn(7, 384), randint_((5,)), (1, randn(2, 192)), randint_((9,))] 62 | ] 63 | 64 | loss = model(text_images_and_audio) 65 | 66 | loss.backward() 67 | 68 | # after much training 69 | 70 | prime = [tensor(model.som_ids[0])] 71 | 72 | one_multimodal_sample = model.sample(prime, max_length = 128, cache_kv = cache_kv) 73 | 74 | 75 | @pytest.mark.parametrize('use_flex_attn', (False, True)) 76 | def test_auto_modality_transform( 77 | use_flex_attn: bool 78 | ): 79 | 80 | if use_flex_attn and (not exists(flex_attention) or not cuda_available): 81 | return pytest.skip() 82 | 83 | text_tokens = 8 84 | randint_ = partial(randint, 0, text_tokens) 85 | 86 | model = Transfusion( 87 | num_text_tokens = text_tokens, 88 | dim_latent = 384, 89 | channel_first_latent = True, 90 | modality_default_shape = (2, 2), 91 | transformer = dict( 92 | dim = 64, 93 | depth = 2, 94 | use_flex_attn = use_flex_attn 95 | ) 96 | ) 97 | 98 | text_and_images = [ 99 | [randint_((16,)), randn(384, 2, 2), randint_((8,)), randn(384, 2, 2)], 100 | [randint_((16,)), randn(384, 2, 2), randint_((5,)), randn(384, 2, 2), randint_((9,))] 101 | ] 102 | 103 | loss = model(text_and_images) 104 | 105 | loss.backward() 106 | 107 | # after much training 108 | 109 | prime = [tensor(model.som_ids[0])] 110 | 111 | one_multimodal_sample = model.sample(prime, max_length = 128) 112 | 113 | @pytest.mark.parametrize('use_flex_attn', (False, True)) 114 | @pytest.mark.parametrize('return_loss', (False, True)) 115 | def test_text( 116 | use_flex_attn: bool, 117 | return_loss: bool 118 | ): 119 | 120 | if use_flex_attn and (not exists(flex_attention) or not cuda_available): 121 | return pytest.skip() 122 | 123 | model = Transfusion( 124 | num_text_tokens = 256, 125 | dim_latent = 384, 126 | channel_first_latent = True, 127 | modality_default_shape = (32,), 128 | transformer = dict( 129 | dim = 64, 130 | depth = 2, 131 | use_flex_attn = use_flex_attn 132 | ) 133 | ) 134 | 135 | if use_flex_attn: 136 | model = model.cuda() 137 | 138 | text = randint(0, 256, (2, 1024)) 139 | 140 | model(text, return_loss = return_loss) 141 | 142 | @pytest.mark.parametrize('channel_first', (False, True)) 143 | def test_modality_only( 144 | channel_first: bool 145 | ): 146 | 147 | model = Transfusion( 148 | num_text_tokens = 256, 149 | dim_latent = (384, 192), 150 | channel_first_latent = channel_first, 151 | modality_default_shape = (32,), 152 | transformer = dict( 153 | dim = 64, 154 | depth = 2, 155 | use_flex_attn = False 156 | ) 157 | ) 158 | 159 | images = randn(2, 8, 8, 192) 160 | 161 | if channel_first: 162 | images = rearrange(images, 'b ... d -> b d ...') 163 | 164 | loss = model(images, return_loss = True, modality_type = 1) 165 | 166 | loss.backward() 167 | 168 | model.generate_modality_only(modality_type = 1) 169 | 170 | @pytest.mark.parametrize('custom_time_fn', (False, True)) 171 | def test_text_image_end_to_end( 172 | custom_time_fn: bool 173 | ): 174 | mock_vae_encoder = nn.Conv2d(3, 384, 3, padding = 1) 175 | mock_vae_decoder = nn.Conv2d(384, 3, 3, padding = 1) 176 | 177 | model = Transfusion( 178 | num_text_tokens = 4, 179 | dim_latent = 384, 180 | channel_first_latent = True, 181 | modality_default_shape = ((4, 4),), 182 | modality_encoder = mock_vae_encoder, 183 | modality_decoder = mock_vae_decoder, 184 | transformer = dict( 185 | dim = 64, 186 | depth = 2 187 | ) 188 | ) 189 | 190 | text_and_images = [ 191 | [ 192 | randint(0, 4, (16,)), 193 | randn(3, 8, 8), 194 | randint(0, 4, (8,)), 195 | randn(3, 7, 7) 196 | ], 197 | [ 198 | randint(0, 4, (16,)), 199 | randn(3, 8, 5), 200 | randint(0, 4, (5,)), 201 | randn(3, 2, 16), 202 | randint(0, 4, (9,)) 203 | ] 204 | ] 205 | 206 | # allow researchers to experiment with different time distributions across multiple modalities in a sample 207 | 208 | def num_modalities_to_times(num_modalities): 209 | batch = num_modalities.shape[0] 210 | device = num_modalities.device 211 | total_modalities = num_modalities.amax().item() 212 | return torch.ones((batch, total_modalities), device = device) 213 | 214 | time_fn = num_modalities_to_times if custom_time_fn else None 215 | 216 | # forward 217 | 218 | loss = model( 219 | text_and_images, 220 | num_modalities_to_times_fn = time_fn 221 | ) 222 | 223 | loss.backward() 224 | 225 | # after much training 226 | 227 | one_multimodal_sample = model.sample(max_length = 128) 228 | 229 | def test_velocity_consistency(): 230 | mock_encoder = nn.Conv2d(3, 384, 3, padding = 1) 231 | mock_decoder = nn.Conv2d(384, 3, 3, padding = 1) 232 | 233 | model = Transfusion( 234 | num_text_tokens = 12, 235 | dim_latent = 384, 236 | channel_first_latent = True, 237 | modality_default_shape = (4, 4), 238 | modality_encoder = mock_encoder, 239 | modality_decoder = mock_decoder, 240 | transformer = dict( 241 | dim = 64, 242 | depth = 1 243 | ) 244 | ) 245 | 246 | ema_model = deepcopy(model) 247 | 248 | text_and_images = [ 249 | [ 250 | randint(0, 12, (16,)), 251 | randn(3, 8, 8), 252 | randint(0, 12, (8,)), 253 | randn(3, 7, 7) 254 | ], 255 | [ 256 | randint(0, 12, (16,)), 257 | randn(3, 8, 5), 258 | randint(0, 12, (5,)), 259 | randn(3, 2, 16), 260 | randint(0, 12, (9,)) 261 | ] 262 | ] 263 | 264 | loss, breakdown = model( 265 | text_and_images, 266 | velocity_consistency_ema_model = ema_model, 267 | return_breakdown = True 268 | ) 269 | 270 | loss.backward() 271 | 272 | assert exists(breakdown.velocity) 273 | 274 | def test_axial_pos_emb(): 275 | model = Transfusion( 276 | num_text_tokens = 256, 277 | dim_latent = (384, 192), # specify multiple latent dimensions 278 | modality_default_shape = ((2, 2), (2,)), # default shapes for first and second modality 279 | fallback_to_default_shape_if_invalid = True, 280 | add_pos_emb = True, 281 | modality_num_dim = (2, 1), 282 | transformer = dict( 283 | dim = 64, 284 | depth = 8 285 | ) 286 | ) 287 | 288 | # then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position 289 | 290 | # any torch.long is text, torch.float is modalities 291 | 292 | text_images_and_audio = [ 293 | [randint(0, 256, (16,)), (0, randn(2, 3, 384)), randint(0, 256, (8,)), (1, randn(6, 192))], 294 | [randint(0, 256, (16,)), randn(1, 4, 384), randint(0, 256, (5,)), (1, randn(2, 192)), randint(0, 256, (9,))] 295 | ] 296 | 297 | loss = model(text_images_and_audio) 298 | 299 | loss.backward() 300 | 301 | # after much training 302 | 303 | one_multimodal_sample = model.sample(max_length = 128) 304 | 305 | # unet related 306 | 307 | def test_modality_only_with_unet(): 308 | 309 | model = Transfusion( 310 | num_text_tokens = 10, 311 | dim_latent = 4, 312 | modality_default_shape = (14, 14), 313 | pre_post_transformer_enc_dec = ( 314 | nn.Conv2d(4, 64, 3, 2, 1), 315 | nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1), 316 | ), 317 | channel_first_latent = True, 318 | add_pos_emb = True, 319 | modality_num_dim = 2, 320 | velocity_consistency_loss_weight = 0.1, 321 | transformer = dict( 322 | dim = 64, 323 | depth = 1, 324 | dim_head = 32, 325 | heads = 8 326 | ) 327 | ) 328 | 329 | x = torch.randn(1, 4, 14, 14) 330 | 331 | loss = model(x) 332 | loss.backward() 333 | 334 | sampled = model.generate_modality_only() 335 | 336 | def test_stack_similar_shape_fn(): 337 | from torch import zeros 338 | 339 | data = [ 340 | zeros(3, 5), 341 | zeros(2, 3), 342 | zeros(3, 5), 343 | zeros(2, 3), 344 | zeros(4, 5), 345 | zeros(4, 5) 346 | ] 347 | 348 | plus_one = lambda x: x + 1 349 | 350 | data = [d + i for i, d in enumerate(data)] 351 | data_plus_one = [plus_one(d) for d in data] 352 | 353 | stacked_tensors, inverse = stack_same_shape_tensors_with_inverse(data) 354 | 355 | stacked_tensors = {k: plus_one(v) for k, v in stacked_tensors.items()} 356 | 357 | batch_processed_data_plus_one = inverse(stacked_tensors) 358 | 359 | assert all([torch.allclose(tensor1, tensor2) for tensor1, tensor2 in zip(data_plus_one, batch_processed_data_plus_one)]) 360 | 361 | def test_filter_with_inverse(): 362 | x = [0, 1, 2, 3, 4] 363 | is_even = lambda el: (el % 2) == 0 364 | 365 | x_even, inverse = filter_with_inverse(is_even, x) 366 | x_even_times_ten = [el * 10 for el in x_even] 367 | 368 | y = inverse(x_even_times_ten) 369 | assert y == [0, 1, 20, 3, 40] 370 | 371 | def test_apply_fn_modality_type(): 372 | from torch import zeros 373 | 374 | modalities = [ 375 | [zeros(3, 5)], 376 | [zeros(1, 5)], 377 | [(1, zeros(3, 5))], 378 | [(1, zeros(2, 5))], 379 | [(0, zeros(1, 5)), (1, zeros(3, 5))], 380 | ] 381 | 382 | modalities = apply_fn_modality_type(lambda x: x + 1, modalities) 383 | 384 | modalities = apply_fn_modality_type(lambda x: x + 2, modalities, modality_type = 1) 385 | 386 | assert (modalities[0][0][-1] == 1).all() 387 | assert (modalities[2][0][-1] == 2).all() 388 | 389 | 390 | def test_zero_dimensional(): 391 | 392 | model = Transfusion( 393 | num_text_tokens = 256, 394 | dim_latent = 384, 395 | modality_default_shape = (), 396 | transformer = dict( 397 | dim = 512, 398 | depth = 8, 399 | num_residual_streams = 1 400 | ) 401 | ) 402 | 403 | # any torch.long is text, torch.float is modalities 404 | 405 | text_and_embeds = [ 406 | [randint(0, 256, (16,)), randn(384), randint(0, 256, (8,)), randn(384)], 407 | [randint(0, 256, (16,)), randn(384), randint(0, 256, (5,)), randn(384), randint(0, 256, (9,))] 408 | ] 409 | 410 | loss = model(text_and_embeds) 411 | 412 | loss.backward() 413 | 414 | # after much training 415 | 416 | one_multimodal_sample = model.sample(prompt = randn(384)) 417 | -------------------------------------------------------------------------------- /train_image_only.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import tensor 6 | from torch.nn import Module 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import Adam 9 | 10 | from einops import rearrange 11 | 12 | import torchvision 13 | import torchvision.transforms as T 14 | from torchvision.utils import save_image 15 | 16 | from transfusion_pytorch import Transfusion, print_modality_sample 17 | 18 | rmtree('./results', ignore_errors = True) 19 | results_folder = Path('./results') 20 | results_folder.mkdir(exist_ok = True, parents = True) 21 | 22 | # functions 23 | 24 | def divisible_by(num, den): 25 | return (num % den) == 0 26 | 27 | # encoder / decoder 28 | 29 | class Encoder(Module): 30 | def forward(self, x): 31 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... h w (p1 p2)', p1 = 2, p2 = 2) 32 | return x * 2 - 1 33 | 34 | class Decoder(Module): 35 | def forward(self, x): 36 | x = rearrange(x, '... h w (p1 p2) -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14) 37 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.) 38 | 39 | model = Transfusion( 40 | num_text_tokens = 10, 41 | dim_latent = 4, 42 | channel_first_latent = False, 43 | modality_default_shape = (14, 14), 44 | modality_encoder = Encoder(), 45 | modality_decoder = Decoder(), 46 | add_pos_emb = True, 47 | modality_num_dim = 2, 48 | velocity_consistency_loss_weight = 0.1, 49 | reconstruction_loss_weight = 0.1, 50 | transformer = dict( 51 | dim = 64, 52 | depth = 4, 53 | dim_head = 32, 54 | heads = 8, 55 | attn_laser = True 56 | ) 57 | ).cuda() 58 | 59 | ema_model = model.create_ema() 60 | 61 | class MnistDataset(Dataset): 62 | def __init__(self): 63 | self.mnist = torchvision.datasets.MNIST( 64 | './data', 65 | download = True 66 | ) 67 | 68 | def __len__(self): 69 | return len(self.mnist) 70 | 71 | def __getitem__(self, idx): 72 | pil, labels = self.mnist[idx] 73 | digit_tensor = T.PILToTensor()(pil) 74 | return (digit_tensor / 255).float() 75 | 76 | def cycle(iter_dl): 77 | while True: 78 | for batch in iter_dl: 79 | yield batch 80 | 81 | dataset = MnistDataset() 82 | 83 | dataloader = DataLoader(dataset, batch_size = 32, shuffle = True) 84 | iter_dl = cycle(dataloader) 85 | 86 | optimizer = Adam(model.parameters(), lr = 8e-4) 87 | 88 | # train loop 89 | 90 | for step in range(1, 100_000 + 1): 91 | 92 | loss = model(next(iter_dl), velocity_consistency_ema_model = ema_model) 93 | loss.backward() 94 | 95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 96 | 97 | optimizer.step() 98 | optimizer.zero_grad() 99 | 100 | ema_model.update() 101 | 102 | print(f'{step}: {loss.item():.3f}') 103 | 104 | if divisible_by(step, 500): 105 | image = ema_model.generate_modality_only(batch_size = 64) 106 | 107 | save_image( 108 | rearrange(image, '(gh gw) 1 h w -> 1 (gh h) (gw w)', gh = 8).detach().cpu(), 109 | str(results_folder / f'{step}.png') 110 | ) 111 | -------------------------------------------------------------------------------- /train_image_only_with_unet.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import tensor, nn 6 | from torch.nn import Module 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import Adam 9 | 10 | from einops import rearrange 11 | 12 | import torchvision 13 | import torchvision.transforms as T 14 | from torchvision.utils import save_image 15 | 16 | from transfusion_pytorch import Transfusion, print_modality_sample 17 | 18 | rmtree('./results', ignore_errors = True) 19 | results_folder = Path('./results') 20 | results_folder.mkdir(exist_ok = True, parents = True) 21 | 22 | # functions 23 | 24 | def divisible_by(num, den): 25 | return (num % den) == 0 26 | 27 | # encoder / decoder 28 | 29 | class Encoder(Module): 30 | def forward(self, x): 31 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... (p1 p2) h w', p1 = 2, p2 = 2) 32 | return x * 2 - 1 33 | 34 | class Decoder(Module): 35 | def forward(self, x): 36 | x = rearrange(x, '... (p1 p2) h w -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14) 37 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.) 38 | 39 | model = Transfusion( 40 | num_text_tokens = 10, 41 | dim_latent = 4, 42 | channel_first_latent = True, 43 | modality_default_shape = (14, 14), 44 | modality_encoder = Encoder(), 45 | modality_decoder = Decoder(), 46 | pre_post_transformer_enc_dec = ( 47 | nn.Conv2d(4, 64, 3, 2, 1), 48 | nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1), 49 | ), 50 | add_pos_emb = True, 51 | modality_num_dim = 2, 52 | velocity_consistency_loss_weight = 0.1, 53 | transformer = dict( 54 | dim = 64, 55 | depth = 4, 56 | dim_head = 32, 57 | heads = 8 58 | ) 59 | ).cuda() 60 | 61 | ema_model = model.create_ema() 62 | 63 | class MnistDataset(Dataset): 64 | def __init__(self): 65 | self.mnist = torchvision.datasets.MNIST( 66 | './data', 67 | download = True 68 | ) 69 | 70 | def __len__(self): 71 | return len(self.mnist) 72 | 73 | def __getitem__(self, idx): 74 | pil, labels = self.mnist[idx] 75 | digit_tensor = T.PILToTensor()(pil) 76 | return (digit_tensor / 255).float() 77 | 78 | def cycle(iter_dl): 79 | while True: 80 | for batch in iter_dl: 81 | yield batch 82 | 83 | dataset = MnistDataset() 84 | 85 | dataloader = DataLoader(dataset, batch_size = 32, shuffle = True) 86 | iter_dl = cycle(dataloader) 87 | 88 | optimizer = Adam(model.parameters(), lr = 8e-4) 89 | 90 | # train loop 91 | 92 | for step in range(1, 100_000 + 1): 93 | 94 | loss = model(next(iter_dl), velocity_consistency_ema_model = ema_model) 95 | loss.backward() 96 | 97 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 98 | 99 | optimizer.step() 100 | optimizer.zero_grad() 101 | 102 | ema_model.update() 103 | 104 | print(f'{step}: {loss.item():.3f}') 105 | 106 | if divisible_by(step, 500): 107 | image = ema_model.generate_modality_only(batch_size = 64) 108 | 109 | save_image( 110 | rearrange(image, '(gh gw) 1 h w -> 1 (gh h) (gw w)', gh = 8).detach().cpu(), 111 | str(results_folder / f'{step}.png') 112 | ) 113 | -------------------------------------------------------------------------------- /train_latent_only.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import tensor 6 | from torch.nn import Module 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import Adam 9 | 10 | from einops import rearrange 11 | 12 | import torchvision 13 | import torchvision.transforms as T 14 | from torchvision.utils import save_image 15 | 16 | from transfusion_pytorch import Transfusion, print_modality_sample 17 | 18 | # hf related 19 | 20 | from datasets import load_dataset 21 | from diffusers.models import AutoencoderKL 22 | 23 | vae = AutoencoderKL.from_pretrained("./path/to/your/autoencoder", subfolder = "vae") 24 | 25 | class Encoder(Module): 26 | def __init__(self, vae): 27 | super().__init__() 28 | self.vae = vae 29 | 30 | def forward(self, image): 31 | with torch.no_grad(): 32 | latent = self.vae.encode(image * 2 - 1) 33 | 34 | return 0.18215 * latent.latent_dist.sample() 35 | 36 | class Decoder(Module): 37 | def __init__(self, vae): 38 | super().__init__() 39 | self.vae = vae 40 | 41 | def forward(self, latents): 42 | latents = (1 / 0.18215) * latents 43 | 44 | with torch.no_grad(): 45 | image = self.vae.decode(latents).sample 46 | 47 | return (image / 2 + 0.5).clamp(0, 1) 48 | 49 | # results folder 50 | 51 | rmtree('./results', ignore_errors = True) 52 | results_folder = Path('./results') 53 | results_folder.mkdir(exist_ok = True, parents = True) 54 | 55 | # constants 56 | 57 | SAMPLE_EVERY = 100 58 | 59 | # functions 60 | 61 | def divisible_by(num, den): 62 | return (num % den) == 0 63 | 64 | # encoder / decoder 65 | 66 | model = Transfusion( 67 | num_text_tokens = 10, 68 | dim_latent = 4, 69 | channel_first_latent = True, 70 | modality_default_shape = (32, 32), 71 | modality_encoder = Encoder(vae), 72 | modality_decoder = Decoder(vae), 73 | add_pos_emb = True, 74 | modality_num_dim = 2, 75 | velocity_consistency_loss_weight = 0.1, 76 | reconstruction_loss_weight = 0.1, 77 | transformer = dict( 78 | dim = 256, 79 | depth = 8, 80 | dim_head = 64, 81 | heads = 8 82 | ) 83 | ).cuda() 84 | 85 | ema_model = model.create_ema(0.9) 86 | 87 | class FlowersDataset(Dataset): 88 | def __init__(self, image_size): 89 | self.ds = load_dataset("nelorth/oxford-flowers")['train'] 90 | 91 | self.transform = T.Compose([ 92 | T.Resize((image_size, image_size)), 93 | T.PILToTensor() 94 | ]) 95 | 96 | def __len__(self): 97 | return len(self.ds) 98 | 99 | def __getitem__(self, idx): 100 | pil = self.ds[idx]['image'] 101 | tensor = self.transform(pil) 102 | return tensor / 255. 103 | 104 | def cycle(iter_dl): 105 | while True: 106 | for batch in iter_dl: 107 | yield batch 108 | 109 | dataset = FlowersDataset(256) 110 | 111 | dataloader = DataLoader(dataset, batch_size = 4, shuffle = True) 112 | 113 | iter_dl = cycle(dataloader) 114 | 115 | optimizer = Adam(model.parameters(), lr = 8e-4) 116 | 117 | # train loop 118 | 119 | for step in range(1, 100_000 + 1): 120 | 121 | for _ in range(4): 122 | loss = model.forward_modality(next(iter_dl)) 123 | (loss / 4).backward() 124 | 125 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 126 | 127 | optimizer.step() 128 | optimizer.zero_grad() 129 | 130 | ema_model.update() 131 | 132 | print(f'{step}: {loss.item():.3f}') 133 | 134 | if divisible_by(step, SAMPLE_EVERY): 135 | image = ema_model.generate_modality_only(batch_size = 4) 136 | 137 | save_image( 138 | rearrange(image, '(gh gw) c h w -> c (gh h) (gw w)', gh = 2).detach().cpu(), 139 | str(results_folder / f'{step}.png') 140 | ) 141 | -------------------------------------------------------------------------------- /train_latent_with_text.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import nn, tensor, Tensor 6 | from torch.nn import Module 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import Adam 9 | 10 | from einops import rearrange 11 | 12 | import torchvision 13 | import torchvision.transforms as T 14 | from torchvision.utils import save_image 15 | 16 | from transfusion_pytorch import Transfusion, print_modality_sample 17 | 18 | # hf related 19 | 20 | from datasets import load_dataset 21 | from diffusers.models import AutoencoderKL 22 | 23 | vae = AutoencoderKL.from_pretrained("./path/to/your/autoencoder", subfolder = "vae") 24 | 25 | class Encoder(Module): 26 | def __init__(self, vae): 27 | super().__init__() 28 | self.vae = vae 29 | 30 | def forward(self, image): 31 | with torch.no_grad(): 32 | latent = self.vae.encode(image * 2 - 1) 33 | 34 | return 0.18215 * latent.latent_dist.sample() 35 | 36 | class Decoder(Module): 37 | def __init__(self, vae): 38 | super().__init__() 39 | self.vae = vae 40 | 41 | def forward(self, latents): 42 | latents = (1 / 0.18215) * latents 43 | 44 | with torch.no_grad(): 45 | image = self.vae.decode(latents).sample 46 | 47 | return (image / 2 + 0.5).clamp(0, 1) 48 | 49 | # results folder 50 | 51 | rmtree('./results', ignore_errors = True) 52 | results_folder = Path('./results') 53 | results_folder.mkdir(exist_ok = True, parents = True) 54 | 55 | # constants 56 | 57 | SAMPLE_EVERY = 100 58 | 59 | with open("./data/flowers/labels.txt", "r") as file: 60 | content = file.read() 61 | 62 | LABELS_TEXT = content.split('\n') 63 | 64 | # functions 65 | 66 | def divisible_by(num, den): 67 | return (num % den) == 0 68 | 69 | def decode_token(token): 70 | return str(chr(max(32, token))) 71 | 72 | def decode_tokens(tokens: Tensor) -> str: 73 | return "".join(list(map(decode_token, tokens.tolist()))) 74 | 75 | def encode_tokens(str: str) -> Tensor: 76 | return tensor([*bytes(str, 'UTF-8')]) 77 | 78 | # encoder / decoder 79 | 80 | model = Transfusion( 81 | num_text_tokens = 256, 82 | dim_latent = 4, 83 | channel_first_latent = True, 84 | modality_default_shape = (8, 8), 85 | modality_encoder = Encoder(vae), 86 | modality_decoder = Decoder(vae), 87 | pre_post_transformer_enc_dec = ( 88 | nn.Conv2d(4, 128, 3, 2, 1), 89 | nn.ConvTranspose2d(128, 4, 3, 2, 1, output_padding = 1), 90 | ), 91 | add_pos_emb = False, 92 | modality_num_dim = 2, 93 | reconstruction_loss_weight = 0.1, 94 | transformer = dict( 95 | dim = 128, 96 | depth = 8, 97 | dim_head = 64, 98 | heads = 8, 99 | ) 100 | ).cuda() 101 | 102 | ema_model = model.create_ema(0.9) 103 | 104 | class FlowersDataset(Dataset): 105 | def __init__(self, image_size): 106 | self.ds = load_dataset("nelorth/oxford-flowers")['train'] 107 | 108 | self.transform = T.Compose([ 109 | T.Resize((image_size, image_size)), 110 | T.PILToTensor(), 111 | T.Lambda(lambda t: t / 255.) 112 | ]) 113 | 114 | def __len__(self): 115 | return len(self.ds) 116 | 117 | def __getitem__(self, idx): 118 | sample = self.ds[idx] 119 | pil = sample['image'] 120 | 121 | labels_int = sample['label'] 122 | labels_text = LABELS_TEXT[labels_int] 123 | 124 | tensor = self.transform(pil) 125 | return encode_tokens(labels_text), tensor 126 | 127 | def cycle(iter_dl): 128 | while True: 129 | for batch in iter_dl: 130 | yield batch 131 | 132 | dataset = FlowersDataset(128) 133 | 134 | dataloader = model.create_dataloader(dataset, batch_size = 4, shuffle = True) 135 | 136 | iter_dl = cycle(dataloader) 137 | 138 | optimizer = Adam(model.parameters(), lr = 8e-4) 139 | 140 | # train loop 141 | 142 | for step in range(1, 100_000 + 1): 143 | 144 | for _ in range(4): 145 | loss = model.forward(next(iter_dl)) 146 | (loss / 4).backward() 147 | 148 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 149 | 150 | optimizer.step() 151 | optimizer.zero_grad() 152 | 153 | ema_model.update() 154 | 155 | print(f'{step}: {loss.item():.3f}') 156 | 157 | if divisible_by(step, SAMPLE_EVERY): 158 | sample = ema_model.sample() 159 | 160 | print_modality_sample(sample) 161 | 162 | if len(sample) < 3: 163 | continue 164 | 165 | text_tensor, maybe_image, *_ = sample 166 | 167 | if not isinstance(maybe_image, tuple): 168 | continue 169 | 170 | _, image = maybe_image 171 | text_tensor = text_tensor[text_tensor < 256] # todo: offer a utility function for removing meta tags and special tokens 172 | 173 | text = decode_tokens(text_tensor) 174 | filename = str(results_folder / f'{step}.{text}.png') 175 | 176 | save_image( 177 | image.detach().cpu(), 178 | filename 179 | ) 180 | -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import tensor 6 | from torch.nn import Module 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import Adam 9 | 10 | from einops import rearrange 11 | 12 | import torchvision 13 | import torchvision.transforms as T 14 | from torchvision.utils import save_image 15 | 16 | from transfusion_pytorch.transfusion import Transfusion, print_modality_sample 17 | 18 | rmtree('./results', ignore_errors = True) 19 | results_folder = Path('./results') 20 | results_folder.mkdir(exist_ok = True, parents = True) 21 | 22 | # constants 23 | 24 | IMAGE_AFTER_TEXT = True # False for captioning, True for text-to-image 25 | USE_PROMPT = False # whether to use prompting, or synthesize from start token 26 | NUM_TRAIN_STEPS = 20_000 27 | SAMPLE_EVERY = 250 28 | CHANNEL_FIRST = True 29 | 30 | # functions 31 | 32 | def divisible_by(num, den): 33 | return (num % den) == 0 34 | 35 | # encoder / decoder 36 | 37 | class Encoder(Module): 38 | def forward(self, x): 39 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... h w (p1 p2)', p1 = 2, p2 = 2) 40 | 41 | if CHANNEL_FIRST: 42 | x = rearrange(x, 'b ... d -> b d ...') 43 | 44 | return x * 2 - 1 45 | 46 | class Decoder(Module): 47 | def forward(self, x): 48 | 49 | if CHANNEL_FIRST: 50 | x = rearrange(x, 'b d ... -> b ... d') 51 | 52 | x = rearrange(x, '... h w (p1 p2) -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2) 53 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.) 54 | 55 | model = Transfusion( 56 | num_text_tokens = 10, 57 | dim_latent = 4, 58 | modality_default_shape = (14, 14), 59 | modality_encoder = Encoder(), 60 | modality_decoder = Decoder(), 61 | add_pos_emb = True, 62 | modality_num_dim = 2, 63 | channel_first_latent = CHANNEL_FIRST, 64 | transformer = dict( 65 | dim = 64, 66 | depth = 4, 67 | dim_head = 32, 68 | heads = 8, 69 | ) 70 | ).cuda() 71 | 72 | ema_model = model.create_ema() 73 | 74 | class MnistDataset(Dataset): 75 | def __init__(self): 76 | self.mnist = torchvision.datasets.MNIST( 77 | './data/mnist', 78 | download = True 79 | ) 80 | 81 | def __len__(self): 82 | return len(self.mnist) 83 | 84 | def __getitem__(self, idx): 85 | pil, labels = self.mnist[idx] 86 | digit_tensor = T.PILToTensor()(pil) 87 | output = tensor(labels), (digit_tensor / 255).float() 88 | 89 | if IMAGE_AFTER_TEXT: 90 | return output 91 | 92 | first, second = output 93 | return second, first 94 | 95 | def cycle(iter_dl): 96 | while True: 97 | for batch in iter_dl: 98 | yield batch 99 | 100 | def collate_fn(data): 101 | data = [*map(list, data)] 102 | return data 103 | 104 | dataset = MnistDataset() 105 | dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True) 106 | 107 | iter_dl = cycle(dataloader) 108 | 109 | optimizer = Adam(model.parameters(), lr = 3e-4) 110 | 111 | # train loop 112 | 113 | for step in range(1, NUM_TRAIN_STEPS + 1): 114 | model.train() 115 | 116 | loss = model(next(iter_dl)) 117 | loss.backward() 118 | 119 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 120 | 121 | optimizer.step() 122 | optimizer.zero_grad() 123 | 124 | ema_model.update() 125 | 126 | print(f'{step}: {loss.item():.3f}') 127 | 128 | # eval 129 | 130 | if divisible_by(step, SAMPLE_EVERY): 131 | 132 | if not USE_PROMPT: 133 | # sampling from start to finish 134 | 135 | one_multimodal_sample = ema_model.sample(max_length = 384) 136 | else: 137 | # sampling using prompt 138 | # which differs depending on which comes first, text or images 139 | 140 | if IMAGE_AFTER_TEXT: 141 | 142 | text_label = torch.randint(0, 10, ()).cuda() 143 | one_multimodal_sample = ema_model.sample(prompt = text_label, max_length = 384) 144 | 145 | else: 146 | 147 | rand_batch = next(iter_dl) 148 | rand_image = rand_batch[0][0] 149 | 150 | one_multimodal_sample = ema_model.sample(prompt = rand_image, max_length = 384) 151 | 152 | # make sure modality sample overall order of modalities look correct 153 | 154 | print_modality_sample(one_multimodal_sample) 155 | 156 | if len(one_multimodal_sample) < 2: 157 | continue 158 | 159 | if IMAGE_AFTER_TEXT: 160 | maybe_label, maybe_image, *_ = one_multimodal_sample 161 | else: 162 | _, maybe_image, maybe_label = one_multimodal_sample 163 | 164 | filename = f'{step}.{maybe_label[1].item()}.png' 165 | 166 | save_image( 167 | maybe_image[1].cpu(), 168 | str(results_folder / filename), 169 | ) 170 | -------------------------------------------------------------------------------- /train_mnist_vae.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import nn, tensor 6 | from torch.nn import Module 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | from torch.optim import Adam 10 | 11 | from einops import rearrange 12 | from einops.layers.torch import Rearrange 13 | 14 | import torchvision 15 | import torchvision.transforms as T 16 | from torchvision.utils import save_image 17 | 18 | from tqdm import tqdm 19 | 20 | from transfusion_pytorch import Transfusion, print_modality_sample 21 | 22 | rmtree('./results', ignore_errors = True) 23 | results_folder = Path('./results') 24 | results_folder.mkdir(exist_ok = True, parents = True) 25 | 26 | # functions 27 | 28 | def divisible_by(num, den): 29 | return (num % den) == 0 30 | 31 | def cycle(iter_dl): 32 | while True: 33 | for batch in iter_dl: 34 | yield batch 35 | 36 | # dataset 37 | 38 | class MnistDataset(Dataset): 39 | def __init__(self): 40 | self.mnist = torchvision.datasets.MNIST( 41 | './data/mnist', 42 | download = True 43 | ) 44 | 45 | self.transform = T.Compose([ 46 | T.PILToTensor(), 47 | T.RandomResizedCrop((28, 28), scale = (0.8, 1.)) 48 | ]) 49 | 50 | def __len__(self): 51 | return len(self.mnist) 52 | 53 | def __getitem__(self, idx): 54 | pil, labels = self.mnist[idx] 55 | digit_tensor = self.transform(pil) 56 | return tensor(labels), (digit_tensor / 255).float() 57 | 58 | dataset = MnistDataset() 59 | 60 | # contrived encoder / decoder with layernorm at bottleneck 61 | 62 | autoencoder_train_steps = 15_000 63 | dim_latent = 16 64 | 65 | class Normalize(Module): 66 | def forward(self, x): 67 | return F.normalize(x, dim = -1) 68 | 69 | encoder = nn.Sequential( 70 | nn.Conv2d(1, 4, 3, padding = 1), 71 | nn.Conv2d(4, 8, 4, 2, 1), 72 | nn.ReLU(), 73 | nn.Dropout(0.05), 74 | nn.Conv2d(8, dim_latent, 1), 75 | Rearrange('b d ... -> b ... d'), 76 | Normalize() 77 | ).cuda() 78 | 79 | decoder = nn.Sequential( 80 | Rearrange('b ... d -> b d ...'), 81 | nn.Conv2d(dim_latent, 8, 1), 82 | nn.ReLU(), 83 | nn.ConvTranspose2d(8, 4, 4, 2, 1), 84 | nn.Conv2d(4, 1, 3, padding = 1), 85 | ).cuda() 86 | 87 | # train autoencoder 88 | 89 | autoencoder_optimizer = Adam([*encoder.parameters(), *decoder.parameters()], lr = 3e-4) 90 | autoencoder_dataloader = DataLoader(dataset, batch_size = 32, shuffle = True) 91 | 92 | autoencoder_iter_dl = cycle(autoencoder_dataloader) 93 | 94 | print('training autoencoder') 95 | 96 | with tqdm(total = autoencoder_train_steps) as pbar: 97 | for _ in range(autoencoder_train_steps): 98 | _, images = next(autoencoder_iter_dl) 99 | images = images.cuda() 100 | 101 | latents = encoder(images) 102 | latents = latents.lerp(torch.randn_like(latents), torch.rand_like(latents) * 0.2) # add a bit of noise to latents 103 | reconstructed = decoder(latents) 104 | 105 | loss = F.mse_loss(images, reconstructed) 106 | 107 | loss.backward() 108 | 109 | pbar.set_description(f'loss: {loss.item():.5f}') 110 | 111 | autoencoder_optimizer.step() 112 | autoencoder_optimizer.zero_grad() 113 | 114 | pbar.update() 115 | 116 | # transfusion 117 | 118 | model = Transfusion( 119 | num_text_tokens = 10, 120 | dim_latent = dim_latent, 121 | modality_default_shape = (14, 14), 122 | modality_encoder = encoder, 123 | modality_decoder = decoder, 124 | add_pos_emb = True, 125 | modality_num_dim = 2, 126 | transformer = dict( 127 | dim = 64, 128 | depth = 4, 129 | dim_head = 32, 130 | heads = 8, 131 | ) 132 | ).cuda() 133 | 134 | # training transfusion 135 | 136 | dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True) 137 | iter_dl = cycle(dataloader) 138 | 139 | optimizer = Adam(model.parameters_without_encoder_decoder(), lr = 3e-4) 140 | 141 | # train loop 142 | 143 | transfusion_train_steps = 25_000 144 | 145 | print('training transfusion with autoencoder') 146 | 147 | with tqdm(total = transfusion_train_steps) as pbar: 148 | for index in range(transfusion_train_steps): 149 | step = index + 1 150 | 151 | model.train() 152 | 153 | loss = model(next(iter_dl)) 154 | loss.backward() 155 | 156 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 157 | 158 | optimizer.step() 159 | optimizer.zero_grad() 160 | 161 | pbar.set_description(f'loss: {loss.item():.3f}') 162 | 163 | pbar.update() 164 | 165 | # eval 166 | 167 | if divisible_by(step, 500): 168 | one_multimodal_sample = model.sample(max_length = 10) 169 | 170 | print_modality_sample(one_multimodal_sample) 171 | 172 | if len(one_multimodal_sample) < 2: 173 | continue 174 | 175 | maybe_label, maybe_image, *_ = one_multimodal_sample 176 | 177 | filename = f'{step}.{maybe_label[1].item()}.png' 178 | 179 | save_image( 180 | maybe_image[1].cpu().clamp(min = 0., max = 1.), 181 | str(results_folder / filename), 182 | ) 183 | -------------------------------------------------------------------------------- /train_mnist_with_unet.py: -------------------------------------------------------------------------------- 1 | from shutil import rmtree 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import tensor, nn 6 | from torch.nn import Module 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import Adam 9 | 10 | from einops import rearrange 11 | 12 | import torchvision 13 | import torchvision.transforms as T 14 | from torchvision.utils import save_image 15 | 16 | from transfusion_pytorch import Transfusion, print_modality_sample 17 | 18 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 19 | 20 | rmtree('./results', ignore_errors = True) 21 | results_folder = Path('./results') 22 | results_folder.mkdir(exist_ok = True, parents = True) 23 | 24 | # constants 25 | 26 | IMAGE_AFTER_TEXT = False 27 | NUM_TRAIN_STEPS = 20_000 28 | SAMPLE_EVERY = 500 29 | 30 | # functions 31 | 32 | def divisible_by(num, den): 33 | return (num % den) == 0 34 | 35 | # encoder / decoder 36 | 37 | class Encoder(Module): 38 | def forward(self, x): 39 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... (p1 p2) h w', p1 = 2, p2 = 2) 40 | return x * 2 - 1 41 | 42 | class Decoder(Module): 43 | def forward(self, x): 44 | x = rearrange(x, '... (p1 p2) h w -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14) 45 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.) 46 | 47 | model = Transfusion( 48 | num_text_tokens = 10, 49 | dim_latent = 4, 50 | modality_default_shape = (14, 14), 51 | modality_encoder = Encoder(), 52 | modality_decoder = Decoder(), 53 | pre_post_transformer_enc_dec = ( 54 | nn.Conv2d(4, 64, 3, 2, 1), 55 | nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1), 56 | ), 57 | add_pos_emb = True, 58 | modality_num_dim = 2, 59 | channel_first_latent = True, 60 | transformer = dict( 61 | dim = 64, 62 | depth = 4, 63 | dim_head = 32, 64 | heads = 8, 65 | ) 66 | ).to(device) 67 | 68 | ema_model = model.create_ema() 69 | 70 | class MnistDataset(Dataset): 71 | def __init__(self): 72 | self.mnist = torchvision.datasets.MNIST( 73 | './data/mnist', 74 | download = True 75 | ) 76 | 77 | def __len__(self): 78 | return len(self.mnist) 79 | 80 | def __getitem__(self, idx): 81 | pil, labels = self.mnist[idx] 82 | digit_tensor = T.PILToTensor()(pil) 83 | output = tensor(labels), (digit_tensor / 255).float() 84 | 85 | if not IMAGE_AFTER_TEXT: 86 | return output 87 | 88 | first, second = output 89 | return second, first 90 | 91 | def cycle(iter_dl): 92 | while True: 93 | for batch in iter_dl: 94 | yield batch 95 | 96 | def collate_fn(data): 97 | data = [*map(list, data)] 98 | return data 99 | 100 | dataset = MnistDataset() 101 | dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True) 102 | 103 | iter_dl = cycle(dataloader) 104 | 105 | optimizer = Adam(model.parameters(), lr = 3e-4) 106 | 107 | # train loop 108 | 109 | for step in range(1, NUM_TRAIN_STEPS + 1): 110 | model.train() 111 | 112 | loss = model(next(iter_dl)) 113 | loss.backward() 114 | 115 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 116 | 117 | optimizer.step() 118 | optimizer.zero_grad() 119 | 120 | ema_model.update() 121 | 122 | print(f'{step}: {loss.item():.3f}') 123 | 124 | # eval 125 | 126 | if divisible_by(step, SAMPLE_EVERY): 127 | one_multimodal_sample = ema_model.sample(max_length = 384) 128 | 129 | print_modality_sample(one_multimodal_sample) 130 | 131 | if len(one_multimodal_sample) < 2: 132 | continue 133 | 134 | if IMAGE_AFTER_TEXT: 135 | _, maybe_image, maybe_label = one_multimodal_sample 136 | else: 137 | maybe_label, maybe_image, *_ = one_multimodal_sample 138 | 139 | filename = f'{step}.{maybe_label[1].item()}.png' 140 | 141 | save_image( 142 | maybe_image[1].cpu(), 143 | str(results_folder / filename), 144 | ) 145 | -------------------------------------------------------------------------------- /train_text_only.py: -------------------------------------------------------------------------------- 1 | import math 2 | import gzip 3 | import random 4 | import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | from torch.optim import Adam 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | from transfusion_pytorch import Transfusion 13 | 14 | # constants 15 | 16 | NUM_BATCHES = int(1e5) 17 | BATCH_SIZE = 4 18 | GRAD_ACCUM_EVERY = 4 19 | LEARNING_RATE = 1e-4 20 | VALIDATE_EVERY = 100 21 | PRIME_LENGTH = 64 22 | GENERATE_EVERY = 500 23 | GENERATE_LENGTH = 256 24 | SEQ_LEN = 256 25 | 26 | # helpers 27 | 28 | def exists(v): 29 | return v is not None 30 | 31 | def divisible_by(num, den): 32 | return (num % den) == 0 33 | 34 | def cycle(loader): 35 | while True: 36 | for data in loader: 37 | yield data 38 | 39 | def decode_token(token): 40 | return str(chr(max(32, token))) 41 | 42 | def decode_tokens(tokens): 43 | return "".join(list(map(decode_token, tokens))) 44 | 45 | # the minGRU char language model 46 | 47 | model = Transfusion( 48 | num_text_tokens = 256, 49 | transformer = dict( 50 | dim = 384, 51 | depth = 8, 52 | dim_head = 64, 53 | heads = 8, 54 | attn_laser = True 55 | ) 56 | ).cuda() 57 | 58 | # prepare enwik8 data 59 | 60 | with gzip.open('./data/enwik8/enwik8.gz') as file: 61 | data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy() 62 | np_train, np_valid = np.split(data, [int(90e6)]) 63 | data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid) 64 | 65 | class TextSamplerDataset(Dataset): 66 | def __init__(self, data, seq_len): 67 | super().__init__() 68 | self.data = data 69 | self.seq_len = seq_len 70 | self.data_length = data.shape[0] 71 | 72 | def __len__(self): 73 | return self.data.size(0) // self.seq_len 74 | 75 | def __getitem__(self, index): 76 | rand_start = torch.randint(0, self.data_length - self.seq_len, (1,)) 77 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() 78 | return full_seq 79 | 80 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 81 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 82 | train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE) 83 | val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE) 84 | 85 | # optimizer 86 | 87 | optim = Adam(model.parameters(), lr = LEARNING_RATE) 88 | 89 | train_loader = cycle(train_loader) 90 | val_loader = cycle(val_loader) 91 | 92 | # training 93 | 94 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"): 95 | model.train() 96 | 97 | for _ in range(GRAD_ACCUM_EVERY): 98 | data = next(train_loader) 99 | 100 | loss = model(data.cuda()) 101 | 102 | (loss / GRAD_ACCUM_EVERY).backward() 103 | 104 | print(f'loss: {loss.item():.3f}') 105 | 106 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 107 | 108 | optim.step() 109 | optim.zero_grad() 110 | 111 | if divisible_by(i, VALIDATE_EVERY): 112 | model.eval() 113 | with torch.no_grad(): 114 | valid_data = next(val_loader) 115 | loss = model(valid_data.cuda()) 116 | print(f'\nvalid loss: {loss.item():.3f}\n') 117 | 118 | if divisible_by(i, GENERATE_EVERY): 119 | model.eval() 120 | 121 | inp = random.choice(val_dataset)[:PRIME_LENGTH] 122 | inp = inp.cuda() 123 | 124 | prime = decode_tokens(inp) 125 | print(f"\nprime: {prime}\n") 126 | 127 | prompt = inp[None, ...] 128 | 129 | sampled = model.generate_text_only(prompt, GENERATE_LENGTH) 130 | 131 | base_decode_output = decode_tokens(sampled[0]) 132 | 133 | print(f"\ngenerated: {base_decode_output}\n") 134 | -------------------------------------------------------------------------------- /train_toy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import randint, randn 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.optim import Adam 5 | 6 | from transfusion_pytorch import Transfusion, print_modality_sample 7 | 8 | def divisible_by(num, den): 9 | return (num % den) == 0 10 | 11 | model = Transfusion( 12 | num_text_tokens = 8, 13 | dim_latent = 16, 14 | modality_default_shape = (2,), 15 | transformer = dict( 16 | dim = 64, 17 | depth = 1, 18 | dim_head = 8, 19 | heads = 2 20 | ) 21 | ).cuda() 22 | 23 | class MockDataset(Dataset): 24 | def __len__(self): 25 | return 100 26 | 27 | def __getitem__(self, idx): 28 | return torch.ones((1,)).long(), randn(2, 16) 29 | 30 | def cycle(iter_dl): 31 | while True: 32 | for batch in iter_dl: 33 | yield batch 34 | 35 | def collate_fn(data): 36 | data = [*map(list, data)] 37 | return data 38 | 39 | mock_dataset = MockDataset() 40 | 41 | dataloader = DataLoader(mock_dataset, batch_size = 4, collate_fn = collate_fn) 42 | iter_dl = cycle(dataloader) 43 | 44 | optimizer = Adam(model.parameters(), lr = 3e-4) 45 | 46 | # train loop 47 | 48 | for step in range(1, 10_000 + 1): 49 | 50 | loss = model(next(iter_dl)) 51 | loss.backward() 52 | 53 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 54 | 55 | optimizer.step() 56 | optimizer.zero_grad() 57 | 58 | print(f'{step}: {loss.item():.3f}') 59 | 60 | # eval 61 | 62 | if divisible_by(step, 100): 63 | one_multimodal_sample = model.sample() 64 | print_modality_sample(one_multimodal_sample) 65 | -------------------------------------------------------------------------------- /transfusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/transfusion-pytorch/ea045706cc45e2cf5cd1ce4523e5af6a49eff54a/transfusion.png -------------------------------------------------------------------------------- /transfusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from transfusion_pytorch.transfusion import ( 2 | Transfusion, 3 | print_modality_sample, 4 | create_dataloader 5 | ) 6 | -------------------------------------------------------------------------------- /transfusion_pytorch/transfusion.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | """ 4 | global ein notation 5 | 6 | b - batch 7 | t - one modality type 8 | m - separate modality instance 9 | n - sequence 10 | d - dimension 11 | l - logits (text) 12 | i, j - sequence (row, col) 13 | p - positions 14 | s - residual streams 15 | """ 16 | 17 | import os 18 | import math 19 | from collections import defaultdict 20 | 21 | from random import randrange 22 | from itertools import count 23 | from functools import partial, wraps, cache 24 | from typing import NamedTuple, Callable, Literal 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | from torch import nn, Tensor, tensor, is_tensor, cat, stack 29 | from torch.nn import Module, ModuleList, Linear 30 | 31 | from torch.utils.data import Dataset, DataLoader 32 | from torch.nn.utils.rnn import pad_sequence 33 | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten 34 | 35 | from torchdiffeq import odeint 36 | 37 | import einx 38 | from einops.layers.torch import Rearrange 39 | from einops import rearrange, repeat, reduce, einsum, pack, unpack 40 | 41 | from ema_pytorch import EMA 42 | 43 | from axial_positional_embedding import ContinuousAxialPositionalEmbedding 44 | 45 | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb 46 | 47 | from hyper_connections import HyperConnections 48 | 49 | from tqdm import tqdm 50 | from loguru import logger 51 | 52 | pad_sequence = partial(pad_sequence, batch_first = True) 53 | 54 | # tensor typing 55 | 56 | import jaxtyping 57 | from jaxtyping import jaxtyped 58 | from beartype import beartype 59 | from beartype.door import is_bearable 60 | 61 | class TorchTyping: 62 | def __init__(self, abstract_dtype): 63 | self.abstract_dtype = abstract_dtype 64 | 65 | def __getitem__(self, shapes: str): 66 | return self.abstract_dtype[Tensor, shapes] 67 | 68 | Float = TorchTyping(jaxtyping.Float) 69 | Int = TorchTyping(jaxtyping.Int) 70 | Bool = TorchTyping(jaxtyping.Bool) 71 | 72 | # maybe flex attention 73 | 74 | try: 75 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask 76 | 77 | if torch.cuda.is_available(): 78 | flex_attention = torch.compile(flex_attention) 79 | 80 | except ImportError: 81 | flex_attention = None 82 | 83 | # types 84 | 85 | Scalar = Float[''] 86 | 87 | ModalitySample = list[Int[''] | Int['_'] | Float['...'] | tuple[int, Float['...']]] 88 | 89 | ModalityTokenTransform = str | Callable | None 90 | 91 | RawModalityPositions = list[list[tuple[int, int, int]]] 92 | 93 | GetPredFlows = dict[int, list[Callable[[Tensor], Tensor]]] 94 | 95 | class LossBreakdown(NamedTuple): 96 | total: Scalar 97 | text: Scalar 98 | flow: list[Scalar] 99 | velocity: list[Scalar] | None = None 100 | recon: list[Scalar] | None = None 101 | 102 | class ModalityInfo(NamedTuple): 103 | encoder: Module | None 104 | decoder: Module | None 105 | latent_to_model: Module 106 | model_to_latent: Module 107 | add_pos_emb: bool 108 | pos_emb_mlp: Module | None 109 | num_dim: int | None 110 | dim_latent: int 111 | default_shape: tuple[int, ...] 112 | som_id: int 113 | eom_id: int 114 | to_shape_fn: Callable | None 115 | channel_first_latent: bool 116 | modality_type: int 117 | 118 | # helper functions 119 | 120 | def exists(v): 121 | return v is not None 122 | 123 | def default(v, d): 124 | return v if exists(v) else d 125 | 126 | def identity(t): 127 | return t 128 | 129 | def always(val): 130 | def inner(*args, **kwargs): 131 | return val 132 | return inner 133 | 134 | def first(it): 135 | return it[0] 136 | 137 | def join(arr, delimiter = ''): 138 | return delimiter.join(arr) 139 | 140 | def divisible_by(num, den): 141 | return (num % den) == 0 142 | 143 | def cast_tuple(t, length = 1): 144 | return t if isinstance(t, tuple) else ((t,) * length) 145 | 146 | def tree_map_tensor(sample, fn: Callable): 147 | return tree_map(lambda t: t if not is_tensor(t) else fn(t), sample) 148 | 149 | def add_temp_batch_dim(fn: Callable): 150 | @wraps(fn) 151 | def inner(t: Tensor, *args, **kwargs) -> Tensor: 152 | t = rearrange(t, '... -> 1 ...') 153 | out = fn(t, *args, **kwargs) 154 | out = rearrange(out, '1 ... -> ...') 155 | return out 156 | return inner 157 | 158 | def pack_with_inverse(t, pattern): 159 | packed, packed_shape = pack(t, pattern) 160 | 161 | def inverse(out, inv_pattern = None): 162 | inv_pattern = default(inv_pattern, pattern) 163 | return unpack(out, packed_shape, inv_pattern) 164 | 165 | return packed, inverse 166 | 167 | def pack_one_with_inverse(t, pattern): 168 | packed, packed_shape = pack([t], pattern) 169 | 170 | def inverse(out, inv_pattern = None): 171 | inv_pattern = default(inv_pattern, pattern) 172 | return unpack(out, packed_shape, inv_pattern)[0] 173 | 174 | return packed, inverse 175 | 176 | def eval_decorator(fn): 177 | def inner(self, *args, **kwargs): 178 | was_training = self.training 179 | self.eval() 180 | out = fn(self, *args, **kwargs) 181 | self.train(was_training) 182 | return out 183 | return inner 184 | 185 | # maybe typecheck 186 | 187 | typecheck = jaxtyped(typechecker = beartype) if os.environ.get('TYPECHECK', '').lower() in ('1', 'true') else identity 188 | 189 | # default function for constituting modality shape from string 190 | 191 | def default_to_modality_shape_fn(maybe_shape_str) -> tuple[int, ...]: 192 | return tuple([*map(int, maybe_shape_str.split(','))]) 193 | 194 | # default function for translating modality length to times (noise level, where 0 is highest noise) 195 | 196 | def random_modality_length_to_time_fn(num_modalities: Int['b']) -> Float['b m']: 197 | batch = num_modalities.shape[0] 198 | device = num_modalities.device 199 | total_modalities = modality_length.amax().item() 200 | return torch.rand((batch, total_modalities), device = device) 201 | 202 | def default_modality_length_to_time_fn(num_modalities: Int['b']) -> Float['b m']: 203 | batch, device = num_modalities.shape[0], num_modalities.device 204 | total_modalities = num_modalities.amax().item() 205 | 206 | if total_modalities == 0: 207 | return torch.empty((batch, 0), device = device, dtype = torch.float) 208 | 209 | rand_num_modalities = torch.floor(torch.rand_like(num_modalities.float()) * num_modalities) 210 | seq = torch.arange(total_modalities, device = device) 211 | 212 | prev_decoded_modality = einx.less('m, b -> b m', seq, rand_num_modalities) 213 | curr_modality_rand_time = torch.rand_like(num_modalities.float()) 214 | 215 | # in paper, they fix previous decoded modalities to 500 / 1000 steps for discrete ddpm, here using flow matching with times 0 - 1 so corresponds to 0.5 216 | return einx.where('b m, , b -> b m', prev_decoded_modality, 0.5, curr_modality_rand_time) 217 | 218 | # pretty print 219 | 220 | def concat_contiguous_text( 221 | modality_sample: ModalitySample 222 | ) -> ModalitySample: 223 | """ within a modality sample, any two tensors of type int / long will be concatted together if next to each other, so all text is followed by a modality, and all modality followed by text """ 224 | 225 | output = [] 226 | 227 | for modality in modality_sample: 228 | if ( 229 | len(output) > 0 and 230 | is_tensor(output[-1]) and is_tensor(modality) and 231 | output[-1].dtype == modality.dtype and 232 | modality.dtype in (torch.int, torch.long) 233 | ): 234 | packed_text, _ = pack((output[-1], modality), '*') 235 | output[-1] = packed_text 236 | 237 | else: 238 | output.append(modality) 239 | 240 | return output 241 | 242 | def print_modality_sample( 243 | modality_sample: ModalitySample 244 | ): 245 | output = [] 246 | 247 | for sample in modality_sample: 248 | if isinstance(sample, tuple): 249 | modality_type, sample = sample 250 | output.append((f'modality:{modality_type}', sample.shape)) 251 | elif sample.dtype in (torch.int, torch.long): 252 | output.append(('text', sample.shape)) 253 | else: 254 | output.append(('modality', sample.shape)) 255 | 256 | logger.info(output) 257 | 258 | # character based tokenizer 259 | 260 | def char_tokenize( 261 | text: str, 262 | device = None, 263 | offset = 0 264 | ) -> Tensor: 265 | tokenized = tensor([*map(ord, text)], device = device) + offset 266 | return tokenized.long() 267 | 268 | def decode_chars( 269 | t: Tensor, 270 | offset = 0, 271 | ) -> str: 272 | byte_list = (t - offset).clamp(min = 0, max = 127).tolist() 273 | return ''.join([*map(chr, byte_list)]) 274 | 275 | def get_tokens_since_rightmost_id( 276 | t: Tensor, 277 | rightmost_id: int 278 | ) -> Tensor: 279 | """ 280 | ex. [9] [2] [8] [4] [7] 281 | 2 would return [8] [4] [7] 282 | """ 283 | 284 | mask = t == rightmost_id 285 | 286 | if not mask.any(): 287 | return t[0:0] # return empty tensor if no id found 288 | 289 | reverse_cumsum = mask.flip(dims = (0,)).cumsum(dim = 0).flip(dims = (0,)) 290 | after_right_mask = reverse_cumsum == 0 291 | return t[after_right_mask] 292 | 293 | # tensor helpers 294 | 295 | def l2norm(t): 296 | return F.normalize(t, dim = -1) 297 | 298 | def softclamp(t, value = 50.): 299 | return (t / value).tanh() * value 300 | 301 | def max_neg_value(t): 302 | return -torch.finfo(t.dtype).max 303 | 304 | def append_dims(t, ndims): 305 | return t.reshape(*t.shape, *((1,) * ndims)) 306 | 307 | def is_empty(t): 308 | return t.numel() == 0 309 | 310 | def log(t, eps = 1e-20): 311 | return torch.log(t.clamp(min = eps)) 312 | 313 | def gumbel_noise(t): 314 | noise = torch.rand_like(t) 315 | return -log(-log(noise)) 316 | 317 | def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True): 318 | noise = gumbel_noise(t) * int(temperature > 0) 319 | return (t / temperature + noise).argmax(dim = dim, keepdim = keepdim) 320 | 321 | # dataloader related 322 | 323 | def collate_fn(data): 324 | return [*map(list, data)] 325 | 326 | @typecheck 327 | def create_dataloader(dataset: Dataset, **kwargs) -> DataLoader: 328 | return DataLoader(dataset, collate_fn = collate_fn, **kwargs) 329 | 330 | # flex attention mask construction 331 | # https://pytorch.org/blog/flexattention/ 332 | 333 | def causal(b, h, q_idx, kv_idx): 334 | return q_idx >= kv_idx 335 | 336 | def modality(offset, length): 337 | 338 | def mask_fn(b, h, q_idx, kv_idx): 339 | return (q_idx >= offset) & (kv_idx < (offset + length)) 340 | 341 | return mask_fn 342 | 343 | def transfusion_attn_mask(modalities: Int['b m 3']): 344 | modalities = modalities.long() 345 | 346 | def mask_mod(b, h, q_idx, kv_idx): 347 | mask = causal(b, h, q_idx, kv_idx) 348 | 349 | modality_batch = modalities[b] 350 | 351 | for _, offset, length in modality_batch: 352 | mask = mask | modality(offset, length)(b, h, q_idx, kv_idx) 353 | 354 | return mask 355 | 356 | return mask_mod 357 | 358 | def softcap_score_mod(softcap): 359 | def inner(score, b, h, q_idx, kv_idx): 360 | score = score / softcap 361 | score = torch.tanh(score) 362 | score = score * softcap 363 | return score 364 | return inner 365 | 366 | # converting a raw list of modality offsets and lengths to tensor 367 | 368 | @typecheck 369 | def modality_positions_to_tensor( 370 | modalities: RawModalityPositions, 371 | pad_value = 0, 372 | device = None 373 | ) -> Int['b m 2'] | Int['b m 3']: 374 | 375 | modalities: list[Tensor] = [tensor(modality, device = device) for modality in modalities] 376 | modalities = pad_sequence(modalities, padding_value = pad_value) 377 | 378 | if modalities.ndim == 2: 379 | modalities = modalities.reshape(*modalities.shape, 3) 380 | 381 | return modalities.long() 382 | 383 | # sanitizing modalities tensor, making sure it is ordered 384 | 385 | @typecheck 386 | def order_modality_positions_by_seq_offset( 387 | modalities: Int['b m 3'] 388 | ) -> tuple[Int['b m 3'], Int['b m']]: 389 | 390 | modality_type, offsets, lengths = modalities.unbind(dim = -1) 391 | 392 | no_modality_mask = lengths <= 0 # there may be uneven number of modalities per batch sample 393 | offsets_to_sort = offsets.masked_fill(no_modality_mask, 1e10) 394 | _, sorted_indices = offsets_to_sort.sort(dim = -1) 395 | 396 | # sort by ascending offset 397 | 398 | modalities = einx.get_at('b [mi] ..., b mo -> b mo ...', modalities, sorted_indices) 399 | return modalities, sorted_indices 400 | 401 | # deriving relative positions from modality positions 402 | # ex. given a sequence of 10 with an image at offset 3 with length 4 - [t] [t] [t] [i] [i] [i] [i] [t] [t] [t] 403 | # relative positions for rotary will be [0] [1] [2] [3] [3] [3] [3] [4] [5] [6] 404 | # rationale is that each modality will need the same position so there is no distance when conducting bidirectional attention, but should still have a relative distance to other text tokens and modalities 405 | 406 | def derive_rotary_positions_from_modality_positions( 407 | seq_len: int, 408 | modalities: Int['b m 3'] 409 | ) -> Int['b n']: 410 | 411 | device = modalities.device 412 | 413 | modality_mask = modality_positions_to_is_modality_mask(seq_len, modalities, offset = torch.tensor([1, -1])) 414 | is_any_modality = reduce(modality_mask, 'b t m n -> b n', 'any') 415 | 416 | return torch.arange(seq_len, device = device) - is_any_modality.cumsum(dim = -1) 417 | 418 | # modality tokens are given as list of tensors, can be then be embedded into the modality tokens for attending alongside text tokens 419 | 420 | @typecheck 421 | def embed_modality_tokens( 422 | seq_len: int, 423 | dim: int, 424 | modality_tokens: list[list[Float['...']]], 425 | modalities: Int['b m 3'], 426 | modality_id: int, 427 | channel_first: bool 428 | ) -> Float['b n d']: 429 | 430 | batch, device = modalities.shape[0], modalities.device 431 | 432 | shape = (batch, seq_len, dim) if not channel_first else (batch, dim, seq_len) 433 | output = torch.zeros(shape, device = device) 434 | 435 | for batch_ind, (one_modality, one_modality_token) in enumerate(zip(modalities, modality_tokens)): 436 | for (modality_type, offset, length), batch_modality_token in zip(one_modality, one_modality_token): 437 | 438 | if modality_id != modality_type or length <= 0: 439 | continue 440 | 441 | modality_shape = batch_modality_token.shape 442 | 443 | if channel_first: 444 | mod_dim, *mod_axial_shape = modality_shape 445 | batch_modality_token = rearrange(batch_modality_token, 'd ... -> d (...)') 446 | else: 447 | *mod_axial_shape, mod_dim = modality_shape 448 | batch_modality_token = rearrange(batch_modality_token, '... d -> (...) d') 449 | 450 | mod_length = math.prod(mod_axial_shape) 451 | 452 | assert length == mod_length, f'received a modality of shape {modality_shape} but sequence length in modalities info is {length}' 453 | assert dim == mod_dim, f'received modality [{modality_id}] with shape {modality_shape} but expected dimension of {dim}' 454 | 455 | if channel_first: 456 | output[batch_ind, :, offset:(offset + length)] = batch_modality_token 457 | else: 458 | output[batch_ind, offset:(offset + length), :] = batch_modality_token 459 | 460 | return output 461 | 462 | # functions for managing modality token mask 463 | 464 | @typecheck 465 | def modality_positions_to_is_modality_mask( 466 | seq_len: int, 467 | modalities: Int['b m 3'], 468 | offset: Int['2'] | None = None, 469 | device = None, 470 | num_modalities = 1 471 | ) -> Bool['b t m n']: 472 | 473 | device = modalities.device 474 | 475 | if exists(offset): 476 | offset = F.pad(offset, (1, 0)) 477 | modalities = modalities + offset.to(modalities) 478 | 479 | seq = torch.arange(seq_len, device = device) 480 | type_seq = torch.arange(num_modalities, device = device) 481 | 482 | modality_types = modalities[..., 0] 483 | 484 | left, right = modalities[..., 1:].cumsum(dim = -1).unbind(dim = -1) 485 | 486 | is_instance_for_type = einx.equal('b m, t -> b t m', modality_types, type_seq) 487 | 488 | is_modality_along_seq = ( 489 | einx.greater_equal('i, b m -> b m i', seq, left) & 490 | einx.less('j, b m -> b m j', seq, right) 491 | ) 492 | 493 | return einx.logical_and('b t m, b m n -> b t m n', is_instance_for_type, is_modality_along_seq) 494 | 495 | @typecheck 496 | def naive_attn_mask( 497 | seq_len: int, 498 | modalities: Int['b m 3'], 499 | device = None 500 | ) -> Bool['b i j']: 501 | 502 | _, offsets, length = modalities.unbind(dim = -1) 503 | 504 | seq = torch.arange(seq_len, device = device) 505 | 506 | is_causal = einx.greater_equal('i, j -> i j', seq, seq) 507 | 508 | is_modality = ( 509 | einx.greater_equal('i, b m -> b m i 1', seq, offsets) & 510 | einx.less('j, b m -> b m 1 j', seq, offsets + length) 511 | ) 512 | 513 | return is_causal | is_modality.any(dim = 1) 514 | 515 | # unet encoder related function 516 | 517 | def stack_same_shape_tensors_with_inverse(tensors: list[Tensor]): 518 | 519 | shape_tensors_dict = defaultdict(list) 520 | shape_batch_dict = defaultdict(int) # also store a shape -> num tensors dictionary to validate inverse function input 521 | inverse_index_list = [] 522 | 523 | for tensor in tensors: 524 | shape = tuple(tensor.shape) 525 | batch_el = shape_batch_dict[shape] 526 | 527 | shape_tensors_dict[shape].append(tensor) 528 | shape_batch_dict[shape] += 1 529 | 530 | inverse_index_list.append((shape, batch_el)) 531 | 532 | # stack all the tensors with same shapes to have a batch dimension 533 | 534 | shape_tensors_dict = {shape: torch.stack(same_shape_tensors) for shape, same_shape_tensors in shape_tensors_dict.items()} 535 | 536 | # inverse function 537 | 538 | def inverse( 539 | indexed_tensors: dict[tuple[int, ...], Tensor] 540 | ) -> list[Tensor]: 541 | 542 | out_shape_batch_dict = {shape: len(tensors) for shape, tensors in indexed_tensors.items()} 543 | 544 | assert out_shape_batch_dict == shape_batch_dict 545 | 546 | inversed = [] 547 | 548 | for shape, batch_el in inverse_index_list: 549 | tensor = indexed_tensors[shape][batch_el] 550 | inversed.append(tensor) 551 | 552 | return inversed 553 | 554 | return shape_tensors_dict, inverse 555 | 556 | def filter_with_inverse(cond, inp): 557 | 558 | indices = set() 559 | filtered = [] 560 | 561 | for ind, el in enumerate(inp): 562 | if cond(el): 563 | indices.add(ind) 564 | filtered.append(el) 565 | 566 | def inverse(inverse_inp): 567 | assert len(inverse_inp) == len(filtered) 568 | 569 | output = [] 570 | inverse_inp_index = 0 571 | 572 | for ind, el in enumerate(inp): 573 | if ind not in indices: 574 | output.append(el) 575 | continue 576 | 577 | inverse_inp_el = inverse_inp[inverse_inp_index] 578 | output.append(inverse_inp_el) 579 | inverse_inp_index += 1 580 | 581 | return output 582 | 583 | return filtered, inverse 584 | 585 | def apply_fn_modality_type( 586 | fn: Callable, 587 | modalities: ModalitySample | list[ModalitySample], 588 | modality_type = 0, 589 | return_untransformed = False 590 | ) -> ModalitySample | list[ModalitySample]: 591 | 592 | modalities, tree_spec = tree_flatten(modalities, is_leaf = lambda el: isinstance(el, tuple)) 593 | 594 | # standardize tuples to (, ) 595 | 596 | modalities = [(0, m) if (is_tensor(m) and m.dtype == torch.float) else m for m in modalities] 597 | 598 | # filter for specific modality type to transform 599 | 600 | modalities, inverse_filter = filter_with_inverse(lambda el: isinstance(el, tuple) and el[0] == modality_type, modalities) 601 | 602 | # remove the type 603 | 604 | modalities = [m for _, m in modalities] 605 | 606 | # batch process 607 | 608 | stacked_modalities, inverse_stack = stack_same_shape_tensors_with_inverse(modalities) 609 | 610 | out = {shape: fn(batched_modalities) for shape, batched_modalities in stacked_modalities.items()} 611 | 612 | out = inverse_stack(out) 613 | 614 | # add back the type 615 | 616 | if return_untransformed: 617 | out = [(modality_type, transformed_m, prev_m) for transformed_m, prev_m in zip(out, modalities)] 618 | else: 619 | out = [(modality_type, transformed_m) for transformed_m in out] 620 | 621 | # replace transformed modalities and untree flatten 622 | 623 | out = inverse_filter(out) 624 | 625 | return tree_unflatten(out, tree_spec) 626 | 627 | # sampling related functions 628 | 629 | # min_p for text 630 | # https://arxiv.org/abs/2407.01082 631 | 632 | def min_p_filter(logits, min_p = 0.1): 633 | probs = logits.softmax(dim = -1) 634 | max_probs = probs.amax(dim = -1, keepdim = True) 635 | limit = min_p * max_probs 636 | return torch.where(probs < limit, float('-inf'), logits) 637 | 638 | # random fourier embedding 639 | 640 | class RandomFourierEmbed(Module): 641 | def __init__(self, dim): 642 | super().__init__() 643 | assert divisible_by(dim, 2) 644 | self.dim = dim 645 | self.register_buffer('weights', torch.randn(dim // 2)) 646 | 647 | @typecheck 648 | def forward( 649 | self, 650 | times: Float['b n'] | Float['b'] 651 | ) -> Float['b n {self.dim+1}']: 652 | 653 | if times.ndim == 1: 654 | times = rearrange(times, 'b -> b 1') 655 | 656 | freqs = einx.multiply('... i, j -> ... i j', times, self.weights) * 2 * torch.pi 657 | fourier_embed, _ = pack((times, freqs.sin(), freqs.cos()), 'b n *') 658 | return fourier_embed 659 | 660 | # adaptive layernorm and ada-ln zero rolled into one wrapper 661 | # from DiT paper and sota for time conditioning for now 662 | 663 | class AdaptiveWrapper(Module): 664 | @beartype 665 | def __init__( 666 | self, 667 | fn: Module, 668 | dim, 669 | dim_cond, 670 | ada_ln_zero_init_bias = -2 671 | ): 672 | super().__init__() 673 | self.fn = fn 674 | self.dim = dim 675 | self.dim_cond = dim_cond 676 | 677 | self.layernorm = nn.LayerNorm(dim, elementwise_affine = False) 678 | 679 | # text will be subjected to normal layernorm bias 680 | # and for output will use layerscale 681 | 682 | self.layernorm_gamma = nn.Parameter(torch.zeros(dim)) 683 | self.layerscale = nn.Parameter(torch.zeros(dim)) 684 | 685 | # modalities will get the adaptive layernorm + ada-ln zero 686 | 687 | self.to_film = Linear(dim_cond, dim * 2) 688 | self.to_ada_ln_zero = Linear(dim_cond, dim) 689 | 690 | nn.init.zeros_(self.to_film.weight) 691 | nn.init.zeros_(self.to_ada_ln_zero.weight) 692 | nn.init.constant_(self.to_ada_ln_zero.bias, ada_ln_zero_init_bias) 693 | 694 | @typecheck 695 | def forward_text( 696 | self, 697 | x: Float['b n {self.dim}'], 698 | **kwargs 699 | ): 700 | x = self.layernorm(x) 701 | 702 | x = x * (self.layernorm_gamma + 1.) 703 | 704 | out = self.fn(x, **kwargs) 705 | 706 | (out, *rest), tree_spec = tree_flatten(out) 707 | 708 | out = out * (self.layerscale + 1.) 709 | 710 | out = tree_unflatten((out, *rest), tree_spec) 711 | 712 | return out 713 | 714 | @typecheck 715 | def forward_modality( 716 | self, 717 | x: Float['b n {self.dim}'], 718 | cond: Float['... {self.dim_cond}'], 719 | **kwargs 720 | ): 721 | x = self.layernorm(x) 722 | 723 | gamma, beta = self.to_film(cond).chunk(2, dim = -1) 724 | 725 | modality_tokens = x * (gamma + 1.) + beta 726 | 727 | # attention or feedforwards 728 | 729 | out = self.fn(modality_tokens, **kwargs) 730 | 731 | (out, *rest), tree_spec = tree_flatten(out) 732 | 733 | # take care of conditioning output separately for text vs modality 734 | 735 | modalities_out = out * self.to_ada_ln_zero(cond).sigmoid() 736 | 737 | # take care of function returning cache or value residual 738 | 739 | modalities_out = tree_unflatten((modalities_out, *rest), tree_spec) 740 | 741 | return modalities_out 742 | 743 | @typecheck 744 | def forward( 745 | self, 746 | x: Float['b n {self.dim}'], 747 | cond: Float['... {self.dim_cond}'] | None = None, 748 | is_any_modality: bool | Bool['b n'] | None = None, 749 | modality_only = False, 750 | **kwargs 751 | ): 752 | if exists(cond) and cond.ndim == 2: 753 | cond = rearrange(cond, 'b d -> b 1 d') 754 | 755 | if modality_only: 756 | return self.forward_modality(x, cond = cond, **kwargs) 757 | 758 | assert not (exists(cond) ^ exists(is_any_modality)) 759 | 760 | has_modality = exists(is_any_modality) 761 | 762 | if not has_modality: 763 | return self.forward_text(x, **kwargs) 764 | 765 | if isinstance(is_any_modality, bool): 766 | is_any_modality = torch.full((x.shape[:-1]), is_any_modality, device = x.device, dtype = torch.bool) 767 | 768 | is_any_modality = rearrange(is_any_modality, '... -> ... 1') 769 | 770 | x = self.layernorm(x) 771 | 772 | gamma, beta = self.to_film(cond).chunk(2, dim = -1) 773 | 774 | text_tokens = x * (self.layernorm_gamma + 1.) 775 | 776 | modality_tokens = x * (gamma + 1.) + beta 777 | 778 | x = torch.where(is_any_modality, modality_tokens, text_tokens) 779 | 780 | # attention or feedforwards 781 | 782 | out = self.fn(x, **kwargs) 783 | 784 | (out, *rest), tree_spec = tree_flatten(out) 785 | 786 | # take care of conditioning output separately for text vs modality 787 | 788 | text_out = out * (self.layerscale + 1.) 789 | 790 | modalities_out = out * self.to_ada_ln_zero(cond).sigmoid() 791 | 792 | conditioned_out = torch.where(is_any_modality, modalities_out, text_out) 793 | 794 | # take care of function returning cache or value residual 795 | 796 | conditioned_out = tree_unflatten((conditioned_out, *rest), tree_spec) 797 | 798 | return conditioned_out 799 | 800 | # attention 801 | 802 | class RMSNorm(Module): 803 | def __init__(self, dim): 804 | super().__init__() 805 | self.scale = dim ** 0.5 806 | self.gamma = nn.Parameter(torch.zeros(dim)) 807 | 808 | def forward(self, x): 809 | return l2norm(x) * self.scale * (self.gamma + 1.) # use unit offset from Ohad Rubin 810 | 811 | class GEGLU(Module): 812 | def forward(self, x): 813 | x, gates = x.chunk(2, dim = -1) 814 | return F.gelu(gates) * x 815 | 816 | def FeedForward( 817 | dim, 818 | expansion_factor = 4., 819 | dropout = 0. 820 | ): 821 | dim_inner = int(dim * expansion_factor * 2 / 3) 822 | return nn.Sequential( 823 | Linear(dim, dim_inner * 2), 824 | GEGLU(), 825 | nn.Dropout(dropout), 826 | Linear(dim_inner, dim) 827 | ) 828 | 829 | class Attention(Module): 830 | def __init__( 831 | self, 832 | dim, 833 | dim_head = 64, 834 | heads = 8, 835 | dropout = 0., 836 | softcap_value = 50., 837 | use_flex_attn = False, 838 | gate_values = True, 839 | laser = False, 840 | laser_softclamp_value = 15., 841 | learned_value_residual_mix = False 842 | ): 843 | super().__init__() 844 | self.scale = dim_head ** -0.5 845 | dim_inner = dim_head * heads 846 | 847 | assert not (use_flex_attn and not exists(flex_attention)), 'flex attention is only available on torch 2.5.0 (nightly) onwards' 848 | self.use_flex_attn = use_flex_attn 849 | 850 | self.to_qkv = nn.Sequential( 851 | Linear(dim, dim_inner * 3, bias = False), 852 | Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads) 853 | ) 854 | 855 | self.to_learned_value_residual = nn.Sequential( 856 | nn.Linear(dim, heads), 857 | nn.Sigmoid(), 858 | Rearrange('b n h -> b h n 1') # add head dimension 859 | ) if learned_value_residual_mix else always(0.5) 860 | 861 | self.to_gates = nn.Sequential( 862 | nn.Linear(dim, heads, bias = False), 863 | Rearrange('b n h -> b h n 1', h = heads) 864 | ) if gate_values else None 865 | 866 | self.softcap_value = softcap_value 867 | 868 | self.laser = laser 869 | self.laser_softclamp_value = laser_softclamp_value 870 | 871 | self.dropout = nn.Dropout(dropout) 872 | 873 | self.to_out = nn.Sequential( 874 | Rearrange('b h n d -> b n (h d)'), 875 | Linear(dim_inner, dim, bias = False) 876 | ) 877 | 878 | def forward( 879 | self, 880 | x, 881 | attn_mask: Tensor | None = None, # for manual masking 882 | rotary_emb: Tensor | None = None, 883 | cache: Tensor | None = None, 884 | causal = False, 885 | block_mask = None, # only passed in for flex attention 886 | return_kv_cache = False, 887 | return_values = False, 888 | value_residual: Tensor | None = None 889 | ): 890 | device, input_is_cuda, is_decoding_with_cache = x.device, x.is_cuda, exists(cache) 891 | 892 | should_use_flex_attn = self.use_flex_attn and input_is_cuda 893 | 894 | # handle maybe mask 895 | # if receiving kv cache, assume decoding and turn off all masking 896 | 897 | if is_decoding_with_cache: 898 | block_mask = attn_mask = None 899 | 900 | assert not (exists(block_mask) and exists(attn_mask)) 901 | assert not (not self.use_flex_attn and exists(block_mask)), 'you cannot pass in the `block_mask` if `use_flex_attn` was not set to be `True`' 902 | 903 | # project to queries, keys, values 904 | 905 | q, k, v = self.to_qkv(x) 906 | 907 | # value residual 908 | 909 | orig_v = v 910 | 911 | if exists(value_residual): 912 | mix = self.to_learned_value_residual(x) 913 | v = v * mix + value_residual * (1. - mix) 914 | 915 | # handle cache being passed in 916 | 917 | if exists(cache): 918 | cached_k, cached_v = cache 919 | k = cat((cached_k, k), dim = -2) 920 | v = cat((cached_v, v), dim = -2) 921 | 922 | # maybe kv cache 923 | 924 | if return_kv_cache: 925 | kv_cache = stack((k, v)) 926 | 927 | # rotary embeddings 928 | 929 | if exists(rotary_emb): 930 | q, k = tuple(apply_rotary_emb(rotary_emb, t, freqs_seq_dim = -2) for t in (q, k)) 931 | 932 | # laser attention 933 | 934 | if self.laser: 935 | v = softclamp(v, self.laser_softclamp_value) 936 | v = v.exp() 937 | 938 | # whether to use flex attention or not 939 | 940 | if should_use_flex_attn: 941 | assert not causal, 'causal mask should be constructed in transformer' 942 | 943 | flex_attn_kwargs = dict(block_mask = block_mask) 944 | 945 | if self.softcap_value > 0.: 946 | flex_attn_kwargs.update(score_mod = softcap_score_mod(self.softcap_value)) 947 | 948 | out = flex_attention(q, k, v, **flex_attn_kwargs) 949 | 950 | else: 951 | q = q * self.scale 952 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') 953 | 954 | sim = softclamp(sim, self.softcap_value) 955 | 956 | mask_value = max_neg_value(sim) 957 | 958 | if causal: 959 | i, j = sim.shape[-2:] 960 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) 961 | sim = sim.masked_fill(causal_mask, mask_value) 962 | 963 | if exists(attn_mask): 964 | sim = einx.where('b i j, b h i j, -> b h i j', attn_mask, sim, mask_value) 965 | 966 | attn = sim.softmax(dim = -1) 967 | 968 | attn = self.dropout(attn) 969 | 970 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d') 971 | 972 | # laser attention 973 | 974 | if self.laser: 975 | out = log(out) 976 | 977 | # maybe gate values 978 | 979 | if exists(self.to_gates): 980 | out = out * self.to_gates(x).sigmoid() 981 | 982 | # combine heads and out 983 | 984 | out = self.to_out(out) 985 | 986 | if return_values: 987 | out = (out, orig_v) 988 | 989 | if not return_kv_cache: 990 | return out 991 | 992 | return out, kv_cache 993 | 994 | class Transformer(Module): 995 | @beartype 996 | def __init__( 997 | self, 998 | dim, 999 | *, 1000 | depth, 1001 | dim_head = 64, 1002 | heads = 8, 1003 | dropout = 0., 1004 | ff_expansion_factor = 4, 1005 | attn_kwargs: dict = dict(), 1006 | ff_kwargs: dict = dict(), 1007 | attn_laser = False, 1008 | unet_skips = True, 1009 | use_flex_attn = False, 1010 | num_residual_streams = 1, 1011 | num_residual_fracs = 4 1012 | ): 1013 | super().__init__() 1014 | self.use_flex_attn = use_flex_attn 1015 | 1016 | self.dim = dim 1017 | self.dim_head = dim_head 1018 | 1019 | self.to_time_cond = nn.Sequential( 1020 | RandomFourierEmbed(dim), 1021 | Linear(dim + 1, dim * 4), 1022 | nn.SiLU() 1023 | ) 1024 | 1025 | # hyper connections 1026 | 1027 | counter = count() 1028 | 1029 | init_residual_fn, self.expand_stream, self.reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_fracs = num_residual_fracs) 1030 | 1031 | # layers 1032 | 1033 | layers = ModuleList([]) 1034 | 1035 | for ind in range(depth): 1036 | is_first = ind == 0 1037 | 1038 | is_latter_half = ind >= (depth / 2) 1039 | 1040 | skip_proj = Linear(dim * 2, dim, bias = False) if is_latter_half and unet_skips else None 1041 | 1042 | attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout, use_flex_attn = use_flex_attn, learned_value_residual_mix = not is_first, laser = attn_laser, **attn_kwargs) 1043 | 1044 | ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor, **ff_kwargs) 1045 | 1046 | attn = AdaptiveWrapper(attn, dim = dim, dim_cond = dim * 4) 1047 | ff = AdaptiveWrapper(ff, dim = dim, dim_cond = dim * 4) 1048 | 1049 | attn_residual = init_residual_fn(dim = dim, layer_index = next(counter)) 1050 | ff_residual = init_residual_fn(dim = dim, layer_index = next(counter)) 1051 | 1052 | layers.append(ModuleList([skip_proj, attn, attn_residual, ff, ff_residual])) 1053 | 1054 | self.layers = layers 1055 | self.norm = RMSNorm(dim) 1056 | 1057 | @typecheck 1058 | def forward( 1059 | self, 1060 | x, 1061 | times: Scalar | Float['b'] | Float['b n'] | None = None, 1062 | attn_mask: Bool['b i j'] | None = None, 1063 | modality_positions: RawModalityPositions | Int['b m 3'] | None = None, 1064 | is_any_modality: bool | Bool['b n'] | None = None, 1065 | rotary_emb: Tensor | None = None, 1066 | cache: Tensor | None = None, 1067 | decode_length: int | None = None, 1068 | modality_only = False, 1069 | causal_mask = False, 1070 | return_kv_cache = False 1071 | ): 1072 | batch, seq_len, device, input_is_cuda = x.shape[0], x.shape[-2], x.device, x.is_cuda 1073 | 1074 | is_decoding_with_cache = exists(cache) 1075 | needs_masking = not is_decoding_with_cache 1076 | 1077 | should_use_flex_attn = input_is_cuda and needs_masking and self.use_flex_attn 1078 | 1079 | assert not (exists(attn_mask) and exists(modality_positions)) 1080 | 1081 | # handle time 1082 | 1083 | cond = None 1084 | 1085 | if exists(times): 1086 | if times.ndim == 0: 1087 | times = repeat(times, ' -> b', b = batch) 1088 | 1089 | cond = self.to_time_cond(times) 1090 | 1091 | # create the specialized mask needed for autoregressive text + bidirectional flow attention 1092 | 1093 | attn_mask_kwargs = dict() 1094 | 1095 | if needs_masking: 1096 | if causal_mask: 1097 | if should_use_flex_attn: 1098 | block_mask = create_block_mask(causal, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True, device = device) 1099 | attn_mask_kwargs.update(block_mask = block_mask) 1100 | else: 1101 | attn_mask_kwargs.update(causal = True) 1102 | 1103 | if exists(modality_positions): 1104 | assert not causal_mask 1105 | 1106 | if should_use_flex_attn: 1107 | transfusion_mask_fn = transfusion_attn_mask(modality_positions) 1108 | block_mask = create_block_mask(transfusion_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True, device = device) 1109 | attn_mask_kwargs.update(block_mask = block_mask) 1110 | else: 1111 | attn_mask = naive_attn_mask(seq_len, modality_positions, device = device) 1112 | attn_mask_kwargs.update(attn_mask = attn_mask) 1113 | 1114 | if not exists(is_any_modality) and exists(modality_positions): 1115 | is_any_modality = modality_positions_to_is_modality_mask(seq_len, modality_positions).any(dim = 1) 1116 | is_any_modality = reduce(is_any_modality, 'b t n -> b n', 'any') 1117 | 1118 | # handle kv caching 1119 | 1120 | if is_decoding_with_cache: 1121 | assert exists(decode_length) 1122 | 1123 | x = x[..., -decode_length:, :] 1124 | cond = cond[..., -decode_length:, :] 1125 | 1126 | if is_tensor(is_any_modality): 1127 | is_any_modality = is_any_modality[..., -decode_length:] 1128 | 1129 | # adaptive layernorm kwargs, which handles text and modality tokens differently 1130 | 1131 | adaptive_kwargs = dict( 1132 | cond = cond, 1133 | modality_only = modality_only, 1134 | is_any_modality = is_any_modality 1135 | ) 1136 | 1137 | # handle cache 1138 | 1139 | cache = default(cache, (None,)) 1140 | iter_cache = iter(cache) 1141 | 1142 | # expand input into multiple residual streams for maybe hyper connection 1143 | 1144 | x = self.expand_stream(x) 1145 | 1146 | # transformer layers as usual, using mask from above 1147 | 1148 | skips = [] 1149 | value_residual = None 1150 | 1151 | new_cache = [] 1152 | 1153 | depth = len(self.layers) 1154 | 1155 | for ind, (skip_proj, attn, attn_residual, ff, ff_residual) in enumerate(self.layers): 1156 | layer = ind + 1 1157 | 1158 | # skip connection 1159 | 1160 | is_first_half = layer <= (depth // 2) 1161 | is_later_half = not is_first_half 1162 | 1163 | if is_first_half: 1164 | skips.append(x) 1165 | 1166 | if is_later_half and exists(skip_proj): 1167 | skip = skips.pop() 1168 | 1169 | residual = x 1170 | x = cat((x, skip), dim = -1) 1171 | x = skip_proj(x) + residual 1172 | 1173 | # attention and feedforward 1174 | 1175 | x, add_attn_residual = attn_residual(x) 1176 | 1177 | (attn_out, attn_values), kv_cache = attn( 1178 | x, 1179 | rotary_emb = rotary_emb, 1180 | cache = next(iter_cache, None), 1181 | return_kv_cache = True, 1182 | return_values = True, 1183 | value_residual = value_residual, 1184 | **attn_mask_kwargs, 1185 | **adaptive_kwargs 1186 | ) 1187 | 1188 | value_residual = default(value_residual, attn_values) 1189 | 1190 | new_cache.append(kv_cache) 1191 | 1192 | x = add_attn_residual(attn_out) 1193 | 1194 | x, add_ff_residual = ff_residual(x) 1195 | 1196 | ff_out = ff(x, **adaptive_kwargs) 1197 | 1198 | x = add_ff_residual(ff_out) 1199 | 1200 | # reduce multiple residual streams for maybe hyper connection 1201 | 1202 | x = self.reduce_stream(x) 1203 | 1204 | assert len(skips) == 0 1205 | 1206 | out = self.norm(x) 1207 | 1208 | if not return_kv_cache: 1209 | return out 1210 | 1211 | return out, stack(new_cache) 1212 | 1213 | # classes 1214 | 1215 | class Transfusion(Module): 1216 | @beartype 1217 | def __init__( 1218 | self, 1219 | *, 1220 | num_text_tokens, 1221 | transformer: dict | Transformer, 1222 | dim_latent: int | tuple[int, ...] | None = None, 1223 | channel_first_latent: bool | tuple[bool, ...] = False, 1224 | add_pos_emb: bool | tuple[bool, ...] = False, 1225 | modality_encoder: Module | tuple[Module, ...] | None = None, 1226 | modality_decoder: Module | tuple[Module, ...] | None = None, 1227 | pre_post_transformer_enc_dec: tuple[Module, Module] | tuple[tuple[Module, Module], ...] | None = None, 1228 | modality_default_shape: tuple[int, ...] | tuple[tuple[int, ...], ...] | None = None, 1229 | fallback_to_default_shape_if_invalid = False, 1230 | modality_num_dim: int | tuple[int, ...] | None = None, 1231 | to_modality_shape_fn: Callable | tuple[Callable, ...] = default_to_modality_shape_fn, 1232 | ignore_index = -1, 1233 | flow_loss_weight = 1., 1234 | text_loss_weight = 1., 1235 | velocity_consistency_loss_weight = 0.1, 1236 | reconstruction_loss_weight = 0., 1237 | modality_encoder_decoder_requires_batch_dim = True, # whether the modality encoder / decoder requires batch dimension, will auto assume it is needed 1238 | odeint_kwargs: dict = dict( 1239 | atol = 1e-5, 1240 | rtol = 1e-5, 1241 | method = 'midpoint' 1242 | ), 1243 | ): 1244 | super().__init__() 1245 | 1246 | # transformer 1247 | 1248 | if isinstance(transformer, dict): 1249 | transformer = Transformer(**transformer) 1250 | 1251 | self.transformer = transformer 1252 | dim = transformer.dim 1253 | 1254 | self.dim = dim 1255 | 1256 | # latent and model dimension not the same 1257 | # make it work for 1 modality for now 1258 | 1259 | dim_latent = default(dim_latent, dim) 1260 | 1261 | self.dim_latents = cast_tuple(dim_latent) 1262 | 1263 | # number of modalities 1264 | 1265 | self.num_modalities = len(self.dim_latents) 1266 | 1267 | # whether the latents are accepted to be channel first or channel last 1268 | # if channel first, will be rearrange(c ... -> ... c -> (...) c) 1269 | 1270 | self.channel_first_latent = cast_tuple(channel_first_latent, self.num_modalities) 1271 | assert len(self.channel_first_latent) == self.num_modalities 1272 | 1273 | # functions for converting the sampled language model meta string back to modality shape of tuple[int, ...] 1274 | 1275 | self.to_modality_shape_fn = cast_tuple(to_modality_shape_fn, self.num_modalities) 1276 | 1277 | # default token lengths for respective modality 1278 | # fallback if the language model does not come up with valid dimensions 1279 | 1280 | if not exists(modality_default_shape) or is_bearable(modality_default_shape, tuple[int, ...]): 1281 | modality_default_shape = (modality_default_shape,) * self.num_modalities 1282 | 1283 | self.modality_default_shape = modality_default_shape 1284 | 1285 | assert len(self.modality_default_shape) == self.num_modalities 1286 | 1287 | self.fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid 1288 | 1289 | # default `modality_num_dim` to `len(modality_default_shape)` if latter is specified but former not 1290 | 1291 | modality_num_dim = default(modality_num_dim, tuple(len(shape) for shape in self.modality_default_shape)) 1292 | 1293 | # specifying the number of dimensions for the modality, which will be hard validated 1294 | 1295 | self.modality_num_dim = cast_tuple(modality_num_dim, self.num_modalities) 1296 | 1297 | assert len(self.modality_num_dim) == self.num_modalities 1298 | 1299 | assert all([not exists(ndim) or not exists(shape) or len(shape) == ndim for ndim, shape in zip(self.modality_num_dim, self.modality_default_shape)]) 1300 | 1301 | # whether to add an extra axial positional embedding per modality 1302 | 1303 | self.add_pos_emb = cast_tuple(add_pos_emb, self.num_modalities) 1304 | assert len(self.add_pos_emb) == self.num_modalities 1305 | 1306 | self.pos_emb_mlp = ModuleList([]) 1307 | 1308 | for modality_add_pos_emb, modality_ndim in zip(self.add_pos_emb, self.modality_num_dim): 1309 | 1310 | if not modality_add_pos_emb: 1311 | self.pos_emb_mlp.append(None) 1312 | continue 1313 | 1314 | assert exists(modality_ndim), '`modality_num_dim` must be set if you wish to automatically inject axial positional embeddings' 1315 | 1316 | pos_generating_mlp = ContinuousAxialPositionalEmbedding( 1317 | dim = dim, 1318 | num_axial_dims = modality_ndim, 1319 | ) 1320 | 1321 | self.pos_emb_mlp.append(pos_generating_mlp) 1322 | 1323 | # modality encoders and decoders 1324 | 1325 | modality_encoder = cast_tuple(modality_encoder, 1 if exists(modality_encoder) else self.num_modalities) 1326 | modality_decoder = cast_tuple(modality_decoder, 1 if exists(modality_decoder) else self.num_modalities) 1327 | 1328 | self.modality_encoder = ModuleList(modality_encoder) 1329 | self.modality_decoder = ModuleList(modality_decoder) 1330 | 1331 | assert len(self.modality_encoder) == self.num_modalities 1332 | assert len(self.modality_decoder) == self.num_modalities 1333 | 1334 | # auto handle batch dimension for modality encoder / decoder 1335 | 1336 | self.maybe_add_temp_batch_dim = add_temp_batch_dim if modality_encoder_decoder_requires_batch_dim else identity 1337 | 1338 | # store number of text tokens 1339 | 1340 | self.num_text_tokens = num_text_tokens 1341 | 1342 | # entire "sentence" start and end id 1343 | 1344 | num_text_special_ids = 2 1345 | 1346 | self.sos_id, self.eos_id = num_text_tokens, (num_text_tokens + 1) 1347 | 1348 | # modality meta, start and end tokens - termed [mom] [som] [eom] in this repo 1349 | 1350 | num_modality_special_ids = self.num_modalities * 2 1351 | som_eom_tensor = torch.arange(num_modality_special_ids) + num_text_tokens + num_text_special_ids # shift to the very end 1352 | som_eom_tensor = rearrange(som_eom_tensor, '(start_end m) -> start_end m', start_end = 2) 1353 | 1354 | # modality meta, start and end ids 1355 | 1356 | self.som_ids, self.eom_ids = som_eom_tensor.tolist() 1357 | 1358 | # char tokenizing for modality meta information 1359 | 1360 | meta_token_offset = num_text_tokens + num_text_special_ids + num_modality_special_ids 1361 | 1362 | self.meta_id = meta_token_offset 1363 | 1364 | self.char_tokenizer = partial(char_tokenize, offset = meta_token_offset + 1) 1365 | self.decode_chars = partial(decode_chars, offset = meta_token_offset + 1) 1366 | 1367 | num_meta_tokens = 128 + 1 1368 | 1369 | # prepare pre-post transformer encoder / decoder, for the learnable unets as in paper 1370 | 1371 | if is_bearable(pre_post_transformer_enc_dec, tuple[Module, Module]): 1372 | pre_post_transformer_enc_dec = (pre_post_transformer_enc_dec,) 1373 | 1374 | pre_post_transformer_enc_dec = cast_tuple(pre_post_transformer_enc_dec, self.num_modalities) 1375 | assert len(pre_post_transformer_enc_dec) == self.num_modalities 1376 | 1377 | # latent to model and back 1378 | # by default will be Linear, with or without rearranges depending on channel_first_latent setting 1379 | # can also be overridden for the unet down/up as in the paper with `pre_post_transformer_enc_dec: tuple[Module, Module]` 1380 | 1381 | latent_to_model_projs = [] 1382 | model_to_latent_projs = [] 1383 | 1384 | for ( 1385 | dim_latent, 1386 | one_channel_first_latent, 1387 | enc_dec, 1388 | ) in zip( 1389 | self.dim_latents, 1390 | self.channel_first_latent, 1391 | pre_post_transformer_enc_dec 1392 | ): 1393 | 1394 | pre_attend_enc, post_attend_dec = default(enc_dec, (None, None)) 1395 | 1396 | latent_to_model_proj = Linear(dim_latent, dim) if dim_latent != dim else nn.Identity() 1397 | model_to_latent_proj = Linear(dim, dim_latent, bias = False) 1398 | 1399 | if one_channel_first_latent: 1400 | latent_to_model_proj = nn.Sequential(Rearrange('b d ... -> b ... d'), latent_to_model_proj) 1401 | model_to_latent_proj = nn.Sequential(model_to_latent_proj, Rearrange('b ... d -> b d ...')) 1402 | 1403 | if exists(pre_attend_enc): 1404 | pre_attend_enc = nn.Sequential(pre_attend_enc, Rearrange('b d ... -> b ... d')) 1405 | 1406 | if exists(post_attend_dec): 1407 | post_attend_dec = nn.Sequential(Rearrange('b ... d -> b d ...'), post_attend_dec) 1408 | 1409 | latent_to_model_projs.append(default(pre_attend_enc, latent_to_model_proj)) 1410 | model_to_latent_projs.append(default(post_attend_dec, model_to_latent_proj)) 1411 | 1412 | self.latent_to_model_projs = ModuleList(latent_to_model_projs) 1413 | self.model_to_latent_projs = ModuleList(model_to_latent_projs) 1414 | 1415 | # relative positions 1416 | 1417 | self.rotary_emb = RotaryEmbedding(transformer.dim_head) 1418 | 1419 | # embeddings and un-embeddings 1420 | 1421 | effective_num_text_tokens = num_text_tokens + num_text_special_ids + num_modality_special_ids + num_meta_tokens 1422 | 1423 | self.text_embed = nn.Embedding(effective_num_text_tokens, dim) 1424 | 1425 | self.to_text_logits = Linear(dim, effective_num_text_tokens, bias = False) 1426 | 1427 | text_only_mask = torch.arange(effective_num_text_tokens) < num_text_tokens 1428 | self.register_buffer('text_only_logits_mask', text_only_mask, persistent = False) 1429 | 1430 | # loss related 1431 | 1432 | self.ignore_index = ignore_index 1433 | self.flow_loss_weight = flow_loss_weight 1434 | self.text_loss_weight = text_loss_weight 1435 | 1436 | # velocity consistency weight - only added if EMA model is passed in during training 1437 | 1438 | self.velocity_consistency_loss_weight = velocity_consistency_loss_weight 1439 | 1440 | # additional reconstruction loss, through the decoder 1441 | 1442 | self.has_recon_loss = reconstruction_loss_weight > 0. 1443 | self.reconstruction_loss_weight = reconstruction_loss_weight 1444 | 1445 | # flow sampling related 1446 | 1447 | self.odeint_fn = partial(odeint, **odeint_kwargs) 1448 | 1449 | # dummy loss 1450 | 1451 | self.register_buffer('zero', tensor(0.), persistent = False) 1452 | 1453 | @property 1454 | def device(self): 1455 | return next(self.parameters()).device 1456 | 1457 | @cache 1458 | def get_modality_info( 1459 | self, 1460 | modality_type: int | None = None 1461 | ) -> ModalityInfo: 1462 | 1463 | modality_type = default(modality_type, 0) 1464 | 1465 | modality_encoder = self.modality_encoder[modality_type] 1466 | modality_decoder = self.modality_decoder[modality_type] 1467 | latent_to_model = self.latent_to_model_projs[modality_type] 1468 | model_to_latent = self.model_to_latent_projs[modality_type] 1469 | 1470 | add_pos_emb = self.add_pos_emb[modality_type] 1471 | pos_emb_mlp = self.pos_emb_mlp[modality_type] 1472 | modality_num_dim = self.modality_num_dim[modality_type] 1473 | 1474 | dim_latent = self.dim_latents[modality_type] 1475 | 1476 | default_shape = self.modality_default_shape[modality_type] 1477 | 1478 | som_id = self.som_ids[modality_type] 1479 | eom_id = self.eom_ids[modality_type] 1480 | 1481 | to_shape_fn = self.to_modality_shape_fn[modality_type] 1482 | 1483 | channel_first_latent = self.channel_first_latent[modality_type] 1484 | 1485 | return ModalityInfo( 1486 | encoder = modality_encoder, 1487 | decoder = modality_decoder, 1488 | latent_to_model = latent_to_model, 1489 | model_to_latent = model_to_latent, 1490 | add_pos_emb = add_pos_emb, 1491 | pos_emb_mlp = pos_emb_mlp, 1492 | num_dim = modality_num_dim, 1493 | dim_latent = dim_latent, 1494 | default_shape = default_shape, 1495 | som_id = som_id, 1496 | eom_id = eom_id, 1497 | to_shape_fn = to_shape_fn, 1498 | channel_first_latent = channel_first_latent, 1499 | modality_type = modality_type 1500 | ) 1501 | 1502 | def get_all_modality_info(self) -> list[ModalityInfo]: 1503 | return [self.get_modality_info(i) for i in range(self.num_modalities)] 1504 | 1505 | def get_modality_shape( 1506 | self, 1507 | modality: Float['...'], 1508 | modality_type: int | None = None 1509 | ) -> tuple[int, ...]: 1510 | 1511 | mod = self.get_modality_info(modality_type) 1512 | 1513 | if mod.channel_first_latent: 1514 | modality = rearrange(modality, 'c ... -> ... c') 1515 | 1516 | return tuple(modality.shape[:-1]) 1517 | 1518 | def parameters_without_encoder_decoder(self): 1519 | return ( 1520 | set(self.parameters()) - 1521 | set(self.modality_encoder.parameters()) - 1522 | set(self.modality_decoder.parameters()) 1523 | ) 1524 | 1525 | def create_dataloader( 1526 | self, 1527 | *args, 1528 | **kwargs 1529 | ): 1530 | return create_dataloader(*args, **kwargs) 1531 | 1532 | def create_ema( 1533 | self, 1534 | beta = 0.99, 1535 | *ema_kwargs 1536 | ) -> EMA: 1537 | 1538 | ema = EMA( 1539 | self, 1540 | beta = beta, 1541 | forward_method_names = ( 1542 | 'sample', 1543 | 'generate_text_only', 1544 | 'generate_modality_only' 1545 | ) 1546 | ) 1547 | 1548 | return ema 1549 | 1550 | @torch.no_grad() 1551 | @eval_decorator 1552 | @typecheck 1553 | def sample( 1554 | self, 1555 | prompt: ModalitySample | Tensor | tuple[int, Float['...']] | None = None, 1556 | max_length = 2048, 1557 | text_temperature = 1.5, 1558 | text_min_p = 0.1, 1559 | cache_kv = False, 1560 | fixed_modality_shape: tuple[int, ...] | None = None, 1561 | init_modality_noise: Float['n d'] | None = None, 1562 | modality_steps = 16, 1563 | return_unprocessed_modalities = False 1564 | ) -> ModalitySample: 1565 | 1566 | device = self.device 1567 | 1568 | # handle edge case where there are no text tokens 1569 | 1570 | if self.num_text_tokens == 0: 1571 | logger.warning(f'you have `num_text_tokens` set to 0, so `sample` will be forwarded to `generate_modality_only(batch_size: int, modality_type: int)` method') 1572 | 1573 | return self.generate_modality_only(batch_size = 1) 1574 | 1575 | # take care of prompt being a raw tensor, either text or raw modality (image, video, actions, latents, etc) 1576 | 1577 | if is_tensor(prompt) and prompt.dtype == torch.float: # is modality with type 0 implicit 1578 | prompt = (0, prompt) 1579 | 1580 | prompt_is_modality = isinstance(prompt, tuple) 1581 | 1582 | if is_tensor(prompt) and prompt.dtype in (torch.int, torch.long): # is text only prompt 1583 | prompt = [prompt] 1584 | 1585 | elif prompt_is_modality: 1586 | modality_type, modality = prompt 1587 | 1588 | mod = self.get_modality_info(modality_type) 1589 | 1590 | if exists(mod.encoder): 1591 | with torch.no_grad(): 1592 | mod.encoder.eval() 1593 | modality = self.maybe_add_temp_batch_dim(mod.encoder)(modality).detach() 1594 | 1595 | modality_shape_tuple = self.get_modality_shape(modality, modality_type) 1596 | modality_shape_str = join([*map(str, modality_shape_tuple)], ',') 1597 | modality_meta_info = self.char_tokenizer(modality_shape_str, device = device) 1598 | 1599 | prompt = [ 1600 | tensor([self.meta_id]), 1601 | modality_meta_info, 1602 | tensor([mod.som_id]), 1603 | (modality_type, modality), 1604 | tensor([mod.eom_id]), 1605 | ] 1606 | 1607 | # sos 1608 | 1609 | init_text_seq = tensor([self.sos_id], device = device) 1610 | 1611 | # just take care of prompt being zero dimensions 1612 | 1613 | modality_sample = [init_text_seq, *default(prompt, [])] 1614 | 1615 | # take care of moving to device 1616 | 1617 | modality_sample = tree_map_tensor(modality_sample, lambda t: t.to(device)) 1618 | modality_sample = tree_map_tensor(modality_sample, lambda t: rearrange(t, '-> 1') if t.ndim == 0 else t) 1619 | 1620 | modality_sample = concat_contiguous_text(modality_sample) 1621 | 1622 | *_, last_modality_sample = modality_sample 1623 | 1624 | curr_length = 0 1625 | curr_modality_id = None 1626 | modality_shape = None 1627 | 1628 | num_past_modalities = int(prompt_is_modality) # either 0 or 1 (if the prompt given is a modality) 1629 | 1630 | text_is_greedy = text_temperature == 0. 1631 | is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected 1632 | 1633 | def maybe_transition_to_modality_decoding(seq): 1634 | nonlocal modality_shape 1635 | nonlocal is_decoding_text 1636 | nonlocal curr_modality_id 1637 | 1638 | sampled_token_id = seq[-1] 1639 | 1640 | if sampled_token_id not in self.som_ids: 1641 | return 1642 | 1643 | curr_modality_id = self.som_ids.index(sampled_token_id) 1644 | 1645 | if exists(fixed_modality_shape): 1646 | modality_shape = fixed_modality_shape 1647 | 1648 | # get the tokens after the modality meta id 1649 | 1650 | maybe_meta_tensor = get_tokens_since_rightmost_id(seq, self.meta_id) 1651 | 1652 | mod = self.get_modality_info(curr_modality_id) 1653 | 1654 | default_shape = mod.default_shape 1655 | maybe_modality_num_dim = mod.num_dim 1656 | meta_str_to_modality_shape = mod.to_shape_fn 1657 | 1658 | if maybe_meta_tensor.numel() > 0: 1659 | meta_tensor = maybe_meta_tensor[:-1] 1660 | meta_str = self.decode_chars(meta_tensor) 1661 | 1662 | if not meta_str.isdigit() or int(meta_str) <= 0: 1663 | 1664 | assert exists(default_shape), 'invalid modality meta information detected, please set `modality_default_shape` in order to properly fallback' 1665 | modality_shape = default_shape 1666 | else: 1667 | modality_shape = meta_str_to_modality_shape(meta_str) 1668 | 1669 | modality_shape = default(modality_shape, default_shape) 1670 | 1671 | if self.fallback_to_default_shape_if_invalid: 1672 | 1673 | if exists(maybe_modality_num_dim) and len(modality_shape) != maybe_modality_num_dim: 1674 | logger.warning(f'invalid modality shape {modality_shape} for modality {curr_modality_id}. falling back to default shape {default_shape}') 1675 | modality_shape = default_shape 1676 | 1677 | assert exists(modality_shape), f'language model did not produce a proper modality shape for modality type {curr_modality_id} - please set a fallback shape with `modality_default_shape`' 1678 | assert not exists(maybe_modality_num_dim) or maybe_modality_num_dim == len(modality_shape), f'expected modality type {curr_modality_id} to have {maybe_modality_num_dim} dimensions but language model produced a shape of {modality_shape}' 1679 | 1680 | is_decoding_text = False 1681 | 1682 | # determine if to transition from start 1683 | 1684 | maybe_transition_to_modality_decoding(last_modality_sample) 1685 | 1686 | cache = None 1687 | 1688 | with tqdm(total = max_length) as pbar: 1689 | 1690 | while curr_length <= max_length: 1691 | 1692 | if is_decoding_text: 1693 | pbar.set_description('decoding text') 1694 | 1695 | *_, seq = modality_sample 1696 | 1697 | logits, new_kv_cache = self.forward( 1698 | [modality_sample], 1699 | return_loss = False, 1700 | cache = cache, 1701 | decode_length = 1, 1702 | decoding_text_or_modality = 'text', 1703 | return_kv_cache = True 1704 | ) 1705 | 1706 | logits = logits[0][-1] 1707 | 1708 | if text_is_greedy: 1709 | sampled = logits.argmax(dim = -1, keepdim = True) 1710 | else: 1711 | logits = min_p_filter(logits, min_p = text_min_p) 1712 | 1713 | probs = (logits / text_temperature).softmax(dim = -1) 1714 | sampled = torch.multinomial(probs, 1) 1715 | 1716 | seq = torch.cat((seq, sampled), dim = -1) 1717 | modality_sample[-1] = seq 1718 | 1719 | pbar.update(1) 1720 | curr_length += 1 1721 | 1722 | if cache_kv: 1723 | cache = new_kv_cache 1724 | 1725 | sampled_token_id = sampled.item() 1726 | 1727 | if sampled_token_id == self.eos_id: 1728 | logger.info(f'detecting an end of string token [{self.eos_id}], terminating sampling early') 1729 | break 1730 | 1731 | maybe_transition_to_modality_decoding(seq) 1732 | 1733 | else: 1734 | assert exists(curr_modality_id) 1735 | pbar.set_description(f'decoding modality [{curr_modality_id}]') 1736 | 1737 | mod = self.get_modality_info(curr_modality_id) 1738 | 1739 | modality_length = math.prod(modality_shape) 1740 | 1741 | if exists(init_modality_noise): 1742 | noise = init_modality_noise[:modality_length, :mod.dim_latent] 1743 | else: 1744 | assert exists(modality_length) 1745 | noise = torch.randn((modality_length, mod.dim_latent), device = device) 1746 | 1747 | assert noise.shape == (modality_length, mod.dim_latent) 1748 | 1749 | noise = noise.reshape(*modality_shape, mod.dim_latent) 1750 | 1751 | if mod.channel_first_latent: 1752 | noise = rearrange(noise, '... d -> d ...') 1753 | 1754 | new_kv_cache = None 1755 | 1756 | def ode_step_fn(step_times, denoised): 1757 | nonlocal new_kv_cache 1758 | 1759 | step_times = rearrange(step_times, ' -> 1 1') # batch size of 1 1760 | step_times = F.pad(step_times, (num_past_modalities, 0), value = 1.) # past decoded modalities receive a time conditioning of 1. 1761 | 1762 | (embeds, get_pred_flows), new_kv_cache = self.forward( 1763 | [[*modality_sample, (curr_modality_id, denoised)]], 1764 | times = step_times, 1765 | return_embed = True, 1766 | cache = cache, 1767 | decode_length = modality_length, 1768 | return_kv_cache = True, 1769 | decoding_text_or_modality = 'modality' 1770 | ) 1771 | 1772 | parse_embed = get_pred_flows[curr_modality_id][-1] 1773 | 1774 | parsed_embed = parse_embed(embeds, need_splice = not exists(cache)) 1775 | 1776 | flow = add_temp_batch_dim(mod.model_to_latent)(parsed_embed) 1777 | 1778 | return flow 1779 | 1780 | times = torch.linspace(0, 1, modality_steps, device = device) 1781 | 1782 | trajectory = self.odeint_fn(ode_step_fn, noise, times) 1783 | 1784 | # add the sampled modality tokens 1785 | 1786 | sampled_modality = trajectory[-1] 1787 | 1788 | modality_sample.append((curr_modality_id, sampled_modality)) 1789 | 1790 | # add the appropriate [eom] 1791 | 1792 | eom_id = mod.eom_id 1793 | modality_sample.append(tensor([eom_id], device = device)) 1794 | 1795 | # set kv cache if needed 1796 | 1797 | if cache_kv: 1798 | cache = new_kv_cache 1799 | 1800 | # back to decoding text 1801 | 1802 | pbar.update(modality_length) 1803 | curr_length += modality_length 1804 | 1805 | num_past_modalities += 1 1806 | curr_modality_id = None 1807 | modality_length = None 1808 | 1809 | is_decoding_text = True 1810 | 1811 | logger.info(f'sampling stopped at length: {curr_length} / {max_length}') 1812 | 1813 | if return_unprocessed_modalities: 1814 | return modality_sample 1815 | 1816 | # post process modality sample, decoding modality types if they have a decoder 1817 | 1818 | for mod in self.get_all_modality_info(): 1819 | decoder_fn = default(mod.decoder, nn.Identity()) 1820 | 1821 | with torch.no_grad(): 1822 | decoder_fn.eval() 1823 | modality_sample = apply_fn_modality_type(decoder_fn, modality_sample, modality_type = mod.modality_type) 1824 | 1825 | return modality_sample 1826 | 1827 | @typecheck 1828 | def forward_text( 1829 | self, 1830 | text: Int['b n'], 1831 | return_loss = True, 1832 | return_embed = False, 1833 | cache: Tensor | None = None, 1834 | return_kv_cache = False 1835 | ) -> ( 1836 | Scalar | 1837 | Float['b n d'] | 1838 | tuple[Float['b n d'], list[Float['...']]] 1839 | ): 1840 | 1841 | device = self.device 1842 | text = text.to(device) 1843 | 1844 | if return_loss: 1845 | text, labels = text[:, :-1], text[:, 1:] 1846 | 1847 | # embed text 1848 | 1849 | text = text.masked_fill(text == -1, 0) 1850 | tokens = self.text_embed(text) 1851 | 1852 | # rotary 1853 | 1854 | seq_len = tokens.shape[-2] 1855 | pos = torch.arange(seq_len, device = device) 1856 | 1857 | rotary_emb = self.rotary_emb(pos) 1858 | 1859 | # attention 1860 | 1861 | embed, kv_cache = self.transformer( 1862 | tokens, 1863 | rotary_emb = rotary_emb, 1864 | causal_mask = True, 1865 | cache = cache, 1866 | return_kv_cache = True 1867 | ) 1868 | 1869 | # text unembedding 1870 | 1871 | logits = self.to_text_logits(embed) 1872 | 1873 | if not return_loss: 1874 | if not return_kv_cache: 1875 | return logits 1876 | 1877 | return logits, kv_cache 1878 | 1879 | logits = logits.masked_fill(~self.text_only_logits_mask, max_neg_value(logits)) 1880 | 1881 | loss = F.cross_entropy( 1882 | rearrange(logits, 'b n l -> b l n'), 1883 | labels, 1884 | ignore_index = self.ignore_index 1885 | ) 1886 | 1887 | return loss 1888 | 1889 | @torch.no_grad() 1890 | @eval_decorator 1891 | @typecheck 1892 | def generate_text_only( 1893 | self, 1894 | prompt: Int['b n'], 1895 | seq_len: int, 1896 | temperature = 1.5, 1897 | min_p = 0.1, 1898 | ) -> Int['b no']: 1899 | 1900 | prompt_seq_len, out = prompt.shape[-1], prompt.clone() 1901 | sample_num_times = max(0, seq_len - prompt_seq_len) 1902 | 1903 | for _ in tqdm(range(sample_num_times)): 1904 | logits = self.forward_text(out, return_loss = False) 1905 | logits = logits[:, -1] 1906 | 1907 | logits = min_p_filter(logits, min_p = min_p) 1908 | 1909 | logits.masked_fill_(~self.text_only_logits_mask, max_neg_value(logits)) 1910 | 1911 | sample = gumbel_sample(logits, temperature = temperature, dim = -1) 1912 | 1913 | out = cat((out, sample), dim = -1) 1914 | 1915 | return out[..., prompt_seq_len:] 1916 | 1917 | @typecheck 1918 | def forward_modality( 1919 | self, 1920 | modalities: Float['b ...'], 1921 | times: Float['b'] | None = None, 1922 | modality_type: int | None = None, 1923 | encode_modality: bool = True, 1924 | velocity_consistency_ema_model: Transfusion | None = None, 1925 | velocity_consistency_delta_time = 1e-5, 1926 | return_loss = True, 1927 | return_loss_breakdown = False 1928 | ) -> Scalar | Float['b ...']: 1929 | requires_velocity_consistency = exists(velocity_consistency_ema_model) 1930 | 1931 | modalities = modalities.to(self.device) 1932 | orig_modalities = modalities 1933 | 1934 | if self.num_modalities > 1: 1935 | assert exists(modality_type), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality' 1936 | 1937 | modality_type = default(modality_type, 0) 1938 | 1939 | mod = self.get_modality_info(modality_type) 1940 | 1941 | # maybe modality encode 1942 | 1943 | if encode_modality and exists(mod.encoder): 1944 | with torch.no_grad(): 1945 | mod.encoder.eval() 1946 | modalities = mod.encoder(modalities).detach() 1947 | 1948 | # shapes and device 1949 | 1950 | tokens = modalities 1951 | 1952 | batch, device = tokens.shape[0], tokens.device 1953 | 1954 | # times 1955 | 1956 | if not exists(times): 1957 | times = torch.rand((batch,), device = device) 1958 | 1959 | if return_loss: 1960 | 1961 | if requires_velocity_consistency: 1962 | orig_times = times.clone() 1963 | times = times * (1. - velocity_consistency_delta_time) # make sure times are max of 1. - small delta, for velocity consistency 1964 | 1965 | padded_times = append_dims(times, tokens.ndim - 1) 1966 | 1967 | noise = torch.randn_like(tokens) 1968 | 1969 | noised_tokens = padded_times * tokens + (1. - padded_times) * noise 1970 | 1971 | flow = tokens - noise 1972 | 1973 | else: 1974 | noised_tokens = tokens 1975 | 1976 | # from latent to model tokens 1977 | 1978 | noised_tokens = mod.latent_to_model(noised_tokens) 1979 | 1980 | # axial positions 1981 | 1982 | if mod.add_pos_emb: 1983 | assert exists(mod.num_dim), f'modality_num_dim must be set for modality {modality_type} if further injecting axial positional embedding' 1984 | 1985 | _, *axial_dims, _ = noised_tokens.shape 1986 | 1987 | assert len(axial_dims) == mod.num_dim, f'received modalities of ndim {len(axial_dims)} but expected {modality_num_dim}' 1988 | 1989 | # maybe transform 1990 | 1991 | noised_tokens, inverse_pack_axial_dims = pack_one_with_inverse(noised_tokens, 'b * d') 1992 | 1993 | # maybe add axial pos emb 1994 | 1995 | if mod.add_pos_emb: 1996 | axial_pos_emb = mod.pos_emb_mlp(tensor(axial_dims), flatten = True) 1997 | noised_tokens = noised_tokens + axial_pos_emb 1998 | 1999 | # attention 2000 | 2001 | embed = self.transformer( 2002 | noised_tokens, 2003 | times = times, 2004 | modality_only = True, 2005 | ) 2006 | 2007 | embed = inverse_pack_axial_dims(embed) 2008 | 2009 | pred_flow = mod.model_to_latent(embed) 2010 | 2011 | if not return_loss: 2012 | return pred_flow 2013 | 2014 | # flow loss 2015 | 2016 | flow_loss = F.mse_loss(pred_flow, flow) 2017 | 2018 | # maybe velocity consistency loss 2019 | 2020 | velocity_loss = self.zero 2021 | 2022 | if requires_velocity_consistency: 2023 | 2024 | with torch.no_grad(): 2025 | flow_with_delta_time = velocity_consistency_ema_model.forward_modality( 2026 | modalities = modalities, 2027 | modality_type = modality_type, 2028 | times = orig_times + velocity_consistency_delta_time, 2029 | encode_modality = False, # modality already encoded 2030 | return_loss = False 2031 | ) 2032 | 2033 | velocity_loss = F.mse_loss(flow, flow_with_delta_time) 2034 | 2035 | # maybe recon loss 2036 | 2037 | recon_loss = self.zero 2038 | 2039 | if self.has_recon_loss: 2040 | assert encode_modality 2041 | 2042 | recon = noise + pred_flow * (1. - padded_times) 2043 | 2044 | if exists(mod.decoder): 2045 | with torch.no_grad(): 2046 | mod.decoder.eval() 2047 | recon = mod.decoder(recon) 2048 | 2049 | recon_loss = F.mse_loss( 2050 | recon, 2051 | orig_modalities 2052 | ) 2053 | 2054 | # total loss 2055 | 2056 | total_loss = ( 2057 | flow_loss + 2058 | velocity_loss * self.velocity_consistency_loss_weight + 2059 | recon_loss * self.reconstruction_loss_weight 2060 | ) 2061 | 2062 | if not return_loss_breakdown: 2063 | return total_loss 2064 | 2065 | return total_loss, (flow_loss, velocity_loss, recon_loss) 2066 | 2067 | @torch.no_grad() 2068 | @eval_decorator 2069 | @typecheck 2070 | def generate_modality_only( 2071 | self, 2072 | batch_size: int = 1, 2073 | modality_type: int | None = None, 2074 | fixed_modality_shape: tuple[int, ...] | None = None, 2075 | modality_steps = 16, 2076 | return_unprocessed_modalities = False 2077 | ) -> Tensor: 2078 | 2079 | device = self.device 2080 | 2081 | if self.num_modalities > 1: 2082 | assert exists(modality_type), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality' 2083 | 2084 | mod = self.get_modality_info(modality_type) 2085 | 2086 | modality_shape = default(fixed_modality_shape, mod.default_shape) 2087 | 2088 | assert exists(modality_shape) 2089 | 2090 | noise = torch.randn((batch_size, *modality_shape, mod.dim_latent), device = device) 2091 | 2092 | if mod.channel_first_latent: 2093 | noise = rearrange(noise, 'b ... d -> b d ...') 2094 | 2095 | def ode_step_fn(step_times, denoised): 2096 | 2097 | step_times = repeat(step_times, ' -> b', b = batch_size) 2098 | 2099 | flow = self.forward_modality( 2100 | denoised, 2101 | times = step_times, 2102 | modality_type = modality_type, 2103 | encode_modality = False, 2104 | return_loss = False 2105 | ) 2106 | 2107 | return flow 2108 | 2109 | times = torch.linspace(0., 1., modality_steps, device = device) 2110 | trajectory = self.odeint_fn(ode_step_fn, noise, times) 2111 | 2112 | # add the sampled modality tokens 2113 | 2114 | sampled_modality = trajectory[-1] 2115 | 2116 | # decode 2117 | 2118 | if exists(mod.decoder): 2119 | mod.decoder.eval() 2120 | sampled_modality = mod.decoder(sampled_modality) 2121 | 2122 | return sampled_modality 2123 | 2124 | @typecheck 2125 | def forward( 2126 | self, 2127 | modalities: ( 2128 | list[ModalitySample] | 2129 | Int['b n'] | 2130 | Float['b ...'] 2131 | ), 2132 | times: Float['b m'] | None = None, 2133 | num_modalities_to_times_fn: Callable[[Int['b']], Float['b m']] | None = None, # allows a researcher to customize the times (noise level) based on the modality lengths in a given sample 2134 | modality_type: int | None = None, 2135 | cache: Tensor | None = None, 2136 | decode_length: int | None = None, 2137 | decoding_text_or_modality: Literal['text', 'modality'] | None = None, 2138 | velocity_consistency_ema_model: Transfusion | EMA | None = None, 2139 | velocity_consistency_delta_time = 1e-3, 2140 | return_only_pred_flows = False, 2141 | return_loss = True, 2142 | return_breakdown = False, 2143 | return_embed = False, 2144 | return_kv_cache = False, 2145 | ) -> ( 2146 | Float['b _ l'] | 2147 | tuple[Float['b _ d'], GetPredFlows] | 2148 | tuple[tuple[Float['b _ _'], GetPredFlows], Tensor] | 2149 | Scalar | 2150 | tuple[Scalar, LossBreakdown] | 2151 | list[Float['b _ _']] 2152 | ): 2153 | is_decoding = exists(decoding_text_or_modality) 2154 | 2155 | is_text_only = is_tensor(modalities) and modalities.dtype in (torch.int, torch.long) 2156 | is_modality_only = is_tensor(modalities) and modalities.dtype == torch.float 2157 | 2158 | # handle ema model being passed in for velocity consistency loss 2159 | 2160 | if isinstance(velocity_consistency_ema_model, EMA): 2161 | assert isinstance(velocity_consistency_ema_model.ema_model, Transfusion) 2162 | velocity_consistency_ema_model = velocity_consistency_ema_model.ema_model 2163 | 2164 | need_velocity_matching = not is_decoding and exists(velocity_consistency_ema_model) 2165 | 2166 | # return loss 2167 | 2168 | return_loss &= not (return_embed or is_decoding) 2169 | 2170 | if is_text_only: 2171 | 2172 | forward_text_kwargs = dict( 2173 | return_loss = return_loss, 2174 | return_embed = return_embed, 2175 | cache = cache, 2176 | return_kv_cache = return_kv_cache 2177 | ) 2178 | 2179 | return self.forward_text(modalities, **forward_text_kwargs) 2180 | 2181 | if is_modality_only: 2182 | assert return_loss 2183 | 2184 | forward_modality_kwargs = dict( 2185 | modality_type = modality_type, 2186 | velocity_consistency_ema_model = velocity_consistency_ema_model 2187 | ) 2188 | 2189 | return self.forward_modality(modalities, **forward_modality_kwargs) 2190 | 2191 | batch = len(modalities) 2192 | device = self.device 2193 | tensor_ = partial(tensor, device = device) 2194 | 2195 | # save a copy for ema model for velocity matching 2196 | 2197 | if need_velocity_matching: 2198 | velocity_modalities = modalities 2199 | 2200 | if isinstance(velocity_modalities, list): 2201 | velocity_modalities = [modality.copy() for modality in velocity_modalities] 2202 | 2203 | # add "sentence" start and end tokens when training 2204 | 2205 | if return_loss or need_velocity_matching: 2206 | modalities = modalities.copy() 2207 | 2208 | for i, modality in enumerate(modalities): 2209 | modalities[i] = [ 2210 | tensor_([self.sos_id]), 2211 | *modality, 2212 | tensor_([self.eos_id]) 2213 | ] 2214 | 2215 | # need axial pos emb 2216 | 2217 | need_axial_pos_emb = any(self.add_pos_emb) 2218 | 2219 | # standardize modalities to be tuple - type 0 modality is implicit if not given 2220 | # also store modality lengths for determining noising times 2221 | 2222 | num_modalities = [] 2223 | 2224 | for batch_modalities in modalities: 2225 | batch_num_modalities = 0 2226 | 2227 | for ind, modality in enumerate(batch_modalities): 2228 | 2229 | if is_tensor(modality) and modality.dtype == torch.float: 2230 | modality = (0, modality) 2231 | 2232 | if not isinstance(modality, tuple): 2233 | continue 2234 | 2235 | modality_type, modality_tensor = modality 2236 | batch_modalities[ind] = modality 2237 | batch_num_modalities += 1 2238 | 2239 | num_modalities.append(batch_num_modalities) 2240 | 2241 | num_modalities = tensor_(num_modalities) 2242 | 2243 | # determine the times 2244 | 2245 | if not exists(times): 2246 | if is_empty(num_modalities) or num_modalities.amax().item() == 0: 2247 | times = torch.empty((batch, 0), device = device, dtype = torch.float) 2248 | else: 2249 | num_modalities_to_times_fn = default(num_modalities_to_times_fn, default_modality_length_to_time_fn) 2250 | 2251 | if exists(num_modalities_to_times_fn): 2252 | times = num_modalities_to_times_fn(num_modalities) 2253 | 2254 | # if needs velocity matching, make sure times are in the range of 0 - (1. - ) 2255 | 2256 | if need_velocity_matching: 2257 | orig_times = times.clone() 2258 | times = times * (1. - velocity_consistency_delta_time) 2259 | 2260 | # process list of text and modalities interspersed with one another 2261 | 2262 | modality_positions = [] 2263 | modality_tokens = [] 2264 | modality_pos_emb = [] 2265 | 2266 | text = [] 2267 | 2268 | # auto move all tensors to device of model 2269 | 2270 | modalities = tree_map_tensor(modalities, lambda t: t.to(device)) 2271 | 2272 | # for all modalities, batch process same shaped modalities of the same type 2273 | 2274 | if not is_decoding: 2275 | for mod in self.get_all_modality_info(): 2276 | encode_fn = default(mod.encoder, nn.Identity()) 2277 | 2278 | with torch.no_grad(): 2279 | encode_fn.eval() 2280 | modalities = apply_fn_modality_type(encode_fn, modalities, modality_type = mod.modality_type) 2281 | 2282 | # for parsing out the predicted flow from flattened sequence of tokens coming out of transformer 2283 | 2284 | flows = defaultdict(list) # store flows for loss 2285 | 2286 | get_pred_flows: GetPredFlows = defaultdict(list) # functions for parsing modalities from Float['b n d'] for model back to latents or pixel space 2287 | 2288 | def model_to_pred_flow(batch_index, start_index, modality_length, unpack_fn): 2289 | 2290 | def inner(embed: Float['b n d'], need_splice = True) -> Float['...']: 2291 | embed = embed[batch_index] 2292 | 2293 | if need_splice: 2294 | embed = embed[start_index:(start_index + modality_length)] 2295 | 2296 | embed = unpack_fn(embed) 2297 | return embed 2298 | 2299 | return inner 2300 | 2301 | # for going from predicted flow -> reconstruction 2302 | 2303 | get_recon_losses: Callable[[Tensor], Tensor] = defaultdict(list) 2304 | 2305 | def get_recon_loss(noise, times, modality): 2306 | 2307 | def inner(pred_flow): 2308 | recon_modality = noise + pred_flow * (1. - times) 2309 | return F.mse_loss(modality, recon_modality) 2310 | 2311 | return inner 2312 | 2313 | # prepare storing of sizes of all modalities that require axial positions, for delayed application for efficiency 2314 | 2315 | pos_emb_max_axial_dims: dict[int, list[Tensor]] = defaultdict(list) 2316 | 2317 | # go through all modality samples and do necessary transform 2318 | 2319 | for batch_index, batch_modalities in enumerate(modalities): 2320 | 2321 | modality_index = 0 2322 | batch_modality_positions = [] 2323 | batch_modality_tokens = [] 2324 | batch_modality_pos_emb = [] 2325 | 2326 | batch_text = [] 2327 | 2328 | offset = 0 2329 | 2330 | for modality in batch_modalities: 2331 | # if non-text modality detected and not given as a tuple 2332 | # cast to (int, Tensor) where int is defaulted to type 0 (convenience for one modality) 2333 | 2334 | is_text = not isinstance(modality, tuple) 2335 | is_modality = not is_text 2336 | 2337 | if is_text: 2338 | modality_tensor = modality 2339 | else: 2340 | modality_type, modality_tensor, *_ = modality 2341 | 2342 | # auto move modality tensor to correct device 2343 | 2344 | mod = self.get_modality_info(modality_type) 2345 | 2346 | if is_modality: 2347 | assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified' 2348 | 2349 | channel_dim = 0 if mod.channel_first_latent else -1 2350 | 2351 | assert mod.dim_latent == modality_tensor.shape[channel_dim], f'mismatch for modality latent dimension - expected {mod.dim_latent} but received {modality_tensor.shape[-1]} - modality shape is {tuple(modality_tensor.shape)}, perhaps you need to set `channel_first_latent` to the correct value' 2352 | assert mod.num_dim == (len(modality_tensor.shape) - 1), f'mismatch for modality number of dimensions - expected {mod.num_dim} but received {len(modality_tensor.shape) - 1} {modality_tensor.shape}' 2353 | 2354 | # auto ward against scalars (lone start end tokens) 2355 | 2356 | if modality_tensor.dtype in (torch.int, torch.long) and modality_tensor.ndim == 0: 2357 | modality_tensor = rearrange(modality_tensor, '-> 1') 2358 | 2359 | # handle text 2360 | 2361 | if is_text: 2362 | assert modality_tensor.ndim == 1 and modality_tensor.dtype in (torch.int, torch.long) 2363 | text_length = modality_tensor.shape[0] 2364 | 2365 | batch_text.append(modality_tensor) 2366 | zeros = torch.zeros(text_length, self.dim, device = device) 2367 | 2368 | batch_modality_tokens.append(zeros) 2369 | 2370 | offset += text_length 2371 | 2372 | if need_axial_pos_emb: 2373 | batch_modality_pos_emb.append(zeros) 2374 | 2375 | continue 2376 | 2377 | # otherwise handle a modality 2378 | 2379 | # get times for noising the modality 2380 | 2381 | modality_time = times[batch_index, modality_index] 2382 | 2383 | # noise 2384 | 2385 | if return_loss: 2386 | noise = torch.randn_like(modality_tensor) 2387 | 2388 | noised_modality = modality_tensor * modality_time + noise * (1. - modality_time) 2389 | 2390 | # the flow is the (data - noise) 2391 | 2392 | modality_flow = modality_tensor - noise 2393 | 2394 | # append to flow for loss 2395 | 2396 | flows[modality_type].append(modality_flow) 2397 | 2398 | modality_tensor = noised_modality 2399 | 2400 | # store function for deriving reconstruction loss from decoder 2401 | 2402 | get_recon_losses[modality_type].append(get_recon_loss(noise, modality_time, modality_tensor)) 2403 | 2404 | # go through maybe encoder 2405 | 2406 | modality_tensor = add_temp_batch_dim(mod.latent_to_model)(modality_tensor) 2407 | 2408 | # gather the modality length 2409 | 2410 | modality_shape_tuple = modality_tensor.shape[:-1] 2411 | modality_length = math.prod(modality_shape_tuple) 2412 | 2413 | text_tensor = torch.full((modality_length,), -1, device = device) # text is all -1 here, so text labels are not learned on 2414 | 2415 | # only add modality meta information when not returning embedding, which only occurs when sampling modality 2416 | 2417 | succeed_modality_tokens = precede_modality_tokens = 0 2418 | 2419 | if not return_embed: 2420 | # add the [som] and [eom] tokens for the modality type 2421 | 2422 | som_id, eom_id = mod.som_id, mod.eom_id 2423 | 2424 | # start by just storing the token length of the modality 2425 | 2426 | modality_shape_str = join([*map(str, modality_shape_tuple)], ',') 2427 | modality_meta_info = self.char_tokenizer(modality_shape_str, device = device) 2428 | 2429 | precede_modality_tokens = len(modality_meta_info) + 2 2430 | succeed_modality_tokens = 1 2431 | 2432 | text_tensor = cat(( 2433 | tensor_([self.meta_id]), 2434 | modality_meta_info, 2435 | tensor_([som_id]), 2436 | text_tensor, 2437 | tensor_([eom_id]) 2438 | )) 2439 | 2440 | batch_modality_positions.append((modality_type, offset + precede_modality_tokens, modality_length)) # offset + preceding meta tag length (which includes the modality start token) 2441 | 2442 | # store parsing out back to shape 2443 | 2444 | modality_tensor, unpack_modality_shape = pack_one_with_inverse(modality_tensor, '* d') 2445 | 2446 | inverse_fn = model_to_pred_flow(batch_index, offset + precede_modality_tokens, modality_length, unpack_modality_shape) 2447 | 2448 | get_pred_flows[modality_type].append(inverse_fn) 2449 | 2450 | # increment offset 2451 | 2452 | offset += modality_length + precede_modality_tokens + succeed_modality_tokens # +2 due to [som] and [eom] - then account for meta start id and modality shape information (or eventually any meta information about modality) 2453 | 2454 | modality_tensor = F.pad(modality_tensor, (0, 0, precede_modality_tokens, succeed_modality_tokens)) 2455 | 2456 | batch_modality_tokens.append(modality_tensor) 2457 | 2458 | batch_text.append(text_tensor) 2459 | 2460 | # handle axial positional embedding 2461 | 2462 | if need_axial_pos_emb: 2463 | 2464 | if exists(mod.pos_emb_mlp): 2465 | pos_emb_max_axial_dims[modality_type].append(tensor(modality_shape_tuple)) 2466 | pos_emb = (modality_type, modality_shape_tuple, (precede_modality_tokens, succeed_modality_tokens)) 2467 | 2468 | else: 2469 | pos_emb = torch.zeros(text_tensor.shape[0], self.dim, device = device) 2470 | 2471 | batch_modality_pos_emb.append(pos_emb) 2472 | 2473 | text.append(cat(batch_text)) 2474 | 2475 | if need_axial_pos_emb: 2476 | modality_pos_emb.append(batch_modality_pos_emb) 2477 | 2478 | modality_tokens.append(cat(batch_modality_tokens)) 2479 | modality_positions.append(batch_modality_positions) 2480 | 2481 | modality_index += 1 2482 | 2483 | if return_loss: 2484 | total_tokens = sum([t.numel() for t in text]) 2485 | 2486 | text = pad_sequence(text, padding_value = -1) 2487 | 2488 | modality_tokens = pad_sequence(modality_tokens, padding_value = 0.) 2489 | 2490 | # handle modality positional embedding 2491 | 2492 | if need_axial_pos_emb: 2493 | pos_emb_max_axial_dims = {mod_type: stack(sizes, dim = -1).amax(dim = -1) for mod_type, sizes in pos_emb_max_axial_dims.items()} 2494 | factorized_pos_emb = {mod_type: self.get_modality_info(mod_type).pos_emb_mlp(max_size, return_factorized = True) for mod_type, max_size in pos_emb_max_axial_dims.items()} 2495 | 2496 | # lazy evaluate the modality positional embedding from the factorized positional embedding from maximum axial dims 2497 | 2498 | evaluated_pos_emb = [] 2499 | 2500 | for batch_modality_pos_emb in modality_pos_emb: 2501 | evaluated_batch_pos_emb = [] 2502 | 2503 | for maybe_pos_emb_config in batch_modality_pos_emb: 2504 | 2505 | if is_tensor(maybe_pos_emb_config): 2506 | evaluated_batch_pos_emb.append(maybe_pos_emb_config) 2507 | continue 2508 | 2509 | mod_type, mod_size, padding = maybe_pos_emb_config 2510 | 2511 | mod_info = self.get_modality_info(mod_type) 2512 | mod_factorized_pos_emb = factorized_pos_emb[mod_type] 2513 | 2514 | mod_pos_emb = mod_info.pos_emb_mlp.combine_factorized(mod_factorized_pos_emb, mod_size, flatten = True) 2515 | mod_pos_emb = F.pad(mod_pos_emb, (0, 0, *padding), value = 0.) # handle padding for preceding and succeeding meta tokens 2516 | 2517 | evaluated_batch_pos_emb.append(mod_pos_emb) 2518 | 2519 | evaluated_pos_emb.append(cat(evaluated_batch_pos_emb, dim = -2)) 2520 | 2521 | modality_pos_emb = pad_sequence(evaluated_pos_emb, padding_value = 0.) 2522 | 2523 | # handle training mode and removal of last token 2524 | 2525 | if return_loss: 2526 | modality_tokens = modality_tokens[:, :-1] 2527 | 2528 | if need_axial_pos_emb: 2529 | modality_pos_emb = modality_pos_emb[:, :-1] 2530 | 2531 | # if returning loss, split text for next token prediction 2532 | 2533 | if return_loss: 2534 | text, text_labels = text[:, :-1], text[:, 1:] 2535 | 2536 | # derive is_modality mask for flow on the right tokens + flow loss 2537 | 2538 | batch, seq_len, device = *text.shape, text.device 2539 | 2540 | assert len(modality_positions) == batch 2541 | 2542 | if isinstance(modality_positions, list): 2543 | modality_positions = modality_positions_to_tensor(modality_positions, device = device) 2544 | 2545 | if modality_positions.shape[-1] == 2: # Int['b m 2'] -> Int['b m 3'] if type is not given (one modality) 2546 | modality_positions = F.pad(modality_positions, (1, 0)) 2547 | 2548 | # for now use dummy padding modality position info if empty (all zeros) 2549 | 2550 | if modality_positions.numel() == 0: 2551 | modality_positions = F.pad(modality_positions, (0, 0, 0, 1)) 2552 | 2553 | # sort the modalities tensor and sanitize, readying for noising of modalities 2554 | 2555 | modality_positions, sorted_indices = order_modality_positions_by_seq_offset(modality_positions) 2556 | 2557 | is_modalities = modality_positions_to_is_modality_mask(seq_len, modality_positions, num_modalities = self.num_modalities, device = device) 2558 | 2559 | is_any_modality = reduce(is_modalities, 'b t m n -> b n', 'any') 2560 | 2561 | # embed text 2562 | 2563 | text = text.masked_fill(text == -1, 0) 2564 | 2565 | text_tokens = self.text_embed(text) 2566 | 2567 | # maybe add the axial positional embedding 2568 | 2569 | if need_axial_pos_emb: 2570 | modality_tokens = modality_tokens + modality_pos_emb 2571 | 2572 | # intersperse the modalities with the text for the joint transformer + flow system 2573 | 2574 | tokens = einx.where('b n, b n d, b n d', is_any_modality, modality_tokens, text_tokens) 2575 | 2576 | # derive rotary positions 2577 | 2578 | rotary_positions = derive_rotary_positions_from_modality_positions(seq_len, modality_positions) 2579 | 2580 | rotary_emb = self.rotary_emb(rotary_positions) 2581 | rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d') 2582 | 2583 | # take care of cache 2584 | 2585 | is_any_modality_when_decoding = None 2586 | 2587 | if exists(cache): 2588 | assert exists(decode_length), '`decode_length` must be passed in on forward for modality sampling. think of a cleaner way on some future date' 2589 | assert exists(decoding_text_or_modality) 2590 | 2591 | if decoding_text_or_modality == 'text': 2592 | decode_length = 1 2593 | 2594 | is_any_modality_when_decoding = decoding_text_or_modality == 'modality' 2595 | modality_positions = None 2596 | 2597 | # times 2598 | 2599 | times_per_token = einsum(is_modalities.float(), times, 'b t m n, b m -> b t n') 2600 | 2601 | times_cond = reduce(times_per_token, 'b t n -> b n', 'sum') 2602 | 2603 | # attention 2604 | 2605 | embed, kv_cache = self.transformer( 2606 | tokens, 2607 | times = times_cond, 2608 | rotary_emb = rotary_emb, 2609 | modality_positions = modality_positions, 2610 | is_any_modality = is_any_modality_when_decoding, 2611 | cache = cache, 2612 | decode_length = decode_length, 2613 | return_kv_cache = True 2614 | ) 2615 | 2616 | # early return for embedding for decoding modality 2617 | 2618 | if return_embed: 2619 | if not return_kv_cache: 2620 | return (embed, get_pred_flows) 2621 | 2622 | return (embed, get_pred_flows), kv_cache 2623 | 2624 | # text unembedding 2625 | 2626 | text_logits = self.to_text_logits(embed) 2627 | 2628 | if not return_loss: 2629 | if not return_kv_cache: 2630 | return text_logits 2631 | 2632 | return text_logits, kv_cache 2633 | 2634 | # flow loss 2635 | 2636 | pred_flows = [] 2637 | recon_losses = [] 2638 | 2639 | for modality_id in range(self.num_modalities): 2640 | mod = self.get_modality_info(modality_id) 2641 | 2642 | modality_get_pred_flows = get_pred_flows[modality_id] 2643 | modality_get_recon_losses = get_recon_losses[modality_id] 2644 | 2645 | modality_pred_flows = [] 2646 | modality_recon_losses = [] 2647 | 2648 | for get_pred_flow, get_recon_loss in zip(modality_get_pred_flows, modality_get_recon_losses): 2649 | 2650 | pred_flow = get_pred_flow(embed) 2651 | pred_flow = add_temp_batch_dim(mod.model_to_latent)(pred_flow) 2652 | modality_pred_flows.append(pred_flow) 2653 | 2654 | if not return_loss or not self.has_recon_loss: 2655 | continue 2656 | 2657 | modality_recon_losses.append(get_recon_loss(pred_flow)) 2658 | 2659 | pred_flows.append(modality_pred_flows) 2660 | recon_losses.append(modality_recon_losses) 2661 | 2662 | # early return for velocity consistency ema model 2663 | 2664 | if return_only_pred_flows: 2665 | return pred_flows 2666 | 2667 | # text autoregressive loss 2668 | 2669 | text_labels = text_labels.masked_fill(is_any_modality, self.ignore_index) 2670 | 2671 | text_loss = F.cross_entropy( 2672 | rearrange(text_logits, 'b n l -> b l n'), 2673 | text_labels, 2674 | ignore_index = self.ignore_index 2675 | ) 2676 | 2677 | text_loss_weight = (text_labels != self.ignore_index).sum() / total_tokens 2678 | 2679 | # calculate flow losses 2680 | 2681 | flow_losses = [] 2682 | 2683 | modality_loss_weights = [] 2684 | 2685 | for modality_id, (pred_flow, is_one_modality) in enumerate(zip(pred_flows, is_modalities.unbind(dim = 1))): 2686 | mod = self.get_modality_info(modality_id) 2687 | 2688 | is_one_modality = reduce(is_one_modality, 'b m n -> b n', 'any') 2689 | modality_loss_weight = is_one_modality.sum() / total_tokens 2690 | 2691 | modality_flows = flows[modality_id] 2692 | 2693 | pack_pattern = 'd *' if mod.channel_first_latent else '* d' 2694 | 2695 | modality_pred_flow, _ = pack(pred_flow, pack_pattern) 2696 | modality_flows, _ = pack(modality_flows, pack_pattern) 2697 | 2698 | flow_loss = F.mse_loss( 2699 | modality_pred_flow, 2700 | modality_flows 2701 | ) 2702 | 2703 | modality_loss_weights.append(modality_loss_weight) 2704 | 2705 | flow_losses.append(flow_loss) 2706 | 2707 | modality_loss_weights = stack(modality_loss_weights) 2708 | 2709 | # only the token positions that are not modalities have autoregressive loss 2710 | 2711 | total_loss = ( 2712 | text_loss * text_loss_weight * self.text_loss_weight + 2713 | (stack(flow_losses) * modality_loss_weights).sum() * self.flow_loss_weight 2714 | ) 2715 | 2716 | # whether to handle velocity consistency 2717 | # for straightening the flow, from consistency flow matching paper https://arxiv.org/abs/2407.02398 2718 | 2719 | velocity_match_losses = None 2720 | 2721 | if need_velocity_matching: 2722 | 2723 | with torch.no_grad(): 2724 | velocity_consistency_ema_model.eval() 2725 | 2726 | ema_pred_flows = velocity_consistency_ema_model( 2727 | velocity_modalities, 2728 | times = orig_times + velocity_consistency_delta_time, 2729 | return_only_pred_flows = True 2730 | ) 2731 | 2732 | velocity_match_losses = [] 2733 | 2734 | for ema_pred_flow, pred_flow in zip(ema_pred_flows, pred_flows): 2735 | 2736 | pack_pattern = 'd *' if mod.channel_first_latent else '* d' 2737 | pred_flow, _ = pack(pred_flow, pack_pattern) 2738 | ema_pred_flow, _ = pack(ema_pred_flow, pack_pattern) 2739 | 2740 | velocity_match_loss = F.mse_loss( 2741 | pred_flow, 2742 | ema_pred_flow 2743 | ) 2744 | 2745 | velocity_match_losses.append(velocity_match_loss) 2746 | 2747 | total_loss = ( 2748 | total_loss + 2749 | (stack(velocity_match_losses) * modality_loss_weights).sum() * self.velocity_consistency_loss_weight 2750 | ) 2751 | 2752 | # maybe reconstruction loss 2753 | 2754 | if self.has_recon_loss: 2755 | 2756 | averaged_recon_losses = [] 2757 | 2758 | for modality_recon_loss in recon_losses: 2759 | averaged_recon_losses.append(sum(modality_recon_loss) / len(modality_recon_loss)) 2760 | 2761 | total_loss = ( 2762 | total_loss + 2763 | (stack(averaged_recon_losses) * modality_loss_weights).sum() * self.reconstruction_loss_weight 2764 | ) 2765 | 2766 | # return total loss if no breakdown needed 2767 | 2768 | if not return_breakdown: 2769 | return total_loss 2770 | 2771 | return total_loss, LossBreakdown(total_loss, text_loss, flow_losses, velocity_match_losses, recon_losses) 2772 | --------------------------------------------------------------------------------