├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── agorabanner.png ├── example.py ├── img_embeds.py ├── pali.png ├── pali ├── __init__.py ├── attend.py ├── autoregressive_wrapper.py ├── model.py └── transformer.py ├── pyproject.toml └── requirements.txt /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | zeta/utils/attention/__pycache__ 162 | zeta/utils/__pycache__ 163 | zeta/utils/__pycache__ 164 | zeta/__pycache__ 165 | zeta/utils/module/__pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eternal Reclaimer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) 2 | 3 | 4 | # PALI: A JOINTLY-SCALED MULTILINGUAL LANGUAGE-IMAGE MODEL 5 | ![pali](pali.png) 6 | 7 | [![GitHub issues](https://img.shields.io/github/issues/kyegomez/pali)](https://github.com/kyegomez/pali/issues) 8 | [![GitHub forks](https://img.shields.io/github/forks/kyegomez/pali)](https://github.com/kyegomez/pali/network) 9 | [![GitHub stars](https://img.shields.io/github/stars/kyegomez/pali)](https://github.com/kyegomez/pali/stargazers) [![GitHub license](https://img.shields.io/github/license/kyegomez/pali)](https://github.com/kyegomez/pali/blob/master/LICENSE) 10 | [![Share on Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Share%20%40kyegomez/pali)](https://twitter.com/intent/tweet?text=Excited%20to%20introduce%20pali,%20the%20all-new%20robotics%20model%20with%20the%20potential%20to%20revolutionize%20automation.%20Join%20us%20on%20this%20journey%20towards%20a%20smarter%20future.%20%23RT1%20%23Robotics&url=https%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali) 11 | [![Share on Facebook](https://img.shields.io/badge/Share-%20facebook-blue)](https://www.facebook.com/sharer/sharer.php?u=https%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali) 12 | [![Share on LinkedIn](https://img.shields.io/badge/Share-%20linkedin-blue)](https://www.linkedin.com/shareArticle?mini=true&url=https%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali&title=Introducing%20pali%2C%20the%20All-New%20Robotics%20Model&summary=pali%20is%20the%20next-generation%20robotics%20model%20that%20promises%20to%20transform%20industries%20with%20its%20intelligence%20and%20efficiency.%20Join%20us%20to%20be%20a%20part%20of%20this%20revolutionary%20journey%20%23RT1%20%23Robotics&source=) 13 | ![Discord](https://img.shields.io/discord/999382051935506503) 14 | [![Share on Reddit](https://img.shields.io/badge/-Share%20on%20Reddit-orange)](https://www.reddit.com/submit?url=https%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali&title=Exciting%20Times%20Ahead%20with%20pali%2C%20the%20All-New%20Robotics%20Model%20%23RT1%20%23Robotics) [![Share on Hacker News](https://img.shields.io/badge/-Share%20on%20Hacker%20News-orange)](https://news.ycombinator.com/submitlink?u=https%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali&t=Exciting%20Times%20Ahead%20with%20pali%2C%20the%20All-New%20Robotics%20Model%20%23RT1%20%23Robotics) 15 | [![Share on Pinterest](https://img.shields.io/badge/-Share%20on%20Pinterest-red)](https://pinterest.com/pin/create/button/?url=https%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali&media=https%3A%2F%2Fexample.com%2Fimage.jpg&description=pali%2C%20the%20Revolutionary%20Robotics%20Model%20that%20will%20Change%20the%20Way%20We%20Work%20%23RT1%20%23Robotics) 16 | [![Share on WhatsApp](https://img.shields.io/badge/-Share%20on%20WhatsApp-green)](https://api.whatsapp.com/send?text=I%20just%20discovered%20pali,%20the%20all-new%20robotics%20model%20that%20promises%20to%20revolutionize%20automation.%20Join%20me%20on%20this%20exciting%20journey%20towards%20a%20smarter%20future.%20%23RT1%20%23Robotics%0A%0Ahttps%3A%2F%2Fgithub.com%2Fkyegomez%2Fpali) 17 | 18 | 19 | The open source implementation of the Multi-Modality AI model from ["PaLI: Scaling Language-Image Learning in 100+ Languages"](https://arxiv.org/abs/2209.06794) The model architecture is text -> encoder -> decoder -> logits -> text. The Vision architecture is image -> vit -> embeddings -> encoder -> decoder -> logits -> text 20 | 21 | # **NOTE** 22 | - This is the base model architecture, no tokenizer or pretrained weights 23 | - To train, find tokenizer, like tokenmonster and patchify the images to make it compatible with example.py 24 | - We're utilizing an Encoder/Decoder as UL2 and a VIT model that embeds the image which is then injected into the text encoder decoder 25 | - If you would like to help train this model and release it open source please click on the agora banner and join the lab! 26 | 27 | ## 🌟 Appreciation 28 | Big bear hugs 🐻💖 to *LucidRains* for the fab x_transformers and for championing the open source AI cause. 29 | 30 | ## 🚀 Install 31 | 32 | ```bash 33 | pip install pali-torch 34 | ``` 35 | --- 36 | 37 | ## 🧙 Usage 38 | ```python 39 | import torch # Importing the torch library for tensor operations 40 | from pali import Pali # Importing the Pali class from the pali module 41 | 42 | model = Pali() # Creating an instance of the Pali class and assigning it to the variable 'model' 43 | 44 | img = torch.randn(1, 3, 256, 256) # Creating a random image tensor with shape (1, 3, 256, 256) 45 | # The shape represents (batch_size, channels, height, width) 46 | 47 | prompt = torch.randint(0, 256, (1, 1024)) # Creating a random text integer tensor with shape (1, 1024) 48 | # The shape represents (batch_size, sequence_length) 49 | 50 | output_text = torch.randint(0, 256, (1, 1024)) # Creating a random target text integer tensor with shape (1, 1024) 51 | # The shape represents (batch_size, sequence_length) 52 | 53 | out = model.forward(img, prompt, output_text, mask=None) # Calling the forward method of the 'model' instance 54 | # The forward method takes the image tensor, prompt tensor, output_text tensor, and an optional mask tensor as inputs 55 | # It performs computations and returns the output tensor 56 | 57 | print(out) # Printing the output tensor 58 | 59 | 60 | ``` 61 | 62 | 63 | ## Vit Image Embedder 64 | - To embed your images, you can use the vit model: 65 | 66 | ```python 67 | from PIL import Image 68 | from torchvision import transforms 69 | 70 | from pali.model import VitModel 71 | 72 | 73 | def img_to_tensor(img: str = "pali.png", img_size: int = 256): 74 | # Load image 75 | image = Image.open(img) 76 | 77 | # Define a transforms to convert the image to a tensor and apply preprocessing 78 | transform = transforms.Compose( 79 | [ 80 | transforms.Lambda(lambda image: image.convert("RGB")), 81 | transforms.Resize((img_size, img_size)), # Resize the image to 256x256 82 | transforms.ToTensor(), # Convert the image to a tensor, 83 | transforms.Normalize( 84 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 85 | ), # Normalize the pixel values 86 | ] 87 | ) 88 | 89 | # apply transforms to the image 90 | x = transform(image) 91 | 92 | # print(f"Image shape: {x.shape}") 93 | 94 | # Add batch dimension 95 | x = x.unsqueeze(0) 96 | print(x.shape) 97 | 98 | return x 99 | 100 | 101 | # Convert image to tensor 102 | x = img_to_tensor() 103 | 104 | # # Initialize model 105 | model = VitModel() 106 | 107 | # Forward pass 108 | out = model(x) 109 | 110 | # Print output shape 111 | print(out) 112 | 113 | 114 | ``` 115 | ---- 116 | 117 | # Datasets Strategy 118 | Dataset strategy as closely shown in the paper. 119 | 120 | Here is a markdown table with metadata and links to the datasets on HuggingFace for the datasets used: 121 | 122 | | Dataset | Description | Size | Languages | Link | 123 | |-|-|-|-|-| 124 | | WebLI | Large-scale web crawled image-text dataset | 10B images, 12B captions | 109 languages | Private | 125 | | CC3M | Conceptual Captions dataset | 3M image-text pairs | English | [Link](https://huggingface.co/datasets/conceptual_captions) | 126 | | CC3M-35L | Translated version of CC3M to 35 languages | 105M image-text pairs | 36 languages | Private | 127 | | VQAv2 | VQA dataset built on COCO images | 204K images, 1.1M QA pairs | English | [Link](https://huggingface.co/datasets/vqa_v2) | 128 | | VQ2A-CC3M | VQA dataset built from CC3M | 3M image-text pairs | English | Private | 129 | | VQ2A-CC3M-35L | Translated version of VQ2A-CC3M to 35 languages | 105M image-text pairs | 36 languages | Private | 130 | | Open Images | Large scale image dataset | 9M images with labels | English | [Link](https://huggingface.co/datasets/open_images_v4) | 131 | | Visual Genome | Image dataset with dense annotations | 108K images with annotations | English | [Link](https://huggingface.co/datasets/visual_genome) | 132 | | Object365 | Image dataset for object detection | 500K images with labels | English | Private | 133 | 134 | The key datasets used for pre-training PaLI include: 135 | 136 | - WebLI: A large-scale multilingual image-text dataset crawled from the web, comprising 10B images and 12B captions in 109 languages. 137 | 138 | - CC3M-35L: CC3M Conceptual Captions dataset machine translated into 35 additional languages, totaling 105M image-text pairs in 36 languages. 139 | 140 | - VQ2A-CC3M-35L: VQA dataset based on CC3M, also translated into 35 languages. 141 | 142 | The model was evaluated on diverse tasks using standard datasets like VQAv2, Open Images, COCO Captions etc. Links and details provided above. 143 | 144 | ---- 145 | 146 | 147 | ---- 148 | 149 | ## 🎉 Features 150 | - **Double the Power**: MT5 for text and ViT for images - Pali's the superhero we didn't know we needed! 💪📖🖼️ 151 | - **Winning Streak**: With roots in the tried-and-true MT5 & ViT, success is in Pali's DNA. 🏆 152 | - **Ready, Set, Go**: No fuss, no muss! Get Pali rolling in no time. ⏱️ 153 | - **Easy-Peasy**: Leave the heavy lifting to Pali and enjoy your smooth sailing. 🛳️ 154 | 155 | 156 | ## 🌆 Real-World Use-Cases 157 | 158 | - **E-commerce**: Jazz up those recs! Understand products inside-out with images & descriptions. 🛍️ 159 | - **Social Media**: Be the smart reply guru for posts with pics & captions. 📱 160 | - **Healthcare**: Boost diagnostics with insights from images & textual data. 🏥 161 | 162 | ---- 163 | 164 | ## 📚 Citation 165 | 166 | ``` 167 | @inproceedings{chen2022pali, 168 | title={PaLI: Scaling Language-Image Learning in 100+ Languages}, 169 | author={Chen, Xi and Wang, Xiao}, 170 | booktitle={Conference on Neural Information Processing Systems (NeurIPS)}, 171 | year={2022} 172 | } 173 | ``` 174 | 175 | # Todo 176 | 177 | - [x] Make a table of datasets used in paper, 178 | - [ ] Provide tokenizer integration 179 | - [ ] Provide training script 180 | - [ ] Provide usage/inference scripts 181 | 182 | ---- 183 | 184 | ## 📜 License 185 | MIT 186 | -------------------------------------------------------------------------------- /agorabanner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyegomez/PALI/c30d6910e90e1026ae317b0c64d5e9e9506519c7/agorabanner.png -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch # Importing the torch library for tensor operations 2 | from pali import Pali # Importing the Pali class from the pali module 3 | 4 | model = ( 5 | Pali() 6 | ) # Creating an instance of the Pali class and assigning it to the variable 'model' 7 | 8 | img = torch.randn( 9 | 1, 3, 256, 256 10 | ) # Creating a random image tensor with shape (1, 3, 256, 256) 11 | # The shape represents (batch_size, channels, height, width) 12 | 13 | prompt = torch.randint( 14 | 0, 256, (1, 1024) 15 | ) # Creating a random text integer tensor with shape (1, 1024) 16 | # The shape represents (batch_size, sequence_length) 17 | 18 | output_text = torch.randint( 19 | 0, 256, (1, 1024) 20 | ) # Creating a random target text integer tensor with shape (1, 1024) 21 | # The shape represents (batch_size, sequence_length) 22 | 23 | out = model.forward( 24 | img, prompt, output_text, mask=None 25 | ) # Calling the forward method of the 'model' instance 26 | # The forward method takes the image tensor, prompt tensor, output_text tensor, and an optional mask tensor as inputs 27 | # It performs computations and returns the output tensor 28 | 29 | print(out) # Printing the output tensor 30 | -------------------------------------------------------------------------------- /img_embeds.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | 4 | from pali.model import VitModel 5 | 6 | 7 | def img_to_tensor(img: str = "pali.png", img_size: int = 256): 8 | # Load image 9 | image = Image.open(img) 10 | 11 | # Define a transforms to convert the image to a tensor and apply preprocessing 12 | transform = transforms.Compose( 13 | [ 14 | transforms.Lambda(lambda image: image.convert("RGB")), 15 | transforms.Resize((img_size, img_size)), # Resize the image to 256x256 16 | transforms.ToTensor(), # Convert the image to a tensor, 17 | transforms.Normalize( 18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 19 | ), # Normalize the pixel values 20 | ] 21 | ) 22 | 23 | # apply transforms to the image 24 | x = transform(image) 25 | 26 | # print(f"Image shape: {x.shape}") 27 | 28 | # Add batch dimension 29 | x = x.unsqueeze(0) 30 | print(x.shape) 31 | 32 | return x 33 | 34 | 35 | # Convert image to tensor 36 | x = img_to_tensor() 37 | 38 | # # Initialize model 39 | model = VitModel() 40 | 41 | # Forward pass 42 | out = model(x) 43 | 44 | # Print output shape 45 | print(out) 46 | -------------------------------------------------------------------------------- /pali.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyegomez/PALI/c30d6910e90e1026ae317b0c64d5e9e9506519c7/pali.png -------------------------------------------------------------------------------- /pali/__init__.py: -------------------------------------------------------------------------------- 1 | from pali.model import VitModel, Pali 2 | from pali.transformer import UL2 3 | 4 | 5 | __all__ = [ 6 | "VitModel", 7 | "Pali", 8 | "UL2", 9 | ] 10 | -------------------------------------------------------------------------------- /pali/attend.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn, einsum, Tensor 5 | import torch.nn.functional as F 6 | 7 | from collections import namedtuple 8 | from functools import wraps 9 | from packaging import version 10 | from dataclasses import dataclass 11 | 12 | from einops import rearrange 13 | 14 | # constants 15 | 16 | EfficientAttentionConfig = namedtuple( 17 | "EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] 18 | ) 19 | 20 | 21 | @dataclass 22 | class Intermediates: 23 | qk_similarities: Tensor = None 24 | pre_softmax_attn: Tensor = None 25 | post_softmax_attn: Tensor = None 26 | 27 | def to_tuple(self): 28 | return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn) 29 | 30 | 31 | # helpers 32 | 33 | 34 | def exists(val): 35 | return val is not None 36 | 37 | 38 | def default(val, d): 39 | return val if exists(val) else d 40 | 41 | 42 | def compact(arr): 43 | return [*filter(exists, arr)] 44 | 45 | 46 | def once(fn): 47 | called = False 48 | 49 | @wraps(fn) 50 | def inner(x): 51 | nonlocal called 52 | if called: 53 | return 54 | called = True 55 | return fn(x) 56 | 57 | return inner 58 | 59 | 60 | print_once = once(print) 61 | 62 | # functions for creating causal mask 63 | # need a special one for onnx cpu (no support for .triu) 64 | 65 | 66 | def create_causal_mask(i, j, device): 67 | return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) 68 | 69 | 70 | def onnx_create_causal_mask(i, j, device): 71 | r = torch.arange(i, device=device) 72 | causal_mask = rearrange(r, "i -> i 1") < rearrange(r, "j -> 1 j") 73 | causal_mask = F.pad(causal_mask, (j - i, 0), value=False) 74 | return causal_mask 75 | 76 | 77 | # main class 78 | 79 | 80 | class Attend(nn.Module): 81 | def __init__( 82 | self, 83 | *, 84 | dropout=0.0, 85 | causal=False, 86 | heads=None, 87 | talking_heads=False, 88 | sparse_topk=None, 89 | scale=None, 90 | qk_norm=False, 91 | flash=False, 92 | add_zero_kv=False, 93 | onnxable=False, 94 | ): 95 | super().__init__() 96 | self.scale = scale 97 | self.qk_norm = qk_norm 98 | 99 | self.causal = causal 100 | self.create_causal_mask = ( 101 | onnx_create_causal_mask if onnxable else create_causal_mask 102 | ) 103 | 104 | self.attn_fn = ( 105 | partial(F.softmax, dtype=torch.float32) if not qk_norm else F.softmax 106 | ) 107 | 108 | self.dropout = dropout 109 | self.attn_dropout = nn.Dropout(dropout) 110 | 111 | # talking heads 112 | 113 | assert not ( 114 | flash and talking_heads 115 | ), "talking heads not compatible with flash attention" 116 | 117 | self.talking_heads = talking_heads 118 | if talking_heads: 119 | self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias=False) 120 | self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias=False) 121 | 122 | # sparse topk 123 | 124 | assert not ( 125 | flash and sparse_topk 126 | ), "sparse topk not compatible with flash attention" 127 | self.sparse_topk = sparse_topk 128 | 129 | # add a key / value token composed of zeros 130 | # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html 131 | 132 | self.add_zero_kv = add_zero_kv 133 | 134 | # flash attention 135 | 136 | self.flash = flash 137 | assert not ( 138 | flash and version.parse(torch.__version__) < version.parse("2.0.0") 139 | ), "in order to use flash attention, you must be using pytorch 2.0 or above" 140 | 141 | # determine efficient attention configs for cuda and cpu 142 | 143 | self.cpu_config = EfficientAttentionConfig(True, True, True) 144 | self.cuda_config = None 145 | 146 | if not torch.cuda.is_available() or not flash: 147 | return 148 | 149 | device_properties = torch.cuda.get_device_properties(torch.device("cuda")) 150 | 151 | if device_properties.major == 8 and device_properties.minor == 0: 152 | print_once( 153 | "A100 GPU detected, using flash attention if input tensor is on cuda" 154 | ) 155 | self.cuda_config = EfficientAttentionConfig(True, False, False) 156 | else: 157 | print_once( 158 | "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda" 159 | ) 160 | self.cuda_config = EfficientAttentionConfig(False, True, True) 161 | 162 | def flash_attn(self, q, k, v, mask=None, attn_bias=None): 163 | batch, heads, q_len, _, k_len, is_cuda, device = ( 164 | *q.shape, 165 | k.shape[-2], 166 | q.is_cuda, 167 | q.device, 168 | ) 169 | 170 | # Recommended for multi-query single-key-value attention by Tri Dao 171 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) 172 | 173 | if k.ndim == 3: 174 | k = rearrange(k, "b ... -> b 1 ...").expand_as(q) 175 | 176 | if v.ndim == 3: 177 | v = rearrange(v, "b ... -> b 1 ...").expand_as(q) 178 | 179 | # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention 180 | 181 | if self.qk_norm: 182 | default_scale = q.shape[-1] ** -0.5 183 | q = q * (default_scale / self.scale) 184 | 185 | # Check if mask exists and expand to compatible shape 186 | # The mask is B L, so it would have to be expanded to B H N L 187 | 188 | causal = self.causal 189 | 190 | if exists(mask): 191 | assert mask.ndim == 4 192 | mask = mask.expand(batch, heads, q_len, k_len) 193 | 194 | # manually handle causal mask, if another mask was given 195 | 196 | if causal: 197 | causal_mask = self.create_causal_mask(q_len, k_len, device=device) 198 | mask = mask & ~causal_mask 199 | causal = False 200 | 201 | # handle alibi positional bias 202 | # convert from bool to float 203 | 204 | if exists(attn_bias): 205 | attn_bias = rearrange(attn_bias, "h i j -> 1 h i j").expand( 206 | batch, heads, -1, -1 207 | ) 208 | 209 | # if mask given, the mask would already contain the causal mask from above logic 210 | # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number 211 | 212 | mask_value = -torch.finfo(q.dtype).max 213 | 214 | if exists(mask): 215 | attn_bias = attn_bias.masked_fill(~mask, mask_value // 2) 216 | elif causal: 217 | causal_mask = self.create_causal_mask(q_len, k_len, device=device) 218 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2) 219 | causal = False 220 | 221 | # scaled_dot_product_attention handles attn_mask either as bool or additive bias 222 | # make it an additive bias here 223 | 224 | mask = attn_bias 225 | 226 | # Check if there is a compatible device for flash attention 227 | 228 | config = self.cuda_config if is_cuda else self.cpu_config 229 | 230 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 231 | 232 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 233 | out = F.scaled_dot_product_attention( 234 | q, 235 | k, 236 | v, 237 | attn_mask=mask, 238 | dropout_p=self.dropout if self.training else 0.0, 239 | is_causal=causal, 240 | ) 241 | 242 | return out, Intermediates() 243 | 244 | def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None): 245 | """ 246 | einstein notation 247 | b - batch 248 | h - heads 249 | n, i, j - sequence length (base sequence length, source, target) 250 | d - feature dimension 251 | """ 252 | 253 | n, device = q.shape[-2], q.device 254 | 255 | scale = default(self.scale, q.shape[-1] ** -0.5) 256 | 257 | if self.add_zero_kv: 258 | k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value=0.0), (k, v)) 259 | 260 | if exists(mask): 261 | mask = F.pad(mask, (1, 0), value=True) 262 | 263 | if exists(attn_bias): 264 | attn_bias = F.pad(attn_bias, (1, 0), value=0.0) 265 | 266 | if self.flash: 267 | assert not exists( 268 | prev_attn 269 | ), "residual attention not compatible with flash attention" 270 | return self.flash_attn(q, k, v, mask=mask, attn_bias=attn_bias) 271 | 272 | kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" 273 | 274 | dots = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale 275 | 276 | if exists(prev_attn): 277 | dots = dots + prev_attn 278 | 279 | qk_similarities = dots.clone() 280 | 281 | if self.talking_heads: 282 | dots = self.pre_softmax_talking_heads(dots) 283 | 284 | if exists(attn_bias): 285 | dots = dots + attn_bias 286 | 287 | i, j, dtype = *dots.shape[-2:], dots.dtype 288 | pre_softmax_attn = dots.clone() 289 | 290 | mask_value = -torch.finfo(dots.dtype).max 291 | 292 | if exists(self.sparse_topk) and self.sparse_topk < j: 293 | top_values, _ = dots.topk(self.sparse_topk, dim=-1) 294 | sparse_topk_mask = dots < top_values[..., -1:] 295 | mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask 296 | 297 | if exists(mask): 298 | dots = dots.masked_fill(~mask, mask_value) 299 | 300 | if self.causal: 301 | causal_mask = self.create_causal_mask(i, j, device=device) 302 | dots = dots.masked_fill(causal_mask, mask_value) 303 | 304 | attn = self.attn_fn(dots, dim=-1) 305 | attn = attn.type(dtype) 306 | 307 | post_softmax_attn = attn.clone() 308 | 309 | attn = self.attn_dropout(attn) 310 | 311 | if self.talking_heads: 312 | attn = self.post_softmax_talking_heads(attn) 313 | 314 | out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) 315 | 316 | intermediates = Intermediates( 317 | qk_similarities=qk_similarities, 318 | pre_softmax_attn=pre_softmax_attn, 319 | post_softmax_attn=post_softmax_attn, 320 | ) 321 | 322 | return out, intermediates 323 | 324 | 325 | # cascading heads logic 326 | 327 | 328 | def to_single_heads(t, dim=1): 329 | heads = t.unbind(dim=dim) 330 | return tuple(head.unsqueeze(dim) for head in heads) 331 | 332 | 333 | class CascadingHeads(nn.Module): 334 | def __init__(self, attend: Attend): 335 | super().__init__() 336 | self.attend = attend 337 | 338 | def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None): 339 | assert ( 340 | q.shape[-1] == v.shape[-1] 341 | ), "cascading heads can only be done if query / key and value head dimensions are the same" 342 | 343 | # split inputs into per-head inputs 344 | 345 | heads = q.shape[1] 346 | 347 | queries = to_single_heads(q) 348 | keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads) 349 | values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads) 350 | 351 | mask = (mask,) * heads 352 | 353 | attn_bias = ( 354 | to_single_heads(attn_bias, dim=0) 355 | if exists(attn_bias) 356 | else ((None,) * heads) 357 | ) 358 | prev_attn = ( 359 | to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads) 360 | ) 361 | 362 | # now loop through each head, without output of previous head summed with the next head 363 | # thus cascading 364 | 365 | all_outs = [] 366 | all_intermediates = [] 367 | 368 | prev_head_out = None 369 | 370 | for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip( 371 | queries, keys, values, mask, attn_bias, prev_attn 372 | ): 373 | if exists(prev_head_out): 374 | h_q = h_q + prev_head_out 375 | 376 | out, intermediates = self.attend( 377 | h_q, h_k, h_v, mask=h_mask, attn_bias=h_attn_bias, prev_attn=h_prev_attn 378 | ) 379 | 380 | prev_head_out = out 381 | 382 | all_outs.append(out) 383 | all_intermediates.append(intermediates) 384 | 385 | # cat all output heads 386 | 387 | all_outs = torch.cat(all_outs, dim=1) 388 | 389 | # cat all intermediates, if they exist 390 | 391 | qk_similarities, pre_softmax_attn, post_softmax_attn = zip( 392 | *map(lambda i: i.to_tuple(), all_intermediates) 393 | ) 394 | 395 | qk_similarities, pre_softmax_attn, post_softmax_attn = map( 396 | compact, (qk_similarities, pre_softmax_attn, post_softmax_attn) 397 | ) 398 | 399 | aggregated_intermediates = Intermediates( 400 | qk_similarities=torch.cat(qk_similarities, dim=1) 401 | if len(qk_similarities) > 0 402 | else None, 403 | pre_softmax_attn=torch.cat(pre_softmax_attn, dim=1) 404 | if len(pre_softmax_attn) > 0 405 | else None, 406 | post_softmax_attn=torch.cat(post_softmax_attn, dim=1) 407 | if len(post_softmax_attn) > 0 408 | else None, 409 | ) 410 | 411 | return all_outs, aggregated_intermediates 412 | -------------------------------------------------------------------------------- /pali/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, pack, unpack 7 | 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | 13 | def eval_decorator(fn): 14 | def inner(self, *args, **kwargs): 15 | was_training = self.training 16 | self.eval() 17 | out = fn(self, *args, **kwargs) 18 | self.train(was_training) 19 | return out 20 | 21 | return inner 22 | 23 | 24 | # nucleus 25 | 26 | 27 | def top_p(logits, thres=0.9): 28 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 29 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 30 | 31 | sorted_indices_to_remove = cum_probs > (1 - thres) 32 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 33 | sorted_indices_to_remove[:, 0] = 0 34 | 35 | sorted_logits[sorted_indices_to_remove] = float("-inf") 36 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 37 | 38 | 39 | # topk 40 | 41 | 42 | def top_k(logits, thres=0.9): 43 | k = ceil((1 - thres) * logits.shape[-1]) 44 | val, ind = torch.topk(logits, k) 45 | probs = torch.full_like(logits, float("-inf")) 46 | probs.scatter_(1, ind, val) 47 | return probs 48 | 49 | 50 | # top_a 51 | 52 | 53 | def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): 54 | probs = F.softmax(logits, dim=-1) 55 | limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio 56 | logits[probs < limit] = float("-inf") 57 | logits[probs >= limit] = 1 58 | return logits 59 | 60 | 61 | # autoregressive wrapper class 62 | 63 | 64 | class AutoregressiveWrapper(nn.Module): 65 | def __init__(self, net, ignore_index=-100, pad_value=0, mask_prob=0.0): 66 | super().__init__() 67 | self.pad_value = pad_value 68 | self.ignore_index = ignore_index 69 | 70 | self.net = net 71 | self.max_seq_len = net.max_seq_len 72 | 73 | # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432 74 | assert mask_prob < 1.0 75 | self.mask_prob = mask_prob 76 | 77 | @torch.no_grad() 78 | @eval_decorator 79 | def generate( 80 | self, 81 | start_tokens, 82 | seq_len, 83 | eos_token=None, 84 | temperature=1.0, 85 | filter_logits_fn=top_k, 86 | filter_thres=0.9, 87 | min_p_pow=2.0, 88 | min_p_ratio=0.02, 89 | **kwargs 90 | ): 91 | start_tokens, ps = pack([start_tokens], "* n") 92 | 93 | b, t = start_tokens.shape 94 | 95 | out = start_tokens 96 | 97 | for _ in range(seq_len): 98 | x = out[:, -self.max_seq_len :] 99 | 100 | logits = self.net(x, **kwargs)[:, -1] 101 | 102 | if filter_logits_fn in {top_k, top_p}: 103 | filtered_logits = filter_logits_fn(logits, thres=filter_thres) 104 | probs = F.softmax(filtered_logits / temperature, dim=-1) 105 | 106 | elif filter_logits_fn is top_a: 107 | filtered_logits = filter_logits_fn( 108 | logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio 109 | ) 110 | probs = F.softmax(filtered_logits / temperature, dim=-1) 111 | 112 | sample = torch.multinomial(probs, 1) 113 | 114 | out = torch.cat((out, sample), dim=-1) 115 | 116 | if exists(eos_token): 117 | is_eos_tokens = out == eos_token 118 | 119 | if is_eos_tokens.any(dim=-1).all(): 120 | # mask out everything after the eos tokens 121 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 122 | mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 123 | out = out.masked_fill(mask, self.pad_value) 124 | break 125 | 126 | out = out[:, t:] 127 | 128 | (out,) = unpack(out, ps, "* n") 129 | 130 | return out 131 | 132 | def forward(self, x, **kwargs): 133 | seq, ignore_index = x.shape[1], self.ignore_index 134 | 135 | inp, target = x[:, :-1], x[:, 1:] 136 | inp = torch.where(inp == ignore_index, self.pad_value, inp) 137 | 138 | if self.mask_prob > 0.0: 139 | rand = torch.randn(inp.shape, device=x.device) 140 | rand[:, 0] = -torch.finfo( 141 | rand.dtype 142 | ).max # first token should not be masked out 143 | num_mask = min(int(seq * self.mask_prob), seq - 1) 144 | indices = rand.topk(num_mask, dim=-1).indices 145 | mask = ~torch.zeros_like(inp).scatter(1, indices, 1.0).bool() 146 | kwargs.update(self_attn_context_mask=mask) 147 | 148 | logits = self.net(inp, **kwargs) 149 | 150 | loss = F.cross_entropy( 151 | rearrange(logits, "b n c -> b c n"), target, ignore_index=ignore_index 152 | ) 153 | 154 | return loss 155 | -------------------------------------------------------------------------------- /pali/model.py: -------------------------------------------------------------------------------- 1 | from pali.transformer import ViTransformerWrapper, Encoder, UL2 2 | 3 | 4 | class VitModel: 5 | """ 6 | Vision Transformer Model. 7 | 8 | Args: 9 | image_size (int): The size of the input image (default: 256). 10 | patch_size (int): The size of each patch in the image (default: 32). 11 | dim (int): The dimension of the transformer model (default: 512). 12 | depth (int): The number of transformer layers (default: 6). 13 | heads (int): The number of attention heads in each transformer layer (default: 8). 14 | *args: Variable length argument list. 15 | **kwargs: Arbitrary keyword arguments. 16 | 17 | Attributes: 18 | image_size (int): The size of the input image. 19 | patch_size (int): The size of each patch in the image. 20 | dim (int): The dimension of the transformer model. 21 | depth (int): The number of transformer layers. 22 | heads (int): The number of attention heads in each transformer layer. 23 | vit (ViTransformerWrapper): The Vision Transformer model. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, image_size=256, patch_size=32, dim=512, depth=6, heads=8, *args, **kwargs 29 | ): 30 | self.image_size = image_size 31 | self.patch_size = patch_size 32 | self.dim = dim 33 | 34 | self.depth = depth 35 | self.heads = heads 36 | self.vit = ViTransformerWrapper( 37 | image_size=image_size, 38 | patch_size=patch_size, 39 | attn_layers=Encoder(dim=dim, depth=depth, heads=heads), 40 | ) 41 | 42 | def __call__(self, img): 43 | """ 44 | Perform forward pass through the Vision Transformer model. 45 | 46 | Args: 47 | img (torch.Tensor): The input image tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output embeddings from the Vision Transformer model. 51 | 52 | Raises: 53 | ValueError: If the input image is None or has an incorrect shape. 54 | 55 | """ 56 | if img is None: 57 | raise ValueError("Input image cannot be None") 58 | if img.shape[1:] != (3, self.image_size, self.image_size): 59 | raise ValueError( 60 | "Input image must have the shape [*, 3, {}, {}]".format( 61 | self.image_size, self.image_size 62 | ) 63 | ) 64 | 65 | return self.vit(img, return_embeddings=True) 66 | 67 | def forward(self, img): 68 | """ 69 | Perform forward pass through the Vision Transformer model. 70 | 71 | Args: 72 | img (torch.Tensor): The input image tensor. 73 | 74 | Returns: 75 | torch.Tensor: The output embeddings from the Vision Transformer model. 76 | 77 | Raises: 78 | ValueError: If the input image is None or has an incorrect shape. 79 | 80 | """ 81 | if img is None: 82 | raise ValueError("Input image cannot be None") 83 | if img.shape[1:] != (3, self.image_size, self.image_size): 84 | raise ValueError( 85 | "Input image must have the shape [*, 3, {}, {}]".format( 86 | self.image_size, self.image_size 87 | ) 88 | ) 89 | 90 | return self.vit(img, return_embeddings=True) 91 | 92 | 93 | class Pali: 94 | """ 95 | Pali class represents the PALI model. 96 | 97 | Args: 98 | model_name (str): The name of the model (optional). 99 | image_size (int): The size of the input image (default: 256). 100 | patch_size (int): The size of each patch in the image (default: 32). 101 | dim (int): The dimensionality of the model (default: 512). 102 | depth (int): The depth of the model (default: 6). 103 | heads (int): The number of attention heads in the model (default: 8). 104 | enc_num_tokens (int): The number of tokens in the encoder (default: 256). 105 | enc_max_seq_len (int): The maximum sequence length for the encoder (default: 1024). 106 | dec_num_tokens (int): The number of tokens in the decoder (default: 256). 107 | dec_max_seq_len (int): The maximum sequence length for the decoder (default: 1024). 108 | enc_depth (int): The depth of the encoder (default: 6). 109 | enc_heads (int): The number of attention heads in the encoder (default: 8). 110 | dec_depth (int): The depth of the decoder (default: 6). 111 | dec_heads (int): The number of attention heads in the decoder (default: 8). 112 | """ 113 | 114 | def __init__( 115 | self, 116 | model_name=None, 117 | image_size=256, 118 | patch_size=32, 119 | dim=512, 120 | depth=6, 121 | heads=8, 122 | enc_num_tokens=256, 123 | enc_max_seq_len=1024, 124 | dec_num_tokens=256, 125 | dec_max_seq_len=1024, 126 | enc_depth=6, 127 | enc_heads=8, 128 | dec_depth=6, 129 | dec_heads=8, 130 | ): 131 | self.tokenizer = None 132 | self.dim = dim 133 | self.vit_model = VitModel( 134 | image_size=image_size, 135 | patch_size=patch_size, 136 | dim=dim, 137 | depth=depth, 138 | heads=heads, 139 | ) 140 | 141 | self.ul = UL2( 142 | dim=dim, 143 | enc_num_tokens=enc_num_tokens, 144 | enc_depth=enc_depth, 145 | enc_heads=enc_heads, 146 | enc_max_seq_len=enc_max_seq_len, 147 | dec_num_tokens=dec_num_tokens, 148 | dec_depth=dec_depth, 149 | dec_heads=dec_heads, 150 | dec_max_seq_len=dec_max_seq_len, 151 | ) 152 | 153 | def forward(self, img, prompt, output, mask): 154 | """Get the image embeddings""" 155 | img_embeds = self.vit_model.forward(img) 156 | 157 | """Get the output text embeddings""" 158 | result = self.ul(prompt, output, mask=mask, src_prepend_embeds=img_embeds) 159 | 160 | # result = OutputHead(self.dim, -1)(result) 161 | 162 | return result 163 | -------------------------------------------------------------------------------- /pali/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from random import random 3 | 4 | import torch 5 | from torch import nn, einsum, Tensor 6 | import torch.nn.functional as F 7 | 8 | from functools import partial, wraps 9 | from inspect import isfunction 10 | from dataclasses import dataclass 11 | from typing import List 12 | 13 | from einops import rearrange, repeat 14 | 15 | from pali.attend import Attend, Intermediates, CascadingHeads 16 | from pali.autoregressive_wrapper import AutoregressiveWrapper 17 | 18 | # constants 19 | 20 | DEFAULT_DIM_HEAD = 64 21 | 22 | 23 | @dataclass 24 | class LayerIntermediates: 25 | hiddens: List[Tensor] = None 26 | attn_intermediates: List[Intermediates] = None 27 | 28 | 29 | # helpers 30 | 31 | 32 | def exists(val): 33 | return val is not None 34 | 35 | 36 | def default(val, d): 37 | if exists(val): 38 | return val 39 | return d() if isfunction(d) else d 40 | 41 | 42 | def cast_tuple(val, depth): 43 | return val if isinstance(val, tuple) else (val,) * depth 44 | 45 | 46 | def maybe(fn): 47 | @wraps(fn) 48 | def inner(x, *args, **kwargs): 49 | if not exists(x): 50 | return x 51 | return fn(x, *args, **kwargs) 52 | 53 | return inner 54 | 55 | 56 | class always: 57 | def __init__(self, val): 58 | self.val = val 59 | 60 | def __call__(self, *args, **kwargs): 61 | return self.val 62 | 63 | 64 | class not_equals: 65 | def __init__(self, val): 66 | self.val = val 67 | 68 | def __call__(self, x, *args, **kwargs): 69 | return x != self.val 70 | 71 | 72 | class equals: 73 | def __init__(self, val): 74 | self.val = val 75 | 76 | def __call__(self, x, *args, **kwargs): 77 | return x == self.val 78 | 79 | 80 | # tensor helpers 81 | 82 | 83 | def max_neg_value(tensor): 84 | return -torch.finfo(tensor.dtype).max 85 | 86 | 87 | def l2norm(t, groups=1): 88 | t = rearrange(t, "... (g d) -> ... g d", g=groups) 89 | t = F.normalize(t, p=2, dim=-1) 90 | return rearrange(t, "... g d -> ... (g d)") 91 | 92 | 93 | def pad_at_dim(t, pad, dim=-1, value=0.0): 94 | dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) 95 | zeros = (0, 0) * dims_from_right 96 | return F.pad(t, (*zeros, *pad), value=value) 97 | 98 | 99 | def or_reduce(masks): 100 | head, *body = masks 101 | for rest in body: 102 | head = head | rest 103 | return head 104 | 105 | 106 | # init helpers 107 | 108 | 109 | def init_zero_(layer): 110 | nn.init.constant_(layer.weight, 0.0) 111 | if exists(layer.bias): 112 | nn.init.constant_(layer.bias, 0.0) 113 | 114 | 115 | # keyword argument helpers 116 | 117 | 118 | def pick_and_pop(keys, d): 119 | values = list(map(lambda key: d.pop(key), keys)) 120 | return dict(zip(keys, values)) 121 | 122 | 123 | def group_dict_by_key(cond, d): 124 | return_val = [dict(), dict()] 125 | for key in d.keys(): 126 | match = bool(cond(key)) 127 | ind = int(not match) 128 | return_val[ind][key] = d[key] 129 | return (*return_val,) 130 | 131 | 132 | def string_begins_with(prefix, str): 133 | return str.startswith(prefix) 134 | 135 | 136 | def group_by_key_prefix(prefix, d): 137 | return group_dict_by_key(partial(string_begins_with, prefix), d) 138 | 139 | 140 | def groupby_prefix_and_trim(prefix, d): 141 | kwargs_with_prefix, kwargs = group_dict_by_key( 142 | partial(string_begins_with, prefix), d 143 | ) 144 | kwargs_without_prefix = dict( 145 | map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) 146 | ) 147 | return kwargs_without_prefix, kwargs 148 | 149 | 150 | # initializations 151 | 152 | 153 | def deepnorm_init( 154 | transformer, beta, module_name_match_list=[".ff.", ".to_v", ".to_out"] 155 | ): 156 | for name, module in transformer.named_modules(): 157 | if type(module) != nn.Linear: 158 | continue 159 | 160 | needs_beta_gain = any( 161 | map(lambda substr: substr in name, module_name_match_list) 162 | ) 163 | gain = beta if needs_beta_gain else 1 164 | nn.init.xavier_normal_(module.weight.data, gain=gain) 165 | 166 | if exists(module.bias): 167 | nn.init.constant_(module.bias.data, 0) 168 | 169 | 170 | # structured dropout, more effective than traditional attention dropouts 171 | 172 | 173 | def dropout_seq(seq, mask, dropout): 174 | b, n, *_, device = *seq.shape, seq.device 175 | logits = torch.randn(b, n, device=device) 176 | 177 | if exists(mask): 178 | mask_value = max_neg_value(logits) 179 | logits = logits.masked_fill(~mask, mask_value) 180 | 181 | keep_prob = 1.0 - dropout 182 | num_keep = max(1, int(keep_prob * n)) 183 | keep_indices = logits.topk(num_keep, dim=1).indices 184 | 185 | batch_indices = torch.arange(b, device=device) 186 | batch_indices = rearrange(batch_indices, "b -> b 1") 187 | 188 | seq = seq[batch_indices, keep_indices] 189 | 190 | if exists(mask): 191 | seq_counts = mask.sum(dim=-1) 192 | seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() 193 | keep_mask = torch.arange(num_keep, device=device) < rearrange( 194 | seq_keep_counts, "b -> b 1" 195 | ) 196 | 197 | mask = mask[batch_indices, keep_indices] & keep_mask 198 | 199 | return seq, mask 200 | 201 | 202 | # activations 203 | 204 | 205 | class ReluSquared(nn.Module): 206 | def forward(self, x): 207 | return F.relu(x) ** 2 208 | 209 | 210 | # embedding 211 | 212 | 213 | class TokenEmbedding(nn.Module): 214 | def __init__(self, dim, num_tokens, l2norm_embed=False): 215 | super().__init__() 216 | self.l2norm_embed = l2norm_embed 217 | self.emb = nn.Embedding(num_tokens, dim) 218 | 219 | def forward(self, x): 220 | token_emb = self.emb(x) 221 | return l2norm(token_emb) if self.l2norm_embed else token_emb 222 | 223 | 224 | # positional embeddings 225 | 226 | 227 | class AbsolutePositionalEmbedding(nn.Module): 228 | def __init__(self, dim, max_seq_len, l2norm_embed=False): 229 | super().__init__() 230 | self.scale = dim**-0.5 if not l2norm_embed else 1.0 231 | self.max_seq_len = max_seq_len 232 | self.l2norm_embed = l2norm_embed 233 | self.emb = nn.Embedding(max_seq_len, dim) 234 | 235 | def forward(self, x, pos=None): 236 | seq_len, device = x.shape[1], x.device 237 | assert ( 238 | seq_len <= self.max_seq_len 239 | ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" 240 | 241 | if not exists(pos): 242 | pos = torch.arange(seq_len, device=device) 243 | 244 | pos_emb = self.emb(pos) 245 | pos_emb = pos_emb * self.scale 246 | return l2norm(pos_emb) if self.l2norm_embed else pos_emb 247 | 248 | 249 | class ScaledSinusoidalEmbedding(nn.Module): 250 | def __init__(self, dim, theta=10000): 251 | super().__init__() 252 | assert (dim % 2) == 0 253 | self.scale = nn.Parameter(torch.ones(1) * dim**-0.5) 254 | 255 | half_dim = dim // 2 256 | freq_seq = torch.arange(half_dim).float() / half_dim 257 | inv_freq = theta**-freq_seq 258 | self.register_buffer("inv_freq", inv_freq, persistent=False) 259 | 260 | def forward(self, x, pos=None): 261 | seq_len, device = x.shape[1], x.device 262 | 263 | if not exists(pos): 264 | pos = torch.arange(seq_len, device=device) 265 | 266 | emb = einsum("i, j -> i j", pos, self.inv_freq) 267 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 268 | return emb * self.scale 269 | 270 | 271 | class RelativePositionBias(nn.Module): 272 | def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): 273 | super().__init__() 274 | self.scale = scale 275 | self.causal = causal 276 | self.num_buckets = num_buckets 277 | self.max_distance = max_distance 278 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 279 | 280 | @staticmethod 281 | def _relative_position_bucket( 282 | relative_position, causal=True, num_buckets=32, max_distance=128 283 | ): 284 | ret = 0 285 | n = -relative_position 286 | if not causal: 287 | num_buckets //= 2 288 | ret += (n < 0).long() * num_buckets 289 | n = torch.abs(n) 290 | else: 291 | n = torch.max(n, torch.zeros_like(n)) 292 | 293 | max_exact = num_buckets // 2 294 | is_small = n < max_exact 295 | 296 | val_if_large = ( 297 | max_exact 298 | + ( 299 | torch.log(n.float() / max_exact) 300 | / math.log(max_distance / max_exact) 301 | * (num_buckets - max_exact) 302 | ).long() 303 | ) 304 | val_if_large = torch.min( 305 | val_if_large, torch.full_like(val_if_large, num_buckets - 1) 306 | ) 307 | 308 | ret += torch.where(is_small, n, val_if_large) 309 | return ret 310 | 311 | @property 312 | def device(self): 313 | return next(self.parameters()).device 314 | 315 | def forward(self, i, j): 316 | device = self.device 317 | q_pos = torch.arange(j - i, j, dtype=torch.long, device=device) 318 | k_pos = torch.arange(j, dtype=torch.long, device=device) 319 | rel_pos = k_pos[None, :] - q_pos[:, None] 320 | rp_bucket = self._relative_position_bucket( 321 | rel_pos, 322 | causal=self.causal, 323 | num_buckets=self.num_buckets, 324 | max_distance=self.max_distance, 325 | ) 326 | values = self.relative_attention_bias(rp_bucket) 327 | bias = rearrange(values, "i j h -> h i j") 328 | return bias * self.scale 329 | 330 | 331 | class DynamicPositionBias(nn.Module): 332 | def __init__(self, dim, *, heads, depth, log_distance=False, norm=False): 333 | super().__init__() 334 | assert ( 335 | depth >= 1 336 | ), "depth for dynamic position bias MLP must be greater or equal to 1" 337 | self.log_distance = log_distance 338 | 339 | self.mlp = nn.ModuleList([]) 340 | 341 | self.mlp.append( 342 | nn.Sequential( 343 | nn.Linear(1, dim), 344 | nn.LayerNorm(dim) if norm else nn.Identity(), 345 | nn.SiLU(), 346 | ) 347 | ) 348 | 349 | for _ in range(depth - 1): 350 | self.mlp.append( 351 | nn.Sequential( 352 | nn.Linear(dim, dim), 353 | nn.LayerNorm(dim) if norm else nn.Identity(), 354 | nn.SiLU(), 355 | ) 356 | ) 357 | 358 | self.mlp.append(nn.Linear(dim, heads)) 359 | 360 | @property 361 | def device(self): 362 | return next(self.parameters()).device 363 | 364 | def forward(self, i, j): 365 | assert i == j 366 | n, device = j, self.device 367 | 368 | # get the (n x n) matrix of distances 369 | seq_arange = torch.arange(n, device=device) 370 | context_arange = torch.arange(n, device=device) 371 | indices = rearrange(seq_arange, "i -> i 1") - rearrange( 372 | context_arange, "j -> 1 j" 373 | ) 374 | indices += n - 1 375 | 376 | # input to continuous positions MLP 377 | pos = torch.arange(-n + 1, n, device=device).float() 378 | pos = rearrange(pos, "... -> ... 1") 379 | 380 | if self.log_distance: 381 | pos = torch.sign(pos) * torch.log( 382 | pos.abs() + 1 383 | ) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) 384 | 385 | for layer in self.mlp: 386 | pos = layer(pos) 387 | 388 | # get position biases 389 | bias = pos[indices] 390 | bias = rearrange(bias, "i j h -> h i j") 391 | return bias 392 | 393 | 394 | class AlibiPositionalBias(nn.Module): 395 | def __init__(self, heads, total_heads, **kwargs): 396 | super().__init__() 397 | self.heads = heads 398 | self.total_heads = total_heads 399 | 400 | slopes = Tensor(self._get_slopes(heads)) 401 | slopes = rearrange(slopes, "h -> h 1 1") 402 | self.register_buffer("slopes", slopes, persistent=False) 403 | self.register_buffer("bias", None, persistent=False) 404 | 405 | def get_bias(self, i, j, device): 406 | i_arange = torch.arange(j - i, j, device=device) 407 | j_arange = torch.arange(j, device=device) 408 | bias = -torch.abs( 409 | rearrange(j_arange, "j -> 1 1 j") - rearrange(i_arange, "i -> 1 i 1") 410 | ) 411 | return bias 412 | 413 | @staticmethod 414 | def _get_slopes(heads): 415 | def get_slopes_power_of_2(n): 416 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 417 | ratio = start 418 | return [start * ratio**i for i in range(n)] 419 | 420 | if math.log2(heads).is_integer(): 421 | return get_slopes_power_of_2(heads) 422 | 423 | closest_power_of_2 = 2 ** math.floor(math.log2(heads)) 424 | return ( 425 | get_slopes_power_of_2(closest_power_of_2) 426 | + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ 427 | : heads - closest_power_of_2 428 | ] 429 | ) 430 | 431 | @property 432 | def device(self): 433 | return next(self.buffers()).device 434 | 435 | def forward(self, i, j): 436 | h, device = self.total_heads, self.device 437 | 438 | if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: 439 | return self.bias[..., :i, :j] 440 | 441 | bias = self.get_bias(i, j, device) 442 | bias = bias * self.slopes 443 | 444 | num_heads_unalibied = h - bias.shape[0] 445 | bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0) 446 | self.register_buffer("bias", bias, persistent=False) 447 | 448 | return self.bias 449 | 450 | 451 | class LearnedAlibiPositionalBias(AlibiPositionalBias): 452 | def __init__(self, heads, total_heads): 453 | super().__init__(heads, total_heads) 454 | log_slopes = torch.log(self.slopes) 455 | self.learned_logslopes = nn.Parameter(log_slopes) 456 | 457 | def forward(self, i, j): 458 | h, device = self.heads, self.device 459 | 460 | def get_slopes(param): 461 | return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim=-2) 462 | 463 | if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: 464 | bias = self.bias[..., :i, :j] 465 | else: 466 | bias = self.get_bias(i, j, device) 467 | self.register_buffer("bias", bias, persistent=False) 468 | 469 | slopes = get_slopes(self.learned_logslopes) 470 | bias = bias * slopes 471 | 472 | return bias 473 | 474 | 475 | class RotaryEmbedding(nn.Module): 476 | def __init__( 477 | self, 478 | dim, 479 | use_xpos=False, 480 | scale_base=512, 481 | interpolation_factor=1.0, 482 | base=10000, 483 | base_rescale_factor=1.0, 484 | ): 485 | super().__init__() 486 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 487 | # has some connection to NTK literature 488 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 489 | base *= base_rescale_factor ** (dim / (dim - 2)) 490 | 491 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 492 | self.register_buffer("inv_freq", inv_freq) 493 | 494 | assert interpolation_factor >= 1.0 495 | self.interpolation_factor = interpolation_factor 496 | 497 | if not use_xpos: 498 | self.register_buffer("scale", None) 499 | return 500 | 501 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) 502 | 503 | self.scale_base = scale_base 504 | self.register_buffer("scale", scale) 505 | 506 | def forward(self, seq_len, device): 507 | t = torch.arange(seq_len, device=device).type_as(self.inv_freq) 508 | t = t / self.interpolation_factor 509 | 510 | freqs = torch.einsum("i , j -> i j", t, self.inv_freq) 511 | freqs = torch.cat((freqs, freqs), dim=-1) 512 | 513 | if not exists(self.scale): 514 | return freqs, 1.0 515 | 516 | power = ( 517 | torch.arange(seq_len, device=device) - (seq_len // 2) 518 | ) / self.scale_base 519 | scale = self.scale ** rearrange(power, "n -> n 1") 520 | scale = torch.cat((scale, scale), dim=-1) 521 | 522 | return freqs, scale 523 | 524 | 525 | def rotate_half(x): 526 | x = rearrange(x, "... (j d) -> ... j d", j=2) 527 | x1, x2 = x.unbind(dim=-2) 528 | return torch.cat((-x2, x1), dim=-1) 529 | 530 | 531 | def apply_rotary_pos_emb(t, freqs, scale=1): 532 | seq_len = t.shape[-2] 533 | freqs = freqs[-seq_len:, :] 534 | return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) 535 | 536 | 537 | # norms 538 | 539 | 540 | class Scale(nn.Module): 541 | def __init__(self, value, fn): 542 | super().__init__() 543 | self.value = value 544 | self.fn = fn 545 | 546 | def forward(self, x, **kwargs): 547 | out = self.fn(x, **kwargs) 548 | 549 | def scale_fn(t): 550 | return t * self.value 551 | 552 | if not isinstance(out, tuple): 553 | return scale_fn(out) 554 | 555 | return (scale_fn(out[0]), *out[1:]) 556 | 557 | 558 | class ScaleNorm(nn.Module): 559 | def __init__(self, dim, eps=1e-5): 560 | super().__init__() 561 | self.eps = eps 562 | self.g = nn.Parameter(torch.ones(1) * (dim**-0.5)) 563 | 564 | def forward(self, x): 565 | norm = torch.norm(x, dim=-1, keepdim=True) 566 | return x / norm.clamp(min=self.eps) * self.g 567 | 568 | 569 | class RMSNorm(nn.Module): 570 | def __init__(self, dim): 571 | super().__init__() 572 | self.scale = dim**0.5 573 | self.g = nn.Parameter(torch.ones(dim)) 574 | 575 | def forward(self, x): 576 | return F.normalize(x, dim=-1) * self.scale * self.g 577 | 578 | 579 | class SimpleRMSNorm(nn.Module): 580 | def __init__(self, dim): 581 | super().__init__() 582 | self.scale = dim**0.5 583 | 584 | def forward(self, x): 585 | return F.normalize(x, dim=-1) * self.scale 586 | 587 | 588 | # residual and residual gates 589 | 590 | 591 | class Residual(nn.Module): 592 | def __init__(self, dim, scale_residual=False, scale_residual_constant=1.0): 593 | super().__init__() 594 | self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None 595 | self.scale_residual_constant = scale_residual_constant 596 | 597 | def forward(self, x, residual): 598 | if exists(self.residual_scale): 599 | residual = residual * self.residual_scale 600 | 601 | if self.scale_residual_constant != 1: 602 | residual = residual * self.scale_residual_constant 603 | 604 | return x + residual 605 | 606 | 607 | class GRUGating(nn.Module): 608 | def __init__(self, dim, scale_residual=False, **kwargs): 609 | super().__init__() 610 | self.gru = nn.GRUCell(dim, dim) 611 | self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None 612 | 613 | def forward(self, x, residual): 614 | if exists(self.residual_scale): 615 | residual = residual * self.residual_scale 616 | 617 | gated_output = self.gru( 618 | rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d") 619 | ) 620 | 621 | return gated_output.reshape_as(x) 622 | 623 | 624 | # token shifting 625 | 626 | 627 | def shift(t, amount, mask=None): 628 | if amount == 0: 629 | return t 630 | else: 631 | amount = min(amount, t.shape[1]) 632 | 633 | if exists(mask): 634 | t = t.masked_fill(~mask[..., None], 0.0) 635 | 636 | return pad_at_dim(t, (amount, -amount), dim=-2, value=0.0) 637 | 638 | 639 | class ShiftTokens(nn.Module): 640 | def __init__(self, shifts, fn): 641 | super().__init__() 642 | self.fn = fn 643 | self.shifts = tuple(shifts) 644 | 645 | def forward(self, x, **kwargs): 646 | mask = kwargs.get("mask", None) 647 | shifts = self.shifts 648 | segments = len(shifts) 649 | feats_per_shift = x.shape[-1] // segments 650 | splitted = x.split(feats_per_shift, dim=-1) 651 | segments_to_shift, rest = splitted[:segments], splitted[segments:] 652 | segments_to_shift = list( 653 | map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)) 654 | ) 655 | x = torch.cat((*segments_to_shift, *rest), dim=-1) 656 | return self.fn(x, **kwargs) 657 | 658 | 659 | # feedforward 660 | 661 | 662 | class GLU(nn.Module): 663 | def __init__(self, dim_in, dim_out, activation): 664 | super().__init__() 665 | self.act = activation 666 | self.proj = nn.Linear(dim_in, dim_out * 2) 667 | 668 | def forward(self, x): 669 | x, gate = self.proj(x).chunk(2, dim=-1) 670 | return x * self.act(gate) 671 | 672 | 673 | class FeedForward(nn.Module): 674 | def __init__( 675 | self, 676 | dim, 677 | dim_out=None, 678 | mult=4, 679 | glu=False, 680 | swish=False, 681 | relu_squared=False, 682 | post_act_ln=False, 683 | dropout=0.0, 684 | no_bias=False, 685 | zero_init_output=False, 686 | ): 687 | super().__init__() 688 | inner_dim = int(dim * mult) 689 | dim_out = default(dim_out, dim) 690 | 691 | if relu_squared: 692 | activation = ReluSquared() 693 | elif swish: 694 | activation = nn.SiLU() 695 | else: 696 | activation = nn.GELU() 697 | 698 | project_in = ( 699 | nn.Sequential(nn.Linear(dim, inner_dim, bias=not no_bias), activation) 700 | if not glu 701 | else GLU(dim, inner_dim, activation) 702 | ) 703 | 704 | self.ff = nn.Sequential( 705 | project_in, 706 | nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), 707 | nn.Dropout(dropout), 708 | nn.Linear(inner_dim, dim_out, bias=not no_bias), 709 | ) 710 | 711 | # init last linear layer to 0 712 | if zero_init_output: 713 | init_zero_(self.ff[-1]) 714 | 715 | def forward(self, x): 716 | return self.ff(x) 717 | 718 | 719 | # attention. it is all we need 720 | 721 | 722 | class Attention(nn.Module): 723 | def __init__( 724 | self, 725 | dim, 726 | dim_head=DEFAULT_DIM_HEAD, 727 | heads=8, 728 | causal=False, 729 | flash=False, 730 | talking_heads=False, 731 | head_scale=False, 732 | sparse_topk=None, 733 | num_mem_kv=0, 734 | dropout=0.0, 735 | on_attn=False, 736 | gate_values=False, 737 | zero_init_output=False, 738 | max_attend_past=None, 739 | qk_norm=False, 740 | qk_norm_groups=1, 741 | qk_norm_scale=10, 742 | qk_norm_dim_scale=False, 743 | one_kv_head=False, 744 | shared_kv=False, 745 | value_dim_head=None, 746 | tensor_product=False, # https://arxiv.org/abs/2208.06061 747 | cascading_heads=False, 748 | add_zero_kv=False, # same as add_zero_attn in pytorch 749 | onnxable=False, 750 | ): 751 | super().__init__() 752 | self.scale = dim_head**-0.5 753 | 754 | self.heads = heads 755 | self.causal = causal 756 | self.max_attend_past = max_attend_past 757 | 758 | value_dim_head = default(value_dim_head, dim_head) 759 | q_dim = k_dim = dim_head * heads 760 | v_dim = out_dim = value_dim_head * heads 761 | 762 | self.one_kv_head = one_kv_head 763 | if one_kv_head: 764 | k_dim = dim_head 765 | v_dim = value_dim_head 766 | out_dim = v_dim * heads 767 | 768 | self.to_q = nn.Linear(dim, q_dim, bias=False) 769 | self.to_k = nn.Linear(dim, k_dim, bias=False) 770 | 771 | # shared key / values, for further memory savings during inference 772 | assert not ( 773 | shared_kv and value_dim_head != dim_head 774 | ), "key and value head dimensions must be equal for shared key / values" 775 | self.to_v = nn.Linear(dim, v_dim, bias=False) if not shared_kv else None 776 | 777 | # relations projection from tp-attention 778 | self.to_r = nn.Linear(dim, v_dim, bias=False) if tensor_product else None 779 | 780 | # add GLU gating for aggregated values, from alphafold2 781 | self.to_v_gate = None 782 | if gate_values: 783 | self.to_v_gate = nn.Linear(dim, out_dim) 784 | nn.init.constant_(self.to_v_gate.weight, 0) 785 | nn.init.constant_(self.to_v_gate.bias, 1) 786 | 787 | # cosine sim attention 788 | self.qk_norm = qk_norm 789 | self.qk_norm_groups = qk_norm_groups 790 | self.qk_norm_scale = qk_norm_scale 791 | 792 | # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442 793 | self.qk_norm_dim_scale = qk_norm_dim_scale 794 | 795 | self.qk_norm_q_scale = self.qk_norm_k_scale = 1 796 | if qk_norm and qk_norm_dim_scale: 797 | self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head)) 798 | self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head)) 799 | 800 | assert (not qk_norm) or ( 801 | dim_head % qk_norm_groups 802 | ) == 0, "dimension per attention head must be divisible by the qk norm groups" 803 | assert not ( 804 | qk_norm and (dim_head // qk_norm_groups) <= 2 805 | ), "the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)" 806 | 807 | # attend class - includes core attention algorithm + talking heads 808 | 809 | self.attend = Attend( 810 | heads=heads, 811 | causal=causal, 812 | talking_heads=talking_heads, 813 | dropout=dropout, 814 | sparse_topk=sparse_topk, 815 | qk_norm=qk_norm, 816 | scale=qk_norm_scale if qk_norm else self.scale, 817 | add_zero_kv=add_zero_kv, 818 | flash=flash, 819 | onnxable=onnxable, 820 | ) 821 | 822 | if cascading_heads: 823 | # cascading heads - wrap the Attend logic 824 | self.attend = CascadingHeads(self.attend) 825 | 826 | # head scaling 827 | self.head_scale = head_scale 828 | if head_scale: 829 | self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) 830 | 831 | # explicit topk sparse attention 832 | self.sparse_topk = sparse_topk 833 | 834 | # add memory key / values 835 | self.num_mem_kv = num_mem_kv 836 | if num_mem_kv > 0: 837 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 838 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 839 | 840 | # attention on attention 841 | self.attn_on_attn = on_attn 842 | self.to_out = ( 843 | nn.Sequential(nn.Linear(out_dim, dim * 2, bias=False), nn.GLU()) 844 | if on_attn 845 | else nn.Linear(out_dim, dim, bias=False) 846 | ) 847 | 848 | # init output projection 0 849 | if zero_init_output: 850 | init_zero_(self.to_out) 851 | 852 | def forward( 853 | self, 854 | x, 855 | context=None, 856 | mask=None, 857 | context_mask=None, 858 | attn_mask=None, 859 | rel_pos=None, 860 | rotary_pos_emb=None, 861 | prev_attn=None, 862 | mem=None, 863 | ): 864 | b, n, _, h, head_scale, device, has_context = ( 865 | *x.shape, 866 | self.heads, 867 | self.head_scale, 868 | x.device, 869 | exists(context), 870 | ) 871 | kv_input = default(context, x) 872 | 873 | q_input = x 874 | k_input = kv_input 875 | v_input = kv_input 876 | r_input = x 877 | 878 | if exists(mem): 879 | k_input = torch.cat((mem, k_input), dim=-2) 880 | v_input = torch.cat((mem, v_input), dim=-2) 881 | 882 | q = self.to_q(q_input) 883 | k = self.to_k(k_input) 884 | v = self.to_v(v_input) if exists(self.to_v) else k 885 | r = self.to_r(r_input) if exists(self.to_r) else None 886 | 887 | q = rearrange(q, "b n (h d) -> b h n d", h=h) 888 | 889 | if not self.one_kv_head: 890 | k, v, r = map( 891 | lambda t: maybe(rearrange)(t, "b n (h d) -> b h n d", h=h), (k, v, r) 892 | ) 893 | 894 | if self.qk_norm: 895 | qk_l2norm = partial(l2norm, groups=self.qk_norm_groups) 896 | q, k = map(qk_l2norm, (q, k)) 897 | 898 | q = q * self.qk_norm_q_scale 899 | k = k * self.qk_norm_k_scale 900 | 901 | if exists(rotary_pos_emb) and not has_context: 902 | freqs, xpos_scale = rotary_pos_emb 903 | l = freqs.shape[-1] 904 | 905 | q_xpos_scale, k_xpos_scale = ( 906 | (xpos_scale, xpos_scale**-1.0) if exists(xpos_scale) else (1.0, 1.0) 907 | ) 908 | (ql, qr), (kl, kr), (vl, vr) = map( 909 | lambda t: (t[..., :l], t[..., l:]), (q, k, v) 910 | ) 911 | 912 | ql, kl, vl = map( 913 | lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), 914 | ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)), 915 | ) 916 | q, k, v = map( 917 | lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)) 918 | ) 919 | 920 | input_mask = default(context_mask, mask) 921 | 922 | if self.num_mem_kv > 0: 923 | mem_k, mem_v = map( 924 | lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v) 925 | ) 926 | 927 | if self.qk_norm: 928 | mem_k = l2norm(mem_k) 929 | mem_k = mem_k * self.qk_norm_k_scale 930 | 931 | k = torch.cat((mem_k, k), dim=-2) 932 | v = torch.cat((mem_v, v), dim=-2) 933 | 934 | if exists(input_mask): 935 | input_mask = pad_at_dim( 936 | input_mask, (self.num_mem_kv, 0), dim=-1, value=True 937 | ) 938 | 939 | i, j = map(lambda t: t.shape[-2], (q, k)) 940 | 941 | # determine masking 942 | 943 | max_neg_value(q) 944 | masks = [] 945 | final_attn_mask = None 946 | 947 | if exists(input_mask): 948 | input_mask = rearrange(input_mask, "b j -> b 1 1 j") 949 | masks.append(~input_mask) 950 | 951 | if exists(attn_mask): 952 | assert ( 953 | 2 <= attn_mask.ndim <= 4 954 | ), "attention mask must have greater than 2 dimensions but less than or equal to 4" 955 | if attn_mask.ndim == 2: 956 | attn_mask = rearrange(attn_mask, "i j -> 1 1 i j") 957 | elif attn_mask.ndim == 3: 958 | attn_mask = rearrange(attn_mask, "h i j -> 1 h i j") 959 | masks.append(~attn_mask) 960 | 961 | if exists(self.max_attend_past): 962 | range_q = torch.arange(j - i, j, device=device) 963 | range_k = torch.arange(j, device=device) 964 | dist = rearrange(range_q, "i -> 1 1 i 1") - rearrange( 965 | range_k, "j -> 1 1 1 j" 966 | ) 967 | max_attend_past_mask = dist > self.max_attend_past 968 | masks.append(max_attend_past_mask) 969 | 970 | if len(masks) > 0: 971 | final_attn_mask = ~or_reduce(masks) 972 | 973 | # prepare relative positional bias, if needed 974 | 975 | attn_bias = None 976 | if exists(rel_pos): 977 | attn_bias = rel_pos(i, j) 978 | 979 | # attention is all we need 980 | 981 | out, intermediates = self.attend( 982 | q, k, v, mask=final_attn_mask, attn_bias=attn_bias, prev_attn=prev_attn 983 | ) 984 | 985 | # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients 986 | 987 | if exists(r): 988 | out = out * r + out 989 | 990 | # normformer scaling of heads 991 | 992 | if head_scale: 993 | out = out * self.head_scale_params 994 | 995 | # merge heads 996 | 997 | out = rearrange(out, "b h n d -> b n (h d)") 998 | 999 | # alphafold2 styled gating of the values 1000 | 1001 | if exists(self.to_v_gate): 1002 | gates = self.to_v_gate(x) 1003 | out = out * gates.sigmoid() 1004 | 1005 | # combine the heads 1006 | 1007 | out = self.to_out(out) 1008 | 1009 | if exists(mask): 1010 | mask = rearrange(mask, "b n -> b n 1") 1011 | out = out.masked_fill(~mask, 0.0) 1012 | 1013 | return out, intermediates 1014 | 1015 | 1016 | class AttentionLayers(nn.Module): 1017 | def __init__( 1018 | self, 1019 | dim, 1020 | depth, 1021 | heads=8, 1022 | causal=False, 1023 | cross_attend=False, 1024 | only_cross=False, 1025 | use_scalenorm=False, 1026 | use_rmsnorm=False, 1027 | use_simple_rmsnorm=False, 1028 | alibi_pos_bias=False, 1029 | alibi_num_heads=None, 1030 | alibi_learned=False, 1031 | rel_pos_bias=False, 1032 | rel_pos_num_buckets=32, 1033 | rel_pos_max_distance=128, 1034 | dynamic_pos_bias=False, 1035 | dynamic_pos_bias_log_distance=False, 1036 | dynamic_pos_bias_mlp_depth=2, 1037 | dynamic_pos_bias_norm=False, 1038 | rotary_pos_emb=False, 1039 | rotary_emb_dim=None, 1040 | rotary_xpos=False, 1041 | rotary_interpolation_factor=1.0, 1042 | rotary_xpos_scale_base=512, 1043 | rotary_base_rescale_factor=1.0, 1044 | custom_layers=None, 1045 | sandwich_coef=None, 1046 | par_ratio=None, 1047 | residual_attn=False, 1048 | cross_residual_attn=False, 1049 | macaron=False, 1050 | pre_norm=True, 1051 | pre_norm_has_final_norm=True, 1052 | gate_residual=False, 1053 | scale_residual=False, 1054 | scale_residual_constant=1.0, 1055 | deepnorm=False, 1056 | shift_tokens=0, 1057 | sandwich_norm=False, 1058 | resi_dual=False, 1059 | resi_dual_scale=1.0, 1060 | zero_init_branch_output=False, 1061 | layer_dropout=0.0, 1062 | cross_attn_tokens_dropout=0.0, 1063 | **kwargs, 1064 | ): 1065 | super().__init__() 1066 | rotary_pos_emb = rotary_pos_emb or rotary_xpos 1067 | 1068 | ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) 1069 | attn_kwargs, kwargs = groupby_prefix_and_trim("attn_", kwargs) 1070 | 1071 | dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) 1072 | 1073 | self.dim = dim 1074 | self.depth = depth 1075 | self.layers = nn.ModuleList([]) 1076 | 1077 | self.has_pos_emb = rel_pos_bias or rotary_pos_emb 1078 | 1079 | rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) 1080 | 1081 | assert not ( 1082 | rotary_xpos and not causal 1083 | ), "rotary xpos is not compatible with bidirectional attention" 1084 | self.rotary_pos_emb = ( 1085 | RotaryEmbedding( 1086 | rotary_emb_dim, 1087 | use_xpos=rotary_xpos, 1088 | scale_base=rotary_xpos_scale_base, 1089 | interpolation_factor=rotary_interpolation_factor, 1090 | base_rescale_factor=rotary_base_rescale_factor, 1091 | ) 1092 | if rotary_pos_emb 1093 | else None 1094 | ) 1095 | 1096 | assert not ( 1097 | alibi_pos_bias and rel_pos_bias 1098 | ), "you can only choose Alibi positional bias or T5 relative positional bias, not both" 1099 | assert ( 1100 | rel_pos_num_buckets <= rel_pos_max_distance 1101 | ), "number of relative position buckets must be less than the relative position max distance" 1102 | 1103 | # relative positional bias 1104 | 1105 | flash_attn = attn_kwargs.get("flash", False) 1106 | assert ( 1107 | int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias) 1108 | ) <= 1, "you can only choose up to one of t5, alibi, or dynamic positional bias" 1109 | 1110 | self.rel_pos = None 1111 | if rel_pos_bias: 1112 | assert ( 1113 | not flash_attn 1114 | ), "flash attention not compatible with t5 relative positional bias" 1115 | self.rel_pos = RelativePositionBias( 1116 | scale=dim_head**0.5, 1117 | causal=causal, 1118 | heads=heads, 1119 | num_buckets=rel_pos_num_buckets, 1120 | max_distance=rel_pos_max_distance, 1121 | ) 1122 | elif dynamic_pos_bias: 1123 | assert ( 1124 | not flash_attn 1125 | ), "flash attention not compatible with dynamic positional bias" 1126 | self.rel_pos = DynamicPositionBias( 1127 | dim=dim // 4, 1128 | heads=heads, 1129 | log_distance=dynamic_pos_bias_log_distance, 1130 | depth=dynamic_pos_bias_mlp_depth, 1131 | norm=dynamic_pos_bias_norm, 1132 | ) 1133 | elif alibi_pos_bias: 1134 | alibi_num_heads = default(alibi_num_heads, heads) 1135 | assert ( 1136 | alibi_num_heads <= heads 1137 | ), "number of ALiBi heads must be less than the total number of heads" 1138 | alibi_pos_klass = ( 1139 | LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias 1140 | ) 1141 | self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, total_heads=heads) 1142 | 1143 | # determine deepnorm and residual scale 1144 | 1145 | if deepnorm: 1146 | assert ( 1147 | scale_residual_constant == 1 1148 | ), "scale residual constant is being overridden by deep norm settings" 1149 | pre_norm = sandwich_norm = resi_dual = False 1150 | scale_residual = True 1151 | scale_residual_constant = (2 * depth) ** 0.25 1152 | 1153 | assert ( 1154 | int(sandwich_norm) + int(resi_dual) 1155 | ) <= 1, "either sandwich norm or resiDual is selected, but not both" 1156 | assert not ( 1157 | not pre_norm and sandwich_norm 1158 | ), "sandwich norm cannot be used when not using prenorm" 1159 | 1160 | if resi_dual: 1161 | pre_norm = False 1162 | 1163 | self.pre_norm = pre_norm 1164 | self.sandwich_norm = sandwich_norm 1165 | 1166 | self.resi_dual = resi_dual 1167 | assert ( 1168 | 0 < resi_dual_scale <= 1.0 1169 | ), "resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1." 1170 | self.resi_dual_scale = resi_dual_scale 1171 | 1172 | self.residual_attn = residual_attn 1173 | self.cross_residual_attn = cross_residual_attn 1174 | assert not ( 1175 | flash_attn and (residual_attn or cross_residual_attn) 1176 | ), "flash attention is not compatible with residual attention" 1177 | 1178 | self.cross_attend = cross_attend 1179 | 1180 | assert ( 1181 | int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm) 1182 | ) <= 1, "you can only use either scalenorm, rmsnorm, or simple rmsnorm" 1183 | 1184 | if use_scalenorm: 1185 | norm_class = ScaleNorm 1186 | elif use_rmsnorm: 1187 | norm_class = RMSNorm 1188 | elif use_simple_rmsnorm: 1189 | norm_class = SimpleRMSNorm 1190 | else: 1191 | norm_class = nn.LayerNorm 1192 | 1193 | norm_fn = partial(norm_class, dim) 1194 | 1195 | if cross_attend and not only_cross: 1196 | default_block = ("a", "c", "f") 1197 | elif cross_attend and only_cross: 1198 | default_block = ("c", "f") 1199 | else: 1200 | default_block = ("a", "f") 1201 | 1202 | if macaron: 1203 | default_block = ("f",) + default_block 1204 | 1205 | # zero init 1206 | 1207 | if zero_init_branch_output: 1208 | attn_kwargs = {**attn_kwargs, "zero_init_output": True} 1209 | ff_kwargs = {**ff_kwargs, "zero_init_output": True} 1210 | 1211 | # calculate layer block order 1212 | 1213 | if exists(custom_layers): 1214 | layer_types = custom_layers 1215 | elif exists(par_ratio): 1216 | par_depth = depth * len(default_block) 1217 | assert 1 < par_ratio <= par_depth, "par ratio out of range" 1218 | default_block = tuple(filter(not_equals("f"), default_block)) 1219 | par_attn = par_depth // par_ratio 1220 | depth_cut = ( 1221 | par_depth * 2 // 3 1222 | ) # 2 / 3 attention layer cutoff suggested by PAR paper 1223 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 1224 | assert ( 1225 | len(default_block) <= par_width 1226 | ), "default block is too large for par_ratio" 1227 | par_block = default_block + ("f",) * (par_width - len(default_block)) 1228 | par_head = par_block * par_attn 1229 | layer_types = par_head + ("f",) * (par_depth - len(par_head)) 1230 | elif exists(sandwich_coef): 1231 | assert ( 1232 | sandwich_coef > 0 and sandwich_coef <= depth 1233 | ), "sandwich coefficient should be less than the depth" 1234 | layer_types = ( 1235 | ("a",) * sandwich_coef 1236 | + default_block * (depth - sandwich_coef) 1237 | + ("f",) * sandwich_coef 1238 | ) 1239 | else: 1240 | layer_types = default_block * depth 1241 | 1242 | self.layer_types = layer_types 1243 | self.num_attn_layers = len(list(filter(equals("a"), layer_types))) 1244 | 1245 | # stochastic depth 1246 | 1247 | self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types)) 1248 | 1249 | # structured dropout for cross attending 1250 | 1251 | self.cross_attn_tokens_dropout = cross_attn_tokens_dropout 1252 | 1253 | # calculate token shifting 1254 | 1255 | shift_tokens = cast_tuple(shift_tokens, len(layer_types)) 1256 | 1257 | # whether it has post norm 1258 | 1259 | self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity() 1260 | 1261 | # iterate and construct layers 1262 | 1263 | for ind, (layer_type, layer_shift_tokens) in enumerate( 1264 | zip(self.layer_types, shift_tokens) 1265 | ): 1266 | ind == (len(self.layer_types) - 1) 1267 | 1268 | if layer_type == "a": 1269 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 1270 | elif layer_type == "c": 1271 | layer = Attention(dim, heads=heads, **attn_kwargs) 1272 | elif layer_type == "f": 1273 | layer = FeedForward(dim, **ff_kwargs) 1274 | layer = layer if not macaron else Scale(0.5, layer) 1275 | else: 1276 | raise Exception(f"invalid layer type {layer_type}") 1277 | 1278 | if layer_shift_tokens > 0: 1279 | shift_range_upper = layer_shift_tokens + 1 1280 | shift_range_lower = -layer_shift_tokens if not causal else 0 1281 | layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) 1282 | 1283 | residual_fn = GRUGating if gate_residual else Residual 1284 | residual = residual_fn( 1285 | dim, 1286 | scale_residual=scale_residual, 1287 | scale_residual_constant=scale_residual_constant, 1288 | ) 1289 | 1290 | pre_branch_norm = norm_fn() if pre_norm else None 1291 | post_branch_norm = norm_fn() if sandwich_norm else None 1292 | post_main_norm = norm_fn() if not pre_norm else None 1293 | 1294 | norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm]) 1295 | 1296 | self.layers.append(nn.ModuleList([norms, layer, residual])) 1297 | 1298 | if deepnorm: 1299 | init_gain = (8 * depth) ** -0.25 1300 | deepnorm_init(self, init_gain) 1301 | 1302 | def forward( 1303 | self, 1304 | x, 1305 | context=None, 1306 | mask=None, 1307 | context_mask=None, 1308 | attn_mask=None, 1309 | self_attn_context_mask=None, 1310 | mems=None, 1311 | return_hiddens=False, 1312 | ): 1313 | assert not ( 1314 | self.cross_attend ^ exists(context) 1315 | ), "context must be passed in if cross_attend is set to True" 1316 | 1317 | hiddens = [] 1318 | intermediates = [] 1319 | prev_attn = None 1320 | prev_cross_attn = None 1321 | 1322 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 1323 | 1324 | rotary_pos_emb = None 1325 | if exists(self.rotary_pos_emb): 1326 | max_rotary_emb_length = max( 1327 | list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)) 1328 | ) 1329 | rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) 1330 | 1331 | outer_residual = x * self.resi_dual_scale 1332 | 1333 | for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate( 1334 | zip(self.layer_types, self.layers, self.layer_dropouts) 1335 | ): 1336 | ind == (len(self.layers) - 1) 1337 | 1338 | if self.training and layer_dropout > 0.0 and random() < layer_dropout: 1339 | continue 1340 | 1341 | if layer_type == "a": 1342 | if return_hiddens: 1343 | hiddens.append(x) 1344 | layer_mem = mems.pop(0) if mems else None 1345 | 1346 | if layer_type == "c": 1347 | if self.training and self.cross_attn_tokens_dropout > 0.0: 1348 | context, context_mask = dropout_seq( 1349 | context, context_mask, self.cross_attn_tokens_dropout 1350 | ) 1351 | 1352 | inner_residual = x 1353 | 1354 | pre_norm, post_branch_norm, post_main_norm = norm 1355 | 1356 | if exists(pre_norm): 1357 | x = pre_norm(x) 1358 | 1359 | if layer_type == "a": 1360 | out, inter = block( 1361 | x, 1362 | mask=mask, 1363 | context_mask=self_attn_context_mask, 1364 | attn_mask=attn_mask, 1365 | rel_pos=self.rel_pos, 1366 | rotary_pos_emb=rotary_pos_emb, 1367 | prev_attn=prev_attn, 1368 | mem=layer_mem, 1369 | ) 1370 | elif layer_type == "c": 1371 | out, inter = block( 1372 | x, 1373 | context=context, 1374 | mask=mask, 1375 | context_mask=context_mask, 1376 | prev_attn=prev_cross_attn, 1377 | ) 1378 | elif layer_type == "f": 1379 | out = block(x) 1380 | 1381 | if self.resi_dual: 1382 | outer_residual = outer_residual + out * self.resi_dual_scale 1383 | 1384 | if exists(post_branch_norm): 1385 | out = post_branch_norm(out) 1386 | 1387 | x = residual_fn(out, inner_residual) 1388 | 1389 | if layer_type in ("a", "c") and return_hiddens: 1390 | intermediates.append(inter) 1391 | 1392 | if layer_type == "a" and self.residual_attn: 1393 | prev_attn = inter.pre_softmax_attn 1394 | elif layer_type == "c" and self.cross_residual_attn: 1395 | prev_cross_attn = inter.pre_softmax_attn 1396 | 1397 | if exists(post_main_norm): 1398 | x = post_main_norm(x) 1399 | 1400 | if self.resi_dual: 1401 | x = x + self.final_norm(outer_residual) 1402 | else: 1403 | x = self.final_norm(x) 1404 | 1405 | if return_hiddens: 1406 | intermediates = LayerIntermediates( 1407 | hiddens=hiddens, attn_intermediates=intermediates 1408 | ) 1409 | 1410 | return x, intermediates 1411 | 1412 | return x 1413 | 1414 | 1415 | class Encoder(AttentionLayers): 1416 | def __init__(self, **kwargs): 1417 | assert "causal" not in kwargs, "cannot set causality on encoder" 1418 | super().__init__(causal=False, **kwargs) 1419 | 1420 | 1421 | class Decoder(AttentionLayers): 1422 | def __init__(self, **kwargs): 1423 | assert "causal" not in kwargs, "cannot set causality on decoder" 1424 | super().__init__(causal=True, **kwargs) 1425 | 1426 | 1427 | class CrossAttender(AttentionLayers): 1428 | def __init__(self, **kwargs): 1429 | super().__init__(cross_attend=True, only_cross=True, **kwargs) 1430 | 1431 | 1432 | class ViTransformerWrapper(nn.Module): 1433 | def __init__( 1434 | self, 1435 | *, 1436 | image_size, 1437 | patch_size, 1438 | attn_layers, 1439 | channels=3, 1440 | num_classes=None, 1441 | dropout=0.0, 1442 | post_emb_norm=False, 1443 | emb_dropout=0.0, 1444 | ): 1445 | super().__init__() 1446 | assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder" 1447 | assert ( 1448 | image_size % patch_size == 0 1449 | ), "image dimensions must be divisible by the patch size" 1450 | dim = attn_layers.dim 1451 | num_patches = (image_size // patch_size) ** 2 1452 | patch_dim = channels * patch_size**2 1453 | 1454 | self.patch_size = patch_size 1455 | 1456 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) 1457 | 1458 | self.patch_to_embedding = nn.Sequential( 1459 | nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim) 1460 | ) 1461 | 1462 | self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity() 1463 | self.dropout = nn.Dropout(emb_dropout) 1464 | 1465 | self.attn_layers = attn_layers 1466 | 1467 | self.mlp_head = ( 1468 | nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity() 1469 | ) 1470 | 1471 | def forward(self, img, return_embeddings=False): 1472 | p = self.patch_size 1473 | 1474 | x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) 1475 | x = self.patch_to_embedding(x) 1476 | n = x.shape[1] 1477 | 1478 | x = x + self.pos_embedding[:, :n] 1479 | 1480 | x = self.post_emb_norm(x) 1481 | x = self.dropout(x) 1482 | 1483 | x = self.attn_layers(x) 1484 | 1485 | if not exists(self.mlp_head) or return_embeddings: 1486 | return x 1487 | 1488 | x = x.mean(dim=-2) 1489 | return self.mlp_head(x) 1490 | 1491 | 1492 | class TransformerWrapper(nn.Module): 1493 | def __init__( 1494 | self, 1495 | *, 1496 | num_tokens, 1497 | max_seq_len, 1498 | attn_layers, 1499 | emb_dim=None, 1500 | max_mem_len=0, 1501 | shift_mem_down=0, 1502 | emb_dropout=0.0, 1503 | post_emb_norm=False, 1504 | num_memory_tokens=None, 1505 | tie_embedding=False, 1506 | logits_dim=None, 1507 | use_abs_pos_emb=True, 1508 | scaled_sinu_pos_emb=False, 1509 | l2norm_embed=False, 1510 | emb_frac_gradient=1.0, # GLM-130B and Cogview successfully used this, set at 0.1 1511 | ): 1512 | super().__init__() 1513 | assert isinstance( 1514 | attn_layers, AttentionLayers 1515 | ), "attention layers must be one of Encoder or Decoder" 1516 | 1517 | dim = attn_layers.dim 1518 | emb_dim = default(emb_dim, dim) 1519 | self.emb_dim = emb_dim 1520 | self.num_tokens = num_tokens 1521 | 1522 | self.max_seq_len = max_seq_len 1523 | self.max_mem_len = max_mem_len 1524 | self.shift_mem_down = shift_mem_down 1525 | 1526 | self.l2norm_embed = l2norm_embed 1527 | self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed=l2norm_embed) 1528 | 1529 | if not (use_abs_pos_emb and not attn_layers.has_pos_emb): 1530 | self.pos_emb = always(0) 1531 | elif scaled_sinu_pos_emb: 1532 | self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) 1533 | else: 1534 | self.pos_emb = AbsolutePositionalEmbedding( 1535 | emb_dim, max_seq_len, l2norm_embed=l2norm_embed 1536 | ) 1537 | 1538 | self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 1539 | 1540 | self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity() 1541 | self.emb_dropout = nn.Dropout(emb_dropout) 1542 | 1543 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 1544 | self.attn_layers = attn_layers 1545 | 1546 | self.init_() 1547 | 1548 | logits_dim = default(logits_dim, num_tokens) 1549 | self.to_logits = ( 1550 | nn.Linear(dim, logits_dim) 1551 | if not tie_embedding 1552 | else lambda t: t @ self.token_emb.emb.weight.t() 1553 | ) 1554 | 1555 | # memory tokens (like [cls]) from Memory Transformers paper 1556 | num_memory_tokens = default(num_memory_tokens, 0) 1557 | self.num_memory_tokens = num_memory_tokens 1558 | if num_memory_tokens > 0: 1559 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 1560 | 1561 | def init_(self): 1562 | if self.l2norm_embed: 1563 | nn.init.normal_(self.token_emb.emb.weight, std=1e-5) 1564 | if not isinstance(self.pos_emb, always): 1565 | nn.init.normal_(self.pos_emb.emb.weight, std=1e-5) 1566 | return 1567 | 1568 | nn.init.kaiming_normal_(self.token_emb.emb.weight) 1569 | 1570 | def forward( 1571 | self, 1572 | x, 1573 | return_embeddings=False, 1574 | return_logits_and_embeddings=False, 1575 | return_intermediates=False, 1576 | mask=None, 1577 | return_mems=False, 1578 | return_attn=False, 1579 | mems=None, 1580 | pos=None, 1581 | prepend_embeds=None, 1582 | sum_embeds=None, 1583 | **kwargs, 1584 | ): 1585 | b, n, device, num_mem, emb_frac_gradient = ( 1586 | *x.shape, 1587 | x.device, 1588 | self.num_memory_tokens, 1589 | self.emb_frac_gradient, 1590 | ) 1591 | return_hiddens = return_mems | return_attn 1592 | 1593 | # absolute positional embedding 1594 | 1595 | external_pos_emb = exists(pos) and pos.dtype != torch.long 1596 | pos_emb = self.pos_emb(x, pos=pos) if not external_pos_emb else pos 1597 | x = self.token_emb(x) + pos_emb 1598 | 1599 | # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training 1600 | 1601 | if exists(sum_embeds): 1602 | x = x + sum_embeds 1603 | 1604 | # post embedding norm, purportedly leads to greater stabilization 1605 | 1606 | x = self.post_emb_norm(x) 1607 | 1608 | # whether to append embeds, as in PaLI, for image embeddings 1609 | 1610 | if exists(prepend_embeds): 1611 | prepend_seq, prepend_dim = prepend_embeds.shape[1:] 1612 | assert ( 1613 | prepend_dim == x.shape[-1] 1614 | ), "prepended embeddings need to have same dimensions as text model dimensions" 1615 | 1616 | x = torch.cat((prepend_embeds, x), dim=-2) 1617 | 1618 | # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model 1619 | 1620 | if emb_frac_gradient < 1: 1621 | assert emb_frac_gradient > 0 1622 | x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient) 1623 | 1624 | # embedding dropout 1625 | 1626 | x = self.emb_dropout(x) 1627 | 1628 | x = self.project_emb(x) 1629 | 1630 | if num_mem > 0: 1631 | mem = repeat(self.memory_tokens, "n d -> b n d", b=b) 1632 | x = torch.cat((mem, x), dim=1) 1633 | 1634 | # auto-handle masking after appending memory tokens 1635 | if exists(mask): 1636 | mask = pad_at_dim(mask, (num_mem, 0), dim=-1, value=True) 1637 | 1638 | if self.shift_mem_down and exists(mems): 1639 | mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :] 1640 | mems = [*mems_r, *mems_l] 1641 | 1642 | if return_hiddens: 1643 | x, intermediates = self.attn_layers( 1644 | x, mask=mask, mems=mems, return_hiddens=True, **kwargs 1645 | ) 1646 | else: 1647 | x = self.attn_layers(x, mask=mask, mems=mems, **kwargs) 1648 | 1649 | mem, x = x[:, :num_mem], x[:, num_mem:] 1650 | 1651 | if return_logits_and_embeddings: 1652 | out = (self.to_logits(x), x) 1653 | elif return_embeddings: 1654 | out = x 1655 | else: 1656 | out = self.to_logits(x) 1657 | 1658 | if return_intermediates: 1659 | return out, intermediates 1660 | 1661 | if return_mems: 1662 | hiddens = intermediates.hiddens 1663 | new_mems = ( 1664 | list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) 1665 | if exists(mems) 1666 | else hiddens 1667 | ) 1668 | new_mems = list( 1669 | map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems) 1670 | ) 1671 | return out, new_mems 1672 | 1673 | if return_attn: 1674 | attn_maps = list( 1675 | map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates) 1676 | ) 1677 | return out, attn_maps 1678 | 1679 | return out 1680 | 1681 | 1682 | class UL2(nn.Module): 1683 | def __init__( 1684 | self, 1685 | *, 1686 | dim, 1687 | tie_token_emb=False, 1688 | ignore_index=-100, 1689 | pad_value=0, 1690 | deepnorm=False, 1691 | cross_attn_tokens_dropout=0.0, 1692 | **kwargs, 1693 | ): 1694 | super().__init__() 1695 | enc_kwargs, kwargs = groupby_prefix_and_trim("enc_", kwargs) 1696 | dec_kwargs, kwargs = groupby_prefix_and_trim("dec_", kwargs) 1697 | 1698 | assert ( 1699 | "dim" not in enc_kwargs and "dim" not in dec_kwargs 1700 | ), "dimension of either encoder or decoder must be set with `dim` keyword" 1701 | enc_transformer_kwargs = pick_and_pop(["num_tokens", "max_seq_len"], enc_kwargs) 1702 | enc_transformer_kwargs["emb_dropout"] = enc_kwargs.pop("emb_dropout", 0) 1703 | enc_transformer_kwargs["num_memory_tokens"] = enc_kwargs.pop( 1704 | "num_memory_tokens", None 1705 | ) 1706 | enc_transformer_kwargs["scaled_sinu_pos_emb"] = enc_kwargs.pop( 1707 | "scaled_sinu_pos_emb", False 1708 | ) 1709 | enc_transformer_kwargs["use_abs_pos_emb"] = enc_kwargs.pop( 1710 | "use_abs_pos_emb", True 1711 | ) 1712 | 1713 | dec_transformer_kwargs = pick_and_pop(["num_tokens", "max_seq_len"], dec_kwargs) 1714 | dec_transformer_kwargs["emb_dropout"] = dec_kwargs.pop("emb_dropout", 0) 1715 | dec_transformer_kwargs["scaled_sinu_pos_emb"] = dec_kwargs.pop( 1716 | "scaled_sinu_pos_emb", False 1717 | ) 1718 | dec_transformer_kwargs["use_abs_pos_emb"] = dec_kwargs.pop( 1719 | "use_abs_pos_emb", True 1720 | ) 1721 | 1722 | self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories 1723 | 1724 | if deepnorm: 1725 | enc_kwargs["scale_residual"] = True 1726 | dec_kwargs["scale_residual"] = True 1727 | 1728 | enc_depth = enc_kwargs["depth"] 1729 | dec_depth = dec_kwargs["depth"] 1730 | 1731 | enc_kwargs["scale_residual_constant"] = ( 1732 | 0.81 * ((enc_depth**4) * dec_depth) ** 0.0625 1733 | ) 1734 | dec_kwargs["scale_residual_constant"] = (3 * dec_depth) ** 0.25 1735 | 1736 | self.encoder = TransformerWrapper( 1737 | **enc_transformer_kwargs, attn_layers=Encoder(dim=dim, **enc_kwargs) 1738 | ) 1739 | 1740 | self.decoder = TransformerWrapper( 1741 | **dec_transformer_kwargs, 1742 | attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs), 1743 | ) 1744 | 1745 | if deepnorm: 1746 | deepnorm_init( 1747 | self.encoder, 0.87 * ((enc_depth**4) * dec_depth) ** -0.0625 1748 | ) 1749 | deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25) 1750 | 1751 | if tie_token_emb: 1752 | self.decoder.token_emb = self.encoder.token_emb 1753 | 1754 | self.decoder = AutoregressiveWrapper( 1755 | self.decoder, ignore_index=ignore_index, pad_value=pad_value 1756 | ) 1757 | 1758 | @torch.no_grad() 1759 | def generate( 1760 | self, seq_in, seq_out_start, seq_len, mask=None, attn_mask=None, **kwargs 1761 | ): 1762 | encodings = self.encoder( 1763 | seq_in, mask=mask, attn_mask=attn_mask, return_embeddings=True 1764 | ) 1765 | return self.decoder.generate( 1766 | seq_out_start, seq_len, context=encodings, context_mask=mask, **kwargs 1767 | ) 1768 | 1769 | def forward(self, src, tgt, mask=None, attn_mask=None, src_prepend_embeds=None): 1770 | if exists(src_prepend_embeds) and exists(mask): 1771 | mask = pad_at_dim( 1772 | mask, (src_prepend_embeds.shape[-2], 0), dim=-1, value=True 1773 | ) 1774 | 1775 | enc = self.encoder( 1776 | src, 1777 | mask=mask, 1778 | attn_mask=attn_mask, 1779 | prepend_embeds=src_prepend_embeds, 1780 | return_embeddings=True, 1781 | ) 1782 | 1783 | if self.training and self.cross_attn_tokens_dropout > 0: 1784 | enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout) 1785 | 1786 | out = self.decoder(tgt, context=enc, context_mask=mask) 1787 | return out 1788 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pali-torch" 3 | version = "0.1.0" 4 | description = "Pali - PyTorch" 5 | authors = ["Kye Gomez "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/kyegomez/Pali" 9 | repository = "https://github.com/kyegomez/Pali" 10 | keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"] 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Intended Audience :: Developers", 14 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 15 | "License :: OSI Approved :: MIT License", 16 | "Programming Language :: Python :: 3.6", 17 | ] 18 | packages = [ 19 | { include = "pali" }, 20 | { include = "pali/**/*.py" }, 21 | ] 22 | 23 | 24 | [tool.poetry.dependencies] 25 | python = "^3.6" 26 | transformers = "*" 27 | torch = "*" 28 | einops = "*" 29 | dataclasses = "*" 30 | zetascale = "*" 31 | 32 | [tool.poetry.dev-dependencies] 33 | 34 | [build-system] 35 | requires = ["poetry-core"] 36 | build-backend = "poetry.core.masonry.api" 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | einops 3 | transformers 4 | dataclasses 5 | zetascale --------------------------------------------------------------------------------