├── flamingo.png ├── flamingo_pytorch ├── __init__.py ├── flamingo_pytorch.py └── flamingo_palm.py ├── setup.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── .gitignore └── README.md /flamingo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/flamingo-pytorch/HEAD/flamingo.png -------------------------------------------------------------------------------- /flamingo_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from flamingo_pytorch.flamingo_pytorch import PerceiverResampler, GatedCrossAttentionBlock 2 | from flamingo_pytorch.flamingo_palm import FlamingoPaLM 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'flamingo-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.2', 7 | license='MIT', 8 | description = 'Flamingo - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/flamingo-pytorch', 12 | long_description_content_type = 'text/markdown', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'visual question answering' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.4', 22 | 'einops-exts', 23 | 'torch>=1.6' 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 🦩 Flamingo - Pytorch 4 | 5 | Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks 6 | 7 | Yannic Kilcher presentation 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install flamingo-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from flamingo_pytorch import PerceiverResampler 20 | 21 | perceive = PerceiverResampler( 22 | dim = 1024, 23 | depth = 2, 24 | dim_head = 64, 25 | heads = 8, 26 | num_latents = 64, # the number of latents to shrink your media sequence to, perceiver style 27 | num_time_embeds = 4 # say you have 4 images maximum in your dialogue 28 | ) 29 | 30 | medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension) 31 | perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension) 32 | ``` 33 | 34 | Then you insert the `GatedCrossAttentionBlock` at different intervals in your giant language model. Your text would then attend to the perceived media from above 35 | 36 | The recommended way to derive the `media_locations` boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do `media_locations = text_id == media_token_id` 37 | 38 | ```python 39 | import torch 40 | from flamingo_pytorch import GatedCrossAttentionBlock 41 | 42 | cross_attn = GatedCrossAttentionBlock( 43 | dim = 1024, 44 | dim_head = 64, 45 | heads = 8 46 | ) 47 | 48 | text = torch.randn(1, 512, 1024) 49 | perceived = torch.randn(1, 2, 64, 1024) 50 | 51 | media_locations = torch.randint(0, 2, (1, 512)).bool() 52 | 53 | text = cross_attn( 54 | text, 55 | perceived, 56 | media_locations = media_locations 57 | ) 58 | ``` 59 | 60 | That's it! 61 | 62 | Attention is all you need. 63 | 64 | ## Full working example with Flamingo + PaLM 🌴🦩🌴 65 | 66 | Integration with PaLM 67 | 68 | First install `vit-pytorch` for the vision encoder 69 | 70 | ```bash 71 | $ pip install vit-pytorch 72 | ``` 73 | 74 | Then 75 | 76 | ```python 77 | from vit_pytorch.vit import ViT 78 | from vit_pytorch.extractor import Extractor 79 | 80 | vit = ViT( 81 | image_size = 256, 82 | patch_size = 32, 83 | num_classes = 1000, 84 | dim = 1024, 85 | depth = 6, 86 | heads = 16, 87 | mlp_dim = 2048, 88 | dropout = 0.1, 89 | emb_dropout = 0.1 90 | ) 91 | 92 | vit = Extractor(vit, return_embeddings_only = True) 93 | 94 | # first take your trained image encoder and wrap it in an adapter that returns the image embeddings 95 | # here we use the ViT from the vit-pytorch library 96 | 97 | import torch 98 | from flamingo_pytorch import FlamingoPaLM 99 | 100 | # a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence 101 | 102 | flamingo_palm = FlamingoPaLM( 103 | num_tokens = 20000, # number of tokens 104 | dim = 1024, # dimensions 105 | depth = 12, # depth 106 | heads = 8, # attention heads 107 | dim_head = 64, # dimension per attention head 108 | img_encoder = vit, # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler) 109 | media_token_id = 3, # the token id representing the [media] or [image] 110 | cross_attn_every = 3, # how often to cross attend 111 | perceiver_num_latents = 64, # perceiver number of latents, should be smaller than the sequence length of the image tokens 112 | perceiver_depth = 2 # perceiver resampler depth 113 | ) 114 | 115 | # train your PaLM as usual 116 | 117 | text = torch.randint(0, 20000, (2, 512)) 118 | 119 | palm_logits = flamingo_palm(text) 120 | 121 | # after much training off the regular PaLM logits 122 | # now you are ready to train Flamingo + PaLM 123 | # by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper 124 | 125 | dialogue = torch.randint(0, 20000, (4, 512)) 126 | images = torch.randn(4, 2, 3, 256, 256) 127 | 128 | flamingo_logits = flamingo_palm(dialogue, images) 129 | 130 | # do your usual cross entropy loss 131 | ``` 132 | 133 | It is quite evident where this is all headed if you think beyond just images. 134 | 135 | ## Inception 136 | 137 | For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base. 138 | 139 | ## Citations 140 | 141 | ```bibtex 142 | @article{Alayrac2022Flamingo, 143 | title = {Flamingo: a Visual Language Model for Few-Shot Learning}, 144 | author = {Jean-Baptiste Alayrac et al}, 145 | year = {2022} 146 | } 147 | ``` 148 | 149 | ```bibtex 150 | @inproceedings{Chowdhery2022PaLMSL, 151 | title = {PaLM: Scaling Language Modeling with Pathways}, 152 | author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel}, 153 | year = {2022} 154 | } 155 | ``` 156 | 157 | -------------------------------------------------------------------------------- /flamingo_pytorch/flamingo_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops_exts import rearrange_many, repeat_many 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | def FeedForward(dim, mult = 4): 12 | inner_dim = int(dim * mult) 13 | return nn.Sequential( 14 | nn.LayerNorm(dim), 15 | nn.Linear(dim, inner_dim, bias = False), 16 | nn.GELU(), 17 | nn.Linear(inner_dim, dim, bias = False) 18 | ) 19 | 20 | class PerceiverAttention(nn.Module): 21 | def __init__( 22 | self, 23 | *, 24 | dim, 25 | dim_head = 64, 26 | heads = 8 27 | ): 28 | super().__init__() 29 | self.scale = dim_head ** -0.5 30 | self.heads = heads 31 | inner_dim = dim_head * heads 32 | 33 | self.norm_media = nn.LayerNorm(dim) 34 | self.norm_latents = nn.LayerNorm(dim) 35 | 36 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 37 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 38 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 39 | 40 | def forward(self, x, latents): 41 | """ 42 | einstein notation 43 | b - batch 44 | t - time 45 | n - sequence 46 | d - dimension 47 | """ 48 | x = self.norm_media(x) 49 | latents = self.norm_latents(latents) 50 | 51 | b, m, h = *x.shape[:2], self.heads 52 | 53 | q = self.to_q(latents) 54 | 55 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to 56 | kv_input = torch.cat((x, latents), dim = -2) 57 | k, v = self.to_kv(kv_input).chunk(2, dim = -1) 58 | 59 | q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h) 60 | 61 | q = q * self.scale 62 | 63 | # attention 64 | 65 | sim = einsum('... i d, ... j d -> ... i j', q, k) 66 | 67 | sim = sim - sim.amax(dim = -1, keepdim = True).detach() 68 | attn = sim.softmax(dim = -1) 69 | 70 | out = einsum('... i j, ... j d -> ... i d', attn, v) 71 | out = rearrange(out, 'b h t n d -> b t n (h d)', h = h) 72 | return self.to_out(out) 73 | 74 | class PerceiverResampler(nn.Module): 75 | def __init__( 76 | self, 77 | *, 78 | dim, 79 | depth, 80 | dim_head = 64, 81 | heads = 8, 82 | num_latents = 64, 83 | num_media_embeds = 4, 84 | ff_mult = 4 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim)) 89 | 90 | self.layers = nn.ModuleList([]) 91 | for _ in range(depth): 92 | self.layers.append(nn.ModuleList([ 93 | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads), 94 | FeedForward(dim = dim, mult = ff_mult) 95 | ])) 96 | 97 | self.norm = nn.LayerNorm(dim) 98 | 99 | def forward(self, x): 100 | if x.ndim == 3: 101 | x = rearrange(x, 'b n d -> b 1 n d') 102 | 103 | times = x.shape[1] 104 | x = x + self.media_pos_emb[:times] 105 | 106 | latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1]) 107 | 108 | for attn, ff in self.layers: 109 | latents = attn(x, latents) + latents 110 | latents = ff(latents) + latents 111 | 112 | return self.norm(latents) 113 | 114 | # gated cross attention 115 | 116 | class MaskedCrossAttention(nn.Module): 117 | def __init__( 118 | self, 119 | *, 120 | dim, 121 | dim_head = 64, 122 | heads = 8, 123 | only_attend_immediate_media = True 124 | ): 125 | super().__init__() 126 | self.scale = dim_head ** -0.5 127 | self.heads = heads 128 | inner_dim = dim_head * heads 129 | 130 | self.norm = nn.LayerNorm(dim) 131 | 132 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 133 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 134 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 135 | 136 | # whether for text to only attend to immediate preceding image, or all images 137 | 138 | self.only_attend_immediate_media = only_attend_immediate_media 139 | 140 | def forward( 141 | self, 142 | x, 143 | media, 144 | media_locations = None 145 | ): 146 | b, t, m = media.shape[:3] 147 | h = self.heads 148 | 149 | x = self.norm(x) 150 | 151 | q = self.to_q(x) 152 | media = rearrange(media, 'b t n d -> b (t n) d') 153 | 154 | k, v = self.to_kv(media).chunk(2, dim = -1) 155 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) 156 | 157 | q = q * self.scale 158 | 159 | sim = einsum('... i d, ... j d -> ... i j', q, k) 160 | 161 | if exists(media_locations): 162 | text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time) 163 | media_time = torch.arange(t, device = x.device) + 1 164 | 165 | # text time must equal media time if only attending to most immediate image 166 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 167 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 168 | 169 | text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m)) 170 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 171 | 172 | sim = sim - sim.amax(dim = -1, keepdim = True).detach() 173 | attn = sim.softmax(dim = -1) 174 | 175 | if exists(media_locations) and self.only_attend_immediate_media: 176 | # any text without a preceding media needs to have attention zeroed out 177 | text_without_media_mask = text_time == 0 178 | text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1') 179 | attn = attn.masked_fill(text_without_media_mask, 0.) 180 | 181 | out = einsum('... i j, ... j d -> ... i d', attn, v) 182 | out = rearrange(out, 'b h n d -> b n (h d)') 183 | return self.to_out(out) 184 | 185 | class GatedCrossAttentionBlock(nn.Module): 186 | def __init__( 187 | self, 188 | *, 189 | dim, 190 | dim_head = 64, 191 | heads = 8, 192 | ff_mult = 4, 193 | only_attend_immediate_media = True 194 | ): 195 | super().__init__() 196 | self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media) 197 | self.attn_gate = nn.Parameter(torch.tensor([0.])) 198 | 199 | self.ff = FeedForward(dim, mult = ff_mult) 200 | self.ff_gate = nn.Parameter(torch.tensor([0.])) 201 | 202 | def forward( 203 | self, 204 | x, 205 | media, # media tensor, encoded by perceiver resample - (batch, time, latents, dim) 206 | media_locations = None # boolean tensor indicating positions of media - (batch, sequence) 207 | ): 208 | x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x 209 | x = self.ff(x) * self.ff_gate.tanh() + x 210 | return x 211 | -------------------------------------------------------------------------------- /flamingo_pytorch/flamingo_palm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from torch import einsum, nn 5 | 6 | from flamingo_pytorch.flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler 7 | 8 | # helper functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # for controlling freezing during training of flamingo 14 | 15 | def set_module_requires_grad_(module, requires_grad): 16 | for param in module.parameters(): 17 | param.requires_grad = requires_grad 18 | 19 | def freeze_all_layers_(module): 20 | set_module_requires_grad_(module, False) 21 | 22 | def unfreeze_all_layers_(module): 23 | set_module_requires_grad_(module, True) 24 | 25 | def freeze_model_and_make_eval_(model): 26 | model.eval() 27 | freeze_all_layers_(model) 28 | 29 | # normalization 30 | # they use layernorm without bias, something that pytorch does not offer 31 | 32 | 33 | class LayerNorm(nn.Module): 34 | def __init__(self, dim): 35 | super().__init__() 36 | self.gamma = nn.Parameter(torch.ones(dim)) 37 | self.register_buffer("beta", torch.zeros(dim)) 38 | 39 | def forward(self, x): 40 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 41 | 42 | # residual 43 | 44 | 45 | class Residual(nn.Module): 46 | def __init__(self, fn): 47 | super().__init__() 48 | self.fn = fn 49 | 50 | def forward(self, x): 51 | return self.fn(x) + x 52 | 53 | 54 | # rotary positional embedding 55 | # https://arxiv.org/abs/2104.09864 56 | 57 | 58 | class RotaryEmbedding(nn.Module): 59 | def __init__(self, dim): 60 | super().__init__() 61 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 62 | self.register_buffer("inv_freq", inv_freq) 63 | 64 | def forward(self, max_seq_len, *, device): 65 | seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) 66 | freqs = einsum("i , j -> i j", seq, self.inv_freq) 67 | return torch.cat((freqs, freqs), dim=-1) 68 | 69 | 70 | def rotate_half(x): 71 | x = rearrange(x, "... (j d) -> ... j d", j=2) 72 | x1, x2 = x.unbind(dim=-2) 73 | return torch.cat((-x2, x1), dim=-1) 74 | 75 | 76 | def apply_rotary_pos_emb(pos, t): 77 | return (t * pos.cos()) + (rotate_half(t) * pos.sin()) 78 | 79 | 80 | # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward 81 | # https://arxiv.org/abs/2002.05202 82 | 83 | 84 | class SwiGLU(nn.Module): 85 | def forward(self, x): 86 | x, gate = x.chunk(2, dim=-1) 87 | return F.silu(gate) * x 88 | 89 | 90 | # parallel attention and feedforward with residual 91 | # discovered by Wang et al + EleutherAI from GPT-J fame 92 | 93 | 94 | class ParallelTransformerBlock(nn.Module): 95 | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): 96 | super().__init__() 97 | self.norm = LayerNorm(dim) 98 | 99 | attn_inner_dim = dim_head * heads 100 | ff_inner_dim = dim * ff_mult 101 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) 102 | 103 | self.heads = heads 104 | self.scale = dim_head**-0.5 105 | self.rotary_emb = RotaryEmbedding(dim_head) 106 | 107 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) 108 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) 109 | 110 | self.ff_out = nn.Sequential( 111 | SwiGLU(), 112 | nn.Linear(ff_inner_dim, dim, bias=False) 113 | ) 114 | 115 | # for caching causal mask and rotary embeddings 116 | 117 | self.register_buffer("mask", None, persistent=False) 118 | self.register_buffer("pos_emb", None, persistent=False) 119 | 120 | def get_mask(self, n, device): 121 | if self.mask is not None and self.mask.shape[-1] >= n: 122 | return self.mask[:n, :n] 123 | 124 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) 125 | self.register_buffer("mask", mask, persistent=False) 126 | return mask 127 | 128 | def get_rotary_embedding(self, n, device): 129 | if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: 130 | return self.pos_emb[:n] 131 | 132 | pos_emb = self.rotary_emb(n, device=device) 133 | self.register_buffer("pos_emb", pos_emb, persistent=False) 134 | return pos_emb 135 | 136 | def forward(self, x): 137 | """ 138 | einstein notation 139 | b - batch 140 | h - heads 141 | n, i, j - sequence length (base sequence length, source, target) 142 | d - feature dimension 143 | """ 144 | 145 | n, device, h = x.shape[1], x.device, self.heads 146 | 147 | # pre layernorm 148 | 149 | x = self.norm(x) 150 | 151 | # attention queries, keys, values, and feedforward inner 152 | 153 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) 154 | 155 | # split heads 156 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper 157 | # they found no performance loss past a certain scale, and more efficient decoding obviously 158 | # https://arxiv.org/abs/1911.02150 159 | 160 | q = rearrange(q, "b n (h d) -> b h n d", h=h) 161 | 162 | # rotary embeddings 163 | 164 | positions = self.get_rotary_embedding(n, device) 165 | q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) 166 | 167 | # scale 168 | 169 | q = q * self.scale 170 | 171 | # similarity 172 | 173 | sim = einsum("b h i d, b j d -> b h i j", q, k) 174 | 175 | # causal mask 176 | 177 | causal_mask = self.get_mask(n, device) 178 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 179 | 180 | # attention 181 | 182 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 183 | attn = sim.softmax(dim=-1) 184 | 185 | # aggregate values 186 | 187 | out = einsum("b h i j, b j d -> b h i d", attn, v) 188 | 189 | # merge heads 190 | 191 | out = rearrange(out, "b h n d -> b n (h d)") 192 | return self.attn_out(out) + self.ff_out(ff) 193 | 194 | 195 | # transformer 196 | 197 | 198 | class FlamingoPaLM(nn.Module): 199 | def __init__( 200 | self, 201 | *, 202 | dim, 203 | num_tokens, 204 | depth, 205 | dim_head=64, 206 | heads=8, 207 | ff_mult=4, 208 | media_token_id=3, 209 | cross_attn_every=3, 210 | img_encoder=None, 211 | perceiver_num_latents=64, 212 | perceiver_depth=2, 213 | max_video_frames = None, 214 | only_attend_immediate_media=True 215 | ): 216 | super().__init__() 217 | 218 | self.token_emb = nn.Embedding(num_tokens, dim) 219 | self.media_token_id = media_token_id # you need to reserve a special token id for media 220 | 221 | self.video_frame_pos_emb = nn.Parameter(torch.randn(max_video_frames, dim)) if exists(max_video_frames) else None 222 | 223 | self.img_encoder = img_encoder 224 | freeze_model_and_make_eval_(self.img_encoder) 225 | 226 | self.perceiver_resampler = PerceiverResampler( 227 | dim=dim, 228 | depth=perceiver_depth, 229 | dim_head=dim_head, 230 | heads=heads, 231 | num_latents=perceiver_num_latents 232 | ) 233 | 234 | self.layers = nn.ModuleList([]) 235 | for ind in range(depth): 236 | self.layers.append(nn.ModuleList([ 237 | Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), 238 | GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media) if not (ind % cross_attn_every) else None 239 | ])) 240 | 241 | self.to_logits = nn.Sequential( 242 | LayerNorm(dim), 243 | nn.Linear(dim, num_tokens, bias=False) 244 | ) 245 | 246 | # they used embedding weight tied projection out to logits, not common, but works 247 | self.to_logits[-1].weight = self.token_emb.weight 248 | nn.init.normal_(self.token_emb.weight, std=0.02) 249 | 250 | def forward( 251 | self, 252 | text, 253 | *, 254 | images=None, 255 | videos=None, 256 | embeds=None 257 | ): 258 | batch, device = text.shape[0], text.device 259 | 260 | flamingo_mode = any([exists(t) for t in (images, videos, embeds)]) 261 | 262 | # automatically take care of freezing or unfreezing depending on what is passed in 263 | 264 | if flamingo_mode: 265 | # in flamingo mode, freeze everything but perceiver and gated cross attention 266 | freeze_all_layers_(self) 267 | unfreeze_all_layers_(self.perceiver_resampler) 268 | [unfreeze_all_layers_(cross_attn) for _, cross_attn in self.layers if exists(cross_attn)] 269 | else: 270 | unfreeze_all_layers_(self) 271 | 272 | # derive the media token ids (as a boolean tensor), for calculating the masked cross attention 273 | 274 | if flamingo_mode: 275 | media_locations = text == self.media_token_id 276 | 277 | text_tokens = self.token_emb(text) 278 | 279 | assert not (exists(embeds) and (exists(images) or exists(video))) 280 | 281 | # encode videos or images into embeddings 282 | # with the img_encoder passed in at init 283 | # it can also accept precomputed image embeddings 284 | 285 | if exists(images): 286 | assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding' 287 | images = rearrange(images, 'b t ... -> (b t) ...') 288 | 289 | with torch.no_grad(): 290 | embeds = self.img_encoder(images) 291 | 292 | embeds = rearrange(embeds, '(b t) ... -> b t ...', b = batch) 293 | 294 | if exists(videos): 295 | assert exists(self.img_encoder), 'img_encoder must be passed in for automatic video encoding' 296 | batch, media, num_times, *_ = videos.shape 297 | videos = rearrange(videos, '... c h w -> (...) c h w') 298 | 299 | with torch.no_grad(): 300 | embeds = self.img_encoder(videos) 301 | 302 | embeds = rearrange(embeds, '(b m t) ... -> b m t ...', b = batch, m = media, t = num_times) 303 | 304 | video_time_pos_emb = repeat(self.video_frame_pos_emb[:num_times], 't d -> b m t n d', b = batch, m = media, n = embeds.shape[-2]) 305 | embeds = embeds + video_time_pos_emb 306 | embeds = rearrange(embeds, 'b m t n d -> b m (t n) d') 307 | 308 | if exists(embeds): 309 | embeds = self.perceiver_resampler(embeds) 310 | 311 | 312 | # go through layers 313 | 314 | for attn_ff, flamingo_cross_attn in self.layers: 315 | text_tokens = attn_ff(text_tokens) 316 | 317 | # if image embeds exist and flamingo cross attention set for the layer 318 | # do the cross attention 319 | if exists(flamingo_cross_attn) and exists(embeds): 320 | text_tokens = flamingo_cross_attn( 321 | text_tokens, 322 | embeds, 323 | media_locations = media_locations 324 | ) 325 | 326 | return self.to_logits(text_tokens) 327 | --------------------------------------------------------------------------------