├── flamingo.png
├── flamingo_pytorch
├── __init__.py
├── flamingo_pytorch.py
└── flamingo_palm.py
├── setup.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
└── README.md
/flamingo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/flamingo-pytorch/HEAD/flamingo.png
--------------------------------------------------------------------------------
/flamingo_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from flamingo_pytorch.flamingo_pytorch import PerceiverResampler, GatedCrossAttentionBlock
2 | from flamingo_pytorch.flamingo_palm import FlamingoPaLM
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'flamingo-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.2',
7 | license='MIT',
8 | description = 'Flamingo - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/flamingo-pytorch',
12 | long_description_content_type = 'text/markdown',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'visual question answering'
19 | ],
20 | install_requires=[
21 | 'einops>=0.4',
22 | 'einops-exts',
23 | 'torch>=1.6'
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## 🦩 Flamingo - Pytorch
4 |
5 | Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks
6 |
7 | Yannic Kilcher presentation
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install flamingo-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from flamingo_pytorch import PerceiverResampler
20 |
21 | perceive = PerceiverResampler(
22 | dim = 1024,
23 | depth = 2,
24 | dim_head = 64,
25 | heads = 8,
26 | num_latents = 64, # the number of latents to shrink your media sequence to, perceiver style
27 | num_time_embeds = 4 # say you have 4 images maximum in your dialogue
28 | )
29 |
30 | medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
31 | perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)
32 | ```
33 |
34 | Then you insert the `GatedCrossAttentionBlock` at different intervals in your giant language model. Your text would then attend to the perceived media from above
35 |
36 | The recommended way to derive the `media_locations` boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do `media_locations = text_id == media_token_id`
37 |
38 | ```python
39 | import torch
40 | from flamingo_pytorch import GatedCrossAttentionBlock
41 |
42 | cross_attn = GatedCrossAttentionBlock(
43 | dim = 1024,
44 | dim_head = 64,
45 | heads = 8
46 | )
47 |
48 | text = torch.randn(1, 512, 1024)
49 | perceived = torch.randn(1, 2, 64, 1024)
50 |
51 | media_locations = torch.randint(0, 2, (1, 512)).bool()
52 |
53 | text = cross_attn(
54 | text,
55 | perceived,
56 | media_locations = media_locations
57 | )
58 | ```
59 |
60 | That's it!
61 |
62 | Attention is all you need.
63 |
64 | ## Full working example with Flamingo + PaLM 🌴🦩🌴
65 |
66 | Integration with PaLM
67 |
68 | First install `vit-pytorch` for the vision encoder
69 |
70 | ```bash
71 | $ pip install vit-pytorch
72 | ```
73 |
74 | Then
75 |
76 | ```python
77 | from vit_pytorch.vit import ViT
78 | from vit_pytorch.extractor import Extractor
79 |
80 | vit = ViT(
81 | image_size = 256,
82 | patch_size = 32,
83 | num_classes = 1000,
84 | dim = 1024,
85 | depth = 6,
86 | heads = 16,
87 | mlp_dim = 2048,
88 | dropout = 0.1,
89 | emb_dropout = 0.1
90 | )
91 |
92 | vit = Extractor(vit, return_embeddings_only = True)
93 |
94 | # first take your trained image encoder and wrap it in an adapter that returns the image embeddings
95 | # here we use the ViT from the vit-pytorch library
96 |
97 | import torch
98 | from flamingo_pytorch import FlamingoPaLM
99 |
100 | # a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence
101 |
102 | flamingo_palm = FlamingoPaLM(
103 | num_tokens = 20000, # number of tokens
104 | dim = 1024, # dimensions
105 | depth = 12, # depth
106 | heads = 8, # attention heads
107 | dim_head = 64, # dimension per attention head
108 | img_encoder = vit, # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
109 | media_token_id = 3, # the token id representing the [media] or [image]
110 | cross_attn_every = 3, # how often to cross attend
111 | perceiver_num_latents = 64, # perceiver number of latents, should be smaller than the sequence length of the image tokens
112 | perceiver_depth = 2 # perceiver resampler depth
113 | )
114 |
115 | # train your PaLM as usual
116 |
117 | text = torch.randint(0, 20000, (2, 512))
118 |
119 | palm_logits = flamingo_palm(text)
120 |
121 | # after much training off the regular PaLM logits
122 | # now you are ready to train Flamingo + PaLM
123 | # by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper
124 |
125 | dialogue = torch.randint(0, 20000, (4, 512))
126 | images = torch.randn(4, 2, 3, 256, 256)
127 |
128 | flamingo_logits = flamingo_palm(dialogue, images)
129 |
130 | # do your usual cross entropy loss
131 | ```
132 |
133 | It is quite evident where this is all headed if you think beyond just images.
134 |
135 | ## Inception
136 |
137 | For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.
138 |
139 | ## Citations
140 |
141 | ```bibtex
142 | @article{Alayrac2022Flamingo,
143 | title = {Flamingo: a Visual Language Model for Few-Shot Learning},
144 | author = {Jean-Baptiste Alayrac et al},
145 | year = {2022}
146 | }
147 | ```
148 |
149 | ```bibtex
150 | @inproceedings{Chowdhery2022PaLMSL,
151 | title = {PaLM: Scaling Language Modeling with Pathways},
152 | author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
153 | year = {2022}
154 | }
155 | ```
156 |
157 |
--------------------------------------------------------------------------------
/flamingo_pytorch/flamingo_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, repeat
6 | from einops_exts import rearrange_many, repeat_many
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 | def FeedForward(dim, mult = 4):
12 | inner_dim = int(dim * mult)
13 | return nn.Sequential(
14 | nn.LayerNorm(dim),
15 | nn.Linear(dim, inner_dim, bias = False),
16 | nn.GELU(),
17 | nn.Linear(inner_dim, dim, bias = False)
18 | )
19 |
20 | class PerceiverAttention(nn.Module):
21 | def __init__(
22 | self,
23 | *,
24 | dim,
25 | dim_head = 64,
26 | heads = 8
27 | ):
28 | super().__init__()
29 | self.scale = dim_head ** -0.5
30 | self.heads = heads
31 | inner_dim = dim_head * heads
32 |
33 | self.norm_media = nn.LayerNorm(dim)
34 | self.norm_latents = nn.LayerNorm(dim)
35 |
36 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
37 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
38 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
39 |
40 | def forward(self, x, latents):
41 | """
42 | einstein notation
43 | b - batch
44 | t - time
45 | n - sequence
46 | d - dimension
47 | """
48 | x = self.norm_media(x)
49 | latents = self.norm_latents(latents)
50 |
51 | b, m, h = *x.shape[:2], self.heads
52 |
53 | q = self.to_q(latents)
54 |
55 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
56 | kv_input = torch.cat((x, latents), dim = -2)
57 | k, v = self.to_kv(kv_input).chunk(2, dim = -1)
58 |
59 | q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h)
60 |
61 | q = q * self.scale
62 |
63 | # attention
64 |
65 | sim = einsum('... i d, ... j d -> ... i j', q, k)
66 |
67 | sim = sim - sim.amax(dim = -1, keepdim = True).detach()
68 | attn = sim.softmax(dim = -1)
69 |
70 | out = einsum('... i j, ... j d -> ... i d', attn, v)
71 | out = rearrange(out, 'b h t n d -> b t n (h d)', h = h)
72 | return self.to_out(out)
73 |
74 | class PerceiverResampler(nn.Module):
75 | def __init__(
76 | self,
77 | *,
78 | dim,
79 | depth,
80 | dim_head = 64,
81 | heads = 8,
82 | num_latents = 64,
83 | num_media_embeds = 4,
84 | ff_mult = 4
85 | ):
86 | super().__init__()
87 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
88 | self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim))
89 |
90 | self.layers = nn.ModuleList([])
91 | for _ in range(depth):
92 | self.layers.append(nn.ModuleList([
93 | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
94 | FeedForward(dim = dim, mult = ff_mult)
95 | ]))
96 |
97 | self.norm = nn.LayerNorm(dim)
98 |
99 | def forward(self, x):
100 | if x.ndim == 3:
101 | x = rearrange(x, 'b n d -> b 1 n d')
102 |
103 | times = x.shape[1]
104 | x = x + self.media_pos_emb[:times]
105 |
106 | latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])
107 |
108 | for attn, ff in self.layers:
109 | latents = attn(x, latents) + latents
110 | latents = ff(latents) + latents
111 |
112 | return self.norm(latents)
113 |
114 | # gated cross attention
115 |
116 | class MaskedCrossAttention(nn.Module):
117 | def __init__(
118 | self,
119 | *,
120 | dim,
121 | dim_head = 64,
122 | heads = 8,
123 | only_attend_immediate_media = True
124 | ):
125 | super().__init__()
126 | self.scale = dim_head ** -0.5
127 | self.heads = heads
128 | inner_dim = dim_head * heads
129 |
130 | self.norm = nn.LayerNorm(dim)
131 |
132 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
133 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
134 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
135 |
136 | # whether for text to only attend to immediate preceding image, or all images
137 |
138 | self.only_attend_immediate_media = only_attend_immediate_media
139 |
140 | def forward(
141 | self,
142 | x,
143 | media,
144 | media_locations = None
145 | ):
146 | b, t, m = media.shape[:3]
147 | h = self.heads
148 |
149 | x = self.norm(x)
150 |
151 | q = self.to_q(x)
152 | media = rearrange(media, 'b t n d -> b (t n) d')
153 |
154 | k, v = self.to_kv(media).chunk(2, dim = -1)
155 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
156 |
157 | q = q * self.scale
158 |
159 | sim = einsum('... i d, ... j d -> ... i j', q, k)
160 |
161 | if exists(media_locations):
162 | text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time)
163 | media_time = torch.arange(t, device = x.device) + 1
164 |
165 | # text time must equal media time if only attending to most immediate image
166 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
167 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
168 |
169 | text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
170 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
171 |
172 | sim = sim - sim.amax(dim = -1, keepdim = True).detach()
173 | attn = sim.softmax(dim = -1)
174 |
175 | if exists(media_locations) and self.only_attend_immediate_media:
176 | # any text without a preceding media needs to have attention zeroed out
177 | text_without_media_mask = text_time == 0
178 | text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
179 | attn = attn.masked_fill(text_without_media_mask, 0.)
180 |
181 | out = einsum('... i j, ... j d -> ... i d', attn, v)
182 | out = rearrange(out, 'b h n d -> b n (h d)')
183 | return self.to_out(out)
184 |
185 | class GatedCrossAttentionBlock(nn.Module):
186 | def __init__(
187 | self,
188 | *,
189 | dim,
190 | dim_head = 64,
191 | heads = 8,
192 | ff_mult = 4,
193 | only_attend_immediate_media = True
194 | ):
195 | super().__init__()
196 | self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
197 | self.attn_gate = nn.Parameter(torch.tensor([0.]))
198 |
199 | self.ff = FeedForward(dim, mult = ff_mult)
200 | self.ff_gate = nn.Parameter(torch.tensor([0.]))
201 |
202 | def forward(
203 | self,
204 | x,
205 | media, # media tensor, encoded by perceiver resample - (batch, time, latents, dim)
206 | media_locations = None # boolean tensor indicating positions of media - (batch, sequence)
207 | ):
208 | x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x
209 | x = self.ff(x) * self.ff_gate.tanh() + x
210 | return x
211 |
--------------------------------------------------------------------------------
/flamingo_pytorch/flamingo_palm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange, repeat
4 | from torch import einsum, nn
5 |
6 | from flamingo_pytorch.flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
7 |
8 | # helper functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # for controlling freezing during training of flamingo
14 |
15 | def set_module_requires_grad_(module, requires_grad):
16 | for param in module.parameters():
17 | param.requires_grad = requires_grad
18 |
19 | def freeze_all_layers_(module):
20 | set_module_requires_grad_(module, False)
21 |
22 | def unfreeze_all_layers_(module):
23 | set_module_requires_grad_(module, True)
24 |
25 | def freeze_model_and_make_eval_(model):
26 | model.eval()
27 | freeze_all_layers_(model)
28 |
29 | # normalization
30 | # they use layernorm without bias, something that pytorch does not offer
31 |
32 |
33 | class LayerNorm(nn.Module):
34 | def __init__(self, dim):
35 | super().__init__()
36 | self.gamma = nn.Parameter(torch.ones(dim))
37 | self.register_buffer("beta", torch.zeros(dim))
38 |
39 | def forward(self, x):
40 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
41 |
42 | # residual
43 |
44 |
45 | class Residual(nn.Module):
46 | def __init__(self, fn):
47 | super().__init__()
48 | self.fn = fn
49 |
50 | def forward(self, x):
51 | return self.fn(x) + x
52 |
53 |
54 | # rotary positional embedding
55 | # https://arxiv.org/abs/2104.09864
56 |
57 |
58 | class RotaryEmbedding(nn.Module):
59 | def __init__(self, dim):
60 | super().__init__()
61 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
62 | self.register_buffer("inv_freq", inv_freq)
63 |
64 | def forward(self, max_seq_len, *, device):
65 | seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
66 | freqs = einsum("i , j -> i j", seq, self.inv_freq)
67 | return torch.cat((freqs, freqs), dim=-1)
68 |
69 |
70 | def rotate_half(x):
71 | x = rearrange(x, "... (j d) -> ... j d", j=2)
72 | x1, x2 = x.unbind(dim=-2)
73 | return torch.cat((-x2, x1), dim=-1)
74 |
75 |
76 | def apply_rotary_pos_emb(pos, t):
77 | return (t * pos.cos()) + (rotate_half(t) * pos.sin())
78 |
79 |
80 | # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
81 | # https://arxiv.org/abs/2002.05202
82 |
83 |
84 | class SwiGLU(nn.Module):
85 | def forward(self, x):
86 | x, gate = x.chunk(2, dim=-1)
87 | return F.silu(gate) * x
88 |
89 |
90 | # parallel attention and feedforward with residual
91 | # discovered by Wang et al + EleutherAI from GPT-J fame
92 |
93 |
94 | class ParallelTransformerBlock(nn.Module):
95 | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
96 | super().__init__()
97 | self.norm = LayerNorm(dim)
98 |
99 | attn_inner_dim = dim_head * heads
100 | ff_inner_dim = dim * ff_mult
101 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
102 |
103 | self.heads = heads
104 | self.scale = dim_head**-0.5
105 | self.rotary_emb = RotaryEmbedding(dim_head)
106 |
107 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
108 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
109 |
110 | self.ff_out = nn.Sequential(
111 | SwiGLU(),
112 | nn.Linear(ff_inner_dim, dim, bias=False)
113 | )
114 |
115 | # for caching causal mask and rotary embeddings
116 |
117 | self.register_buffer("mask", None, persistent=False)
118 | self.register_buffer("pos_emb", None, persistent=False)
119 |
120 | def get_mask(self, n, device):
121 | if self.mask is not None and self.mask.shape[-1] >= n:
122 | return self.mask[:n, :n]
123 |
124 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
125 | self.register_buffer("mask", mask, persistent=False)
126 | return mask
127 |
128 | def get_rotary_embedding(self, n, device):
129 | if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
130 | return self.pos_emb[:n]
131 |
132 | pos_emb = self.rotary_emb(n, device=device)
133 | self.register_buffer("pos_emb", pos_emb, persistent=False)
134 | return pos_emb
135 |
136 | def forward(self, x):
137 | """
138 | einstein notation
139 | b - batch
140 | h - heads
141 | n, i, j - sequence length (base sequence length, source, target)
142 | d - feature dimension
143 | """
144 |
145 | n, device, h = x.shape[1], x.device, self.heads
146 |
147 | # pre layernorm
148 |
149 | x = self.norm(x)
150 |
151 | # attention queries, keys, values, and feedforward inner
152 |
153 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
154 |
155 | # split heads
156 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper
157 | # they found no performance loss past a certain scale, and more efficient decoding obviously
158 | # https://arxiv.org/abs/1911.02150
159 |
160 | q = rearrange(q, "b n (h d) -> b h n d", h=h)
161 |
162 | # rotary embeddings
163 |
164 | positions = self.get_rotary_embedding(n, device)
165 | q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
166 |
167 | # scale
168 |
169 | q = q * self.scale
170 |
171 | # similarity
172 |
173 | sim = einsum("b h i d, b j d -> b h i j", q, k)
174 |
175 | # causal mask
176 |
177 | causal_mask = self.get_mask(n, device)
178 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
179 |
180 | # attention
181 |
182 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
183 | attn = sim.softmax(dim=-1)
184 |
185 | # aggregate values
186 |
187 | out = einsum("b h i j, b j d -> b h i d", attn, v)
188 |
189 | # merge heads
190 |
191 | out = rearrange(out, "b h n d -> b n (h d)")
192 | return self.attn_out(out) + self.ff_out(ff)
193 |
194 |
195 | # transformer
196 |
197 |
198 | class FlamingoPaLM(nn.Module):
199 | def __init__(
200 | self,
201 | *,
202 | dim,
203 | num_tokens,
204 | depth,
205 | dim_head=64,
206 | heads=8,
207 | ff_mult=4,
208 | media_token_id=3,
209 | cross_attn_every=3,
210 | img_encoder=None,
211 | perceiver_num_latents=64,
212 | perceiver_depth=2,
213 | max_video_frames = None,
214 | only_attend_immediate_media=True
215 | ):
216 | super().__init__()
217 |
218 | self.token_emb = nn.Embedding(num_tokens, dim)
219 | self.media_token_id = media_token_id # you need to reserve a special token id for media
220 |
221 | self.video_frame_pos_emb = nn.Parameter(torch.randn(max_video_frames, dim)) if exists(max_video_frames) else None
222 |
223 | self.img_encoder = img_encoder
224 | freeze_model_and_make_eval_(self.img_encoder)
225 |
226 | self.perceiver_resampler = PerceiverResampler(
227 | dim=dim,
228 | depth=perceiver_depth,
229 | dim_head=dim_head,
230 | heads=heads,
231 | num_latents=perceiver_num_latents
232 | )
233 |
234 | self.layers = nn.ModuleList([])
235 | for ind in range(depth):
236 | self.layers.append(nn.ModuleList([
237 | Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
238 | GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media) if not (ind % cross_attn_every) else None
239 | ]))
240 |
241 | self.to_logits = nn.Sequential(
242 | LayerNorm(dim),
243 | nn.Linear(dim, num_tokens, bias=False)
244 | )
245 |
246 | # they used embedding weight tied projection out to logits, not common, but works
247 | self.to_logits[-1].weight = self.token_emb.weight
248 | nn.init.normal_(self.token_emb.weight, std=0.02)
249 |
250 | def forward(
251 | self,
252 | text,
253 | *,
254 | images=None,
255 | videos=None,
256 | embeds=None
257 | ):
258 | batch, device = text.shape[0], text.device
259 |
260 | flamingo_mode = any([exists(t) for t in (images, videos, embeds)])
261 |
262 | # automatically take care of freezing or unfreezing depending on what is passed in
263 |
264 | if flamingo_mode:
265 | # in flamingo mode, freeze everything but perceiver and gated cross attention
266 | freeze_all_layers_(self)
267 | unfreeze_all_layers_(self.perceiver_resampler)
268 | [unfreeze_all_layers_(cross_attn) for _, cross_attn in self.layers if exists(cross_attn)]
269 | else:
270 | unfreeze_all_layers_(self)
271 |
272 | # derive the media token ids (as a boolean tensor), for calculating the masked cross attention
273 |
274 | if flamingo_mode:
275 | media_locations = text == self.media_token_id
276 |
277 | text_tokens = self.token_emb(text)
278 |
279 | assert not (exists(embeds) and (exists(images) or exists(video)))
280 |
281 | # encode videos or images into embeddings
282 | # with the img_encoder passed in at init
283 | # it can also accept precomputed image embeddings
284 |
285 | if exists(images):
286 | assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
287 | images = rearrange(images, 'b t ... -> (b t) ...')
288 |
289 | with torch.no_grad():
290 | embeds = self.img_encoder(images)
291 |
292 | embeds = rearrange(embeds, '(b t) ... -> b t ...', b = batch)
293 |
294 | if exists(videos):
295 | assert exists(self.img_encoder), 'img_encoder must be passed in for automatic video encoding'
296 | batch, media, num_times, *_ = videos.shape
297 | videos = rearrange(videos, '... c h w -> (...) c h w')
298 |
299 | with torch.no_grad():
300 | embeds = self.img_encoder(videos)
301 |
302 | embeds = rearrange(embeds, '(b m t) ... -> b m t ...', b = batch, m = media, t = num_times)
303 |
304 | video_time_pos_emb = repeat(self.video_frame_pos_emb[:num_times], 't d -> b m t n d', b = batch, m = media, n = embeds.shape[-2])
305 | embeds = embeds + video_time_pos_emb
306 | embeds = rearrange(embeds, 'b m t n d -> b m (t n) d')
307 |
308 | if exists(embeds):
309 | embeds = self.perceiver_resampler(embeds)
310 |
311 |
312 | # go through layers
313 |
314 | for attn_ff, flamingo_cross_attn in self.layers:
315 | text_tokens = attn_ff(text_tokens)
316 |
317 | # if image embeds exist and flamingo cross attention set for the layer
318 | # do the cross attention
319 | if exists(flamingo_cross_attn) and exists(embeds):
320 | text_tokens = flamingo_cross_attn(
321 | text_tokens,
322 | embeds,
323 | media_locations = media_locations
324 | )
325 |
326 | return self.to_logits(text_tokens)
327 |
--------------------------------------------------------------------------------