├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── enwik8
│ └── enwik8.gz
└── flowers
│ └── labels.txt
├── pyproject.toml
├── tests
└── test_transfusion.py
├── train_image_only.py
├── train_image_only_with_unet.py
├── train_latent_only.py
├── train_latent_with_text.py
├── train_mnist.py
├── train_mnist_vae.py
├── train_mnist_with_unet.py
├── train_text_only.py
├── train_toy.py
├── transfusion.png
└── transfusion_pytorch
├── __init__.py
└── transfusion.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests the examples in README
2 | on: [push, pull_request]
3 |
4 | env:
5 | TYPECHECK: True
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Install Python
13 | uses: actions/setup-python@v5
14 | with:
15 | python-version: "3.11"
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install uv
19 | python -m uv pip install --upgrade pip
20 | python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21 | python -m uv pip install -e .[test]
22 | - name: Test with pytest
23 | run: |
24 | python -m pytest tests/
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/mnist/
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113 | .pdm.toml
114 | .pdm-python
115 | .pdm-build/
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Transfusion - Pytorch
4 |
5 | Pytorch implementation of [Transfusion](https://www.arxiv.org/abs/2408.11039), "Predict the Next Token and Diffuse Images with One Multi-Modal Model", from MetaAI.
6 |
7 | In this repo, we will substitute diffusion with flow matching given the success of Flux from Black Forest Labs (but will keep the original paper title given Transflow does not have the same ring). This repository will also attempt to extend to any number of modalities.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install transfusion-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | One modality, say images
18 |
19 | ```python
20 | from torch import randint, randn
21 | from transfusion_pytorch import Transfusion
22 |
23 | model = Transfusion(
24 | num_text_tokens = 256,
25 | dim_latent = 384,
26 | modality_default_shape = (4,), # fallback, in the case the language model did not produce a valid modality shape
27 | transformer = dict(
28 | dim = 512,
29 | depth = 8
30 | )
31 | )
32 |
33 | # any torch.long is text, torch.float is modalities
34 |
35 | text_and_images = [
36 | [randint(0, 256, (16,)), randn(4, 384), randint(0, 256, (8,)), randn(6, 384)],
37 | [randint(0, 256, (16,)), randn(7, 384), randint(0, 256, (5,)), randn(2, 384), randint(0, 256, (9,))]
38 | ]
39 |
40 | loss = model(text_and_images)
41 |
42 | loss.backward()
43 |
44 | # after much training
45 |
46 | one_multimodal_sample = model.sample()
47 | ```
48 |
49 | Multiple different modalities
50 |
51 | ```python
52 | from torch import randint, randn
53 | from transfusion_pytorch import Transfusion
54 |
55 | model = Transfusion(
56 | num_text_tokens = 256,
57 | dim_latent = (384, 192), # specify multiple latent dimensions
58 | modality_default_shape = ((4,), (2,)), # default shapes for first and second modality
59 | transformer = dict(
60 | dim = 512,
61 | depth = 8
62 | )
63 | )
64 |
65 | # then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position
66 |
67 | # any torch.long is text, torch.float is modalities
68 |
69 | text_images_and_audio = [
70 | [randint(0, 256, (16,)), (0, randn(4, 384)), randint(0, 256, (8,)), (1, randn(6, 192))],
71 | [randint(0, 256, (16,)), randn(7, 384), randint(0, 256, (5,)), (1, randn(2, 192)), randint(0, 256, (9,))]
72 | ]
73 |
74 | loss = model(text_images_and_audio)
75 |
76 | loss.backward()
77 |
78 | # after much training
79 |
80 | one_multimodal_sample = model.sample()
81 | ```
82 |
83 | Automatically taking care of encoding and decoding of images
84 |
85 | ```python
86 | import torch
87 | from torch import nn, randint, randn
88 | from transfusion_pytorch import Transfusion, print_modality_sample
89 |
90 | mock_encoder = nn.Conv2d(3, 384, 3, padding = 1)
91 | mock_decoder = nn.Conv2d(384, 3, 3, padding = 1)
92 |
93 | model = Transfusion(
94 | num_text_tokens = 12,
95 | dim_latent = 384,
96 | channel_first_latent = True,
97 | modality_default_shape = (4, 4),
98 | modality_encoder = mock_encoder,
99 | modality_decoder = mock_decoder,
100 | transformer = dict(
101 | dim = 512,
102 | depth = 8
103 | )
104 | )
105 |
106 | text_and_images = [
107 | [
108 | randint(0, 12, (16,)), # 16 text tokens
109 | randn(3, 8, 8), # (8 x 8) 3 channeled image
110 | randint(0, 12, (8,)), # 8 text tokens
111 | randn(3, 7, 7) # (7 x 7) 3 channeled image
112 | ],
113 | [
114 | randint(0, 12, (16,)), # 16 text tokens
115 | randn(3, 8, 5), # (8 x 5) 3 channeled image
116 | randint(0, 12, (5,)), # 5 text tokens
117 | randn(3, 2, 16), # (2 x 16) 3 channeled image
118 | randint(0, 12, (9,)) # 9 text tokens
119 | ]
120 | ]
121 |
122 | loss = model(text_and_images)
123 |
124 | loss.backward()
125 |
126 | # after much training
127 |
128 | one_multimodal_sample = model.sample()
129 |
130 | print_modality_sample(one_multimodal_sample)
131 | ```
132 |
133 | To pretrain on language first, just pass in your text as type `Int['batch seq']`
134 |
135 | ```python
136 | import torch
137 | from transfusion_pytorch import Transfusion
138 |
139 | model = Transfusion(
140 | num_text_tokens = 256,
141 | dim_latent = 384,
142 | transformer = dict(
143 | dim = 512,
144 | depth = 8,
145 | )
146 | ).cuda()
147 |
148 | text = torch.randint(0, 256, (2, 1024)).cuda()
149 |
150 | loss = model(text)
151 | loss.backward()
152 |
153 | # after much training
154 |
155 | sampled = model.generate_text_only(text[:, :1], 1024)
156 | ```
157 |
158 | ## Examples
159 |
160 | To run any of the examples `train_{example_name}.py` in the project root, simply install dependencies first as so
161 |
162 | ```bash
163 | $ pip install .[examples]
164 | ```
165 |
166 | If you run into some weird error with `safetensors`, run this too
167 |
168 | ```bash
169 | $ pip install -U diffusers transformers accelerate scipy ftfy safetensors
170 | ```
171 |
172 | ## Citations
173 |
174 | ```bibtex
175 | @inproceedings{Zhou2024TransfusionPT,
176 | title = {Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model},
177 | author = {Chunting Zhou and Lili Yu and Arun Babu and Kushal Tirumala and Michihiro Yasunaga and Leonid Shamis and Jacob Kahn and Xuezhe Ma and Luke Zettlemoyer and Omer Levy},
178 | year = {2024},
179 | url = {https://api.semanticscholar.org/CorpusID:271909855}
180 | }
181 | ```
182 |
183 | ```bibtex
184 | @misc{Rubin2024,
185 | author = {Ohad Rubin},
186 | url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
187 | }
188 | ```
189 |
190 | ```bibtex
191 | @article{Nguyen2024MinPS,
192 | title = {Min P Sampling: Balancing Creativity and Coherence at High Temperature},
193 | author = {Minh Nguyen and Andrew Baker and Andreas Kirsch and Clement Neo},
194 | journal = {ArXiv},
195 | year = {2024},
196 | volume = {abs/2407.01082},
197 | url = {https://api.semanticscholar.org/CorpusID:270870613}
198 | }
199 | ```
200 |
201 | ```bibtex
202 | @article{Bao2022AllAW,
203 | title = {All are Worth Words: A ViT Backbone for Diffusion Models},
204 | author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
205 | journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
206 | year = {2022},
207 | pages = {22669-22679},
208 | url = {https://api.semanticscholar.org/CorpusID:253581703}
209 | }
210 | ```
211 |
212 | ```bibtex
213 | @inproceedings{Zhao2024MonoFormerOT,
214 | title = {MonoFormer: One Transformer for Both Diffusion and Autoregression},
215 | author = {Chuyang Zhao and Yuxing Song and Wenhao Wang and Haocheng Feng and Errui Ding and Yifan Sun and Xinyan Xiao and Jingdong Wang},
216 | year = {2024},
217 | url = {https://api.semanticscholar.org/CorpusID:272832492}
218 | }
219 | ```
220 |
221 | ```bibtex
222 | @article{Yang2024ConsistencyFM,
223 | title = {Consistency Flow Matching: Defining Straight Flows with Velocity Consistency},
224 | author = {Ling Yang and Zixiang Zhang and Zhilong Zhang and Xingchao Liu and Minkai Xu and Wentao Zhang and Chenlin Meng and Stefano Ermon and Bin Cui},
225 | journal = {ArXiv},
226 | year = {2024},
227 | volume = {abs/2407.02398},
228 | url = {https://api.semanticscholar.org/CorpusID:270878436}
229 | }
230 | ```
231 |
232 | ```bibtex
233 | @inproceedings{Zhou2024ValueRL,
234 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
235 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
236 | year = {2024},
237 | url = {https://api.semanticscholar.org/CorpusID:273532030}
238 | }
239 | ```
240 |
241 | ```bibtex
242 | @inproceedings{Duvvuri2024LASERAW,
243 | title = {LASER: Attention with Exponential Transformation},
244 | author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
245 | year = {2024},
246 | url = {https://api.semanticscholar.org/CorpusID:273849947}
247 | }
248 | ```
249 |
250 | ```bibtex
251 | @inproceedings{Dong2024HymbaAH,
252 | title = {Hymba: A Hybrid-head Architecture for Small Language Models},
253 | author = {Xin Dong and Y. Fu and Shizhe Diao and Wonmin Byeon and Zijia Chen and Ameya Mahabaleshwarkar and Shih-Yang Liu and Matthijs Van Keirsbilck and Min-Hung Chen and Yoshi Suhara and Yingyan Lin and Jan Kautz and Pavlo Molchanov},
254 | year = {2024},
255 | url = {https://api.semanticscholar.org/CorpusID:274166163}
256 | ```
257 |
258 | ```bibtex
259 | @article{Zhu2024HyperConnections,
260 | title = {Hyper-Connections},
261 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
262 | journal = {ArXiv},
263 | year = {2024},
264 | volume = {abs/2409.19606},
265 | url = {https://api.semanticscholar.org/CorpusID:272987528}
266 | }
267 | ```
268 |
269 | ```bibtex
270 | @article{Zhu2025FracConnectionsFE,
271 | title = {Frac-Connections: Fractional Extension of Hyper-Connections},
272 | author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
273 | journal = {ArXiv},
274 | year = {2025},
275 | volume = {abs/2503.14125},
276 | url = {https://api.semanticscholar.org/CorpusID:277104144}
277 | }
278 | ```
279 |
--------------------------------------------------------------------------------
/data/enwik8/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/transfusion-pytorch/ea045706cc45e2cf5cd1ce4523e5af6a49eff54a/data/enwik8/enwik8.gz
--------------------------------------------------------------------------------
/data/flowers/labels.txt:
--------------------------------------------------------------------------------
1 | pink primrose
2 | hard-leaved pocket orchid
3 | canterbury bells
4 | sweet pea
5 | english marigold
6 | tiger lily
7 | moon orchid
8 | bird of paradise
9 | monkshood
10 | globe thistle
11 | snapdragon
12 | colt's foot
13 | king protea
14 | spear thistle
15 | yellow iris
16 | globe-flower
17 | purple coneflower
18 | peruvian lily
19 | balloon flower
20 | giant white arum lily
21 | fire lily
22 | pincushion flower
23 | fritillary
24 | red ginger
25 | grape hyacinth
26 | corn poppy
27 | prince of wales feathers
28 | stemless gentian
29 | artichoke
30 | sweet william
31 | carnation
32 | garden phlox
33 | love in the mist
34 | mexican aster
35 | alpine sea holly
36 | ruby-lipped cattleya
37 | cape flower
38 | great masterwort
39 | siam tulip
40 | lenten rose
41 | barbeton daisy
42 | daffodil
43 | sword lily
44 | poinsettia
45 | bolero deep blue
46 | wallflower
47 | marigold
48 | buttercup
49 | oxeye daisy
50 | common dandelion
51 | petunia
52 | wild pansy
53 | primula
54 | sunflower
55 | pelargonium
56 | bishop of llandaff
57 | gaura
58 | geranium
59 | orange dahlia
60 | pink-yellow dahlia?
61 | cautleya spicata
62 | japanese anemone
63 | black-eyed susan
64 | silverbush
65 | californian poppy
66 | osteospermum
67 | spring crocus
68 | bearded iris
69 | windflower
70 | tree poppy
71 | gazania
72 | azalea
73 | water lily
74 | rose
75 | thorn apple
76 | morning glory
77 | passion flower
78 | lotus
79 | toad lily
80 | anthurium
81 | frangipani
82 | clematis
83 | hibiscus
84 | columbine
85 | desert-rose
86 | tree mallow
87 | magnolia
88 | cyclamen
89 | watercress
90 | canna lily
91 | hippeastrum
92 | bee balm
93 | ball moss
94 | foxglove
95 | bougainvillea
96 | camellia
97 | mallow
98 | mexican petunia
99 | bromelia
100 | blanket flower
101 | trumpet creeper
102 | blackberry lily
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "transfusion-pytorch"
3 | version = "0.11.0"
4 | description = "Transfusion in Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'attention mechanism',
16 | 'rectified flow',
17 | ]
18 | classifiers=[
19 | 'Development Status :: 4 - Beta',
20 | 'Intended Audience :: Developers',
21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
22 | 'License :: OSI Approved :: MIT License',
23 | 'Programming Language :: Python :: 3.8',
24 | ]
25 |
26 | dependencies = [
27 | 'axial-positional-embedding>=0.3.5',
28 | 'beartype',
29 | 'einx>=0.3.0',
30 | 'einops>=0.8.0',
31 | 'ema-pytorch',
32 | 'hyper-connections>=0.0.10',
33 | 'jaxtyping',
34 | 'loguru',
35 | 'rotary_embedding_torch>=0.8.4',
36 | 'torchdiffeq',
37 | 'torch>=2.0',
38 | 'tqdm'
39 | ]
40 |
41 | [project.urls]
42 | Homepage = "https://pypi.org/project/transfusion-pytorch/"
43 | Repository = "https://github.com/lucidrains/transfusion-pytorch"
44 |
45 | [build-system]
46 | requires = ["hatchling"]
47 | build-backend = "hatchling.build"
48 |
49 | [project.optional-dependencies]
50 |
51 | examples = [
52 | "datasets",
53 | "diffusers"
54 | ]
55 | test = [
56 | "pytest",
57 | ]
58 |
59 | [tool.ruff]
60 | line-length = 1000
61 |
62 | lint.ignore = [
63 | "F722", # for jaxtyping shape annotation
64 | "F401",
65 | "F821",
66 | "E402"
67 | ]
68 |
69 | [tool.pytest.ini_options]
70 | pythonpath = [
71 | "."
72 | ]
73 |
74 | [tool.hatch.metadata]
75 | allow-direct-references = true
76 |
77 | [tool.hatch.build.targets.wheel]
78 | packages = ["transfusion_pytorch"]
79 |
--------------------------------------------------------------------------------
/tests/test_transfusion.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from functools import partial
3 | from copy import deepcopy
4 |
5 | import torch
6 | from torch import nn, randint, randn, tensor, cuda
7 |
8 | from einops import rearrange
9 |
10 | import torch._dynamo
11 | torch._dynamo.config.suppress_errors = True
12 |
13 | cuda_available = cuda.is_available()
14 |
15 | from transfusion_pytorch.transfusion import (
16 | Transfusion,
17 | flex_attention,
18 | exists,
19 | stack_same_shape_tensors_with_inverse,
20 | filter_with_inverse,
21 | apply_fn_modality_type
22 | )
23 |
24 | @pytest.mark.parametrize('cache_kv', (False, True))
25 | @pytest.mark.parametrize('use_flex_attn', (False, True))
26 | @pytest.mark.parametrize('num_residual_streams', (1, 4))
27 | @pytest.mark.parametrize('reconstruction_loss_weight', (0., 0.1))
28 | def test_transfusion(
29 | cache_kv: bool,
30 | use_flex_attn: bool,
31 | num_residual_streams: int,
32 | reconstruction_loss_weight: float
33 | ):
34 |
35 | if use_flex_attn and (not exists(flex_attention) or not cuda_available):
36 | return pytest.skip()
37 |
38 | text_tokens = 8
39 | randint_ = partial(randint, 0, text_tokens)
40 |
41 | model = Transfusion(
42 | num_text_tokens = text_tokens,
43 | dim_latent = (384, 192),
44 | modality_default_shape = ((32,), (64,)),
45 | reconstruction_loss_weight = reconstruction_loss_weight,
46 | transformer = dict(
47 | dim = 64,
48 | depth = 2,
49 | use_flex_attn = use_flex_attn,
50 | num_residual_streams = num_residual_streams
51 | )
52 | )
53 |
54 | if use_flex_attn:
55 | model = model.cuda()
56 |
57 | # then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position
58 |
59 | text_images_and_audio = [
60 | [randint_((16,)), (0, randn(4, 384)), randint_((8,)), (1, randn(6, 192))],
61 | [randint_((16,)), randn(7, 384), randint_((5,)), (1, randn(2, 192)), randint_((9,))]
62 | ]
63 |
64 | loss = model(text_images_and_audio)
65 |
66 | loss.backward()
67 |
68 | # after much training
69 |
70 | prime = [tensor(model.som_ids[0])]
71 |
72 | one_multimodal_sample = model.sample(prime, max_length = 128, cache_kv = cache_kv)
73 |
74 |
75 | @pytest.mark.parametrize('use_flex_attn', (False, True))
76 | def test_auto_modality_transform(
77 | use_flex_attn: bool
78 | ):
79 |
80 | if use_flex_attn and (not exists(flex_attention) or not cuda_available):
81 | return pytest.skip()
82 |
83 | text_tokens = 8
84 | randint_ = partial(randint, 0, text_tokens)
85 |
86 | model = Transfusion(
87 | num_text_tokens = text_tokens,
88 | dim_latent = 384,
89 | channel_first_latent = True,
90 | modality_default_shape = (2, 2),
91 | transformer = dict(
92 | dim = 64,
93 | depth = 2,
94 | use_flex_attn = use_flex_attn
95 | )
96 | )
97 |
98 | text_and_images = [
99 | [randint_((16,)), randn(384, 2, 2), randint_((8,)), randn(384, 2, 2)],
100 | [randint_((16,)), randn(384, 2, 2), randint_((5,)), randn(384, 2, 2), randint_((9,))]
101 | ]
102 |
103 | loss = model(text_and_images)
104 |
105 | loss.backward()
106 |
107 | # after much training
108 |
109 | prime = [tensor(model.som_ids[0])]
110 |
111 | one_multimodal_sample = model.sample(prime, max_length = 128)
112 |
113 | @pytest.mark.parametrize('use_flex_attn', (False, True))
114 | @pytest.mark.parametrize('return_loss', (False, True))
115 | def test_text(
116 | use_flex_attn: bool,
117 | return_loss: bool
118 | ):
119 |
120 | if use_flex_attn and (not exists(flex_attention) or not cuda_available):
121 | return pytest.skip()
122 |
123 | model = Transfusion(
124 | num_text_tokens = 256,
125 | dim_latent = 384,
126 | channel_first_latent = True,
127 | modality_default_shape = (32,),
128 | transformer = dict(
129 | dim = 64,
130 | depth = 2,
131 | use_flex_attn = use_flex_attn
132 | )
133 | )
134 |
135 | if use_flex_attn:
136 | model = model.cuda()
137 |
138 | text = randint(0, 256, (2, 1024))
139 |
140 | model(text, return_loss = return_loss)
141 |
142 | @pytest.mark.parametrize('channel_first', (False, True))
143 | def test_modality_only(
144 | channel_first: bool
145 | ):
146 |
147 | model = Transfusion(
148 | num_text_tokens = 256,
149 | dim_latent = (384, 192),
150 | channel_first_latent = channel_first,
151 | modality_default_shape = (32,),
152 | transformer = dict(
153 | dim = 64,
154 | depth = 2,
155 | use_flex_attn = False
156 | )
157 | )
158 |
159 | images = randn(2, 8, 8, 192)
160 |
161 | if channel_first:
162 | images = rearrange(images, 'b ... d -> b d ...')
163 |
164 | loss = model(images, return_loss = True, modality_type = 1)
165 |
166 | loss.backward()
167 |
168 | model.generate_modality_only(modality_type = 1)
169 |
170 | @pytest.mark.parametrize('custom_time_fn', (False, True))
171 | def test_text_image_end_to_end(
172 | custom_time_fn: bool
173 | ):
174 | mock_vae_encoder = nn.Conv2d(3, 384, 3, padding = 1)
175 | mock_vae_decoder = nn.Conv2d(384, 3, 3, padding = 1)
176 |
177 | model = Transfusion(
178 | num_text_tokens = 4,
179 | dim_latent = 384,
180 | channel_first_latent = True,
181 | modality_default_shape = ((4, 4),),
182 | modality_encoder = mock_vae_encoder,
183 | modality_decoder = mock_vae_decoder,
184 | transformer = dict(
185 | dim = 64,
186 | depth = 2
187 | )
188 | )
189 |
190 | text_and_images = [
191 | [
192 | randint(0, 4, (16,)),
193 | randn(3, 8, 8),
194 | randint(0, 4, (8,)),
195 | randn(3, 7, 7)
196 | ],
197 | [
198 | randint(0, 4, (16,)),
199 | randn(3, 8, 5),
200 | randint(0, 4, (5,)),
201 | randn(3, 2, 16),
202 | randint(0, 4, (9,))
203 | ]
204 | ]
205 |
206 | # allow researchers to experiment with different time distributions across multiple modalities in a sample
207 |
208 | def num_modalities_to_times(num_modalities):
209 | batch = num_modalities.shape[0]
210 | device = num_modalities.device
211 | total_modalities = num_modalities.amax().item()
212 | return torch.ones((batch, total_modalities), device = device)
213 |
214 | time_fn = num_modalities_to_times if custom_time_fn else None
215 |
216 | # forward
217 |
218 | loss = model(
219 | text_and_images,
220 | num_modalities_to_times_fn = time_fn
221 | )
222 |
223 | loss.backward()
224 |
225 | # after much training
226 |
227 | one_multimodal_sample = model.sample(max_length = 128)
228 |
229 | def test_velocity_consistency():
230 | mock_encoder = nn.Conv2d(3, 384, 3, padding = 1)
231 | mock_decoder = nn.Conv2d(384, 3, 3, padding = 1)
232 |
233 | model = Transfusion(
234 | num_text_tokens = 12,
235 | dim_latent = 384,
236 | channel_first_latent = True,
237 | modality_default_shape = (4, 4),
238 | modality_encoder = mock_encoder,
239 | modality_decoder = mock_decoder,
240 | transformer = dict(
241 | dim = 64,
242 | depth = 1
243 | )
244 | )
245 |
246 | ema_model = deepcopy(model)
247 |
248 | text_and_images = [
249 | [
250 | randint(0, 12, (16,)),
251 | randn(3, 8, 8),
252 | randint(0, 12, (8,)),
253 | randn(3, 7, 7)
254 | ],
255 | [
256 | randint(0, 12, (16,)),
257 | randn(3, 8, 5),
258 | randint(0, 12, (5,)),
259 | randn(3, 2, 16),
260 | randint(0, 12, (9,))
261 | ]
262 | ]
263 |
264 | loss, breakdown = model(
265 | text_and_images,
266 | velocity_consistency_ema_model = ema_model,
267 | return_breakdown = True
268 | )
269 |
270 | loss.backward()
271 |
272 | assert exists(breakdown.velocity)
273 |
274 | def test_axial_pos_emb():
275 | model = Transfusion(
276 | num_text_tokens = 256,
277 | dim_latent = (384, 192), # specify multiple latent dimensions
278 | modality_default_shape = ((2, 2), (2,)), # default shapes for first and second modality
279 | fallback_to_default_shape_if_invalid = True,
280 | add_pos_emb = True,
281 | modality_num_dim = (2, 1),
282 | transformer = dict(
283 | dim = 64,
284 | depth = 8
285 | )
286 | )
287 |
288 | # then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position
289 |
290 | # any torch.long is text, torch.float is modalities
291 |
292 | text_images_and_audio = [
293 | [randint(0, 256, (16,)), (0, randn(2, 3, 384)), randint(0, 256, (8,)), (1, randn(6, 192))],
294 | [randint(0, 256, (16,)), randn(1, 4, 384), randint(0, 256, (5,)), (1, randn(2, 192)), randint(0, 256, (9,))]
295 | ]
296 |
297 | loss = model(text_images_and_audio)
298 |
299 | loss.backward()
300 |
301 | # after much training
302 |
303 | one_multimodal_sample = model.sample(max_length = 128)
304 |
305 | # unet related
306 |
307 | def test_modality_only_with_unet():
308 |
309 | model = Transfusion(
310 | num_text_tokens = 10,
311 | dim_latent = 4,
312 | modality_default_shape = (14, 14),
313 | pre_post_transformer_enc_dec = (
314 | nn.Conv2d(4, 64, 3, 2, 1),
315 | nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1),
316 | ),
317 | channel_first_latent = True,
318 | add_pos_emb = True,
319 | modality_num_dim = 2,
320 | velocity_consistency_loss_weight = 0.1,
321 | transformer = dict(
322 | dim = 64,
323 | depth = 1,
324 | dim_head = 32,
325 | heads = 8
326 | )
327 | )
328 |
329 | x = torch.randn(1, 4, 14, 14)
330 |
331 | loss = model(x)
332 | loss.backward()
333 |
334 | sampled = model.generate_modality_only()
335 |
336 | def test_stack_similar_shape_fn():
337 | from torch import zeros
338 |
339 | data = [
340 | zeros(3, 5),
341 | zeros(2, 3),
342 | zeros(3, 5),
343 | zeros(2, 3),
344 | zeros(4, 5),
345 | zeros(4, 5)
346 | ]
347 |
348 | plus_one = lambda x: x + 1
349 |
350 | data = [d + i for i, d in enumerate(data)]
351 | data_plus_one = [plus_one(d) for d in data]
352 |
353 | stacked_tensors, inverse = stack_same_shape_tensors_with_inverse(data)
354 |
355 | stacked_tensors = {k: plus_one(v) for k, v in stacked_tensors.items()}
356 |
357 | batch_processed_data_plus_one = inverse(stacked_tensors)
358 |
359 | assert all([torch.allclose(tensor1, tensor2) for tensor1, tensor2 in zip(data_plus_one, batch_processed_data_plus_one)])
360 |
361 | def test_filter_with_inverse():
362 | x = [0, 1, 2, 3, 4]
363 | is_even = lambda el: (el % 2) == 0
364 |
365 | x_even, inverse = filter_with_inverse(is_even, x)
366 | x_even_times_ten = [el * 10 for el in x_even]
367 |
368 | y = inverse(x_even_times_ten)
369 | assert y == [0, 1, 20, 3, 40]
370 |
371 | def test_apply_fn_modality_type():
372 | from torch import zeros
373 |
374 | modalities = [
375 | [zeros(3, 5)],
376 | [zeros(1, 5)],
377 | [(1, zeros(3, 5))],
378 | [(1, zeros(2, 5))],
379 | [(0, zeros(1, 5)), (1, zeros(3, 5))],
380 | ]
381 |
382 | modalities = apply_fn_modality_type(lambda x: x + 1, modalities)
383 |
384 | modalities = apply_fn_modality_type(lambda x: x + 2, modalities, modality_type = 1)
385 |
386 | assert (modalities[0][0][-1] == 1).all()
387 | assert (modalities[2][0][-1] == 2).all()
388 |
389 |
390 | def test_zero_dimensional():
391 |
392 | model = Transfusion(
393 | num_text_tokens = 256,
394 | dim_latent = 384,
395 | modality_default_shape = (),
396 | transformer = dict(
397 | dim = 512,
398 | depth = 8,
399 | num_residual_streams = 1
400 | )
401 | )
402 |
403 | # any torch.long is text, torch.float is modalities
404 |
405 | text_and_embeds = [
406 | [randint(0, 256, (16,)), randn(384), randint(0, 256, (8,)), randn(384)],
407 | [randint(0, 256, (16,)), randn(384), randint(0, 256, (5,)), randn(384), randint(0, 256, (9,))]
408 | ]
409 |
410 | loss = model(text_and_embeds)
411 |
412 | loss.backward()
413 |
414 | # after much training
415 |
416 | one_multimodal_sample = model.sample(prompt = randn(384))
417 |
--------------------------------------------------------------------------------
/train_image_only.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import tensor
6 | from torch.nn import Module
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.optim import Adam
9 |
10 | from einops import rearrange
11 |
12 | import torchvision
13 | import torchvision.transforms as T
14 | from torchvision.utils import save_image
15 |
16 | from transfusion_pytorch import Transfusion, print_modality_sample
17 |
18 | rmtree('./results', ignore_errors = True)
19 | results_folder = Path('./results')
20 | results_folder.mkdir(exist_ok = True, parents = True)
21 |
22 | # functions
23 |
24 | def divisible_by(num, den):
25 | return (num % den) == 0
26 |
27 | # encoder / decoder
28 |
29 | class Encoder(Module):
30 | def forward(self, x):
31 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... h w (p1 p2)', p1 = 2, p2 = 2)
32 | return x * 2 - 1
33 |
34 | class Decoder(Module):
35 | def forward(self, x):
36 | x = rearrange(x, '... h w (p1 p2) -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14)
37 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
38 |
39 | model = Transfusion(
40 | num_text_tokens = 10,
41 | dim_latent = 4,
42 | channel_first_latent = False,
43 | modality_default_shape = (14, 14),
44 | modality_encoder = Encoder(),
45 | modality_decoder = Decoder(),
46 | add_pos_emb = True,
47 | modality_num_dim = 2,
48 | velocity_consistency_loss_weight = 0.1,
49 | reconstruction_loss_weight = 0.1,
50 | transformer = dict(
51 | dim = 64,
52 | depth = 4,
53 | dim_head = 32,
54 | heads = 8,
55 | attn_laser = True
56 | )
57 | ).cuda()
58 |
59 | ema_model = model.create_ema()
60 |
61 | class MnistDataset(Dataset):
62 | def __init__(self):
63 | self.mnist = torchvision.datasets.MNIST(
64 | './data',
65 | download = True
66 | )
67 |
68 | def __len__(self):
69 | return len(self.mnist)
70 |
71 | def __getitem__(self, idx):
72 | pil, labels = self.mnist[idx]
73 | digit_tensor = T.PILToTensor()(pil)
74 | return (digit_tensor / 255).float()
75 |
76 | def cycle(iter_dl):
77 | while True:
78 | for batch in iter_dl:
79 | yield batch
80 |
81 | dataset = MnistDataset()
82 |
83 | dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)
84 | iter_dl = cycle(dataloader)
85 |
86 | optimizer = Adam(model.parameters(), lr = 8e-4)
87 |
88 | # train loop
89 |
90 | for step in range(1, 100_000 + 1):
91 |
92 | loss = model(next(iter_dl), velocity_consistency_ema_model = ema_model)
93 | loss.backward()
94 |
95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
96 |
97 | optimizer.step()
98 | optimizer.zero_grad()
99 |
100 | ema_model.update()
101 |
102 | print(f'{step}: {loss.item():.3f}')
103 |
104 | if divisible_by(step, 500):
105 | image = ema_model.generate_modality_only(batch_size = 64)
106 |
107 | save_image(
108 | rearrange(image, '(gh gw) 1 h w -> 1 (gh h) (gw w)', gh = 8).detach().cpu(),
109 | str(results_folder / f'{step}.png')
110 | )
111 |
--------------------------------------------------------------------------------
/train_image_only_with_unet.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import tensor, nn
6 | from torch.nn import Module
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.optim import Adam
9 |
10 | from einops import rearrange
11 |
12 | import torchvision
13 | import torchvision.transforms as T
14 | from torchvision.utils import save_image
15 |
16 | from transfusion_pytorch import Transfusion, print_modality_sample
17 |
18 | rmtree('./results', ignore_errors = True)
19 | results_folder = Path('./results')
20 | results_folder.mkdir(exist_ok = True, parents = True)
21 |
22 | # functions
23 |
24 | def divisible_by(num, den):
25 | return (num % den) == 0
26 |
27 | # encoder / decoder
28 |
29 | class Encoder(Module):
30 | def forward(self, x):
31 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... (p1 p2) h w', p1 = 2, p2 = 2)
32 | return x * 2 - 1
33 |
34 | class Decoder(Module):
35 | def forward(self, x):
36 | x = rearrange(x, '... (p1 p2) h w -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14)
37 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
38 |
39 | model = Transfusion(
40 | num_text_tokens = 10,
41 | dim_latent = 4,
42 | channel_first_latent = True,
43 | modality_default_shape = (14, 14),
44 | modality_encoder = Encoder(),
45 | modality_decoder = Decoder(),
46 | pre_post_transformer_enc_dec = (
47 | nn.Conv2d(4, 64, 3, 2, 1),
48 | nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1),
49 | ),
50 | add_pos_emb = True,
51 | modality_num_dim = 2,
52 | velocity_consistency_loss_weight = 0.1,
53 | transformer = dict(
54 | dim = 64,
55 | depth = 4,
56 | dim_head = 32,
57 | heads = 8
58 | )
59 | ).cuda()
60 |
61 | ema_model = model.create_ema()
62 |
63 | class MnistDataset(Dataset):
64 | def __init__(self):
65 | self.mnist = torchvision.datasets.MNIST(
66 | './data',
67 | download = True
68 | )
69 |
70 | def __len__(self):
71 | return len(self.mnist)
72 |
73 | def __getitem__(self, idx):
74 | pil, labels = self.mnist[idx]
75 | digit_tensor = T.PILToTensor()(pil)
76 | return (digit_tensor / 255).float()
77 |
78 | def cycle(iter_dl):
79 | while True:
80 | for batch in iter_dl:
81 | yield batch
82 |
83 | dataset = MnistDataset()
84 |
85 | dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)
86 | iter_dl = cycle(dataloader)
87 |
88 | optimizer = Adam(model.parameters(), lr = 8e-4)
89 |
90 | # train loop
91 |
92 | for step in range(1, 100_000 + 1):
93 |
94 | loss = model(next(iter_dl), velocity_consistency_ema_model = ema_model)
95 | loss.backward()
96 |
97 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
98 |
99 | optimizer.step()
100 | optimizer.zero_grad()
101 |
102 | ema_model.update()
103 |
104 | print(f'{step}: {loss.item():.3f}')
105 |
106 | if divisible_by(step, 500):
107 | image = ema_model.generate_modality_only(batch_size = 64)
108 |
109 | save_image(
110 | rearrange(image, '(gh gw) 1 h w -> 1 (gh h) (gw w)', gh = 8).detach().cpu(),
111 | str(results_folder / f'{step}.png')
112 | )
113 |
--------------------------------------------------------------------------------
/train_latent_only.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import tensor
6 | from torch.nn import Module
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.optim import Adam
9 |
10 | from einops import rearrange
11 |
12 | import torchvision
13 | import torchvision.transforms as T
14 | from torchvision.utils import save_image
15 |
16 | from transfusion_pytorch import Transfusion, print_modality_sample
17 |
18 | # hf related
19 |
20 | from datasets import load_dataset
21 | from diffusers.models import AutoencoderKL
22 |
23 | vae = AutoencoderKL.from_pretrained("./path/to/your/autoencoder", subfolder = "vae")
24 |
25 | class Encoder(Module):
26 | def __init__(self, vae):
27 | super().__init__()
28 | self.vae = vae
29 |
30 | def forward(self, image):
31 | with torch.no_grad():
32 | latent = self.vae.encode(image * 2 - 1)
33 |
34 | return 0.18215 * latent.latent_dist.sample()
35 |
36 | class Decoder(Module):
37 | def __init__(self, vae):
38 | super().__init__()
39 | self.vae = vae
40 |
41 | def forward(self, latents):
42 | latents = (1 / 0.18215) * latents
43 |
44 | with torch.no_grad():
45 | image = self.vae.decode(latents).sample
46 |
47 | return (image / 2 + 0.5).clamp(0, 1)
48 |
49 | # results folder
50 |
51 | rmtree('./results', ignore_errors = True)
52 | results_folder = Path('./results')
53 | results_folder.mkdir(exist_ok = True, parents = True)
54 |
55 | # constants
56 |
57 | SAMPLE_EVERY = 100
58 |
59 | # functions
60 |
61 | def divisible_by(num, den):
62 | return (num % den) == 0
63 |
64 | # encoder / decoder
65 |
66 | model = Transfusion(
67 | num_text_tokens = 10,
68 | dim_latent = 4,
69 | channel_first_latent = True,
70 | modality_default_shape = (32, 32),
71 | modality_encoder = Encoder(vae),
72 | modality_decoder = Decoder(vae),
73 | add_pos_emb = True,
74 | modality_num_dim = 2,
75 | velocity_consistency_loss_weight = 0.1,
76 | reconstruction_loss_weight = 0.1,
77 | transformer = dict(
78 | dim = 256,
79 | depth = 8,
80 | dim_head = 64,
81 | heads = 8
82 | )
83 | ).cuda()
84 |
85 | ema_model = model.create_ema(0.9)
86 |
87 | class FlowersDataset(Dataset):
88 | def __init__(self, image_size):
89 | self.ds = load_dataset("nelorth/oxford-flowers")['train']
90 |
91 | self.transform = T.Compose([
92 | T.Resize((image_size, image_size)),
93 | T.PILToTensor()
94 | ])
95 |
96 | def __len__(self):
97 | return len(self.ds)
98 |
99 | def __getitem__(self, idx):
100 | pil = self.ds[idx]['image']
101 | tensor = self.transform(pil)
102 | return tensor / 255.
103 |
104 | def cycle(iter_dl):
105 | while True:
106 | for batch in iter_dl:
107 | yield batch
108 |
109 | dataset = FlowersDataset(256)
110 |
111 | dataloader = DataLoader(dataset, batch_size = 4, shuffle = True)
112 |
113 | iter_dl = cycle(dataloader)
114 |
115 | optimizer = Adam(model.parameters(), lr = 8e-4)
116 |
117 | # train loop
118 |
119 | for step in range(1, 100_000 + 1):
120 |
121 | for _ in range(4):
122 | loss = model.forward_modality(next(iter_dl))
123 | (loss / 4).backward()
124 |
125 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
126 |
127 | optimizer.step()
128 | optimizer.zero_grad()
129 |
130 | ema_model.update()
131 |
132 | print(f'{step}: {loss.item():.3f}')
133 |
134 | if divisible_by(step, SAMPLE_EVERY):
135 | image = ema_model.generate_modality_only(batch_size = 4)
136 |
137 | save_image(
138 | rearrange(image, '(gh gw) c h w -> c (gh h) (gw w)', gh = 2).detach().cpu(),
139 | str(results_folder / f'{step}.png')
140 | )
141 |
--------------------------------------------------------------------------------
/train_latent_with_text.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import nn, tensor, Tensor
6 | from torch.nn import Module
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.optim import Adam
9 |
10 | from einops import rearrange
11 |
12 | import torchvision
13 | import torchvision.transforms as T
14 | from torchvision.utils import save_image
15 |
16 | from transfusion_pytorch import Transfusion, print_modality_sample
17 |
18 | # hf related
19 |
20 | from datasets import load_dataset
21 | from diffusers.models import AutoencoderKL
22 |
23 | vae = AutoencoderKL.from_pretrained("./path/to/your/autoencoder", subfolder = "vae")
24 |
25 | class Encoder(Module):
26 | def __init__(self, vae):
27 | super().__init__()
28 | self.vae = vae
29 |
30 | def forward(self, image):
31 | with torch.no_grad():
32 | latent = self.vae.encode(image * 2 - 1)
33 |
34 | return 0.18215 * latent.latent_dist.sample()
35 |
36 | class Decoder(Module):
37 | def __init__(self, vae):
38 | super().__init__()
39 | self.vae = vae
40 |
41 | def forward(self, latents):
42 | latents = (1 / 0.18215) * latents
43 |
44 | with torch.no_grad():
45 | image = self.vae.decode(latents).sample
46 |
47 | return (image / 2 + 0.5).clamp(0, 1)
48 |
49 | # results folder
50 |
51 | rmtree('./results', ignore_errors = True)
52 | results_folder = Path('./results')
53 | results_folder.mkdir(exist_ok = True, parents = True)
54 |
55 | # constants
56 |
57 | SAMPLE_EVERY = 100
58 |
59 | with open("./data/flowers/labels.txt", "r") as file:
60 | content = file.read()
61 |
62 | LABELS_TEXT = content.split('\n')
63 |
64 | # functions
65 |
66 | def divisible_by(num, den):
67 | return (num % den) == 0
68 |
69 | def decode_token(token):
70 | return str(chr(max(32, token)))
71 |
72 | def decode_tokens(tokens: Tensor) -> str:
73 | return "".join(list(map(decode_token, tokens.tolist())))
74 |
75 | def encode_tokens(str: str) -> Tensor:
76 | return tensor([*bytes(str, 'UTF-8')])
77 |
78 | # encoder / decoder
79 |
80 | model = Transfusion(
81 | num_text_tokens = 256,
82 | dim_latent = 4,
83 | channel_first_latent = True,
84 | modality_default_shape = (8, 8),
85 | modality_encoder = Encoder(vae),
86 | modality_decoder = Decoder(vae),
87 | pre_post_transformer_enc_dec = (
88 | nn.Conv2d(4, 128, 3, 2, 1),
89 | nn.ConvTranspose2d(128, 4, 3, 2, 1, output_padding = 1),
90 | ),
91 | add_pos_emb = False,
92 | modality_num_dim = 2,
93 | reconstruction_loss_weight = 0.1,
94 | transformer = dict(
95 | dim = 128,
96 | depth = 8,
97 | dim_head = 64,
98 | heads = 8,
99 | )
100 | ).cuda()
101 |
102 | ema_model = model.create_ema(0.9)
103 |
104 | class FlowersDataset(Dataset):
105 | def __init__(self, image_size):
106 | self.ds = load_dataset("nelorth/oxford-flowers")['train']
107 |
108 | self.transform = T.Compose([
109 | T.Resize((image_size, image_size)),
110 | T.PILToTensor(),
111 | T.Lambda(lambda t: t / 255.)
112 | ])
113 |
114 | def __len__(self):
115 | return len(self.ds)
116 |
117 | def __getitem__(self, idx):
118 | sample = self.ds[idx]
119 | pil = sample['image']
120 |
121 | labels_int = sample['label']
122 | labels_text = LABELS_TEXT[labels_int]
123 |
124 | tensor = self.transform(pil)
125 | return encode_tokens(labels_text), tensor
126 |
127 | def cycle(iter_dl):
128 | while True:
129 | for batch in iter_dl:
130 | yield batch
131 |
132 | dataset = FlowersDataset(128)
133 |
134 | dataloader = model.create_dataloader(dataset, batch_size = 4, shuffle = True)
135 |
136 | iter_dl = cycle(dataloader)
137 |
138 | optimizer = Adam(model.parameters(), lr = 8e-4)
139 |
140 | # train loop
141 |
142 | for step in range(1, 100_000 + 1):
143 |
144 | for _ in range(4):
145 | loss = model.forward(next(iter_dl))
146 | (loss / 4).backward()
147 |
148 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
149 |
150 | optimizer.step()
151 | optimizer.zero_grad()
152 |
153 | ema_model.update()
154 |
155 | print(f'{step}: {loss.item():.3f}')
156 |
157 | if divisible_by(step, SAMPLE_EVERY):
158 | sample = ema_model.sample()
159 |
160 | print_modality_sample(sample)
161 |
162 | if len(sample) < 3:
163 | continue
164 |
165 | text_tensor, maybe_image, *_ = sample
166 |
167 | if not isinstance(maybe_image, tuple):
168 | continue
169 |
170 | _, image = maybe_image
171 | text_tensor = text_tensor[text_tensor < 256] # todo: offer a utility function for removing meta tags and special tokens
172 |
173 | text = decode_tokens(text_tensor)
174 | filename = str(results_folder / f'{step}.{text}.png')
175 |
176 | save_image(
177 | image.detach().cpu(),
178 | filename
179 | )
180 |
--------------------------------------------------------------------------------
/train_mnist.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import tensor
6 | from torch.nn import Module
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.optim import Adam
9 |
10 | from einops import rearrange
11 |
12 | import torchvision
13 | import torchvision.transforms as T
14 | from torchvision.utils import save_image
15 |
16 | from transfusion_pytorch.transfusion import Transfusion, print_modality_sample
17 |
18 | rmtree('./results', ignore_errors = True)
19 | results_folder = Path('./results')
20 | results_folder.mkdir(exist_ok = True, parents = True)
21 |
22 | # constants
23 |
24 | IMAGE_AFTER_TEXT = True # False for captioning, True for text-to-image
25 | USE_PROMPT = False # whether to use prompting, or synthesize from start token
26 | NUM_TRAIN_STEPS = 20_000
27 | SAMPLE_EVERY = 250
28 | CHANNEL_FIRST = True
29 |
30 | # functions
31 |
32 | def divisible_by(num, den):
33 | return (num % den) == 0
34 |
35 | # encoder / decoder
36 |
37 | class Encoder(Module):
38 | def forward(self, x):
39 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... h w (p1 p2)', p1 = 2, p2 = 2)
40 |
41 | if CHANNEL_FIRST:
42 | x = rearrange(x, 'b ... d -> b d ...')
43 |
44 | return x * 2 - 1
45 |
46 | class Decoder(Module):
47 | def forward(self, x):
48 |
49 | if CHANNEL_FIRST:
50 | x = rearrange(x, 'b d ... -> b ... d')
51 |
52 | x = rearrange(x, '... h w (p1 p2) -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2)
53 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
54 |
55 | model = Transfusion(
56 | num_text_tokens = 10,
57 | dim_latent = 4,
58 | modality_default_shape = (14, 14),
59 | modality_encoder = Encoder(),
60 | modality_decoder = Decoder(),
61 | add_pos_emb = True,
62 | modality_num_dim = 2,
63 | channel_first_latent = CHANNEL_FIRST,
64 | transformer = dict(
65 | dim = 64,
66 | depth = 4,
67 | dim_head = 32,
68 | heads = 8,
69 | )
70 | ).cuda()
71 |
72 | ema_model = model.create_ema()
73 |
74 | class MnistDataset(Dataset):
75 | def __init__(self):
76 | self.mnist = torchvision.datasets.MNIST(
77 | './data/mnist',
78 | download = True
79 | )
80 |
81 | def __len__(self):
82 | return len(self.mnist)
83 |
84 | def __getitem__(self, idx):
85 | pil, labels = self.mnist[idx]
86 | digit_tensor = T.PILToTensor()(pil)
87 | output = tensor(labels), (digit_tensor / 255).float()
88 |
89 | if IMAGE_AFTER_TEXT:
90 | return output
91 |
92 | first, second = output
93 | return second, first
94 |
95 | def cycle(iter_dl):
96 | while True:
97 | for batch in iter_dl:
98 | yield batch
99 |
100 | def collate_fn(data):
101 | data = [*map(list, data)]
102 | return data
103 |
104 | dataset = MnistDataset()
105 | dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True)
106 |
107 | iter_dl = cycle(dataloader)
108 |
109 | optimizer = Adam(model.parameters(), lr = 3e-4)
110 |
111 | # train loop
112 |
113 | for step in range(1, NUM_TRAIN_STEPS + 1):
114 | model.train()
115 |
116 | loss = model(next(iter_dl))
117 | loss.backward()
118 |
119 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
120 |
121 | optimizer.step()
122 | optimizer.zero_grad()
123 |
124 | ema_model.update()
125 |
126 | print(f'{step}: {loss.item():.3f}')
127 |
128 | # eval
129 |
130 | if divisible_by(step, SAMPLE_EVERY):
131 |
132 | if not USE_PROMPT:
133 | # sampling from start to finish
134 |
135 | one_multimodal_sample = ema_model.sample(max_length = 384)
136 | else:
137 | # sampling using prompt
138 | # which differs depending on which comes first, text or images
139 |
140 | if IMAGE_AFTER_TEXT:
141 |
142 | text_label = torch.randint(0, 10, ()).cuda()
143 | one_multimodal_sample = ema_model.sample(prompt = text_label, max_length = 384)
144 |
145 | else:
146 |
147 | rand_batch = next(iter_dl)
148 | rand_image = rand_batch[0][0]
149 |
150 | one_multimodal_sample = ema_model.sample(prompt = rand_image, max_length = 384)
151 |
152 | # make sure modality sample overall order of modalities look correct
153 |
154 | print_modality_sample(one_multimodal_sample)
155 |
156 | if len(one_multimodal_sample) < 2:
157 | continue
158 |
159 | if IMAGE_AFTER_TEXT:
160 | maybe_label, maybe_image, *_ = one_multimodal_sample
161 | else:
162 | _, maybe_image, maybe_label = one_multimodal_sample
163 |
164 | filename = f'{step}.{maybe_label[1].item()}.png'
165 |
166 | save_image(
167 | maybe_image[1].cpu(),
168 | str(results_folder / filename),
169 | )
170 |
--------------------------------------------------------------------------------
/train_mnist_vae.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import nn, tensor
6 | from torch.nn import Module
7 | import torch.nn.functional as F
8 | from torch.utils.data import Dataset, DataLoader
9 | from torch.optim import Adam
10 |
11 | from einops import rearrange
12 | from einops.layers.torch import Rearrange
13 |
14 | import torchvision
15 | import torchvision.transforms as T
16 | from torchvision.utils import save_image
17 |
18 | from tqdm import tqdm
19 |
20 | from transfusion_pytorch import Transfusion, print_modality_sample
21 |
22 | rmtree('./results', ignore_errors = True)
23 | results_folder = Path('./results')
24 | results_folder.mkdir(exist_ok = True, parents = True)
25 |
26 | # functions
27 |
28 | def divisible_by(num, den):
29 | return (num % den) == 0
30 |
31 | def cycle(iter_dl):
32 | while True:
33 | for batch in iter_dl:
34 | yield batch
35 |
36 | # dataset
37 |
38 | class MnistDataset(Dataset):
39 | def __init__(self):
40 | self.mnist = torchvision.datasets.MNIST(
41 | './data/mnist',
42 | download = True
43 | )
44 |
45 | self.transform = T.Compose([
46 | T.PILToTensor(),
47 | T.RandomResizedCrop((28, 28), scale = (0.8, 1.))
48 | ])
49 |
50 | def __len__(self):
51 | return len(self.mnist)
52 |
53 | def __getitem__(self, idx):
54 | pil, labels = self.mnist[idx]
55 | digit_tensor = self.transform(pil)
56 | return tensor(labels), (digit_tensor / 255).float()
57 |
58 | dataset = MnistDataset()
59 |
60 | # contrived encoder / decoder with layernorm at bottleneck
61 |
62 | autoencoder_train_steps = 15_000
63 | dim_latent = 16
64 |
65 | class Normalize(Module):
66 | def forward(self, x):
67 | return F.normalize(x, dim = -1)
68 |
69 | encoder = nn.Sequential(
70 | nn.Conv2d(1, 4, 3, padding = 1),
71 | nn.Conv2d(4, 8, 4, 2, 1),
72 | nn.ReLU(),
73 | nn.Dropout(0.05),
74 | nn.Conv2d(8, dim_latent, 1),
75 | Rearrange('b d ... -> b ... d'),
76 | Normalize()
77 | ).cuda()
78 |
79 | decoder = nn.Sequential(
80 | Rearrange('b ... d -> b d ...'),
81 | nn.Conv2d(dim_latent, 8, 1),
82 | nn.ReLU(),
83 | nn.ConvTranspose2d(8, 4, 4, 2, 1),
84 | nn.Conv2d(4, 1, 3, padding = 1),
85 | ).cuda()
86 |
87 | # train autoencoder
88 |
89 | autoencoder_optimizer = Adam([*encoder.parameters(), *decoder.parameters()], lr = 3e-4)
90 | autoencoder_dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)
91 |
92 | autoencoder_iter_dl = cycle(autoencoder_dataloader)
93 |
94 | print('training autoencoder')
95 |
96 | with tqdm(total = autoencoder_train_steps) as pbar:
97 | for _ in range(autoencoder_train_steps):
98 | _, images = next(autoencoder_iter_dl)
99 | images = images.cuda()
100 |
101 | latents = encoder(images)
102 | latents = latents.lerp(torch.randn_like(latents), torch.rand_like(latents) * 0.2) # add a bit of noise to latents
103 | reconstructed = decoder(latents)
104 |
105 | loss = F.mse_loss(images, reconstructed)
106 |
107 | loss.backward()
108 |
109 | pbar.set_description(f'loss: {loss.item():.5f}')
110 |
111 | autoencoder_optimizer.step()
112 | autoencoder_optimizer.zero_grad()
113 |
114 | pbar.update()
115 |
116 | # transfusion
117 |
118 | model = Transfusion(
119 | num_text_tokens = 10,
120 | dim_latent = dim_latent,
121 | modality_default_shape = (14, 14),
122 | modality_encoder = encoder,
123 | modality_decoder = decoder,
124 | add_pos_emb = True,
125 | modality_num_dim = 2,
126 | transformer = dict(
127 | dim = 64,
128 | depth = 4,
129 | dim_head = 32,
130 | heads = 8,
131 | )
132 | ).cuda()
133 |
134 | # training transfusion
135 |
136 | dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True)
137 | iter_dl = cycle(dataloader)
138 |
139 | optimizer = Adam(model.parameters_without_encoder_decoder(), lr = 3e-4)
140 |
141 | # train loop
142 |
143 | transfusion_train_steps = 25_000
144 |
145 | print('training transfusion with autoencoder')
146 |
147 | with tqdm(total = transfusion_train_steps) as pbar:
148 | for index in range(transfusion_train_steps):
149 | step = index + 1
150 |
151 | model.train()
152 |
153 | loss = model(next(iter_dl))
154 | loss.backward()
155 |
156 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
157 |
158 | optimizer.step()
159 | optimizer.zero_grad()
160 |
161 | pbar.set_description(f'loss: {loss.item():.3f}')
162 |
163 | pbar.update()
164 |
165 | # eval
166 |
167 | if divisible_by(step, 500):
168 | one_multimodal_sample = model.sample(max_length = 10)
169 |
170 | print_modality_sample(one_multimodal_sample)
171 |
172 | if len(one_multimodal_sample) < 2:
173 | continue
174 |
175 | maybe_label, maybe_image, *_ = one_multimodal_sample
176 |
177 | filename = f'{step}.{maybe_label[1].item()}.png'
178 |
179 | save_image(
180 | maybe_image[1].cpu().clamp(min = 0., max = 1.),
181 | str(results_folder / filename),
182 | )
183 |
--------------------------------------------------------------------------------
/train_mnist_with_unet.py:
--------------------------------------------------------------------------------
1 | from shutil import rmtree
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import tensor, nn
6 | from torch.nn import Module
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.optim import Adam
9 |
10 | from einops import rearrange
11 |
12 | import torchvision
13 | import torchvision.transforms as T
14 | from torchvision.utils import save_image
15 |
16 | from transfusion_pytorch import Transfusion, print_modality_sample
17 |
18 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19 |
20 | rmtree('./results', ignore_errors = True)
21 | results_folder = Path('./results')
22 | results_folder.mkdir(exist_ok = True, parents = True)
23 |
24 | # constants
25 |
26 | IMAGE_AFTER_TEXT = False
27 | NUM_TRAIN_STEPS = 20_000
28 | SAMPLE_EVERY = 500
29 |
30 | # functions
31 |
32 | def divisible_by(num, den):
33 | return (num % den) == 0
34 |
35 | # encoder / decoder
36 |
37 | class Encoder(Module):
38 | def forward(self, x):
39 | x = rearrange(x, '... 1 (h p1) (w p2) -> ... (p1 p2) h w', p1 = 2, p2 = 2)
40 | return x * 2 - 1
41 |
42 | class Decoder(Module):
43 | def forward(self, x):
44 | x = rearrange(x, '... (p1 p2) h w -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14)
45 | return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
46 |
47 | model = Transfusion(
48 | num_text_tokens = 10,
49 | dim_latent = 4,
50 | modality_default_shape = (14, 14),
51 | modality_encoder = Encoder(),
52 | modality_decoder = Decoder(),
53 | pre_post_transformer_enc_dec = (
54 | nn.Conv2d(4, 64, 3, 2, 1),
55 | nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1),
56 | ),
57 | add_pos_emb = True,
58 | modality_num_dim = 2,
59 | channel_first_latent = True,
60 | transformer = dict(
61 | dim = 64,
62 | depth = 4,
63 | dim_head = 32,
64 | heads = 8,
65 | )
66 | ).to(device)
67 |
68 | ema_model = model.create_ema()
69 |
70 | class MnistDataset(Dataset):
71 | def __init__(self):
72 | self.mnist = torchvision.datasets.MNIST(
73 | './data/mnist',
74 | download = True
75 | )
76 |
77 | def __len__(self):
78 | return len(self.mnist)
79 |
80 | def __getitem__(self, idx):
81 | pil, labels = self.mnist[idx]
82 | digit_tensor = T.PILToTensor()(pil)
83 | output = tensor(labels), (digit_tensor / 255).float()
84 |
85 | if not IMAGE_AFTER_TEXT:
86 | return output
87 |
88 | first, second = output
89 | return second, first
90 |
91 | def cycle(iter_dl):
92 | while True:
93 | for batch in iter_dl:
94 | yield batch
95 |
96 | def collate_fn(data):
97 | data = [*map(list, data)]
98 | return data
99 |
100 | dataset = MnistDataset()
101 | dataloader = model.create_dataloader(dataset, batch_size = 16, shuffle = True)
102 |
103 | iter_dl = cycle(dataloader)
104 |
105 | optimizer = Adam(model.parameters(), lr = 3e-4)
106 |
107 | # train loop
108 |
109 | for step in range(1, NUM_TRAIN_STEPS + 1):
110 | model.train()
111 |
112 | loss = model(next(iter_dl))
113 | loss.backward()
114 |
115 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
116 |
117 | optimizer.step()
118 | optimizer.zero_grad()
119 |
120 | ema_model.update()
121 |
122 | print(f'{step}: {loss.item():.3f}')
123 |
124 | # eval
125 |
126 | if divisible_by(step, SAMPLE_EVERY):
127 | one_multimodal_sample = ema_model.sample(max_length = 384)
128 |
129 | print_modality_sample(one_multimodal_sample)
130 |
131 | if len(one_multimodal_sample) < 2:
132 | continue
133 |
134 | if IMAGE_AFTER_TEXT:
135 | _, maybe_image, maybe_label = one_multimodal_sample
136 | else:
137 | maybe_label, maybe_image, *_ = one_multimodal_sample
138 |
139 | filename = f'{step}.{maybe_label[1].item()}.png'
140 |
141 | save_image(
142 | maybe_image[1].cpu(),
143 | str(results_folder / filename),
144 | )
145 |
--------------------------------------------------------------------------------
/train_text_only.py:
--------------------------------------------------------------------------------
1 | import math
2 | import gzip
3 | import random
4 | import tqdm
5 | import numpy as np
6 |
7 | import torch
8 | from torch.optim import Adam
9 | from torch import Tensor
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 | from transfusion_pytorch import Transfusion
13 |
14 | # constants
15 |
16 | NUM_BATCHES = int(1e5)
17 | BATCH_SIZE = 4
18 | GRAD_ACCUM_EVERY = 4
19 | LEARNING_RATE = 1e-4
20 | VALIDATE_EVERY = 100
21 | PRIME_LENGTH = 64
22 | GENERATE_EVERY = 500
23 | GENERATE_LENGTH = 256
24 | SEQ_LEN = 256
25 |
26 | # helpers
27 |
28 | def exists(v):
29 | return v is not None
30 |
31 | def divisible_by(num, den):
32 | return (num % den) == 0
33 |
34 | def cycle(loader):
35 | while True:
36 | for data in loader:
37 | yield data
38 |
39 | def decode_token(token):
40 | return str(chr(max(32, token)))
41 |
42 | def decode_tokens(tokens):
43 | return "".join(list(map(decode_token, tokens)))
44 |
45 | # the minGRU char language model
46 |
47 | model = Transfusion(
48 | num_text_tokens = 256,
49 | transformer = dict(
50 | dim = 384,
51 | depth = 8,
52 | dim_head = 64,
53 | heads = 8,
54 | attn_laser = True
55 | )
56 | ).cuda()
57 |
58 | # prepare enwik8 data
59 |
60 | with gzip.open('./data/enwik8/enwik8.gz') as file:
61 | data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
62 | np_train, np_valid = np.split(data, [int(90e6)])
63 | data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
64 |
65 | class TextSamplerDataset(Dataset):
66 | def __init__(self, data, seq_len):
67 | super().__init__()
68 | self.data = data
69 | self.seq_len = seq_len
70 | self.data_length = data.shape[0]
71 |
72 | def __len__(self):
73 | return self.data.size(0) // self.seq_len
74 |
75 | def __getitem__(self, index):
76 | rand_start = torch.randint(0, self.data_length - self.seq_len, (1,))
77 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
78 | return full_seq
79 |
80 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
81 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
82 | train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
83 | val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
84 |
85 | # optimizer
86 |
87 | optim = Adam(model.parameters(), lr = LEARNING_RATE)
88 |
89 | train_loader = cycle(train_loader)
90 | val_loader = cycle(val_loader)
91 |
92 | # training
93 |
94 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
95 | model.train()
96 |
97 | for _ in range(GRAD_ACCUM_EVERY):
98 | data = next(train_loader)
99 |
100 | loss = model(data.cuda())
101 |
102 | (loss / GRAD_ACCUM_EVERY).backward()
103 |
104 | print(f'loss: {loss.item():.3f}')
105 |
106 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
107 |
108 | optim.step()
109 | optim.zero_grad()
110 |
111 | if divisible_by(i, VALIDATE_EVERY):
112 | model.eval()
113 | with torch.no_grad():
114 | valid_data = next(val_loader)
115 | loss = model(valid_data.cuda())
116 | print(f'\nvalid loss: {loss.item():.3f}\n')
117 |
118 | if divisible_by(i, GENERATE_EVERY):
119 | model.eval()
120 |
121 | inp = random.choice(val_dataset)[:PRIME_LENGTH]
122 | inp = inp.cuda()
123 |
124 | prime = decode_tokens(inp)
125 | print(f"\nprime: {prime}\n")
126 |
127 | prompt = inp[None, ...]
128 |
129 | sampled = model.generate_text_only(prompt, GENERATE_LENGTH)
130 |
131 | base_decode_output = decode_tokens(sampled[0])
132 |
133 | print(f"\ngenerated: {base_decode_output}\n")
134 |
--------------------------------------------------------------------------------
/train_toy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import randint, randn
3 | from torch.utils.data import Dataset, DataLoader
4 | from torch.optim import Adam
5 |
6 | from transfusion_pytorch import Transfusion, print_modality_sample
7 |
8 | def divisible_by(num, den):
9 | return (num % den) == 0
10 |
11 | model = Transfusion(
12 | num_text_tokens = 8,
13 | dim_latent = 16,
14 | modality_default_shape = (2,),
15 | transformer = dict(
16 | dim = 64,
17 | depth = 1,
18 | dim_head = 8,
19 | heads = 2
20 | )
21 | ).cuda()
22 |
23 | class MockDataset(Dataset):
24 | def __len__(self):
25 | return 100
26 |
27 | def __getitem__(self, idx):
28 | return torch.ones((1,)).long(), randn(2, 16)
29 |
30 | def cycle(iter_dl):
31 | while True:
32 | for batch in iter_dl:
33 | yield batch
34 |
35 | def collate_fn(data):
36 | data = [*map(list, data)]
37 | return data
38 |
39 | mock_dataset = MockDataset()
40 |
41 | dataloader = DataLoader(mock_dataset, batch_size = 4, collate_fn = collate_fn)
42 | iter_dl = cycle(dataloader)
43 |
44 | optimizer = Adam(model.parameters(), lr = 3e-4)
45 |
46 | # train loop
47 |
48 | for step in range(1, 10_000 + 1):
49 |
50 | loss = model(next(iter_dl))
51 | loss.backward()
52 |
53 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
54 |
55 | optimizer.step()
56 | optimizer.zero_grad()
57 |
58 | print(f'{step}: {loss.item():.3f}')
59 |
60 | # eval
61 |
62 | if divisible_by(step, 100):
63 | one_multimodal_sample = model.sample()
64 | print_modality_sample(one_multimodal_sample)
65 |
--------------------------------------------------------------------------------
/transfusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/transfusion-pytorch/ea045706cc45e2cf5cd1ce4523e5af6a49eff54a/transfusion.png
--------------------------------------------------------------------------------
/transfusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from transfusion_pytorch.transfusion import (
2 | Transfusion,
3 | print_modality_sample,
4 | create_dataloader
5 | )
6 |
--------------------------------------------------------------------------------
/transfusion_pytorch/transfusion.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | """
4 | global ein notation
5 |
6 | b - batch
7 | t - one modality type
8 | m - separate modality instance
9 | n - sequence
10 | d - dimension
11 | l - logits (text)
12 | i, j - sequence (row, col)
13 | p - positions
14 | s - residual streams
15 | """
16 |
17 | import os
18 | import math
19 | from collections import defaultdict
20 |
21 | from random import randrange
22 | from itertools import count
23 | from functools import partial, wraps, cache
24 | from typing import NamedTuple, Callable, Literal
25 |
26 | import torch
27 | import torch.nn.functional as F
28 | from torch import nn, Tensor, tensor, is_tensor, cat, stack
29 | from torch.nn import Module, ModuleList, Linear
30 |
31 | from torch.utils.data import Dataset, DataLoader
32 | from torch.nn.utils.rnn import pad_sequence
33 | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
34 |
35 | from torchdiffeq import odeint
36 |
37 | import einx
38 | from einops.layers.torch import Rearrange
39 | from einops import rearrange, repeat, reduce, einsum, pack, unpack
40 |
41 | from ema_pytorch import EMA
42 |
43 | from axial_positional_embedding import ContinuousAxialPositionalEmbedding
44 |
45 | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
46 |
47 | from hyper_connections import HyperConnections
48 |
49 | from tqdm import tqdm
50 | from loguru import logger
51 |
52 | pad_sequence = partial(pad_sequence, batch_first = True)
53 |
54 | # tensor typing
55 |
56 | import jaxtyping
57 | from jaxtyping import jaxtyped
58 | from beartype import beartype
59 | from beartype.door import is_bearable
60 |
61 | class TorchTyping:
62 | def __init__(self, abstract_dtype):
63 | self.abstract_dtype = abstract_dtype
64 |
65 | def __getitem__(self, shapes: str):
66 | return self.abstract_dtype[Tensor, shapes]
67 |
68 | Float = TorchTyping(jaxtyping.Float)
69 | Int = TorchTyping(jaxtyping.Int)
70 | Bool = TorchTyping(jaxtyping.Bool)
71 |
72 | # maybe flex attention
73 |
74 | try:
75 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask
76 |
77 | if torch.cuda.is_available():
78 | flex_attention = torch.compile(flex_attention)
79 |
80 | except ImportError:
81 | flex_attention = None
82 |
83 | # types
84 |
85 | Scalar = Float['']
86 |
87 | ModalitySample = list[Int[''] | Int['_'] | Float['...'] | tuple[int, Float['...']]]
88 |
89 | ModalityTokenTransform = str | Callable | None
90 |
91 | RawModalityPositions = list[list[tuple[int, int, int]]]
92 |
93 | GetPredFlows = dict[int, list[Callable[[Tensor], Tensor]]]
94 |
95 | class LossBreakdown(NamedTuple):
96 | total: Scalar
97 | text: Scalar
98 | flow: list[Scalar]
99 | velocity: list[Scalar] | None = None
100 | recon: list[Scalar] | None = None
101 |
102 | class ModalityInfo(NamedTuple):
103 | encoder: Module | None
104 | decoder: Module | None
105 | latent_to_model: Module
106 | model_to_latent: Module
107 | add_pos_emb: bool
108 | pos_emb_mlp: Module | None
109 | num_dim: int | None
110 | dim_latent: int
111 | default_shape: tuple[int, ...]
112 | som_id: int
113 | eom_id: int
114 | to_shape_fn: Callable | None
115 | channel_first_latent: bool
116 | modality_type: int
117 |
118 | # helper functions
119 |
120 | def exists(v):
121 | return v is not None
122 |
123 | def default(v, d):
124 | return v if exists(v) else d
125 |
126 | def identity(t):
127 | return t
128 |
129 | def always(val):
130 | def inner(*args, **kwargs):
131 | return val
132 | return inner
133 |
134 | def first(it):
135 | return it[0]
136 |
137 | def join(arr, delimiter = ''):
138 | return delimiter.join(arr)
139 |
140 | def divisible_by(num, den):
141 | return (num % den) == 0
142 |
143 | def cast_tuple(t, length = 1):
144 | return t if isinstance(t, tuple) else ((t,) * length)
145 |
146 | def tree_map_tensor(sample, fn: Callable):
147 | return tree_map(lambda t: t if not is_tensor(t) else fn(t), sample)
148 |
149 | def add_temp_batch_dim(fn: Callable):
150 | @wraps(fn)
151 | def inner(t: Tensor, *args, **kwargs) -> Tensor:
152 | t = rearrange(t, '... -> 1 ...')
153 | out = fn(t, *args, **kwargs)
154 | out = rearrange(out, '1 ... -> ...')
155 | return out
156 | return inner
157 |
158 | def pack_with_inverse(t, pattern):
159 | packed, packed_shape = pack(t, pattern)
160 |
161 | def inverse(out, inv_pattern = None):
162 | inv_pattern = default(inv_pattern, pattern)
163 | return unpack(out, packed_shape, inv_pattern)
164 |
165 | return packed, inverse
166 |
167 | def pack_one_with_inverse(t, pattern):
168 | packed, packed_shape = pack([t], pattern)
169 |
170 | def inverse(out, inv_pattern = None):
171 | inv_pattern = default(inv_pattern, pattern)
172 | return unpack(out, packed_shape, inv_pattern)[0]
173 |
174 | return packed, inverse
175 |
176 | def eval_decorator(fn):
177 | def inner(self, *args, **kwargs):
178 | was_training = self.training
179 | self.eval()
180 | out = fn(self, *args, **kwargs)
181 | self.train(was_training)
182 | return out
183 | return inner
184 |
185 | # maybe typecheck
186 |
187 | typecheck = jaxtyped(typechecker = beartype) if os.environ.get('TYPECHECK', '').lower() in ('1', 'true') else identity
188 |
189 | # default function for constituting modality shape from string
190 |
191 | def default_to_modality_shape_fn(maybe_shape_str) -> tuple[int, ...]:
192 | return tuple([*map(int, maybe_shape_str.split(','))])
193 |
194 | # default function for translating modality length to times (noise level, where 0 is highest noise)
195 |
196 | def random_modality_length_to_time_fn(num_modalities: Int['b']) -> Float['b m']:
197 | batch = num_modalities.shape[0]
198 | device = num_modalities.device
199 | total_modalities = modality_length.amax().item()
200 | return torch.rand((batch, total_modalities), device = device)
201 |
202 | def default_modality_length_to_time_fn(num_modalities: Int['b']) -> Float['b m']:
203 | batch, device = num_modalities.shape[0], num_modalities.device
204 | total_modalities = num_modalities.amax().item()
205 |
206 | if total_modalities == 0:
207 | return torch.empty((batch, 0), device = device, dtype = torch.float)
208 |
209 | rand_num_modalities = torch.floor(torch.rand_like(num_modalities.float()) * num_modalities)
210 | seq = torch.arange(total_modalities, device = device)
211 |
212 | prev_decoded_modality = einx.less('m, b -> b m', seq, rand_num_modalities)
213 | curr_modality_rand_time = torch.rand_like(num_modalities.float())
214 |
215 | # in paper, they fix previous decoded modalities to 500 / 1000 steps for discrete ddpm, here using flow matching with times 0 - 1 so corresponds to 0.5
216 | return einx.where('b m, , b -> b m', prev_decoded_modality, 0.5, curr_modality_rand_time)
217 |
218 | # pretty print
219 |
220 | def concat_contiguous_text(
221 | modality_sample: ModalitySample
222 | ) -> ModalitySample:
223 | """ within a modality sample, any two tensors of type int / long will be concatted together if next to each other, so all text is followed by a modality, and all modality followed by text """
224 |
225 | output = []
226 |
227 | for modality in modality_sample:
228 | if (
229 | len(output) > 0 and
230 | is_tensor(output[-1]) and is_tensor(modality) and
231 | output[-1].dtype == modality.dtype and
232 | modality.dtype in (torch.int, torch.long)
233 | ):
234 | packed_text, _ = pack((output[-1], modality), '*')
235 | output[-1] = packed_text
236 |
237 | else:
238 | output.append(modality)
239 |
240 | return output
241 |
242 | def print_modality_sample(
243 | modality_sample: ModalitySample
244 | ):
245 | output = []
246 |
247 | for sample in modality_sample:
248 | if isinstance(sample, tuple):
249 | modality_type, sample = sample
250 | output.append((f'modality:{modality_type}', sample.shape))
251 | elif sample.dtype in (torch.int, torch.long):
252 | output.append(('text', sample.shape))
253 | else:
254 | output.append(('modality', sample.shape))
255 |
256 | logger.info(output)
257 |
258 | # character based tokenizer
259 |
260 | def char_tokenize(
261 | text: str,
262 | device = None,
263 | offset = 0
264 | ) -> Tensor:
265 | tokenized = tensor([*map(ord, text)], device = device) + offset
266 | return tokenized.long()
267 |
268 | def decode_chars(
269 | t: Tensor,
270 | offset = 0,
271 | ) -> str:
272 | byte_list = (t - offset).clamp(min = 0, max = 127).tolist()
273 | return ''.join([*map(chr, byte_list)])
274 |
275 | def get_tokens_since_rightmost_id(
276 | t: Tensor,
277 | rightmost_id: int
278 | ) -> Tensor:
279 | """
280 | ex. [9] [2] [8] [4] [7]
281 | 2 would return [8] [4] [7]
282 | """
283 |
284 | mask = t == rightmost_id
285 |
286 | if not mask.any():
287 | return t[0:0] # return empty tensor if no id found
288 |
289 | reverse_cumsum = mask.flip(dims = (0,)).cumsum(dim = 0).flip(dims = (0,))
290 | after_right_mask = reverse_cumsum == 0
291 | return t[after_right_mask]
292 |
293 | # tensor helpers
294 |
295 | def l2norm(t):
296 | return F.normalize(t, dim = -1)
297 |
298 | def softclamp(t, value = 50.):
299 | return (t / value).tanh() * value
300 |
301 | def max_neg_value(t):
302 | return -torch.finfo(t.dtype).max
303 |
304 | def append_dims(t, ndims):
305 | return t.reshape(*t.shape, *((1,) * ndims))
306 |
307 | def is_empty(t):
308 | return t.numel() == 0
309 |
310 | def log(t, eps = 1e-20):
311 | return torch.log(t.clamp(min = eps))
312 |
313 | def gumbel_noise(t):
314 | noise = torch.rand_like(t)
315 | return -log(-log(noise))
316 |
317 | def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
318 | noise = gumbel_noise(t) * int(temperature > 0)
319 | return (t / temperature + noise).argmax(dim = dim, keepdim = keepdim)
320 |
321 | # dataloader related
322 |
323 | def collate_fn(data):
324 | return [*map(list, data)]
325 |
326 | @typecheck
327 | def create_dataloader(dataset: Dataset, **kwargs) -> DataLoader:
328 | return DataLoader(dataset, collate_fn = collate_fn, **kwargs)
329 |
330 | # flex attention mask construction
331 | # https://pytorch.org/blog/flexattention/
332 |
333 | def causal(b, h, q_idx, kv_idx):
334 | return q_idx >= kv_idx
335 |
336 | def modality(offset, length):
337 |
338 | def mask_fn(b, h, q_idx, kv_idx):
339 | return (q_idx >= offset) & (kv_idx < (offset + length))
340 |
341 | return mask_fn
342 |
343 | def transfusion_attn_mask(modalities: Int['b m 3']):
344 | modalities = modalities.long()
345 |
346 | def mask_mod(b, h, q_idx, kv_idx):
347 | mask = causal(b, h, q_idx, kv_idx)
348 |
349 | modality_batch = modalities[b]
350 |
351 | for _, offset, length in modality_batch:
352 | mask = mask | modality(offset, length)(b, h, q_idx, kv_idx)
353 |
354 | return mask
355 |
356 | return mask_mod
357 |
358 | def softcap_score_mod(softcap):
359 | def inner(score, b, h, q_idx, kv_idx):
360 | score = score / softcap
361 | score = torch.tanh(score)
362 | score = score * softcap
363 | return score
364 | return inner
365 |
366 | # converting a raw list of modality offsets and lengths to tensor
367 |
368 | @typecheck
369 | def modality_positions_to_tensor(
370 | modalities: RawModalityPositions,
371 | pad_value = 0,
372 | device = None
373 | ) -> Int['b m 2'] | Int['b m 3']:
374 |
375 | modalities: list[Tensor] = [tensor(modality, device = device) for modality in modalities]
376 | modalities = pad_sequence(modalities, padding_value = pad_value)
377 |
378 | if modalities.ndim == 2:
379 | modalities = modalities.reshape(*modalities.shape, 3)
380 |
381 | return modalities.long()
382 |
383 | # sanitizing modalities tensor, making sure it is ordered
384 |
385 | @typecheck
386 | def order_modality_positions_by_seq_offset(
387 | modalities: Int['b m 3']
388 | ) -> tuple[Int['b m 3'], Int['b m']]:
389 |
390 | modality_type, offsets, lengths = modalities.unbind(dim = -1)
391 |
392 | no_modality_mask = lengths <= 0 # there may be uneven number of modalities per batch sample
393 | offsets_to_sort = offsets.masked_fill(no_modality_mask, 1e10)
394 | _, sorted_indices = offsets_to_sort.sort(dim = -1)
395 |
396 | # sort by ascending offset
397 |
398 | modalities = einx.get_at('b [mi] ..., b mo -> b mo ...', modalities, sorted_indices)
399 | return modalities, sorted_indices
400 |
401 | # deriving relative positions from modality positions
402 | # ex. given a sequence of 10 with an image at offset 3 with length 4 - [t] [t] [t] [i] [i] [i] [i] [t] [t] [t]
403 | # relative positions for rotary will be [0] [1] [2] [3] [3] [3] [3] [4] [5] [6]
404 | # rationale is that each modality will need the same position so there is no distance when conducting bidirectional attention, but should still have a relative distance to other text tokens and modalities
405 |
406 | def derive_rotary_positions_from_modality_positions(
407 | seq_len: int,
408 | modalities: Int['b m 3']
409 | ) -> Int['b n']:
410 |
411 | device = modalities.device
412 |
413 | modality_mask = modality_positions_to_is_modality_mask(seq_len, modalities, offset = torch.tensor([1, -1]))
414 | is_any_modality = reduce(modality_mask, 'b t m n -> b n', 'any')
415 |
416 | return torch.arange(seq_len, device = device) - is_any_modality.cumsum(dim = -1)
417 |
418 | # modality tokens are given as list of tensors, can be then be embedded into the modality tokens for attending alongside text tokens
419 |
420 | @typecheck
421 | def embed_modality_tokens(
422 | seq_len: int,
423 | dim: int,
424 | modality_tokens: list[list[Float['...']]],
425 | modalities: Int['b m 3'],
426 | modality_id: int,
427 | channel_first: bool
428 | ) -> Float['b n d']:
429 |
430 | batch, device = modalities.shape[0], modalities.device
431 |
432 | shape = (batch, seq_len, dim) if not channel_first else (batch, dim, seq_len)
433 | output = torch.zeros(shape, device = device)
434 |
435 | for batch_ind, (one_modality, one_modality_token) in enumerate(zip(modalities, modality_tokens)):
436 | for (modality_type, offset, length), batch_modality_token in zip(one_modality, one_modality_token):
437 |
438 | if modality_id != modality_type or length <= 0:
439 | continue
440 |
441 | modality_shape = batch_modality_token.shape
442 |
443 | if channel_first:
444 | mod_dim, *mod_axial_shape = modality_shape
445 | batch_modality_token = rearrange(batch_modality_token, 'd ... -> d (...)')
446 | else:
447 | *mod_axial_shape, mod_dim = modality_shape
448 | batch_modality_token = rearrange(batch_modality_token, '... d -> (...) d')
449 |
450 | mod_length = math.prod(mod_axial_shape)
451 |
452 | assert length == mod_length, f'received a modality of shape {modality_shape} but sequence length in modalities info is {length}'
453 | assert dim == mod_dim, f'received modality [{modality_id}] with shape {modality_shape} but expected dimension of {dim}'
454 |
455 | if channel_first:
456 | output[batch_ind, :, offset:(offset + length)] = batch_modality_token
457 | else:
458 | output[batch_ind, offset:(offset + length), :] = batch_modality_token
459 |
460 | return output
461 |
462 | # functions for managing modality token mask
463 |
464 | @typecheck
465 | def modality_positions_to_is_modality_mask(
466 | seq_len: int,
467 | modalities: Int['b m 3'],
468 | offset: Int['2'] | None = None,
469 | device = None,
470 | num_modalities = 1
471 | ) -> Bool['b t m n']:
472 |
473 | device = modalities.device
474 |
475 | if exists(offset):
476 | offset = F.pad(offset, (1, 0))
477 | modalities = modalities + offset.to(modalities)
478 |
479 | seq = torch.arange(seq_len, device = device)
480 | type_seq = torch.arange(num_modalities, device = device)
481 |
482 | modality_types = modalities[..., 0]
483 |
484 | left, right = modalities[..., 1:].cumsum(dim = -1).unbind(dim = -1)
485 |
486 | is_instance_for_type = einx.equal('b m, t -> b t m', modality_types, type_seq)
487 |
488 | is_modality_along_seq = (
489 | einx.greater_equal('i, b m -> b m i', seq, left) &
490 | einx.less('j, b m -> b m j', seq, right)
491 | )
492 |
493 | return einx.logical_and('b t m, b m n -> b t m n', is_instance_for_type, is_modality_along_seq)
494 |
495 | @typecheck
496 | def naive_attn_mask(
497 | seq_len: int,
498 | modalities: Int['b m 3'],
499 | device = None
500 | ) -> Bool['b i j']:
501 |
502 | _, offsets, length = modalities.unbind(dim = -1)
503 |
504 | seq = torch.arange(seq_len, device = device)
505 |
506 | is_causal = einx.greater_equal('i, j -> i j', seq, seq)
507 |
508 | is_modality = (
509 | einx.greater_equal('i, b m -> b m i 1', seq, offsets) &
510 | einx.less('j, b m -> b m 1 j', seq, offsets + length)
511 | )
512 |
513 | return is_causal | is_modality.any(dim = 1)
514 |
515 | # unet encoder related function
516 |
517 | def stack_same_shape_tensors_with_inverse(tensors: list[Tensor]):
518 |
519 | shape_tensors_dict = defaultdict(list)
520 | shape_batch_dict = defaultdict(int) # also store a shape -> num tensors dictionary to validate inverse function input
521 | inverse_index_list = []
522 |
523 | for tensor in tensors:
524 | shape = tuple(tensor.shape)
525 | batch_el = shape_batch_dict[shape]
526 |
527 | shape_tensors_dict[shape].append(tensor)
528 | shape_batch_dict[shape] += 1
529 |
530 | inverse_index_list.append((shape, batch_el))
531 |
532 | # stack all the tensors with same shapes to have a batch dimension
533 |
534 | shape_tensors_dict = {shape: torch.stack(same_shape_tensors) for shape, same_shape_tensors in shape_tensors_dict.items()}
535 |
536 | # inverse function
537 |
538 | def inverse(
539 | indexed_tensors: dict[tuple[int, ...], Tensor]
540 | ) -> list[Tensor]:
541 |
542 | out_shape_batch_dict = {shape: len(tensors) for shape, tensors in indexed_tensors.items()}
543 |
544 | assert out_shape_batch_dict == shape_batch_dict
545 |
546 | inversed = []
547 |
548 | for shape, batch_el in inverse_index_list:
549 | tensor = indexed_tensors[shape][batch_el]
550 | inversed.append(tensor)
551 |
552 | return inversed
553 |
554 | return shape_tensors_dict, inverse
555 |
556 | def filter_with_inverse(cond, inp):
557 |
558 | indices = set()
559 | filtered = []
560 |
561 | for ind, el in enumerate(inp):
562 | if cond(el):
563 | indices.add(ind)
564 | filtered.append(el)
565 |
566 | def inverse(inverse_inp):
567 | assert len(inverse_inp) == len(filtered)
568 |
569 | output = []
570 | inverse_inp_index = 0
571 |
572 | for ind, el in enumerate(inp):
573 | if ind not in indices:
574 | output.append(el)
575 | continue
576 |
577 | inverse_inp_el = inverse_inp[inverse_inp_index]
578 | output.append(inverse_inp_el)
579 | inverse_inp_index += 1
580 |
581 | return output
582 |
583 | return filtered, inverse
584 |
585 | def apply_fn_modality_type(
586 | fn: Callable,
587 | modalities: ModalitySample | list[ModalitySample],
588 | modality_type = 0,
589 | return_untransformed = False
590 | ) -> ModalitySample | list[ModalitySample]:
591 |
592 | modalities, tree_spec = tree_flatten(modalities, is_leaf = lambda el: isinstance(el, tuple))
593 |
594 | # standardize tuples to (, )
595 |
596 | modalities = [(0, m) if (is_tensor(m) and m.dtype == torch.float) else m for m in modalities]
597 |
598 | # filter for specific modality type to transform
599 |
600 | modalities, inverse_filter = filter_with_inverse(lambda el: isinstance(el, tuple) and el[0] == modality_type, modalities)
601 |
602 | # remove the type
603 |
604 | modalities = [m for _, m in modalities]
605 |
606 | # batch process
607 |
608 | stacked_modalities, inverse_stack = stack_same_shape_tensors_with_inverse(modalities)
609 |
610 | out = {shape: fn(batched_modalities) for shape, batched_modalities in stacked_modalities.items()}
611 |
612 | out = inverse_stack(out)
613 |
614 | # add back the type
615 |
616 | if return_untransformed:
617 | out = [(modality_type, transformed_m, prev_m) for transformed_m, prev_m in zip(out, modalities)]
618 | else:
619 | out = [(modality_type, transformed_m) for transformed_m in out]
620 |
621 | # replace transformed modalities and untree flatten
622 |
623 | out = inverse_filter(out)
624 |
625 | return tree_unflatten(out, tree_spec)
626 |
627 | # sampling related functions
628 |
629 | # min_p for text
630 | # https://arxiv.org/abs/2407.01082
631 |
632 | def min_p_filter(logits, min_p = 0.1):
633 | probs = logits.softmax(dim = -1)
634 | max_probs = probs.amax(dim = -1, keepdim = True)
635 | limit = min_p * max_probs
636 | return torch.where(probs < limit, float('-inf'), logits)
637 |
638 | # random fourier embedding
639 |
640 | class RandomFourierEmbed(Module):
641 | def __init__(self, dim):
642 | super().__init__()
643 | assert divisible_by(dim, 2)
644 | self.dim = dim
645 | self.register_buffer('weights', torch.randn(dim // 2))
646 |
647 | @typecheck
648 | def forward(
649 | self,
650 | times: Float['b n'] | Float['b']
651 | ) -> Float['b n {self.dim+1}']:
652 |
653 | if times.ndim == 1:
654 | times = rearrange(times, 'b -> b 1')
655 |
656 | freqs = einx.multiply('... i, j -> ... i j', times, self.weights) * 2 * torch.pi
657 | fourier_embed, _ = pack((times, freqs.sin(), freqs.cos()), 'b n *')
658 | return fourier_embed
659 |
660 | # adaptive layernorm and ada-ln zero rolled into one wrapper
661 | # from DiT paper and sota for time conditioning for now
662 |
663 | class AdaptiveWrapper(Module):
664 | @beartype
665 | def __init__(
666 | self,
667 | fn: Module,
668 | dim,
669 | dim_cond,
670 | ada_ln_zero_init_bias = -2
671 | ):
672 | super().__init__()
673 | self.fn = fn
674 | self.dim = dim
675 | self.dim_cond = dim_cond
676 |
677 | self.layernorm = nn.LayerNorm(dim, elementwise_affine = False)
678 |
679 | # text will be subjected to normal layernorm bias
680 | # and for output will use layerscale
681 |
682 | self.layernorm_gamma = nn.Parameter(torch.zeros(dim))
683 | self.layerscale = nn.Parameter(torch.zeros(dim))
684 |
685 | # modalities will get the adaptive layernorm + ada-ln zero
686 |
687 | self.to_film = Linear(dim_cond, dim * 2)
688 | self.to_ada_ln_zero = Linear(dim_cond, dim)
689 |
690 | nn.init.zeros_(self.to_film.weight)
691 | nn.init.zeros_(self.to_ada_ln_zero.weight)
692 | nn.init.constant_(self.to_ada_ln_zero.bias, ada_ln_zero_init_bias)
693 |
694 | @typecheck
695 | def forward_text(
696 | self,
697 | x: Float['b n {self.dim}'],
698 | **kwargs
699 | ):
700 | x = self.layernorm(x)
701 |
702 | x = x * (self.layernorm_gamma + 1.)
703 |
704 | out = self.fn(x, **kwargs)
705 |
706 | (out, *rest), tree_spec = tree_flatten(out)
707 |
708 | out = out * (self.layerscale + 1.)
709 |
710 | out = tree_unflatten((out, *rest), tree_spec)
711 |
712 | return out
713 |
714 | @typecheck
715 | def forward_modality(
716 | self,
717 | x: Float['b n {self.dim}'],
718 | cond: Float['... {self.dim_cond}'],
719 | **kwargs
720 | ):
721 | x = self.layernorm(x)
722 |
723 | gamma, beta = self.to_film(cond).chunk(2, dim = -1)
724 |
725 | modality_tokens = x * (gamma + 1.) + beta
726 |
727 | # attention or feedforwards
728 |
729 | out = self.fn(modality_tokens, **kwargs)
730 |
731 | (out, *rest), tree_spec = tree_flatten(out)
732 |
733 | # take care of conditioning output separately for text vs modality
734 |
735 | modalities_out = out * self.to_ada_ln_zero(cond).sigmoid()
736 |
737 | # take care of function returning cache or value residual
738 |
739 | modalities_out = tree_unflatten((modalities_out, *rest), tree_spec)
740 |
741 | return modalities_out
742 |
743 | @typecheck
744 | def forward(
745 | self,
746 | x: Float['b n {self.dim}'],
747 | cond: Float['... {self.dim_cond}'] | None = None,
748 | is_any_modality: bool | Bool['b n'] | None = None,
749 | modality_only = False,
750 | **kwargs
751 | ):
752 | if exists(cond) and cond.ndim == 2:
753 | cond = rearrange(cond, 'b d -> b 1 d')
754 |
755 | if modality_only:
756 | return self.forward_modality(x, cond = cond, **kwargs)
757 |
758 | assert not (exists(cond) ^ exists(is_any_modality))
759 |
760 | has_modality = exists(is_any_modality)
761 |
762 | if not has_modality:
763 | return self.forward_text(x, **kwargs)
764 |
765 | if isinstance(is_any_modality, bool):
766 | is_any_modality = torch.full((x.shape[:-1]), is_any_modality, device = x.device, dtype = torch.bool)
767 |
768 | is_any_modality = rearrange(is_any_modality, '... -> ... 1')
769 |
770 | x = self.layernorm(x)
771 |
772 | gamma, beta = self.to_film(cond).chunk(2, dim = -1)
773 |
774 | text_tokens = x * (self.layernorm_gamma + 1.)
775 |
776 | modality_tokens = x * (gamma + 1.) + beta
777 |
778 | x = torch.where(is_any_modality, modality_tokens, text_tokens)
779 |
780 | # attention or feedforwards
781 |
782 | out = self.fn(x, **kwargs)
783 |
784 | (out, *rest), tree_spec = tree_flatten(out)
785 |
786 | # take care of conditioning output separately for text vs modality
787 |
788 | text_out = out * (self.layerscale + 1.)
789 |
790 | modalities_out = out * self.to_ada_ln_zero(cond).sigmoid()
791 |
792 | conditioned_out = torch.where(is_any_modality, modalities_out, text_out)
793 |
794 | # take care of function returning cache or value residual
795 |
796 | conditioned_out = tree_unflatten((conditioned_out, *rest), tree_spec)
797 |
798 | return conditioned_out
799 |
800 | # attention
801 |
802 | class RMSNorm(Module):
803 | def __init__(self, dim):
804 | super().__init__()
805 | self.scale = dim ** 0.5
806 | self.gamma = nn.Parameter(torch.zeros(dim))
807 |
808 | def forward(self, x):
809 | return l2norm(x) * self.scale * (self.gamma + 1.) # use unit offset from Ohad Rubin
810 |
811 | class GEGLU(Module):
812 | def forward(self, x):
813 | x, gates = x.chunk(2, dim = -1)
814 | return F.gelu(gates) * x
815 |
816 | def FeedForward(
817 | dim,
818 | expansion_factor = 4.,
819 | dropout = 0.
820 | ):
821 | dim_inner = int(dim * expansion_factor * 2 / 3)
822 | return nn.Sequential(
823 | Linear(dim, dim_inner * 2),
824 | GEGLU(),
825 | nn.Dropout(dropout),
826 | Linear(dim_inner, dim)
827 | )
828 |
829 | class Attention(Module):
830 | def __init__(
831 | self,
832 | dim,
833 | dim_head = 64,
834 | heads = 8,
835 | dropout = 0.,
836 | softcap_value = 50.,
837 | use_flex_attn = False,
838 | gate_values = True,
839 | laser = False,
840 | laser_softclamp_value = 15.,
841 | learned_value_residual_mix = False
842 | ):
843 | super().__init__()
844 | self.scale = dim_head ** -0.5
845 | dim_inner = dim_head * heads
846 |
847 | assert not (use_flex_attn and not exists(flex_attention)), 'flex attention is only available on torch 2.5.0 (nightly) onwards'
848 | self.use_flex_attn = use_flex_attn
849 |
850 | self.to_qkv = nn.Sequential(
851 | Linear(dim, dim_inner * 3, bias = False),
852 | Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
853 | )
854 |
855 | self.to_learned_value_residual = nn.Sequential(
856 | nn.Linear(dim, heads),
857 | nn.Sigmoid(),
858 | Rearrange('b n h -> b h n 1') # add head dimension
859 | ) if learned_value_residual_mix else always(0.5)
860 |
861 | self.to_gates = nn.Sequential(
862 | nn.Linear(dim, heads, bias = False),
863 | Rearrange('b n h -> b h n 1', h = heads)
864 | ) if gate_values else None
865 |
866 | self.softcap_value = softcap_value
867 |
868 | self.laser = laser
869 | self.laser_softclamp_value = laser_softclamp_value
870 |
871 | self.dropout = nn.Dropout(dropout)
872 |
873 | self.to_out = nn.Sequential(
874 | Rearrange('b h n d -> b n (h d)'),
875 | Linear(dim_inner, dim, bias = False)
876 | )
877 |
878 | def forward(
879 | self,
880 | x,
881 | attn_mask: Tensor | None = None, # for manual masking
882 | rotary_emb: Tensor | None = None,
883 | cache: Tensor | None = None,
884 | causal = False,
885 | block_mask = None, # only passed in for flex attention
886 | return_kv_cache = False,
887 | return_values = False,
888 | value_residual: Tensor | None = None
889 | ):
890 | device, input_is_cuda, is_decoding_with_cache = x.device, x.is_cuda, exists(cache)
891 |
892 | should_use_flex_attn = self.use_flex_attn and input_is_cuda
893 |
894 | # handle maybe mask
895 | # if receiving kv cache, assume decoding and turn off all masking
896 |
897 | if is_decoding_with_cache:
898 | block_mask = attn_mask = None
899 |
900 | assert not (exists(block_mask) and exists(attn_mask))
901 | assert not (not self.use_flex_attn and exists(block_mask)), 'you cannot pass in the `block_mask` if `use_flex_attn` was not set to be `True`'
902 |
903 | # project to queries, keys, values
904 |
905 | q, k, v = self.to_qkv(x)
906 |
907 | # value residual
908 |
909 | orig_v = v
910 |
911 | if exists(value_residual):
912 | mix = self.to_learned_value_residual(x)
913 | v = v * mix + value_residual * (1. - mix)
914 |
915 | # handle cache being passed in
916 |
917 | if exists(cache):
918 | cached_k, cached_v = cache
919 | k = cat((cached_k, k), dim = -2)
920 | v = cat((cached_v, v), dim = -2)
921 |
922 | # maybe kv cache
923 |
924 | if return_kv_cache:
925 | kv_cache = stack((k, v))
926 |
927 | # rotary embeddings
928 |
929 | if exists(rotary_emb):
930 | q, k = tuple(apply_rotary_emb(rotary_emb, t, freqs_seq_dim = -2) for t in (q, k))
931 |
932 | # laser attention
933 |
934 | if self.laser:
935 | v = softclamp(v, self.laser_softclamp_value)
936 | v = v.exp()
937 |
938 | # whether to use flex attention or not
939 |
940 | if should_use_flex_attn:
941 | assert not causal, 'causal mask should be constructed in transformer'
942 |
943 | flex_attn_kwargs = dict(block_mask = block_mask)
944 |
945 | if self.softcap_value > 0.:
946 | flex_attn_kwargs.update(score_mod = softcap_score_mod(self.softcap_value))
947 |
948 | out = flex_attention(q, k, v, **flex_attn_kwargs)
949 |
950 | else:
951 | q = q * self.scale
952 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
953 |
954 | sim = softclamp(sim, self.softcap_value)
955 |
956 | mask_value = max_neg_value(sim)
957 |
958 | if causal:
959 | i, j = sim.shape[-2:]
960 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
961 | sim = sim.masked_fill(causal_mask, mask_value)
962 |
963 | if exists(attn_mask):
964 | sim = einx.where('b i j, b h i j, -> b h i j', attn_mask, sim, mask_value)
965 |
966 | attn = sim.softmax(dim = -1)
967 |
968 | attn = self.dropout(attn)
969 |
970 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
971 |
972 | # laser attention
973 |
974 | if self.laser:
975 | out = log(out)
976 |
977 | # maybe gate values
978 |
979 | if exists(self.to_gates):
980 | out = out * self.to_gates(x).sigmoid()
981 |
982 | # combine heads and out
983 |
984 | out = self.to_out(out)
985 |
986 | if return_values:
987 | out = (out, orig_v)
988 |
989 | if not return_kv_cache:
990 | return out
991 |
992 | return out, kv_cache
993 |
994 | class Transformer(Module):
995 | @beartype
996 | def __init__(
997 | self,
998 | dim,
999 | *,
1000 | depth,
1001 | dim_head = 64,
1002 | heads = 8,
1003 | dropout = 0.,
1004 | ff_expansion_factor = 4,
1005 | attn_kwargs: dict = dict(),
1006 | ff_kwargs: dict = dict(),
1007 | attn_laser = False,
1008 | unet_skips = True,
1009 | use_flex_attn = False,
1010 | num_residual_streams = 1,
1011 | num_residual_fracs = 4
1012 | ):
1013 | super().__init__()
1014 | self.use_flex_attn = use_flex_attn
1015 |
1016 | self.dim = dim
1017 | self.dim_head = dim_head
1018 |
1019 | self.to_time_cond = nn.Sequential(
1020 | RandomFourierEmbed(dim),
1021 | Linear(dim + 1, dim * 4),
1022 | nn.SiLU()
1023 | )
1024 |
1025 | # hyper connections
1026 |
1027 | counter = count()
1028 |
1029 | init_residual_fn, self.expand_stream, self.reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_fracs = num_residual_fracs)
1030 |
1031 | # layers
1032 |
1033 | layers = ModuleList([])
1034 |
1035 | for ind in range(depth):
1036 | is_first = ind == 0
1037 |
1038 | is_latter_half = ind >= (depth / 2)
1039 |
1040 | skip_proj = Linear(dim * 2, dim, bias = False) if is_latter_half and unet_skips else None
1041 |
1042 | attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout, use_flex_attn = use_flex_attn, learned_value_residual_mix = not is_first, laser = attn_laser, **attn_kwargs)
1043 |
1044 | ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor, **ff_kwargs)
1045 |
1046 | attn = AdaptiveWrapper(attn, dim = dim, dim_cond = dim * 4)
1047 | ff = AdaptiveWrapper(ff, dim = dim, dim_cond = dim * 4)
1048 |
1049 | attn_residual = init_residual_fn(dim = dim, layer_index = next(counter))
1050 | ff_residual = init_residual_fn(dim = dim, layer_index = next(counter))
1051 |
1052 | layers.append(ModuleList([skip_proj, attn, attn_residual, ff, ff_residual]))
1053 |
1054 | self.layers = layers
1055 | self.norm = RMSNorm(dim)
1056 |
1057 | @typecheck
1058 | def forward(
1059 | self,
1060 | x,
1061 | times: Scalar | Float['b'] | Float['b n'] | None = None,
1062 | attn_mask: Bool['b i j'] | None = None,
1063 | modality_positions: RawModalityPositions | Int['b m 3'] | None = None,
1064 | is_any_modality: bool | Bool['b n'] | None = None,
1065 | rotary_emb: Tensor | None = None,
1066 | cache: Tensor | None = None,
1067 | decode_length: int | None = None,
1068 | modality_only = False,
1069 | causal_mask = False,
1070 | return_kv_cache = False
1071 | ):
1072 | batch, seq_len, device, input_is_cuda = x.shape[0], x.shape[-2], x.device, x.is_cuda
1073 |
1074 | is_decoding_with_cache = exists(cache)
1075 | needs_masking = not is_decoding_with_cache
1076 |
1077 | should_use_flex_attn = input_is_cuda and needs_masking and self.use_flex_attn
1078 |
1079 | assert not (exists(attn_mask) and exists(modality_positions))
1080 |
1081 | # handle time
1082 |
1083 | cond = None
1084 |
1085 | if exists(times):
1086 | if times.ndim == 0:
1087 | times = repeat(times, ' -> b', b = batch)
1088 |
1089 | cond = self.to_time_cond(times)
1090 |
1091 | # create the specialized mask needed for autoregressive text + bidirectional flow attention
1092 |
1093 | attn_mask_kwargs = dict()
1094 |
1095 | if needs_masking:
1096 | if causal_mask:
1097 | if should_use_flex_attn:
1098 | block_mask = create_block_mask(causal, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True, device = device)
1099 | attn_mask_kwargs.update(block_mask = block_mask)
1100 | else:
1101 | attn_mask_kwargs.update(causal = True)
1102 |
1103 | if exists(modality_positions):
1104 | assert not causal_mask
1105 |
1106 | if should_use_flex_attn:
1107 | transfusion_mask_fn = transfusion_attn_mask(modality_positions)
1108 | block_mask = create_block_mask(transfusion_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True, device = device)
1109 | attn_mask_kwargs.update(block_mask = block_mask)
1110 | else:
1111 | attn_mask = naive_attn_mask(seq_len, modality_positions, device = device)
1112 | attn_mask_kwargs.update(attn_mask = attn_mask)
1113 |
1114 | if not exists(is_any_modality) and exists(modality_positions):
1115 | is_any_modality = modality_positions_to_is_modality_mask(seq_len, modality_positions).any(dim = 1)
1116 | is_any_modality = reduce(is_any_modality, 'b t n -> b n', 'any')
1117 |
1118 | # handle kv caching
1119 |
1120 | if is_decoding_with_cache:
1121 | assert exists(decode_length)
1122 |
1123 | x = x[..., -decode_length:, :]
1124 | cond = cond[..., -decode_length:, :]
1125 |
1126 | if is_tensor(is_any_modality):
1127 | is_any_modality = is_any_modality[..., -decode_length:]
1128 |
1129 | # adaptive layernorm kwargs, which handles text and modality tokens differently
1130 |
1131 | adaptive_kwargs = dict(
1132 | cond = cond,
1133 | modality_only = modality_only,
1134 | is_any_modality = is_any_modality
1135 | )
1136 |
1137 | # handle cache
1138 |
1139 | cache = default(cache, (None,))
1140 | iter_cache = iter(cache)
1141 |
1142 | # expand input into multiple residual streams for maybe hyper connection
1143 |
1144 | x = self.expand_stream(x)
1145 |
1146 | # transformer layers as usual, using mask from above
1147 |
1148 | skips = []
1149 | value_residual = None
1150 |
1151 | new_cache = []
1152 |
1153 | depth = len(self.layers)
1154 |
1155 | for ind, (skip_proj, attn, attn_residual, ff, ff_residual) in enumerate(self.layers):
1156 | layer = ind + 1
1157 |
1158 | # skip connection
1159 |
1160 | is_first_half = layer <= (depth // 2)
1161 | is_later_half = not is_first_half
1162 |
1163 | if is_first_half:
1164 | skips.append(x)
1165 |
1166 | if is_later_half and exists(skip_proj):
1167 | skip = skips.pop()
1168 |
1169 | residual = x
1170 | x = cat((x, skip), dim = -1)
1171 | x = skip_proj(x) + residual
1172 |
1173 | # attention and feedforward
1174 |
1175 | x, add_attn_residual = attn_residual(x)
1176 |
1177 | (attn_out, attn_values), kv_cache = attn(
1178 | x,
1179 | rotary_emb = rotary_emb,
1180 | cache = next(iter_cache, None),
1181 | return_kv_cache = True,
1182 | return_values = True,
1183 | value_residual = value_residual,
1184 | **attn_mask_kwargs,
1185 | **adaptive_kwargs
1186 | )
1187 |
1188 | value_residual = default(value_residual, attn_values)
1189 |
1190 | new_cache.append(kv_cache)
1191 |
1192 | x = add_attn_residual(attn_out)
1193 |
1194 | x, add_ff_residual = ff_residual(x)
1195 |
1196 | ff_out = ff(x, **adaptive_kwargs)
1197 |
1198 | x = add_ff_residual(ff_out)
1199 |
1200 | # reduce multiple residual streams for maybe hyper connection
1201 |
1202 | x = self.reduce_stream(x)
1203 |
1204 | assert len(skips) == 0
1205 |
1206 | out = self.norm(x)
1207 |
1208 | if not return_kv_cache:
1209 | return out
1210 |
1211 | return out, stack(new_cache)
1212 |
1213 | # classes
1214 |
1215 | class Transfusion(Module):
1216 | @beartype
1217 | def __init__(
1218 | self,
1219 | *,
1220 | num_text_tokens,
1221 | transformer: dict | Transformer,
1222 | dim_latent: int | tuple[int, ...] | None = None,
1223 | channel_first_latent: bool | tuple[bool, ...] = False,
1224 | add_pos_emb: bool | tuple[bool, ...] = False,
1225 | modality_encoder: Module | tuple[Module, ...] | None = None,
1226 | modality_decoder: Module | tuple[Module, ...] | None = None,
1227 | pre_post_transformer_enc_dec: tuple[Module, Module] | tuple[tuple[Module, Module], ...] | None = None,
1228 | modality_default_shape: tuple[int, ...] | tuple[tuple[int, ...], ...] | None = None,
1229 | fallback_to_default_shape_if_invalid = False,
1230 | modality_num_dim: int | tuple[int, ...] | None = None,
1231 | to_modality_shape_fn: Callable | tuple[Callable, ...] = default_to_modality_shape_fn,
1232 | ignore_index = -1,
1233 | flow_loss_weight = 1.,
1234 | text_loss_weight = 1.,
1235 | velocity_consistency_loss_weight = 0.1,
1236 | reconstruction_loss_weight = 0.,
1237 | modality_encoder_decoder_requires_batch_dim = True, # whether the modality encoder / decoder requires batch dimension, will auto assume it is needed
1238 | odeint_kwargs: dict = dict(
1239 | atol = 1e-5,
1240 | rtol = 1e-5,
1241 | method = 'midpoint'
1242 | ),
1243 | ):
1244 | super().__init__()
1245 |
1246 | # transformer
1247 |
1248 | if isinstance(transformer, dict):
1249 | transformer = Transformer(**transformer)
1250 |
1251 | self.transformer = transformer
1252 | dim = transformer.dim
1253 |
1254 | self.dim = dim
1255 |
1256 | # latent and model dimension not the same
1257 | # make it work for 1 modality for now
1258 |
1259 | dim_latent = default(dim_latent, dim)
1260 |
1261 | self.dim_latents = cast_tuple(dim_latent)
1262 |
1263 | # number of modalities
1264 |
1265 | self.num_modalities = len(self.dim_latents)
1266 |
1267 | # whether the latents are accepted to be channel first or channel last
1268 | # if channel first, will be rearrange(c ... -> ... c -> (...) c)
1269 |
1270 | self.channel_first_latent = cast_tuple(channel_first_latent, self.num_modalities)
1271 | assert len(self.channel_first_latent) == self.num_modalities
1272 |
1273 | # functions for converting the sampled language model meta string back to modality shape of tuple[int, ...]
1274 |
1275 | self.to_modality_shape_fn = cast_tuple(to_modality_shape_fn, self.num_modalities)
1276 |
1277 | # default token lengths for respective modality
1278 | # fallback if the language model does not come up with valid dimensions
1279 |
1280 | if not exists(modality_default_shape) or is_bearable(modality_default_shape, tuple[int, ...]):
1281 | modality_default_shape = (modality_default_shape,) * self.num_modalities
1282 |
1283 | self.modality_default_shape = modality_default_shape
1284 |
1285 | assert len(self.modality_default_shape) == self.num_modalities
1286 |
1287 | self.fallback_to_default_shape_if_invalid = fallback_to_default_shape_if_invalid
1288 |
1289 | # default `modality_num_dim` to `len(modality_default_shape)` if latter is specified but former not
1290 |
1291 | modality_num_dim = default(modality_num_dim, tuple(len(shape) for shape in self.modality_default_shape))
1292 |
1293 | # specifying the number of dimensions for the modality, which will be hard validated
1294 |
1295 | self.modality_num_dim = cast_tuple(modality_num_dim, self.num_modalities)
1296 |
1297 | assert len(self.modality_num_dim) == self.num_modalities
1298 |
1299 | assert all([not exists(ndim) or not exists(shape) or len(shape) == ndim for ndim, shape in zip(self.modality_num_dim, self.modality_default_shape)])
1300 |
1301 | # whether to add an extra axial positional embedding per modality
1302 |
1303 | self.add_pos_emb = cast_tuple(add_pos_emb, self.num_modalities)
1304 | assert len(self.add_pos_emb) == self.num_modalities
1305 |
1306 | self.pos_emb_mlp = ModuleList([])
1307 |
1308 | for modality_add_pos_emb, modality_ndim in zip(self.add_pos_emb, self.modality_num_dim):
1309 |
1310 | if not modality_add_pos_emb:
1311 | self.pos_emb_mlp.append(None)
1312 | continue
1313 |
1314 | assert exists(modality_ndim), '`modality_num_dim` must be set if you wish to automatically inject axial positional embeddings'
1315 |
1316 | pos_generating_mlp = ContinuousAxialPositionalEmbedding(
1317 | dim = dim,
1318 | num_axial_dims = modality_ndim,
1319 | )
1320 |
1321 | self.pos_emb_mlp.append(pos_generating_mlp)
1322 |
1323 | # modality encoders and decoders
1324 |
1325 | modality_encoder = cast_tuple(modality_encoder, 1 if exists(modality_encoder) else self.num_modalities)
1326 | modality_decoder = cast_tuple(modality_decoder, 1 if exists(modality_decoder) else self.num_modalities)
1327 |
1328 | self.modality_encoder = ModuleList(modality_encoder)
1329 | self.modality_decoder = ModuleList(modality_decoder)
1330 |
1331 | assert len(self.modality_encoder) == self.num_modalities
1332 | assert len(self.modality_decoder) == self.num_modalities
1333 |
1334 | # auto handle batch dimension for modality encoder / decoder
1335 |
1336 | self.maybe_add_temp_batch_dim = add_temp_batch_dim if modality_encoder_decoder_requires_batch_dim else identity
1337 |
1338 | # store number of text tokens
1339 |
1340 | self.num_text_tokens = num_text_tokens
1341 |
1342 | # entire "sentence" start and end id
1343 |
1344 | num_text_special_ids = 2
1345 |
1346 | self.sos_id, self.eos_id = num_text_tokens, (num_text_tokens + 1)
1347 |
1348 | # modality meta, start and end tokens - termed [mom] [som] [eom] in this repo
1349 |
1350 | num_modality_special_ids = self.num_modalities * 2
1351 | som_eom_tensor = torch.arange(num_modality_special_ids) + num_text_tokens + num_text_special_ids # shift to the very end
1352 | som_eom_tensor = rearrange(som_eom_tensor, '(start_end m) -> start_end m', start_end = 2)
1353 |
1354 | # modality meta, start and end ids
1355 |
1356 | self.som_ids, self.eom_ids = som_eom_tensor.tolist()
1357 |
1358 | # char tokenizing for modality meta information
1359 |
1360 | meta_token_offset = num_text_tokens + num_text_special_ids + num_modality_special_ids
1361 |
1362 | self.meta_id = meta_token_offset
1363 |
1364 | self.char_tokenizer = partial(char_tokenize, offset = meta_token_offset + 1)
1365 | self.decode_chars = partial(decode_chars, offset = meta_token_offset + 1)
1366 |
1367 | num_meta_tokens = 128 + 1
1368 |
1369 | # prepare pre-post transformer encoder / decoder, for the learnable unets as in paper
1370 |
1371 | if is_bearable(pre_post_transformer_enc_dec, tuple[Module, Module]):
1372 | pre_post_transformer_enc_dec = (pre_post_transformer_enc_dec,)
1373 |
1374 | pre_post_transformer_enc_dec = cast_tuple(pre_post_transformer_enc_dec, self.num_modalities)
1375 | assert len(pre_post_transformer_enc_dec) == self.num_modalities
1376 |
1377 | # latent to model and back
1378 | # by default will be Linear, with or without rearranges depending on channel_first_latent setting
1379 | # can also be overridden for the unet down/up as in the paper with `pre_post_transformer_enc_dec: tuple[Module, Module]`
1380 |
1381 | latent_to_model_projs = []
1382 | model_to_latent_projs = []
1383 |
1384 | for (
1385 | dim_latent,
1386 | one_channel_first_latent,
1387 | enc_dec,
1388 | ) in zip(
1389 | self.dim_latents,
1390 | self.channel_first_latent,
1391 | pre_post_transformer_enc_dec
1392 | ):
1393 |
1394 | pre_attend_enc, post_attend_dec = default(enc_dec, (None, None))
1395 |
1396 | latent_to_model_proj = Linear(dim_latent, dim) if dim_latent != dim else nn.Identity()
1397 | model_to_latent_proj = Linear(dim, dim_latent, bias = False)
1398 |
1399 | if one_channel_first_latent:
1400 | latent_to_model_proj = nn.Sequential(Rearrange('b d ... -> b ... d'), latent_to_model_proj)
1401 | model_to_latent_proj = nn.Sequential(model_to_latent_proj, Rearrange('b ... d -> b d ...'))
1402 |
1403 | if exists(pre_attend_enc):
1404 | pre_attend_enc = nn.Sequential(pre_attend_enc, Rearrange('b d ... -> b ... d'))
1405 |
1406 | if exists(post_attend_dec):
1407 | post_attend_dec = nn.Sequential(Rearrange('b ... d -> b d ...'), post_attend_dec)
1408 |
1409 | latent_to_model_projs.append(default(pre_attend_enc, latent_to_model_proj))
1410 | model_to_latent_projs.append(default(post_attend_dec, model_to_latent_proj))
1411 |
1412 | self.latent_to_model_projs = ModuleList(latent_to_model_projs)
1413 | self.model_to_latent_projs = ModuleList(model_to_latent_projs)
1414 |
1415 | # relative positions
1416 |
1417 | self.rotary_emb = RotaryEmbedding(transformer.dim_head)
1418 |
1419 | # embeddings and un-embeddings
1420 |
1421 | effective_num_text_tokens = num_text_tokens + num_text_special_ids + num_modality_special_ids + num_meta_tokens
1422 |
1423 | self.text_embed = nn.Embedding(effective_num_text_tokens, dim)
1424 |
1425 | self.to_text_logits = Linear(dim, effective_num_text_tokens, bias = False)
1426 |
1427 | text_only_mask = torch.arange(effective_num_text_tokens) < num_text_tokens
1428 | self.register_buffer('text_only_logits_mask', text_only_mask, persistent = False)
1429 |
1430 | # loss related
1431 |
1432 | self.ignore_index = ignore_index
1433 | self.flow_loss_weight = flow_loss_weight
1434 | self.text_loss_weight = text_loss_weight
1435 |
1436 | # velocity consistency weight - only added if EMA model is passed in during training
1437 |
1438 | self.velocity_consistency_loss_weight = velocity_consistency_loss_weight
1439 |
1440 | # additional reconstruction loss, through the decoder
1441 |
1442 | self.has_recon_loss = reconstruction_loss_weight > 0.
1443 | self.reconstruction_loss_weight = reconstruction_loss_weight
1444 |
1445 | # flow sampling related
1446 |
1447 | self.odeint_fn = partial(odeint, **odeint_kwargs)
1448 |
1449 | # dummy loss
1450 |
1451 | self.register_buffer('zero', tensor(0.), persistent = False)
1452 |
1453 | @property
1454 | def device(self):
1455 | return next(self.parameters()).device
1456 |
1457 | @cache
1458 | def get_modality_info(
1459 | self,
1460 | modality_type: int | None = None
1461 | ) -> ModalityInfo:
1462 |
1463 | modality_type = default(modality_type, 0)
1464 |
1465 | modality_encoder = self.modality_encoder[modality_type]
1466 | modality_decoder = self.modality_decoder[modality_type]
1467 | latent_to_model = self.latent_to_model_projs[modality_type]
1468 | model_to_latent = self.model_to_latent_projs[modality_type]
1469 |
1470 | add_pos_emb = self.add_pos_emb[modality_type]
1471 | pos_emb_mlp = self.pos_emb_mlp[modality_type]
1472 | modality_num_dim = self.modality_num_dim[modality_type]
1473 |
1474 | dim_latent = self.dim_latents[modality_type]
1475 |
1476 | default_shape = self.modality_default_shape[modality_type]
1477 |
1478 | som_id = self.som_ids[modality_type]
1479 | eom_id = self.eom_ids[modality_type]
1480 |
1481 | to_shape_fn = self.to_modality_shape_fn[modality_type]
1482 |
1483 | channel_first_latent = self.channel_first_latent[modality_type]
1484 |
1485 | return ModalityInfo(
1486 | encoder = modality_encoder,
1487 | decoder = modality_decoder,
1488 | latent_to_model = latent_to_model,
1489 | model_to_latent = model_to_latent,
1490 | add_pos_emb = add_pos_emb,
1491 | pos_emb_mlp = pos_emb_mlp,
1492 | num_dim = modality_num_dim,
1493 | dim_latent = dim_latent,
1494 | default_shape = default_shape,
1495 | som_id = som_id,
1496 | eom_id = eom_id,
1497 | to_shape_fn = to_shape_fn,
1498 | channel_first_latent = channel_first_latent,
1499 | modality_type = modality_type
1500 | )
1501 |
1502 | def get_all_modality_info(self) -> list[ModalityInfo]:
1503 | return [self.get_modality_info(i) for i in range(self.num_modalities)]
1504 |
1505 | def get_modality_shape(
1506 | self,
1507 | modality: Float['...'],
1508 | modality_type: int | None = None
1509 | ) -> tuple[int, ...]:
1510 |
1511 | mod = self.get_modality_info(modality_type)
1512 |
1513 | if mod.channel_first_latent:
1514 | modality = rearrange(modality, 'c ... -> ... c')
1515 |
1516 | return tuple(modality.shape[:-1])
1517 |
1518 | def parameters_without_encoder_decoder(self):
1519 | return (
1520 | set(self.parameters()) -
1521 | set(self.modality_encoder.parameters()) -
1522 | set(self.modality_decoder.parameters())
1523 | )
1524 |
1525 | def create_dataloader(
1526 | self,
1527 | *args,
1528 | **kwargs
1529 | ):
1530 | return create_dataloader(*args, **kwargs)
1531 |
1532 | def create_ema(
1533 | self,
1534 | beta = 0.99,
1535 | *ema_kwargs
1536 | ) -> EMA:
1537 |
1538 | ema = EMA(
1539 | self,
1540 | beta = beta,
1541 | forward_method_names = (
1542 | 'sample',
1543 | 'generate_text_only',
1544 | 'generate_modality_only'
1545 | )
1546 | )
1547 |
1548 | return ema
1549 |
1550 | @torch.no_grad()
1551 | @eval_decorator
1552 | @typecheck
1553 | def sample(
1554 | self,
1555 | prompt: ModalitySample | Tensor | tuple[int, Float['...']] | None = None,
1556 | max_length = 2048,
1557 | text_temperature = 1.5,
1558 | text_min_p = 0.1,
1559 | cache_kv = False,
1560 | fixed_modality_shape: tuple[int, ...] | None = None,
1561 | init_modality_noise: Float['n d'] | None = None,
1562 | modality_steps = 16,
1563 | return_unprocessed_modalities = False
1564 | ) -> ModalitySample:
1565 |
1566 | device = self.device
1567 |
1568 | # handle edge case where there are no text tokens
1569 |
1570 | if self.num_text_tokens == 0:
1571 | logger.warning(f'you have `num_text_tokens` set to 0, so `sample` will be forwarded to `generate_modality_only(batch_size: int, modality_type: int)` method')
1572 |
1573 | return self.generate_modality_only(batch_size = 1)
1574 |
1575 | # take care of prompt being a raw tensor, either text or raw modality (image, video, actions, latents, etc)
1576 |
1577 | if is_tensor(prompt) and prompt.dtype == torch.float: # is modality with type 0 implicit
1578 | prompt = (0, prompt)
1579 |
1580 | prompt_is_modality = isinstance(prompt, tuple)
1581 |
1582 | if is_tensor(prompt) and prompt.dtype in (torch.int, torch.long): # is text only prompt
1583 | prompt = [prompt]
1584 |
1585 | elif prompt_is_modality:
1586 | modality_type, modality = prompt
1587 |
1588 | mod = self.get_modality_info(modality_type)
1589 |
1590 | if exists(mod.encoder):
1591 | with torch.no_grad():
1592 | mod.encoder.eval()
1593 | modality = self.maybe_add_temp_batch_dim(mod.encoder)(modality).detach()
1594 |
1595 | modality_shape_tuple = self.get_modality_shape(modality, modality_type)
1596 | modality_shape_str = join([*map(str, modality_shape_tuple)], ',')
1597 | modality_meta_info = self.char_tokenizer(modality_shape_str, device = device)
1598 |
1599 | prompt = [
1600 | tensor([self.meta_id]),
1601 | modality_meta_info,
1602 | tensor([mod.som_id]),
1603 | (modality_type, modality),
1604 | tensor([mod.eom_id]),
1605 | ]
1606 |
1607 | # sos
1608 |
1609 | init_text_seq = tensor([self.sos_id], device = device)
1610 |
1611 | # just take care of prompt being zero dimensions
1612 |
1613 | modality_sample = [init_text_seq, *default(prompt, [])]
1614 |
1615 | # take care of moving to device
1616 |
1617 | modality_sample = tree_map_tensor(modality_sample, lambda t: t.to(device))
1618 | modality_sample = tree_map_tensor(modality_sample, lambda t: rearrange(t, '-> 1') if t.ndim == 0 else t)
1619 |
1620 | modality_sample = concat_contiguous_text(modality_sample)
1621 |
1622 | *_, last_modality_sample = modality_sample
1623 |
1624 | curr_length = 0
1625 | curr_modality_id = None
1626 | modality_shape = None
1627 |
1628 | num_past_modalities = int(prompt_is_modality) # either 0 or 1 (if the prompt given is a modality)
1629 |
1630 | text_is_greedy = text_temperature == 0.
1631 | is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected
1632 |
1633 | def maybe_transition_to_modality_decoding(seq):
1634 | nonlocal modality_shape
1635 | nonlocal is_decoding_text
1636 | nonlocal curr_modality_id
1637 |
1638 | sampled_token_id = seq[-1]
1639 |
1640 | if sampled_token_id not in self.som_ids:
1641 | return
1642 |
1643 | curr_modality_id = self.som_ids.index(sampled_token_id)
1644 |
1645 | if exists(fixed_modality_shape):
1646 | modality_shape = fixed_modality_shape
1647 |
1648 | # get the tokens after the modality meta id
1649 |
1650 | maybe_meta_tensor = get_tokens_since_rightmost_id(seq, self.meta_id)
1651 |
1652 | mod = self.get_modality_info(curr_modality_id)
1653 |
1654 | default_shape = mod.default_shape
1655 | maybe_modality_num_dim = mod.num_dim
1656 | meta_str_to_modality_shape = mod.to_shape_fn
1657 |
1658 | if maybe_meta_tensor.numel() > 0:
1659 | meta_tensor = maybe_meta_tensor[:-1]
1660 | meta_str = self.decode_chars(meta_tensor)
1661 |
1662 | if not meta_str.isdigit() or int(meta_str) <= 0:
1663 |
1664 | assert exists(default_shape), 'invalid modality meta information detected, please set `modality_default_shape` in order to properly fallback'
1665 | modality_shape = default_shape
1666 | else:
1667 | modality_shape = meta_str_to_modality_shape(meta_str)
1668 |
1669 | modality_shape = default(modality_shape, default_shape)
1670 |
1671 | if self.fallback_to_default_shape_if_invalid:
1672 |
1673 | if exists(maybe_modality_num_dim) and len(modality_shape) != maybe_modality_num_dim:
1674 | logger.warning(f'invalid modality shape {modality_shape} for modality {curr_modality_id}. falling back to default shape {default_shape}')
1675 | modality_shape = default_shape
1676 |
1677 | assert exists(modality_shape), f'language model did not produce a proper modality shape for modality type {curr_modality_id} - please set a fallback shape with `modality_default_shape`'
1678 | assert not exists(maybe_modality_num_dim) or maybe_modality_num_dim == len(modality_shape), f'expected modality type {curr_modality_id} to have {maybe_modality_num_dim} dimensions but language model produced a shape of {modality_shape}'
1679 |
1680 | is_decoding_text = False
1681 |
1682 | # determine if to transition from start
1683 |
1684 | maybe_transition_to_modality_decoding(last_modality_sample)
1685 |
1686 | cache = None
1687 |
1688 | with tqdm(total = max_length) as pbar:
1689 |
1690 | while curr_length <= max_length:
1691 |
1692 | if is_decoding_text:
1693 | pbar.set_description('decoding text')
1694 |
1695 | *_, seq = modality_sample
1696 |
1697 | logits, new_kv_cache = self.forward(
1698 | [modality_sample],
1699 | return_loss = False,
1700 | cache = cache,
1701 | decode_length = 1,
1702 | decoding_text_or_modality = 'text',
1703 | return_kv_cache = True
1704 | )
1705 |
1706 | logits = logits[0][-1]
1707 |
1708 | if text_is_greedy:
1709 | sampled = logits.argmax(dim = -1, keepdim = True)
1710 | else:
1711 | logits = min_p_filter(logits, min_p = text_min_p)
1712 |
1713 | probs = (logits / text_temperature).softmax(dim = -1)
1714 | sampled = torch.multinomial(probs, 1)
1715 |
1716 | seq = torch.cat((seq, sampled), dim = -1)
1717 | modality_sample[-1] = seq
1718 |
1719 | pbar.update(1)
1720 | curr_length += 1
1721 |
1722 | if cache_kv:
1723 | cache = new_kv_cache
1724 |
1725 | sampled_token_id = sampled.item()
1726 |
1727 | if sampled_token_id == self.eos_id:
1728 | logger.info(f'detecting an end of string token [{self.eos_id}], terminating sampling early')
1729 | break
1730 |
1731 | maybe_transition_to_modality_decoding(seq)
1732 |
1733 | else:
1734 | assert exists(curr_modality_id)
1735 | pbar.set_description(f'decoding modality [{curr_modality_id}]')
1736 |
1737 | mod = self.get_modality_info(curr_modality_id)
1738 |
1739 | modality_length = math.prod(modality_shape)
1740 |
1741 | if exists(init_modality_noise):
1742 | noise = init_modality_noise[:modality_length, :mod.dim_latent]
1743 | else:
1744 | assert exists(modality_length)
1745 | noise = torch.randn((modality_length, mod.dim_latent), device = device)
1746 |
1747 | assert noise.shape == (modality_length, mod.dim_latent)
1748 |
1749 | noise = noise.reshape(*modality_shape, mod.dim_latent)
1750 |
1751 | if mod.channel_first_latent:
1752 | noise = rearrange(noise, '... d -> d ...')
1753 |
1754 | new_kv_cache = None
1755 |
1756 | def ode_step_fn(step_times, denoised):
1757 | nonlocal new_kv_cache
1758 |
1759 | step_times = rearrange(step_times, ' -> 1 1') # batch size of 1
1760 | step_times = F.pad(step_times, (num_past_modalities, 0), value = 1.) # past decoded modalities receive a time conditioning of 1.
1761 |
1762 | (embeds, get_pred_flows), new_kv_cache = self.forward(
1763 | [[*modality_sample, (curr_modality_id, denoised)]],
1764 | times = step_times,
1765 | return_embed = True,
1766 | cache = cache,
1767 | decode_length = modality_length,
1768 | return_kv_cache = True,
1769 | decoding_text_or_modality = 'modality'
1770 | )
1771 |
1772 | parse_embed = get_pred_flows[curr_modality_id][-1]
1773 |
1774 | parsed_embed = parse_embed(embeds, need_splice = not exists(cache))
1775 |
1776 | flow = add_temp_batch_dim(mod.model_to_latent)(parsed_embed)
1777 |
1778 | return flow
1779 |
1780 | times = torch.linspace(0, 1, modality_steps, device = device)
1781 |
1782 | trajectory = self.odeint_fn(ode_step_fn, noise, times)
1783 |
1784 | # add the sampled modality tokens
1785 |
1786 | sampled_modality = trajectory[-1]
1787 |
1788 | modality_sample.append((curr_modality_id, sampled_modality))
1789 |
1790 | # add the appropriate [eom]
1791 |
1792 | eom_id = mod.eom_id
1793 | modality_sample.append(tensor([eom_id], device = device))
1794 |
1795 | # set kv cache if needed
1796 |
1797 | if cache_kv:
1798 | cache = new_kv_cache
1799 |
1800 | # back to decoding text
1801 |
1802 | pbar.update(modality_length)
1803 | curr_length += modality_length
1804 |
1805 | num_past_modalities += 1
1806 | curr_modality_id = None
1807 | modality_length = None
1808 |
1809 | is_decoding_text = True
1810 |
1811 | logger.info(f'sampling stopped at length: {curr_length} / {max_length}')
1812 |
1813 | if return_unprocessed_modalities:
1814 | return modality_sample
1815 |
1816 | # post process modality sample, decoding modality types if they have a decoder
1817 |
1818 | for mod in self.get_all_modality_info():
1819 | decoder_fn = default(mod.decoder, nn.Identity())
1820 |
1821 | with torch.no_grad():
1822 | decoder_fn.eval()
1823 | modality_sample = apply_fn_modality_type(decoder_fn, modality_sample, modality_type = mod.modality_type)
1824 |
1825 | return modality_sample
1826 |
1827 | @typecheck
1828 | def forward_text(
1829 | self,
1830 | text: Int['b n'],
1831 | return_loss = True,
1832 | return_embed = False,
1833 | cache: Tensor | None = None,
1834 | return_kv_cache = False
1835 | ) -> (
1836 | Scalar |
1837 | Float['b n d'] |
1838 | tuple[Float['b n d'], list[Float['...']]]
1839 | ):
1840 |
1841 | device = self.device
1842 | text = text.to(device)
1843 |
1844 | if return_loss:
1845 | text, labels = text[:, :-1], text[:, 1:]
1846 |
1847 | # embed text
1848 |
1849 | text = text.masked_fill(text == -1, 0)
1850 | tokens = self.text_embed(text)
1851 |
1852 | # rotary
1853 |
1854 | seq_len = tokens.shape[-2]
1855 | pos = torch.arange(seq_len, device = device)
1856 |
1857 | rotary_emb = self.rotary_emb(pos)
1858 |
1859 | # attention
1860 |
1861 | embed, kv_cache = self.transformer(
1862 | tokens,
1863 | rotary_emb = rotary_emb,
1864 | causal_mask = True,
1865 | cache = cache,
1866 | return_kv_cache = True
1867 | )
1868 |
1869 | # text unembedding
1870 |
1871 | logits = self.to_text_logits(embed)
1872 |
1873 | if not return_loss:
1874 | if not return_kv_cache:
1875 | return logits
1876 |
1877 | return logits, kv_cache
1878 |
1879 | logits = logits.masked_fill(~self.text_only_logits_mask, max_neg_value(logits))
1880 |
1881 | loss = F.cross_entropy(
1882 | rearrange(logits, 'b n l -> b l n'),
1883 | labels,
1884 | ignore_index = self.ignore_index
1885 | )
1886 |
1887 | return loss
1888 |
1889 | @torch.no_grad()
1890 | @eval_decorator
1891 | @typecheck
1892 | def generate_text_only(
1893 | self,
1894 | prompt: Int['b n'],
1895 | seq_len: int,
1896 | temperature = 1.5,
1897 | min_p = 0.1,
1898 | ) -> Int['b no']:
1899 |
1900 | prompt_seq_len, out = prompt.shape[-1], prompt.clone()
1901 | sample_num_times = max(0, seq_len - prompt_seq_len)
1902 |
1903 | for _ in tqdm(range(sample_num_times)):
1904 | logits = self.forward_text(out, return_loss = False)
1905 | logits = logits[:, -1]
1906 |
1907 | logits = min_p_filter(logits, min_p = min_p)
1908 |
1909 | logits.masked_fill_(~self.text_only_logits_mask, max_neg_value(logits))
1910 |
1911 | sample = gumbel_sample(logits, temperature = temperature, dim = -1)
1912 |
1913 | out = cat((out, sample), dim = -1)
1914 |
1915 | return out[..., prompt_seq_len:]
1916 |
1917 | @typecheck
1918 | def forward_modality(
1919 | self,
1920 | modalities: Float['b ...'],
1921 | times: Float['b'] | None = None,
1922 | modality_type: int | None = None,
1923 | encode_modality: bool = True,
1924 | velocity_consistency_ema_model: Transfusion | None = None,
1925 | velocity_consistency_delta_time = 1e-5,
1926 | return_loss = True,
1927 | return_loss_breakdown = False
1928 | ) -> Scalar | Float['b ...']:
1929 | requires_velocity_consistency = exists(velocity_consistency_ema_model)
1930 |
1931 | modalities = modalities.to(self.device)
1932 | orig_modalities = modalities
1933 |
1934 | if self.num_modalities > 1:
1935 | assert exists(modality_type), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
1936 |
1937 | modality_type = default(modality_type, 0)
1938 |
1939 | mod = self.get_modality_info(modality_type)
1940 |
1941 | # maybe modality encode
1942 |
1943 | if encode_modality and exists(mod.encoder):
1944 | with torch.no_grad():
1945 | mod.encoder.eval()
1946 | modalities = mod.encoder(modalities).detach()
1947 |
1948 | # shapes and device
1949 |
1950 | tokens = modalities
1951 |
1952 | batch, device = tokens.shape[0], tokens.device
1953 |
1954 | # times
1955 |
1956 | if not exists(times):
1957 | times = torch.rand((batch,), device = device)
1958 |
1959 | if return_loss:
1960 |
1961 | if requires_velocity_consistency:
1962 | orig_times = times.clone()
1963 | times = times * (1. - velocity_consistency_delta_time) # make sure times are max of 1. - small delta, for velocity consistency
1964 |
1965 | padded_times = append_dims(times, tokens.ndim - 1)
1966 |
1967 | noise = torch.randn_like(tokens)
1968 |
1969 | noised_tokens = padded_times * tokens + (1. - padded_times) * noise
1970 |
1971 | flow = tokens - noise
1972 |
1973 | else:
1974 | noised_tokens = tokens
1975 |
1976 | # from latent to model tokens
1977 |
1978 | noised_tokens = mod.latent_to_model(noised_tokens)
1979 |
1980 | # axial positions
1981 |
1982 | if mod.add_pos_emb:
1983 | assert exists(mod.num_dim), f'modality_num_dim must be set for modality {modality_type} if further injecting axial positional embedding'
1984 |
1985 | _, *axial_dims, _ = noised_tokens.shape
1986 |
1987 | assert len(axial_dims) == mod.num_dim, f'received modalities of ndim {len(axial_dims)} but expected {modality_num_dim}'
1988 |
1989 | # maybe transform
1990 |
1991 | noised_tokens, inverse_pack_axial_dims = pack_one_with_inverse(noised_tokens, 'b * d')
1992 |
1993 | # maybe add axial pos emb
1994 |
1995 | if mod.add_pos_emb:
1996 | axial_pos_emb = mod.pos_emb_mlp(tensor(axial_dims), flatten = True)
1997 | noised_tokens = noised_tokens + axial_pos_emb
1998 |
1999 | # attention
2000 |
2001 | embed = self.transformer(
2002 | noised_tokens,
2003 | times = times,
2004 | modality_only = True,
2005 | )
2006 |
2007 | embed = inverse_pack_axial_dims(embed)
2008 |
2009 | pred_flow = mod.model_to_latent(embed)
2010 |
2011 | if not return_loss:
2012 | return pred_flow
2013 |
2014 | # flow loss
2015 |
2016 | flow_loss = F.mse_loss(pred_flow, flow)
2017 |
2018 | # maybe velocity consistency loss
2019 |
2020 | velocity_loss = self.zero
2021 |
2022 | if requires_velocity_consistency:
2023 |
2024 | with torch.no_grad():
2025 | flow_with_delta_time = velocity_consistency_ema_model.forward_modality(
2026 | modalities = modalities,
2027 | modality_type = modality_type,
2028 | times = orig_times + velocity_consistency_delta_time,
2029 | encode_modality = False, # modality already encoded
2030 | return_loss = False
2031 | )
2032 |
2033 | velocity_loss = F.mse_loss(flow, flow_with_delta_time)
2034 |
2035 | # maybe recon loss
2036 |
2037 | recon_loss = self.zero
2038 |
2039 | if self.has_recon_loss:
2040 | assert encode_modality
2041 |
2042 | recon = noise + pred_flow * (1. - padded_times)
2043 |
2044 | if exists(mod.decoder):
2045 | with torch.no_grad():
2046 | mod.decoder.eval()
2047 | recon = mod.decoder(recon)
2048 |
2049 | recon_loss = F.mse_loss(
2050 | recon,
2051 | orig_modalities
2052 | )
2053 |
2054 | # total loss
2055 |
2056 | total_loss = (
2057 | flow_loss +
2058 | velocity_loss * self.velocity_consistency_loss_weight +
2059 | recon_loss * self.reconstruction_loss_weight
2060 | )
2061 |
2062 | if not return_loss_breakdown:
2063 | return total_loss
2064 |
2065 | return total_loss, (flow_loss, velocity_loss, recon_loss)
2066 |
2067 | @torch.no_grad()
2068 | @eval_decorator
2069 | @typecheck
2070 | def generate_modality_only(
2071 | self,
2072 | batch_size: int = 1,
2073 | modality_type: int | None = None,
2074 | fixed_modality_shape: tuple[int, ...] | None = None,
2075 | modality_steps = 16,
2076 | return_unprocessed_modalities = False
2077 | ) -> Tensor:
2078 |
2079 | device = self.device
2080 |
2081 | if self.num_modalities > 1:
2082 | assert exists(modality_type), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
2083 |
2084 | mod = self.get_modality_info(modality_type)
2085 |
2086 | modality_shape = default(fixed_modality_shape, mod.default_shape)
2087 |
2088 | assert exists(modality_shape)
2089 |
2090 | noise = torch.randn((batch_size, *modality_shape, mod.dim_latent), device = device)
2091 |
2092 | if mod.channel_first_latent:
2093 | noise = rearrange(noise, 'b ... d -> b d ...')
2094 |
2095 | def ode_step_fn(step_times, denoised):
2096 |
2097 | step_times = repeat(step_times, ' -> b', b = batch_size)
2098 |
2099 | flow = self.forward_modality(
2100 | denoised,
2101 | times = step_times,
2102 | modality_type = modality_type,
2103 | encode_modality = False,
2104 | return_loss = False
2105 | )
2106 |
2107 | return flow
2108 |
2109 | times = torch.linspace(0., 1., modality_steps, device = device)
2110 | trajectory = self.odeint_fn(ode_step_fn, noise, times)
2111 |
2112 | # add the sampled modality tokens
2113 |
2114 | sampled_modality = trajectory[-1]
2115 |
2116 | # decode
2117 |
2118 | if exists(mod.decoder):
2119 | mod.decoder.eval()
2120 | sampled_modality = mod.decoder(sampled_modality)
2121 |
2122 | return sampled_modality
2123 |
2124 | @typecheck
2125 | def forward(
2126 | self,
2127 | modalities: (
2128 | list[ModalitySample] |
2129 | Int['b n'] |
2130 | Float['b ...']
2131 | ),
2132 | times: Float['b m'] | None = None,
2133 | num_modalities_to_times_fn: Callable[[Int['b']], Float['b m']] | None = None, # allows a researcher to customize the times (noise level) based on the modality lengths in a given sample
2134 | modality_type: int | None = None,
2135 | cache: Tensor | None = None,
2136 | decode_length: int | None = None,
2137 | decoding_text_or_modality: Literal['text', 'modality'] | None = None,
2138 | velocity_consistency_ema_model: Transfusion | EMA | None = None,
2139 | velocity_consistency_delta_time = 1e-3,
2140 | return_only_pred_flows = False,
2141 | return_loss = True,
2142 | return_breakdown = False,
2143 | return_embed = False,
2144 | return_kv_cache = False,
2145 | ) -> (
2146 | Float['b _ l'] |
2147 | tuple[Float['b _ d'], GetPredFlows] |
2148 | tuple[tuple[Float['b _ _'], GetPredFlows], Tensor] |
2149 | Scalar |
2150 | tuple[Scalar, LossBreakdown] |
2151 | list[Float['b _ _']]
2152 | ):
2153 | is_decoding = exists(decoding_text_or_modality)
2154 |
2155 | is_text_only = is_tensor(modalities) and modalities.dtype in (torch.int, torch.long)
2156 | is_modality_only = is_tensor(modalities) and modalities.dtype == torch.float
2157 |
2158 | # handle ema model being passed in for velocity consistency loss
2159 |
2160 | if isinstance(velocity_consistency_ema_model, EMA):
2161 | assert isinstance(velocity_consistency_ema_model.ema_model, Transfusion)
2162 | velocity_consistency_ema_model = velocity_consistency_ema_model.ema_model
2163 |
2164 | need_velocity_matching = not is_decoding and exists(velocity_consistency_ema_model)
2165 |
2166 | # return loss
2167 |
2168 | return_loss &= not (return_embed or is_decoding)
2169 |
2170 | if is_text_only:
2171 |
2172 | forward_text_kwargs = dict(
2173 | return_loss = return_loss,
2174 | return_embed = return_embed,
2175 | cache = cache,
2176 | return_kv_cache = return_kv_cache
2177 | )
2178 |
2179 | return self.forward_text(modalities, **forward_text_kwargs)
2180 |
2181 | if is_modality_only:
2182 | assert return_loss
2183 |
2184 | forward_modality_kwargs = dict(
2185 | modality_type = modality_type,
2186 | velocity_consistency_ema_model = velocity_consistency_ema_model
2187 | )
2188 |
2189 | return self.forward_modality(modalities, **forward_modality_kwargs)
2190 |
2191 | batch = len(modalities)
2192 | device = self.device
2193 | tensor_ = partial(tensor, device = device)
2194 |
2195 | # save a copy for ema model for velocity matching
2196 |
2197 | if need_velocity_matching:
2198 | velocity_modalities = modalities
2199 |
2200 | if isinstance(velocity_modalities, list):
2201 | velocity_modalities = [modality.copy() for modality in velocity_modalities]
2202 |
2203 | # add "sentence" start and end tokens when training
2204 |
2205 | if return_loss or need_velocity_matching:
2206 | modalities = modalities.copy()
2207 |
2208 | for i, modality in enumerate(modalities):
2209 | modalities[i] = [
2210 | tensor_([self.sos_id]),
2211 | *modality,
2212 | tensor_([self.eos_id])
2213 | ]
2214 |
2215 | # need axial pos emb
2216 |
2217 | need_axial_pos_emb = any(self.add_pos_emb)
2218 |
2219 | # standardize modalities to be tuple - type 0 modality is implicit if not given
2220 | # also store modality lengths for determining noising times
2221 |
2222 | num_modalities = []
2223 |
2224 | for batch_modalities in modalities:
2225 | batch_num_modalities = 0
2226 |
2227 | for ind, modality in enumerate(batch_modalities):
2228 |
2229 | if is_tensor(modality) and modality.dtype == torch.float:
2230 | modality = (0, modality)
2231 |
2232 | if not isinstance(modality, tuple):
2233 | continue
2234 |
2235 | modality_type, modality_tensor = modality
2236 | batch_modalities[ind] = modality
2237 | batch_num_modalities += 1
2238 |
2239 | num_modalities.append(batch_num_modalities)
2240 |
2241 | num_modalities = tensor_(num_modalities)
2242 |
2243 | # determine the times
2244 |
2245 | if not exists(times):
2246 | if is_empty(num_modalities) or num_modalities.amax().item() == 0:
2247 | times = torch.empty((batch, 0), device = device, dtype = torch.float)
2248 | else:
2249 | num_modalities_to_times_fn = default(num_modalities_to_times_fn, default_modality_length_to_time_fn)
2250 |
2251 | if exists(num_modalities_to_times_fn):
2252 | times = num_modalities_to_times_fn(num_modalities)
2253 |
2254 | # if needs velocity matching, make sure times are in the range of 0 - (1. - )
2255 |
2256 | if need_velocity_matching:
2257 | orig_times = times.clone()
2258 | times = times * (1. - velocity_consistency_delta_time)
2259 |
2260 | # process list of text and modalities interspersed with one another
2261 |
2262 | modality_positions = []
2263 | modality_tokens = []
2264 | modality_pos_emb = []
2265 |
2266 | text = []
2267 |
2268 | # auto move all tensors to device of model
2269 |
2270 | modalities = tree_map_tensor(modalities, lambda t: t.to(device))
2271 |
2272 | # for all modalities, batch process same shaped modalities of the same type
2273 |
2274 | if not is_decoding:
2275 | for mod in self.get_all_modality_info():
2276 | encode_fn = default(mod.encoder, nn.Identity())
2277 |
2278 | with torch.no_grad():
2279 | encode_fn.eval()
2280 | modalities = apply_fn_modality_type(encode_fn, modalities, modality_type = mod.modality_type)
2281 |
2282 | # for parsing out the predicted flow from flattened sequence of tokens coming out of transformer
2283 |
2284 | flows = defaultdict(list) # store flows for loss
2285 |
2286 | get_pred_flows: GetPredFlows = defaultdict(list) # functions for parsing modalities from Float['b n d'] for model back to latents or pixel space
2287 |
2288 | def model_to_pred_flow(batch_index, start_index, modality_length, unpack_fn):
2289 |
2290 | def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
2291 | embed = embed[batch_index]
2292 |
2293 | if need_splice:
2294 | embed = embed[start_index:(start_index + modality_length)]
2295 |
2296 | embed = unpack_fn(embed)
2297 | return embed
2298 |
2299 | return inner
2300 |
2301 | # for going from predicted flow -> reconstruction
2302 |
2303 | get_recon_losses: Callable[[Tensor], Tensor] = defaultdict(list)
2304 |
2305 | def get_recon_loss(noise, times, modality):
2306 |
2307 | def inner(pred_flow):
2308 | recon_modality = noise + pred_flow * (1. - times)
2309 | return F.mse_loss(modality, recon_modality)
2310 |
2311 | return inner
2312 |
2313 | # prepare storing of sizes of all modalities that require axial positions, for delayed application for efficiency
2314 |
2315 | pos_emb_max_axial_dims: dict[int, list[Tensor]] = defaultdict(list)
2316 |
2317 | # go through all modality samples and do necessary transform
2318 |
2319 | for batch_index, batch_modalities in enumerate(modalities):
2320 |
2321 | modality_index = 0
2322 | batch_modality_positions = []
2323 | batch_modality_tokens = []
2324 | batch_modality_pos_emb = []
2325 |
2326 | batch_text = []
2327 |
2328 | offset = 0
2329 |
2330 | for modality in batch_modalities:
2331 | # if non-text modality detected and not given as a tuple
2332 | # cast to (int, Tensor) where int is defaulted to type 0 (convenience for one modality)
2333 |
2334 | is_text = not isinstance(modality, tuple)
2335 | is_modality = not is_text
2336 |
2337 | if is_text:
2338 | modality_tensor = modality
2339 | else:
2340 | modality_type, modality_tensor, *_ = modality
2341 |
2342 | # auto move modality tensor to correct device
2343 |
2344 | mod = self.get_modality_info(modality_type)
2345 |
2346 | if is_modality:
2347 | assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified'
2348 |
2349 | channel_dim = 0 if mod.channel_first_latent else -1
2350 |
2351 | assert mod.dim_latent == modality_tensor.shape[channel_dim], f'mismatch for modality latent dimension - expected {mod.dim_latent} but received {modality_tensor.shape[-1]} - modality shape is {tuple(modality_tensor.shape)}, perhaps you need to set `channel_first_latent` to the correct value'
2352 | assert mod.num_dim == (len(modality_tensor.shape) - 1), f'mismatch for modality number of dimensions - expected {mod.num_dim} but received {len(modality_tensor.shape) - 1} {modality_tensor.shape}'
2353 |
2354 | # auto ward against scalars (lone start end tokens)
2355 |
2356 | if modality_tensor.dtype in (torch.int, torch.long) and modality_tensor.ndim == 0:
2357 | modality_tensor = rearrange(modality_tensor, '-> 1')
2358 |
2359 | # handle text
2360 |
2361 | if is_text:
2362 | assert modality_tensor.ndim == 1 and modality_tensor.dtype in (torch.int, torch.long)
2363 | text_length = modality_tensor.shape[0]
2364 |
2365 | batch_text.append(modality_tensor)
2366 | zeros = torch.zeros(text_length, self.dim, device = device)
2367 |
2368 | batch_modality_tokens.append(zeros)
2369 |
2370 | offset += text_length
2371 |
2372 | if need_axial_pos_emb:
2373 | batch_modality_pos_emb.append(zeros)
2374 |
2375 | continue
2376 |
2377 | # otherwise handle a modality
2378 |
2379 | # get times for noising the modality
2380 |
2381 | modality_time = times[batch_index, modality_index]
2382 |
2383 | # noise
2384 |
2385 | if return_loss:
2386 | noise = torch.randn_like(modality_tensor)
2387 |
2388 | noised_modality = modality_tensor * modality_time + noise * (1. - modality_time)
2389 |
2390 | # the flow is the (data - noise)
2391 |
2392 | modality_flow = modality_tensor - noise
2393 |
2394 | # append to flow for loss
2395 |
2396 | flows[modality_type].append(modality_flow)
2397 |
2398 | modality_tensor = noised_modality
2399 |
2400 | # store function for deriving reconstruction loss from decoder
2401 |
2402 | get_recon_losses[modality_type].append(get_recon_loss(noise, modality_time, modality_tensor))
2403 |
2404 | # go through maybe encoder
2405 |
2406 | modality_tensor = add_temp_batch_dim(mod.latent_to_model)(modality_tensor)
2407 |
2408 | # gather the modality length
2409 |
2410 | modality_shape_tuple = modality_tensor.shape[:-1]
2411 | modality_length = math.prod(modality_shape_tuple)
2412 |
2413 | text_tensor = torch.full((modality_length,), -1, device = device) # text is all -1 here, so text labels are not learned on
2414 |
2415 | # only add modality meta information when not returning embedding, which only occurs when sampling modality
2416 |
2417 | succeed_modality_tokens = precede_modality_tokens = 0
2418 |
2419 | if not return_embed:
2420 | # add the [som] and [eom] tokens for the modality type
2421 |
2422 | som_id, eom_id = mod.som_id, mod.eom_id
2423 |
2424 | # start by just storing the token length of the modality
2425 |
2426 | modality_shape_str = join([*map(str, modality_shape_tuple)], ',')
2427 | modality_meta_info = self.char_tokenizer(modality_shape_str, device = device)
2428 |
2429 | precede_modality_tokens = len(modality_meta_info) + 2
2430 | succeed_modality_tokens = 1
2431 |
2432 | text_tensor = cat((
2433 | tensor_([self.meta_id]),
2434 | modality_meta_info,
2435 | tensor_([som_id]),
2436 | text_tensor,
2437 | tensor_([eom_id])
2438 | ))
2439 |
2440 | batch_modality_positions.append((modality_type, offset + precede_modality_tokens, modality_length)) # offset + preceding meta tag length (which includes the modality start token)
2441 |
2442 | # store parsing out back to shape
2443 |
2444 | modality_tensor, unpack_modality_shape = pack_one_with_inverse(modality_tensor, '* d')
2445 |
2446 | inverse_fn = model_to_pred_flow(batch_index, offset + precede_modality_tokens, modality_length, unpack_modality_shape)
2447 |
2448 | get_pred_flows[modality_type].append(inverse_fn)
2449 |
2450 | # increment offset
2451 |
2452 | offset += modality_length + precede_modality_tokens + succeed_modality_tokens # +2 due to [som] and [eom] - then account for meta start id and modality shape information (or eventually any meta information about modality)
2453 |
2454 | modality_tensor = F.pad(modality_tensor, (0, 0, precede_modality_tokens, succeed_modality_tokens))
2455 |
2456 | batch_modality_tokens.append(modality_tensor)
2457 |
2458 | batch_text.append(text_tensor)
2459 |
2460 | # handle axial positional embedding
2461 |
2462 | if need_axial_pos_emb:
2463 |
2464 | if exists(mod.pos_emb_mlp):
2465 | pos_emb_max_axial_dims[modality_type].append(tensor(modality_shape_tuple))
2466 | pos_emb = (modality_type, modality_shape_tuple, (precede_modality_tokens, succeed_modality_tokens))
2467 |
2468 | else:
2469 | pos_emb = torch.zeros(text_tensor.shape[0], self.dim, device = device)
2470 |
2471 | batch_modality_pos_emb.append(pos_emb)
2472 |
2473 | text.append(cat(batch_text))
2474 |
2475 | if need_axial_pos_emb:
2476 | modality_pos_emb.append(batch_modality_pos_emb)
2477 |
2478 | modality_tokens.append(cat(batch_modality_tokens))
2479 | modality_positions.append(batch_modality_positions)
2480 |
2481 | modality_index += 1
2482 |
2483 | if return_loss:
2484 | total_tokens = sum([t.numel() for t in text])
2485 |
2486 | text = pad_sequence(text, padding_value = -1)
2487 |
2488 | modality_tokens = pad_sequence(modality_tokens, padding_value = 0.)
2489 |
2490 | # handle modality positional embedding
2491 |
2492 | if need_axial_pos_emb:
2493 | pos_emb_max_axial_dims = {mod_type: stack(sizes, dim = -1).amax(dim = -1) for mod_type, sizes in pos_emb_max_axial_dims.items()}
2494 | factorized_pos_emb = {mod_type: self.get_modality_info(mod_type).pos_emb_mlp(max_size, return_factorized = True) for mod_type, max_size in pos_emb_max_axial_dims.items()}
2495 |
2496 | # lazy evaluate the modality positional embedding from the factorized positional embedding from maximum axial dims
2497 |
2498 | evaluated_pos_emb = []
2499 |
2500 | for batch_modality_pos_emb in modality_pos_emb:
2501 | evaluated_batch_pos_emb = []
2502 |
2503 | for maybe_pos_emb_config in batch_modality_pos_emb:
2504 |
2505 | if is_tensor(maybe_pos_emb_config):
2506 | evaluated_batch_pos_emb.append(maybe_pos_emb_config)
2507 | continue
2508 |
2509 | mod_type, mod_size, padding = maybe_pos_emb_config
2510 |
2511 | mod_info = self.get_modality_info(mod_type)
2512 | mod_factorized_pos_emb = factorized_pos_emb[mod_type]
2513 |
2514 | mod_pos_emb = mod_info.pos_emb_mlp.combine_factorized(mod_factorized_pos_emb, mod_size, flatten = True)
2515 | mod_pos_emb = F.pad(mod_pos_emb, (0, 0, *padding), value = 0.) # handle padding for preceding and succeeding meta tokens
2516 |
2517 | evaluated_batch_pos_emb.append(mod_pos_emb)
2518 |
2519 | evaluated_pos_emb.append(cat(evaluated_batch_pos_emb, dim = -2))
2520 |
2521 | modality_pos_emb = pad_sequence(evaluated_pos_emb, padding_value = 0.)
2522 |
2523 | # handle training mode and removal of last token
2524 |
2525 | if return_loss:
2526 | modality_tokens = modality_tokens[:, :-1]
2527 |
2528 | if need_axial_pos_emb:
2529 | modality_pos_emb = modality_pos_emb[:, :-1]
2530 |
2531 | # if returning loss, split text for next token prediction
2532 |
2533 | if return_loss:
2534 | text, text_labels = text[:, :-1], text[:, 1:]
2535 |
2536 | # derive is_modality mask for flow on the right tokens + flow loss
2537 |
2538 | batch, seq_len, device = *text.shape, text.device
2539 |
2540 | assert len(modality_positions) == batch
2541 |
2542 | if isinstance(modality_positions, list):
2543 | modality_positions = modality_positions_to_tensor(modality_positions, device = device)
2544 |
2545 | if modality_positions.shape[-1] == 2: # Int['b m 2'] -> Int['b m 3'] if type is not given (one modality)
2546 | modality_positions = F.pad(modality_positions, (1, 0))
2547 |
2548 | # for now use dummy padding modality position info if empty (all zeros)
2549 |
2550 | if modality_positions.numel() == 0:
2551 | modality_positions = F.pad(modality_positions, (0, 0, 0, 1))
2552 |
2553 | # sort the modalities tensor and sanitize, readying for noising of modalities
2554 |
2555 | modality_positions, sorted_indices = order_modality_positions_by_seq_offset(modality_positions)
2556 |
2557 | is_modalities = modality_positions_to_is_modality_mask(seq_len, modality_positions, num_modalities = self.num_modalities, device = device)
2558 |
2559 | is_any_modality = reduce(is_modalities, 'b t m n -> b n', 'any')
2560 |
2561 | # embed text
2562 |
2563 | text = text.masked_fill(text == -1, 0)
2564 |
2565 | text_tokens = self.text_embed(text)
2566 |
2567 | # maybe add the axial positional embedding
2568 |
2569 | if need_axial_pos_emb:
2570 | modality_tokens = modality_tokens + modality_pos_emb
2571 |
2572 | # intersperse the modalities with the text for the joint transformer + flow system
2573 |
2574 | tokens = einx.where('b n, b n d, b n d', is_any_modality, modality_tokens, text_tokens)
2575 |
2576 | # derive rotary positions
2577 |
2578 | rotary_positions = derive_rotary_positions_from_modality_positions(seq_len, modality_positions)
2579 |
2580 | rotary_emb = self.rotary_emb(rotary_positions)
2581 | rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d')
2582 |
2583 | # take care of cache
2584 |
2585 | is_any_modality_when_decoding = None
2586 |
2587 | if exists(cache):
2588 | assert exists(decode_length), '`decode_length` must be passed in on forward for modality sampling. think of a cleaner way on some future date'
2589 | assert exists(decoding_text_or_modality)
2590 |
2591 | if decoding_text_or_modality == 'text':
2592 | decode_length = 1
2593 |
2594 | is_any_modality_when_decoding = decoding_text_or_modality == 'modality'
2595 | modality_positions = None
2596 |
2597 | # times
2598 |
2599 | times_per_token = einsum(is_modalities.float(), times, 'b t m n, b m -> b t n')
2600 |
2601 | times_cond = reduce(times_per_token, 'b t n -> b n', 'sum')
2602 |
2603 | # attention
2604 |
2605 | embed, kv_cache = self.transformer(
2606 | tokens,
2607 | times = times_cond,
2608 | rotary_emb = rotary_emb,
2609 | modality_positions = modality_positions,
2610 | is_any_modality = is_any_modality_when_decoding,
2611 | cache = cache,
2612 | decode_length = decode_length,
2613 | return_kv_cache = True
2614 | )
2615 |
2616 | # early return for embedding for decoding modality
2617 |
2618 | if return_embed:
2619 | if not return_kv_cache:
2620 | return (embed, get_pred_flows)
2621 |
2622 | return (embed, get_pred_flows), kv_cache
2623 |
2624 | # text unembedding
2625 |
2626 | text_logits = self.to_text_logits(embed)
2627 |
2628 | if not return_loss:
2629 | if not return_kv_cache:
2630 | return text_logits
2631 |
2632 | return text_logits, kv_cache
2633 |
2634 | # flow loss
2635 |
2636 | pred_flows = []
2637 | recon_losses = []
2638 |
2639 | for modality_id in range(self.num_modalities):
2640 | mod = self.get_modality_info(modality_id)
2641 |
2642 | modality_get_pred_flows = get_pred_flows[modality_id]
2643 | modality_get_recon_losses = get_recon_losses[modality_id]
2644 |
2645 | modality_pred_flows = []
2646 | modality_recon_losses = []
2647 |
2648 | for get_pred_flow, get_recon_loss in zip(modality_get_pred_flows, modality_get_recon_losses):
2649 |
2650 | pred_flow = get_pred_flow(embed)
2651 | pred_flow = add_temp_batch_dim(mod.model_to_latent)(pred_flow)
2652 | modality_pred_flows.append(pred_flow)
2653 |
2654 | if not return_loss or not self.has_recon_loss:
2655 | continue
2656 |
2657 | modality_recon_losses.append(get_recon_loss(pred_flow))
2658 |
2659 | pred_flows.append(modality_pred_flows)
2660 | recon_losses.append(modality_recon_losses)
2661 |
2662 | # early return for velocity consistency ema model
2663 |
2664 | if return_only_pred_flows:
2665 | return pred_flows
2666 |
2667 | # text autoregressive loss
2668 |
2669 | text_labels = text_labels.masked_fill(is_any_modality, self.ignore_index)
2670 |
2671 | text_loss = F.cross_entropy(
2672 | rearrange(text_logits, 'b n l -> b l n'),
2673 | text_labels,
2674 | ignore_index = self.ignore_index
2675 | )
2676 |
2677 | text_loss_weight = (text_labels != self.ignore_index).sum() / total_tokens
2678 |
2679 | # calculate flow losses
2680 |
2681 | flow_losses = []
2682 |
2683 | modality_loss_weights = []
2684 |
2685 | for modality_id, (pred_flow, is_one_modality) in enumerate(zip(pred_flows, is_modalities.unbind(dim = 1))):
2686 | mod = self.get_modality_info(modality_id)
2687 |
2688 | is_one_modality = reduce(is_one_modality, 'b m n -> b n', 'any')
2689 | modality_loss_weight = is_one_modality.sum() / total_tokens
2690 |
2691 | modality_flows = flows[modality_id]
2692 |
2693 | pack_pattern = 'd *' if mod.channel_first_latent else '* d'
2694 |
2695 | modality_pred_flow, _ = pack(pred_flow, pack_pattern)
2696 | modality_flows, _ = pack(modality_flows, pack_pattern)
2697 |
2698 | flow_loss = F.mse_loss(
2699 | modality_pred_flow,
2700 | modality_flows
2701 | )
2702 |
2703 | modality_loss_weights.append(modality_loss_weight)
2704 |
2705 | flow_losses.append(flow_loss)
2706 |
2707 | modality_loss_weights = stack(modality_loss_weights)
2708 |
2709 | # only the token positions that are not modalities have autoregressive loss
2710 |
2711 | total_loss = (
2712 | text_loss * text_loss_weight * self.text_loss_weight +
2713 | (stack(flow_losses) * modality_loss_weights).sum() * self.flow_loss_weight
2714 | )
2715 |
2716 | # whether to handle velocity consistency
2717 | # for straightening the flow, from consistency flow matching paper https://arxiv.org/abs/2407.02398
2718 |
2719 | velocity_match_losses = None
2720 |
2721 | if need_velocity_matching:
2722 |
2723 | with torch.no_grad():
2724 | velocity_consistency_ema_model.eval()
2725 |
2726 | ema_pred_flows = velocity_consistency_ema_model(
2727 | velocity_modalities,
2728 | times = orig_times + velocity_consistency_delta_time,
2729 | return_only_pred_flows = True
2730 | )
2731 |
2732 | velocity_match_losses = []
2733 |
2734 | for ema_pred_flow, pred_flow in zip(ema_pred_flows, pred_flows):
2735 |
2736 | pack_pattern = 'd *' if mod.channel_first_latent else '* d'
2737 | pred_flow, _ = pack(pred_flow, pack_pattern)
2738 | ema_pred_flow, _ = pack(ema_pred_flow, pack_pattern)
2739 |
2740 | velocity_match_loss = F.mse_loss(
2741 | pred_flow,
2742 | ema_pred_flow
2743 | )
2744 |
2745 | velocity_match_losses.append(velocity_match_loss)
2746 |
2747 | total_loss = (
2748 | total_loss +
2749 | (stack(velocity_match_losses) * modality_loss_weights).sum() * self.velocity_consistency_loss_weight
2750 | )
2751 |
2752 | # maybe reconstruction loss
2753 |
2754 | if self.has_recon_loss:
2755 |
2756 | averaged_recon_losses = []
2757 |
2758 | for modality_recon_loss in recon_losses:
2759 | averaged_recon_losses.append(sum(modality_recon_loss) / len(modality_recon_loss))
2760 |
2761 | total_loss = (
2762 | total_loss +
2763 | (stack(averaged_recon_losses) * modality_loss_weights).sum() * self.reconstruction_loss_weight
2764 | )
2765 |
2766 | # return total loss if no breakdown needed
2767 |
2768 | if not return_breakdown:
2769 | return total_loss
2770 |
2771 | return total_loss, LossBreakdown(total_loss, text_loss, flow_losses, velocity_match_losses, recon_losses)
2772 |
--------------------------------------------------------------------------------