├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── agorabanner.png ├── cm3 ├── __init__.py ├── model.py └── utils │ ├── __init__.py │ └── stable_adamw.py ├── example.py ├── generate_example.py ├── img_embeds.py ├── pyproject.toml └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyegomez/CM3Leon/00faf1fe7685791dbdcc061c856a3f753180045b/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .venv/ 3 | 4 | .env 5 | 6 | image/ 7 | audio/ 8 | video/ 9 | dataframe/ 10 | 11 | static/generated 12 | swarms/__pycache__ 13 | venv 14 | .DS_Store 15 | 16 | .DS_STORE 17 | swarms/agents/.DS_Store 18 | 19 | _build 20 | 21 | 22 | .DS_STORE 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | *.py,cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | cover/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | .pybuilder/ 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | .DS_Store 107 | # pyenv 108 | # For a library or package, you might want to ignore these files since the code is 109 | # intended to run in multiple environments; otherwise, check them in: 110 | # .python-version 111 | 112 | # pipenv 113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 116 | # install all needed dependencies. 117 | #Pipfile.lock 118 | 119 | # poetry 120 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 121 | # This is especially recommended for binary packages to ensure reproducibility, and is more 122 | # commonly ignored for libraries. 123 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 124 | #poetry.lock 125 | 126 | # pdm 127 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 128 | #pdm.lock 129 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 130 | # in version control. 131 | # https://pdm.fming.dev/#use-with-ide 132 | .pdm.toml 133 | 134 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 135 | __pypackages__/ 136 | 137 | # Celery stuff 138 | celerybeat-schedule 139 | celerybeat.pid 140 | 141 | # SageMath parsed files 142 | *.sage.py 143 | 144 | # Environments 145 | .env 146 | .venv 147 | env/ 148 | venv/ 149 | ENV/ 150 | env.bak/ 151 | venv.bak/ 152 | 153 | # Spyder project settings 154 | .spyderproject 155 | .spyproject 156 | 157 | # Rope project settings 158 | .ropeproject 159 | 160 | # mkdocs documentation 161 | /site 162 | 163 | # mypy 164 | .mypy_cache/ 165 | .dmypy.json 166 | dmypy.json 167 | 168 | # Pyre type checker 169 | .pyre/ 170 | 171 | # pytype static type analyzer 172 | .pytype/ 173 | 174 | # Cython debug symbols 175 | cython_debug/ 176 | 177 | dist 178 | 179 | # PyCharm 180 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 181 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 182 | # and can be added to the global gitignore or merged into this file. For a more nuclear 183 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 184 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eternal Reclaimer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) 2 | 3 | 4 | # CM3Leon: Autoregressive Multi-Modal Model for Text and Image Generation (wip) 5 | 6 | [![GitHub issues](https://img.shields.io/github/issues/kyegomez/CM3Leon)](https://github.com/kyegomez/CM3Leon/issues) 7 | [![GitHub forks](https://img.shields.io/github/forks/kyegomez/CM3Leon)](https://github.com/kyegomez/CM3Leon/network) 8 | [![GitHub stars](https://img.shields.io/github/stars/kyegomez/CM3Leon)](https://github.com/kyegomez/CM3Leon/stargazers) [![GitHub license](https://img.shields.io/github/license/kyegomez/CM3Leon)](https://github.com/kyegomez/CM3Leon/blob/master/LICENSE) 9 | [![Share on Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Share%20%40kyegomez/CM3Leon)](https://twitter.com/intent/tweet?text=Excited%20to%20introduce%20CM3Leon,%20the%20all-new%20Multi-Modal%20model%20with%20the%20potential%20to%20revolutionize%20automation.%20Join%20us%20on%20this%20journey%20towards%20a%20smarter%20future.%20%23CM3Leon%20%23Multi-Modal&url=https%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon) 10 | [![Share on Facebook](https://img.shields.io/badge/Share-%20facebook-blue)](https://www.facebook.com/sharer/sharer.php?u=https%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon) 11 | [![Share on LinkedIn](https://img.shields.io/badge/Share-%20linkedin-blue)](https://www.linkedin.com/shareArticle?mini=true&url=https%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon&title=Introducing%20CM3Leon%2C%20the%20All-New%20Multi-Modal%20Model&summary=CM3Leon%20is%20the%20next-generation%20Multi-Modal%20model%20that%20promises%20to%20transform%20industries%20with%20its%20intelligence%20and%20efficiency.%20Join%20us%20to%20be%20a%20part%20of%20this%20revolutionary%20journey%20%23CM3Leon%20%23Multi-Modal&source=) 12 | ![Discord](https://img.shields.io/discord/999382051935506503) 13 | [![Share on Reddit](https://img.shields.io/badge/-Share%20on%20Reddit-orange)](https://www.reddit.com/submit?url=https%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon&title=Exciting%20Times%20Ahead%20with%20CM3Leon%2C%20the%20All-New%20Multi-Modal%20Model%20%23CM3Leon%20%23Multi-Modal) [![Share on Hacker News](https://img.shields.io/badge/-Share%20on%20Hacker%20News-orange)](https://news.ycombinator.com/submitlink?u=https%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon&t=Exciting%20Times%20Ahead%20with%20CM3Leon%2C%20the%20All-New%20Multi-Modal%20Model%20%23CM3Leon%20%23Multi-Modal) 14 | [![Share on Pinterest](https://img.shields.io/badge/-Share%20on%20Pinterest-red)](https://pinterest.com/pin/create/button/?url=https%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon&media=https%3A%2F%2Fexample.com%2Fimage.jpg&description=CM3Leon%2C%20the%20Revolutionary%20Multi-Modal%20Model%20that%20will%20Change%20the%20Way%20We%20Work%20%23CM3Leon%20%23Multi-Modal) 15 | [![Share on WhatsApp](https://img.shields.io/badge/-Share%20on%20WhatsApp-green)](https://api.whatsapp.com/send?text=I%20just%20discovered%20CM3Leon,%20the%20all-new%20Multi-Modal%20model%20that%20promises%20to%20revolutionize%20automation.%20Join%20me%20on%20this%20exciting%20journey%20towards%20a%20smarter%20future.%20%23CM3Leon%20%23Multi-Modal%0A%0Ahttps%3A%2F%2Fgithub.com%2Fkyegomez%2FCM3Leon) 16 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kyegomez/CM3Leon/blob/main/google_colab.ipynb) 17 | 18 | 19 | CM3Leon is a transformer-based autoregressive model designed for multi-modal tasks, specifically text and image generation. The model is trained in two stages, using a large diverse multimodal dataset and augmented retrieval pretraining. It also implements contrastive decoding to enhance the quality of the generated samples. 20 | 21 | [CM3LEON, PAPER LINK](https://ai.meta.com/research/publications/scaling-autoregressive-multi-modal-models-pretraining-and-instruction-tuning/) 22 | 23 | * Please Help with this open source implementation in the Agora discord, ![Discord](https://img.shields.io/discord/999382051935506503) 24 | * This implementation is still not finished. 25 | 26 | ## Install 27 | 28 | ```pip3 install cm3``` 29 | 30 | --- 31 | 32 | ## Usage & Example 33 | 34 | To start with CM3Leon in a PyTorch environment: 35 | 36 | ```python 37 | import torch 38 | from cm3.model import CM3 39 | 40 | # usage 41 | img = torch.randn(1, 3, 256, 256) 42 | caption = torch.randint(0, 20000, (1, 1024)) 43 | 44 | model = CM3() 45 | 46 | output = model(img, caption) 47 | print(output.shape) # (1, 1024, 20000) 48 | 49 | 50 | ``` 51 | 52 | This repository hosts the open-source implementation of CM3Leon, a state-of-the-art autoregressive multi-modal model for text and image generation. The model is introduced in the paper "Scaling Autoregressive Multi-Modal Models: Pretraining and Instruction Tuning". 53 | 54 | --- 55 | 56 | ## Overview 57 | 58 | Key Features of CM3Leon: 59 | 60 | - Retrieval augmented pretraining on a large diverse multimodal dataset. 61 | - Two-stage training: pretraining and supervised fine-tuning. 62 | - Contrastive decoding for enhanced sample quality. 63 | 64 | CM3Leon sets a new benchmark in text-to-image generation, outperforming comparable models while requiring 5x less computational resources. 65 | 66 | ## Getting Started 67 | 68 | The following sections provide a detailed analysis of the model architecture, the necessary resources, and the steps needed to replicate the CM3Leon model. 69 | 70 | ### Requirements 71 | 72 | Replicating CM3Leon involves several critical components and requires proficiency in the following areas: 73 | 74 | - Large-scale distributed training of transformer models using a significant number of GPUs/TPUs. 75 | - Efficient data loading and preprocessing to handle extensive multimodal datasets. 76 | - Memory optimization techniques to accommodate large models within the GPU memory. 77 | - Custom tokenizer implementation for both text and image modalities. 78 | - Setting up a retrieval infrastructure for dense retrieval during pretraining. 79 | - Developing a fine-tuning framework to handle mixed text-image tasks. 80 | - Inference optimizations such as compiler-accelerated decoders, lower precision computing, and batching. 81 | 82 | ### System Architecture 83 | 84 | The CM3Leon implementation comprises: 85 | 86 | - A distributed training framework, preferably TensorFlow or PyTorch. 87 | - High-performance compute infrastructure (HPC cluster with GPUs/TPUs). 88 | - A retrieval index and dense retriever module for augmentation. 89 | - Data pipelines for efficient preprocessing and loading. 90 | - Custom code for tokenizers and the CM3 model architecture. 91 | - Fine-tuning framework and relevant task datasets. 92 | - Serving infrastructure for low-latency inference. 93 | 94 | Implementing these components involves challenges such as efficient utilization of large compute clusters, minimizing data loading and preprocessing bottlenecks, optimizing memory usage during training and inference, and ensuring low latency serving. 95 | 96 | ### Model Architecture 97 | 98 | The architecture of CM3Leon includes: 99 | 100 | - Text and Image Tokenizers: Custom text tokenizer trained on CommonCrawl data and Image tokenizer that encodes 256x256 images into 1024 tokens. 101 | - Special Tokens: Usage of `` token to indicate modality transitions. 102 | - Retrieval Augmentation: Using a bi-encoder based on CLIP to retrieve relevant text and images from the memory bank. 103 | - Autoregressive Decoder-only Transformer: Standard transformer architecture similar to GPT models. 104 | - Two-Stage Training: Pretraining with retrieval augmentation and supervised finetuning on text-image tasks via instruction tuning. 105 | - Contrastive Decoding: Modified contrastive decoding for better sample quality. 106 | 107 | The model size ranges from 350M to 7B parameters. 108 | 109 | ### Data 110 | 111 | 112 | Here is a markdown table with the datasets used in the paper along with additional metadata and source links: 113 | 114 | | Dataset | Domain | Size | Source | 115 | |-|-|-|-| 116 | | Shutterstock | Images and captions | 3 billion text tokens, licensed image data | Proprietary dataset, described in paper | 117 | | MS-COCO | Image captioning | 591K image-caption pairs | [Microsoft COCO Captions](https://cocodataset.org/#captions-2015) | 118 | | Flickr30k | Image captioning | 144K image-caption pairs | [Flickr30k Entities](https://www.robots.ox.ac.uk/~vgg/data/flickr30k/) | 119 | | Image Paragraph | Dense image captioning | 14K images with paragraph captions | [Image Paragraph dataset](https://cs.stanford.edu/people/ranjaykrishna/imcap/) | 120 | | Localized Narratives | Image paragraph captioning | 164K images with localized narratives | [Localized Narratives](https://github.com/jponttuset/localizing-narratives) | 121 | | VQA2 | Visual question answering | 1.3M images with question-answer pairs | [VQA2 dataset](https://visualqa.org/download.html) | 122 | | VizWiz | Visual question answering for blind users | 92K images with question-answer pairs | [VizWiz dataset](https://vizwiz.org/) | 123 | | OKVQA | Knowledge-based VQA | 26K images with question-answer pairs | [OK-VQA dataset](https://okvqa.allenai.org/) | 124 | | ScienceQA | Scientific visual QA | 6K images with multi-choice QA pairs | [ScienceQA](https://allenai.org/data/science-qa) | 125 | 126 | 127 | The model was trained and evaluated on several datasets including MS-COCO [...] (Chen et al., 2015), Flickr30k [...] (Young et al., 2014), etc. 128 | 129 | For successful implementation, CM3Leon requires: 130 | 131 | - A large (100M+ examples) diverse multimodal dataset like Shutterstock for pretraining. 132 | - A mixture of text and image tasks with accompanying datasets for finetuning. 133 | - Efficient and scalable data loading that does not bottleneck model training. 134 | - Preprocessing steps like resizing images to 256x256 pixels and text tokenization. 135 | 136 | ### Training 137 | 138 | CM3Leon's training process involves: 139 | 140 | - Pretraining with retrieval augmentation and CM3 objective. 141 | - Supervised finetuning on text-image tasks. 142 | - Efficient distributed training infrastructure for large-scale model training. 143 | - Hyperparameter tuning for learning rates, batch sizes, optimizers, etc. 144 | 145 | ### Inference 146 | 147 | For efficient inference, consider: 148 | 149 | - Using compiler-accelerated decoders like FasterTransformer. 150 | - Other optimizations like lower precision (FP16/INT8) and batching. 151 | - Efficient implementation of contrastive decoding. 152 | 153 | 154 | ## HyperParameters 155 | ```Model size # L dmodel Seq Length Batch LR Warmup Steps # GPUs # Tokens 156 | 350M 24 1024 4096 8M 6e-04 1500 256 1.4T 157 | 760M 24 1536 4096 8M 5e-04 1500 256 1.9T 158 | 7B 32 4096 4096 8M 1.2e-04 1500 512 2.4T 159 | ``` 160 | 161 | ## SuperVised FineTuning parameters 162 | ```  163 | Model # GPUS Seq Length Batch Size LR Warm-up Steps # Tokens 164 | CM3Leon-760m 64 4096 2M 5e-05 150 30B 165 | CM3Leon-7b 128 4096 2M 5e-05 150 30B 166 | ``` 167 | 168 | # Innovations in the paper: 169 | 170 | * Conditional text + image generation with objective function + contrastive top k decoding 171 | 172 | * Multi-Modality models need to be dynamic they can't just generate the types of data they were trained on they need to be able to adapt to user needs therefore multi-modality models should be conditional, if prompted the model will generate text and or images, this is the future. 173 | 174 | ## Contributing 175 | 176 | This repository welcomes contributions. Feel free to submit pull requests, create issues, or suggest any enhancements. 177 | 178 | ## Support 179 | 180 | If you encounter any issues or need further clarification, please create an issue in the GitHub issue tracker. 181 | 182 | ## License 183 | 184 | CM3Leon is open-sourced under the [MIT license](LICENSE). 185 | 186 | # Roadmap 187 | 188 | * Implement Objective function where multi-modal inputs are transformed into an infilling instance by masking specific spans and relocating them to the end. 189 | 190 | * Implement a next token prediction loss, -log p(x input) 191 | 192 | * Implement TopP sampling 193 | 194 | * Implement Free Guidance CFG => directing an unconditional sample towards a conditional sample. Replace text with mask token from cm3 objective for uncoditional sampling so that during inference 2 concurrent tokens tsreams are generated a conditional stream, which is contigent on the input text and an unconditional token stream which is conditioned on a mask token Where 195 | 196 | ```python 197 | Logits, cond = T(ty | ty), logit.uncond = T(ty | ) 198 | logits.cf = logits.uncond + a.c * (logits.cond - logits.uncond) 199 | 200 | T = transformer 201 | ty = output tokens 202 | tx = conditional input text 203 | = no input text + replacement with a mask token 204 | a.c = scaling factor 205 | ``` 206 | 207 | * Implement Contrastive Decoding TopK => 208 | ``` 209 | V(t.y < .i) = {t.yi is in V: P.exp(t.yi | t.y<.i) >= a * kmax(p.exp(w|t.y"] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.8" 10 | torch = { version = "*", source = "torch_nightly" } 11 | lion-pytorch = "*" 12 | numpy = "*" 13 | einops = "*" 14 | accelerate = "*" 15 | transformers = "*" 16 | SentencePiece = "*" 17 | bitsandbytes = "*" 18 | datasets = "*" 19 | triton = "*" 20 | deepspeed = "*" 21 | memory-profiler = "*" 22 | zetascale = "*" 23 | classifier-free-guidance-pytorch = "*" 24 | 25 | [[tool.poetry.source]] 26 | name = "torch_nightly" 27 | url = "https://download.pytorch.org/whl/nightly/cu118/torch_nightly.html" 28 | secondary = true 29 | 30 | [build-system] 31 | requires = ["poetry-core>=1.0.0"] 32 | build-backend = "poetry.core.masonry.api" 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch -f https://download.pytorch.org/whl/nightly/cu118/torch_nightly.html 2 | lion-pytorch 3 | numpy 4 | # colt5_attention 5 | einops 6 | # local_attention 7 | accelerate 8 | transformers 9 | # triton 10 | SentencePiece 11 | bitsandbytes 12 | datasets 13 | triton 14 | deepspeed 15 | bitsandbytes 16 | memory-profiler 17 | zetascale 18 | clipq 19 | classifier-free-guidance-pytorch 20 | 21 | git+https://github.com/Qiyuan-Ge/PaintMind.git --------------------------------------------------------------------------------