├── .gitignore ├── LICENSE ├── README.md ├── assets ├── arch.png ├── comparison.png └── demo.png ├── autoencode.py ├── emu3 ├── mllm │ ├── __init__.py │ ├── configuration_emu3.py │ ├── modeling_emu3.py │ ├── processing_emu3.py │ ├── tokenization_emu3.py │ └── utils_emu3.py ├── tokenizer │ ├── __init__.py │ ├── configuration_emu3visionvq.py │ ├── image_processing_emu3visionvq.py │ └── modeling_emu3visionvq.py └── train │ ├── __init__.py │ ├── datasets.py │ ├── prepare_data.py │ └── train.py ├── gradio_demo.py ├── image_generation.py ├── multimodal_understanding.py ├── replicate_demo ├── cog.yaml ├── predict_chat.py └── predict_gen.py ├── requirements.txt └── scripts ├── t2i_sft.sh ├── t2i_sft_offload.sh ├── zero3.json └── zero3_offload.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Emu3: Next-Token Prediction is All You Need 3 |

4 | 5 | [Emu3 Team, BAAI](https://www.baai.ac.cn/english.html) 6 | 7 | | [Project Page](https://emu.baai.ac.cn) | [Paper](https://arxiv.org/pdf/2409.18869) | [🤗HF Models](https://huggingface.co/collections/BAAI/emu3-66f4e64f70850ff358a2e60f) | [Modelscope](https://modelscope.cn/collections/Emu3-9eacc8668b1043) | [Demo](https://huggingface.co/spaces/BAAI/Emu3) | 8 | 9 | 10 |
11 | 12 |
13 | arch. 14 |
15 | 16 | We introduce **Emu3**, a new suite of state-of-the-art multimodal models trained solely with **next-token prediction**! By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences. 17 | 18 | ### Emu3 excels in both generation and perception 19 | **Emu3** outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship open models such as SDXL, LLaVA-1.6 and OpenSora-1.2, while eliminating the need for diffusion or compositional architectures. 20 | 21 |
22 | comparison. 23 |
24 | 25 | ### Highlights 26 | 27 | - **Emu3** is capable of generating high-quality images following the text input, by simply predicting the next vision token. The model naturally supports flexible resolutions and styles. 28 | - **Emu3** shows strong vision-language understanding capabilities to see the physical world and provides coherent text responses. Notably, this capability is achieved without depending on a CLIP and a pretrained LLM. 29 | - **Emu3** simply generates a video causally by predicting the next token in a video sequence, unlike the video diffusion model as in Sora. With a video in context, Emu3 can also naturally extend the video and predict what will happen next. 30 | 31 | ## News 32 | - 2024.10 We release the image pretrained model **[Emu3-Stage1](https://huggingface.co/BAAI/Emu3-Stage1)** and the sft scripts. The model supports image captioning and can generate images at a resolution of 512x512. You can use our training scripts for further instruction tuning for more image generation and perception tasks. 🔥🔥🔥 33 | - 2024.09 We relase **[Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)** and **[Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)** which are post training models separately for vision-language understanding and vision generation. 34 | - 2024.09 We introduce Emu3, a new suite of state-of-the-art multimodal models trained solely with next-token prediction. 35 | 36 | 37 | ### TODO 38 | 39 | - [X] Release model weights of tokenizer, Emu3-Chat and Emu3-Gen 40 | - [X] Release the inference code. 41 | - [ ] Release the evaluation code. 42 | - [X] Release training scripts for sft. 43 | - [ ] Release training scripts for pretrain and dpo. 44 | 45 | 46 | ### Setup 47 | 48 | Clone this repository and install required packages: 49 | 50 | ```shell 51 | git clone https://github.com/baaivision/Emu3 52 | cd Emu3 53 | 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | ### Model Weights 58 | 59 | | Model name | HF Weight | Modelscope | Wisemodel | 60 | | ------------------------ | -------------------------------------------------------------- | ------------------------------------------------------------------------- | ----------------------------------------------------------------------- | 61 | | **Emu3-Stage1** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Stage1) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-Stage1) | | 62 | | **Emu3-Chat** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Chat) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-Chat) | [Wisemodel link](https://wisemodel.cn/models/BAAI/Emu3-Chat) | 63 | | **Emu3-Gen** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Gen) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-Gen) | [Wisemodel link](https://wisemodel.cn/models/BAAI/Emu3-Gen) | 64 | | **Emu3-VisionTokenizer** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-VisionTokenizer) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-VisionTokenizer) | [Wisemodel link](https://wisemodel.cn/models/BAAI/Emu3-VisionTokenizer) | 65 | 66 | ### Quickstart 67 | 68 | #### Use 🤗Transformers to run Emu3-Gen/Stage1 for image generation 69 | ```python 70 | from PIL import Image 71 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM 72 | from transformers.generation.configuration_utils import GenerationConfig 73 | from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor 74 | import torch 75 | 76 | from emu3.mllm.processing_emu3 import Emu3Processor 77 | 78 | 79 | # model path 80 | EMU_HUB = "BAAI/Emu3-Gen" 81 | VQ_HUB = "BAAI/Emu3-VisionTokenizer" 82 | 83 | # prepare model and processor 84 | model = AutoModelForCausalLM.from_pretrained( 85 | EMU_HUB, 86 | device_map="cuda:0", 87 | torch_dtype=torch.bfloat16, 88 | attn_implementation="flash_attention_2", 89 | trust_remote_code=True, 90 | ) 91 | 92 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left") 93 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True) 94 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval() 95 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) 96 | 97 | # prepare input 98 | POSITIVE_PROMPT = " masterpiece, film grained, best quality." 99 | NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." 100 | 101 | classifier_free_guidance = 3.0 102 | prompt = "a portrait of young girl." 103 | prompt += POSITIVE_PROMPT 104 | 105 | kwargs = dict( 106 | mode='G', 107 | ratio="1:1", 108 | image_area=model.config.image_area, 109 | return_tensors="pt", 110 | padding="longest", 111 | ) 112 | pos_inputs = processor(text=prompt, **kwargs) 113 | neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs) 114 | 115 | # prepare hyper parameters 116 | GENERATION_CONFIG = GenerationConfig( 117 | use_cache=True, 118 | eos_token_id=model.config.eos_token_id, 119 | pad_token_id=model.config.pad_token_id, 120 | max_new_tokens=40960, 121 | do_sample=True, 122 | top_k=2048, 123 | ) 124 | 125 | h = pos_inputs.image_size[:, 0] 126 | w = pos_inputs.image_size[:, 1] 127 | constrained_fn = processor.build_prefix_constrained_fn(h, w) 128 | logits_processor = LogitsProcessorList([ 129 | UnbatchedClassifierFreeGuidanceLogitsProcessor( 130 | classifier_free_guidance, 131 | model, 132 | unconditional_ids=neg_inputs.input_ids.to("cuda:0"), 133 | ), 134 | PrefixConstrainedLogitsProcessor( 135 | constrained_fn , 136 | num_beams=1, 137 | ), 138 | ]) 139 | 140 | # generate 141 | outputs = model.generate( 142 | pos_inputs.input_ids.to("cuda:0"), 143 | GENERATION_CONFIG, 144 | logits_processor=logits_processor, 145 | attention_mask=pos_inputs.attention_mask.to("cuda:0"), 146 | ) 147 | 148 | mm_list = processor.decode(outputs[0]) 149 | for idx, im in enumerate(mm_list): 150 | if not isinstance(im, Image.Image): 151 | continue 152 | im.save(f"result_{idx}.png") 153 | ``` 154 | 155 | #### Use 🤗Transformers to run Emu3-Chat/Stage1 for vision-language understanding 156 | 157 | ```python 158 | from PIL import Image 159 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM 160 | from transformers.generation.configuration_utils import GenerationConfig 161 | import torch 162 | 163 | from emu3.mllm.processing_emu3 import Emu3Processor 164 | 165 | 166 | # model path 167 | EMU_HUB = "BAAI/Emu3-Chat" 168 | VQ_HUB = "BAAI/Emu3-VisionTokenizer" 169 | 170 | # prepare model and processor 171 | model = AutoModelForCausalLM.from_pretrained( 172 | EMU_HUB, 173 | device_map="cuda:0", 174 | torch_dtype=torch.bfloat16, 175 | attn_implementation="flash_attention_2", 176 | trust_remote_code=True, 177 | ) 178 | 179 | # used for Emu3-Chat 180 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left") 181 | # used for Emu3-Stage1 182 | # tokenizer = AutoTokenizer.from_pretrained( 183 | # EMU_HUB, 184 | # trust_remote_code=True, 185 | # chat_template="{image_prompt}{text_prompt}", 186 | # padding_side="left", 187 | # ) 188 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True) 189 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval() 190 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) 191 | 192 | # prepare input 193 | text = "Please describe the image" 194 | image = Image.open("assets/demo.png") 195 | 196 | inputs = processor( 197 | text=text, 198 | image=image, 199 | mode='U', 200 | return_tensors="pt", 201 | padding="longest", 202 | ) 203 | 204 | # prepare hyper parameters 205 | GENERATION_CONFIG = GenerationConfig( 206 | pad_token_id=tokenizer.pad_token_id, 207 | bos_token_id=tokenizer.bos_token_id, 208 | eos_token_id=tokenizer.eos_token_id, 209 | max_new_tokens=1024, 210 | ) 211 | 212 | # generate 213 | outputs = model.generate( 214 | inputs.input_ids.to("cuda:0"), 215 | GENERATION_CONFIG, 216 | attention_mask=inputs.attention_mask.to("cuda:0"), 217 | ) 218 | 219 | outputs = outputs[:, inputs.input_ids.shape[-1]:] 220 | print(processor.batch_decode(outputs, skip_special_tokens=True)[0]) 221 | ``` 222 | 223 | #### Use 🤗Transformers to run Emu3-VisionTokenzier for vision encoding and decoding 224 | ```python 225 | import os 226 | import os.path as osp 227 | 228 | from PIL import Image 229 | import torch 230 | from transformers import AutoModel, AutoImageProcessor 231 | 232 | MODEL_HUB = "BAAI/Emu3-VisionTokenizer" 233 | 234 | model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda() 235 | processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True) 236 | 237 | # TODO: you need to modify the path here 238 | VIDEO_FRAMES_PATH = "YOUR_VIDEO_FRAMES_PATH" 239 | 240 | video = os.listdir(VIDEO_FRAMES_PATH) 241 | video.sort() 242 | video = [Image.open(osp.join(VIDEO_FRAMES_PATH, v)) for v in video] 243 | 244 | images = processor(video, return_tensors="pt")["pixel_values"] 245 | images = images.unsqueeze(0).cuda() 246 | 247 | # image autoencode 248 | image = images[:, 0] 249 | print(image.shape) 250 | with torch.no_grad(): 251 | # encode 252 | codes = model.encode(image) 253 | # decode 254 | recon = model.decode(codes) 255 | 256 | recon = recon.view(-1, *recon.shape[2:]) 257 | recon_image = processor.postprocess(recon)["pixel_values"][0] 258 | recon_image.save("recon_image.png") 259 | 260 | # video autoencode 261 | images = images.view( 262 | -1, 263 | model.config.temporal_downsample_factor, 264 | *images.shape[2:], 265 | ) 266 | 267 | print(images.shape) 268 | with torch.no_grad(): 269 | # encode 270 | codes = model.encode(images) 271 | # decode 272 | recon = model.decode(codes) 273 | 274 | recon = recon.view(-1, *recon.shape[2:]) 275 | recon_images = processor.postprocess(recon)["pixel_values"] 276 | for idx, im in enumerate(recon_images): 277 | im.save(f"recon_video_{idx}.png") 278 | ``` 279 | 280 | ## Acknowledgement 281 | 282 | We thank the great work from [Emu Series](https://github.com/baaivision/Emu), [QWen2-VL](https://github.com/QwenLM/Qwen2-VL) and [MoVQGAN](https://github.com/ai-forever/MoVQGAN) 283 | 284 | ## Citation 285 | 286 | If you find Emu3 useful for your research and applications, please consider starring this repository and citing: 287 | 288 | ``` 289 | @article{wang2024emu3, 290 | title={Emu3: Next-Token Prediction is All You Need}, 291 | author={Wang, Xinlong and Zhang, Xiaosong and Luo, Zhengxiong and Sun, Quan and Cui, Yufeng and Wang, Jinsheng and Zhang, Fan and Wang, Yueze and Li, Zhen and Yu, Qiying and others}, 292 | journal={arXiv preprint arXiv:2409.18869}, 293 | year={2024} 294 | } 295 | ``` 296 | 297 | 298 | -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/assets/arch.png -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/assets/comparison.png -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/assets/demo.png -------------------------------------------------------------------------------- /autoencode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import os.path as osp 5 | 6 | from PIL import Image 7 | import torch 8 | from transformers import AutoModel, AutoImageProcessor 9 | 10 | MODEL_HUB = "BAAI/Emu3-VisionTokenizer" 11 | 12 | model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda() 13 | processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True) 14 | 15 | # TODO: you need to modify the path here 16 | VIDEO_FRAMES_PATH = "YOUR_VIDEO_FRAMES_PATH" 17 | 18 | video = os.listdir(VIDEO_FRAMES_PATH) 19 | video.sort() 20 | video = [Image.open(osp.join(VIDEO_FRAMES_PATH, v)) for v in video] 21 | 22 | images = processor(video, return_tensors="pt")["pixel_values"] 23 | images = images.unsqueeze(0).cuda() 24 | 25 | # image autoencode 26 | image = images[:, 0] 27 | print(image.shape) 28 | with torch.no_grad(): 29 | # encode 30 | codes = model.encode(image) 31 | # decode 32 | recon = model.decode(codes) 33 | 34 | recon = recon.view(-1, *recon.shape[2:]) 35 | recon_image = processor.postprocess(recon)["pixel_values"][0] 36 | recon_image.save("recon_image.png") 37 | 38 | # video autoencode 39 | images = images.view( 40 | -1, 41 | model.config.temporal_downsample_factor, 42 | *images.shape[2:], 43 | ) 44 | 45 | print(images.shape) 46 | with torch.no_grad(): 47 | # encode 48 | codes = model.encode(images) 49 | # decode 50 | recon = model.decode(codes) 51 | 52 | recon = recon.view(-1, *recon.shape[2:]) 53 | recon_images = processor.postprocess(recon)["pixel_values"] 54 | for idx, im in enumerate(recon_images): 55 | im.save(f"recon_video_{idx}.png") 56 | -------------------------------------------------------------------------------- /emu3/mllm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 BAAI and the HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_torch_available, 20 | ) 21 | 22 | 23 | _import_structure = { 24 | "configuration_emu3": ["Emu3Config"], 25 | "tokenization_emu3": ["Emu3Tokenizer"], 26 | "processing_emu3": ["Emu3Processor"], 27 | } 28 | 29 | try: 30 | if not is_torch_available(): 31 | raise OptionalDependencyNotAvailable() 32 | except OptionalDependencyNotAvailable: 33 | pass 34 | else: 35 | _import_structure["modeling_emu3"] = [ 36 | "Emu3Model", 37 | "Emu3PretrainedModel", 38 | "Emu3ForCausalLM", 39 | ] 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_emu3 import Emu3Config 43 | from .tokenization_emu3 import Emu3Tokenizer 44 | from .processing_emu3 import Emu3Processor 45 | 46 | try: 47 | if not is_torch_available(): 48 | raise OptionalDependencyNotAvailable() 49 | except OptionalDependencyNotAvailable: 50 | pass 51 | else: 52 | from .modeling_emu3 import ( 53 | Emu3Model, 54 | Emu3PretrainedModel, 55 | Emu3ForCausalLM, 56 | ) 57 | 58 | else: 59 | import sys 60 | 61 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) 62 | -------------------------------------------------------------------------------- /emu3/mllm/configuration_emu3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ Emu3 model configuration""" 21 | 22 | from typing import Optional 23 | 24 | from transformers.configuration_utils import PretrainedConfig 25 | from transformers.utils import logging 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | EMU3_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 31 | 32 | 33 | class Emu3Config(PretrainedConfig): 34 | r""" 35 | This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate an Emu3 36 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 37 | defaults will yield a similar configuration to that of the Emu3-8B. 38 | 39 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 40 | documentation from [`PretrainedConfig`] for more information. 41 | 42 | 43 | Args: 44 | vocab_size (`int`, *optional*, defaults to 184622): 45 | Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the 46 | `inputs_ids` passed when calling [`Emu3Model`] 47 | hidden_size (`int`, *optional*, defaults to 4096): 48 | Dimension of the hidden representations. 49 | intermediate_size (`int`, *optional*, defaults to 14336): 50 | Dimension of the MLP representations. 51 | num_hidden_layers (`int`, *optional*, defaults to 32): 52 | Number of hidden layers in the Transformer decoder. 53 | num_attention_heads (`int`, *optional*, defaults to 32): 54 | Number of attention heads for each attention layer in the Transformer decoder. 55 | num_key_value_heads (`int`, *optional*, defaults to 8): 56 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 57 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 58 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 59 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 60 | by meanpooling all the original heads within that group. For more details checkout [this 61 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 62 | `num_attention_heads`. 63 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 64 | The non-linear activation function (function or string) in the decoder. 65 | max_position_embeddings (`int`, *optional*, defaults to 9216): 66 | The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens, 67 | initializer_range (`float`, *optional*, defaults to 0.02): 68 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 69 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 70 | The epsilon used by the rms normalization layers. 71 | use_cache (`bool`, *optional*, defaults to `True`): 72 | Whether or not the model should return the last key/values attentions (not used by all models). Only 73 | relevant if `config.is_decoder=True`. 74 | pad_token_id (`int`, *optional*, 151643): 75 | Padding token id. 76 | bos_token_id (`int`, *optional*, defaults to 151849): 77 | Beginning of stream token id. 78 | eos_token_id (`int`, *optional*, defaults to 151850): 79 | End of stream token id. 80 | img_token_id (`int`, *optional*, defaults to 151851): 81 | image token id. 82 | boi_token_id (`int`, *optional*, defaults to 151852): 83 | Beginning of image token id. 84 | eoi_token_id (`int`, *optional*, defaults to 151853): 85 | End of image token id. 86 | eol_token_id (`int`, *optional*, defaults to 151846): 87 | End of line token id. 88 | eof_token_id (`int`, *optional*, defaults to 151847): 89 | End of line token id. 90 | image_area (`int`, *optional*, defaults to 720 * 720) 91 | generated image area (image area used in training) 92 | pretraining_tp (`int`, *optional*, defaults to 1): 93 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 94 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 95 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 96 | issue](https://github.com/pytorch/pytorch/issues/76232). 97 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 98 | Whether to tie weight embeddings 99 | rope_theta (`float`, *optional*, defaults to 1_000_000.0): 100 | The base period of the RoPE embeddings. 101 | rope_scaling (`Dict`, *optional*): 102 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 103 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 104 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 105 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 106 | these scaling strategies behave: 107 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 108 | experimental feature, subject to breaking API changes in future versions. 109 | attention_dropout (`float`, *optional*, defaults to 0.1): 110 | The dropout ratio for the attention probabilities. 111 | 112 | ```python 113 | >>> from transformers import Emu3Model, Emu3Config 114 | 115 | >>> # Initializing a Emu3-8b style configuration 116 | >>> configuration = Emu3Config() 117 | 118 | >>> # Initializing a model from the Emu3-8b style configuration 119 | >>> model = Emu3Model(configuration) 120 | 121 | >>> # Accessing the model configuration 122 | >>> configuration = model.config 123 | ```""" 124 | 125 | model_type = "Emu3" 126 | keys_to_ignore_at_inference = ["past_key_values"] 127 | 128 | def __init__( 129 | self, 130 | vocab_size: int = 184622, 131 | hidden_size: int = 4096, 132 | intermediate_size: int = 14336, 133 | num_hidden_layers: int = 32, 134 | num_attention_heads: int = 32, 135 | num_key_value_heads: Optional[int] = 8, 136 | hidden_act: str = "silu", 137 | max_position_embeddings: int = 9216, 138 | initializer_range: float = 0.02, 139 | rms_norm_eps: float = 1e-5, 140 | use_cache: bool = True, 141 | pad_token_id: int = 151643, 142 | bos_token_id: int = 151849, 143 | eos_token_id: int = 151850, 144 | img_token_id: int = 151851, 145 | boi_token_id: int = 151852, 146 | eoi_token_id: int = 151853, 147 | eol_token_id: int = 151846, 148 | eof_token_id: int = 151847, 149 | image_area: int = 720 * 720, 150 | pretraining_tp: int = 1, 151 | tie_word_embeddings: bool = False, 152 | rope_theta: float = 1000000.0, 153 | rope_scaling: Optional = None, 154 | attention_dropout: float = 0.1, 155 | **kwargs, 156 | ): 157 | self.vocab_size = vocab_size 158 | self.max_position_embeddings = max_position_embeddings 159 | self.hidden_size = hidden_size 160 | self.intermediate_size = intermediate_size 161 | self.num_hidden_layers = num_hidden_layers 162 | self.num_attention_heads = num_attention_heads 163 | 164 | # for backward compatibility 165 | if num_key_value_heads is None: 166 | num_key_value_heads = num_attention_heads 167 | 168 | self.num_key_value_heads = num_key_value_heads 169 | self.hidden_act = hidden_act 170 | self.initializer_range = initializer_range 171 | self.rms_norm_eps = rms_norm_eps 172 | self.pretraining_tp = pretraining_tp 173 | self.use_cache = use_cache 174 | self.rope_theta = rope_theta 175 | self.rope_scaling = rope_scaling 176 | self._rope_scaling_validation() 177 | self.attention_dropout = attention_dropout 178 | 179 | self.img_token_id = img_token_id 180 | self.boi_token_id = boi_token_id 181 | self.eoi_token_id = eoi_token_id 182 | self.eol_token_id = eol_token_id 183 | self.eof_token_id = eof_token_id 184 | self.image_area = image_area 185 | 186 | super().__init__( 187 | pad_token_id=pad_token_id, 188 | bos_token_id=bos_token_id, 189 | eos_token_id=eos_token_id, 190 | tie_word_embeddings=tie_word_embeddings, 191 | **kwargs, 192 | ) 193 | 194 | def _rope_scaling_validation(self): 195 | """ 196 | Validate the `rope_scaling` configuration. 197 | """ 198 | if self.rope_scaling is None: 199 | return 200 | 201 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 202 | raise ValueError( 203 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 204 | f"got {self.rope_scaling}" 205 | ) 206 | rope_scaling_type = self.rope_scaling.get("type", None) 207 | rope_scaling_factor = self.rope_scaling.get("factor", None) 208 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 209 | raise ValueError( 210 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 211 | ) 212 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 213 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 214 | -------------------------------------------------------------------------------- /emu3/mllm/processing_emu3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Processor class for Emu3. """ 16 | 17 | from math import ceil 18 | import re 19 | from typing import List, Optional, Sequence, Union 20 | from functools import partial 21 | 22 | from PIL import Image 23 | import torch 24 | from torch.nn import functional as F 25 | from transformers.feature_extraction_utils import BatchFeature 26 | from transformers.image_utils import ImageInput, get_image_size, to_numpy_array 27 | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin 28 | from transformers.tokenization_utils_base import TextInput, PreTokenizedInput 29 | from transformers.utils import logging 30 | 31 | from .utils_emu3 import Emu3PrefixConstrainedLogitsHelper 32 | 33 | 34 | logger = logging.get_logger(__name__) 35 | 36 | 37 | class Emu3Processor(ProcessorMixin): 38 | r""" 39 | Constructs an Emu3 processor which wraps an Emu3 image processor and an Emu3 vision vq model and an Emu3 tokenizer into a single processor. 40 | 41 | [`Emu3Processor`] offers all the functionalities of [`Emu3VisionVQModel`] and [`Emu3Tokenizer`]. See the 42 | [`~Emu3Processor.__call__`], [`~Emu3Processor.decode`], [`~Emu3Processor.vision_encode`], [`~Emu3Processor.vision_decode`] 43 | for more information. 44 | 45 | Args: 46 | image_processor ([`Emu3VisionVQImageProcessor`]): 47 | The image processor is a required input. 48 | vision_tokenizer ([`Emu3VisionVQModel`]): 49 | The vision tokenizer is a required input. 50 | tokenizer ([`Emu3Tokenizer`]): 51 | The tokenizer is a required input. 52 | prefix_template(`str`, *optional*): 53 | The prefix template for image tokens 54 | visual_template(`Tuple[str, ...]`, *optional*): 55 | The visual token template for image tokens 56 | """ 57 | 58 | attributes = ["image_processor", "tokenizer"] 59 | valid_kwargs = ["vision_tokenizer", "prefix_template", "visual_template"] 60 | image_processor_class = "AutoImageProcessor" 61 | tokenizer_class = "AutoTokenizer" 62 | 63 | def __init__( 64 | self, 65 | image_processor=None, 66 | vision_tokenizer=None, 67 | tokenizer=None, 68 | chat_template="You are a helpful assistant. USER: {image_prompt}{text_prompt}. ASSISTANT:", 69 | prefix_template="{H}*{W}", 70 | visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>"), 71 | **kwargs, 72 | ): 73 | assert vision_tokenizer is not None, "image tokenizer can not be None" 74 | 75 | self.vision_tokenizer = vision_tokenizer 76 | self.prefix_template = prefix_template 77 | self.visual_template = visual_template 78 | self.vis_tok_spatial_factor = 2 ** (len(self.vision_tokenizer.config.ch_mult) - 1) 79 | 80 | super().__init__(image_processor, tokenizer, chat_template=chat_template) 81 | self.const_helper = self.build_const_helper() 82 | 83 | @torch.no_grad() 84 | def __call__( 85 | self, 86 | text: Optional[TextInput | PreTokenizedInput] = None, 87 | image: Optional[Image.Image | List[Image.Image]] = None, 88 | *, 89 | mode: str = "G", 90 | ratio: str | List[str] = "1:1", 91 | image_area: int = 518400, 92 | padding_image: bool = False, 93 | **kwargs, 94 | ) -> BatchFeature: 95 | """ 96 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 97 | and `kwargs` arguments to Emu3Tokenizer's [`~Emu3Tokenizer.__call__`] to encode the text. 98 | To prepare the image(s), this method forwards the `image` argument to 99 | Emu3VisionVQImageProcessor's [`~Emu3VisionVQImageProcessor.__call__`] and Emu3VisionVQModel's [`~EmuVideoVQModel.encode`] 100 | if `image` is not `None`. Please refer to the doctsring of the above two methods for more information. 101 | 102 | Args: 103 | text (`str` or `List[str]`): 104 | The sequence or a batch of sequence to be encoded. A sequence is a string. 105 | image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*): 106 | The image or a batch of images to be prepared. An image is a PIL image. 107 | mode (`str`, *optional*, in `G` or `U`): 108 | task mode, `G` for generation and `U` for understanding 109 | ratio (`str`, *optional*): 110 | the image width-height ratio for generation 111 | image_area (`int`, *optional*): 112 | image area used to calcualte the generated image height and width 113 | padding_image (`bool`, *optional*): 114 | whether pad images to same size for fast preprocessing if they have different sizes 115 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 116 | If set, will return tensors of a particular framework. Acceptable values are: 117 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 118 | - `'np'`: Return NumPy `np.ndarray` objects. 119 | 120 | Returns: 121 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 122 | 123 | - **input_ids** -- List of token ids to be fed to a model. 124 | - **image_size** -- List of image size of input images or generated images. 125 | """ 126 | assert mode in ('G', 'U'), "mode must be 'G' or 'U'." 127 | if isinstance(text, str): 128 | text = [text] 129 | 130 | if isinstance(image, Image.Image): 131 | image = [image] 132 | 133 | if not isinstance(text[0], str): 134 | raise ValueError("`text` must be string or list of string") 135 | 136 | image_tokens = None 137 | if mode == 'G': 138 | if image is not None: 139 | raise ValueError("You have to specify only `text` in generation mode") 140 | 141 | if isinstance(ratio, str): 142 | ratio = [ratio] * len(text) 143 | 144 | if len(ratio) != len(text): 145 | raise ValueError("ratio number must match text number") 146 | else: 147 | if image is None: 148 | raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.") 149 | 150 | if not isinstance(image, Sequence) and not isinstance(image, Image.Image): 151 | raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].") 152 | 153 | if isinstance(image, Sequence) and not isinstance(image[0], Image.Image): 154 | raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].") 155 | 156 | image_tokens = self.tokenize_image(image, padding_image=padding_image) 157 | if len(text) != len(image_tokens): 158 | raise ValueError("number of image must match number of text prompt") 159 | 160 | prompt_list, size_list = [], [] 161 | for idx, text_prompt in enumerate(text): 162 | prompt = self.tokenizer.bos_token 163 | if mode == 'U': 164 | h, w = image_tokens[idx].shape 165 | imgstr = self.to_imgstr(image_tokens[idx]) 166 | image_prompt = ( 167 | self.tokenizer.boi_token + 168 | self.prefix_template.format(H=h, W=w) + 169 | self.tokenizer.img_token + 170 | imgstr + 171 | self.tokenizer.eol_token + 172 | self.tokenizer.eof_token + 173 | self.tokenizer.eoi_token 174 | ) 175 | prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt) 176 | else: 177 | h, w = self.calculate_generate_size(ratio[idx], image_area, self.vision_tokenizer.spatial_scale_factor) 178 | image_prompt = ( 179 | self.tokenizer.boi_token + 180 | self.prefix_template.format(H=h, W=w) + 181 | self.tokenizer.img_token 182 | ) 183 | prompt += (text_prompt + image_prompt) 184 | 185 | prompt_list.append(prompt) 186 | size_list.append([h, w]) 187 | 188 | text_inputs = self.tokenizer(prompt_list, **kwargs) 189 | return BatchFeature(data={**text_inputs, "image_size": size_list}, tensor_type=kwargs.get("return_tensors")) 190 | 191 | @torch.no_grad() 192 | def batch_decode(self, *args, **kwargs): 193 | docs = self.tokenizer.batch_decode(*args, **kwargs) 194 | return [self.multimodal_decode(d) for d in docs] 195 | 196 | @torch.no_grad() 197 | def decode(self, *args, **kwargs): 198 | doc = self.tokenizer.decode(*args, **kwargs) 199 | return self.multimodal_decode(doc) 200 | 201 | @torch.no_grad() 202 | def vision_encode(self, *args, **kwargs): 203 | return self.vision_tokenizer.encode(*args, **kwargs) 204 | 205 | @torch.no_grad() 206 | def vision_decode(self, *args, **kwargs): 207 | return self.vision_tokenizer.decode(*args, **kwargs) 208 | 209 | @torch.no_grad() 210 | def multimodal_decode(self, doc): 211 | multimodal_output = [] 212 | pattern = rf'({re.escape(self.tokenizer.boi_token)}.*?{re.escape(self.tokenizer.eoi_token)})' 213 | chunks = re.split(pattern, doc) 214 | for c in chunks: 215 | if len(c) == 0: 216 | continue 217 | 218 | if self.tokenizer.boi_token in c: 219 | image = [] 220 | image_rows = re.split(re.escape(self.tokenizer.eol_token), c) 221 | for r in image_rows: 222 | token_ids = re.findall(self.visual_template[1], r) 223 | if len(token_ids) > 0: 224 | row_token = [int(m) for m in token_ids] 225 | image.append(row_token) 226 | image = torch.tensor(image, dtype=torch.long, device=self.vision_tokenizer.device) 227 | image = self.vision_tokenizer.decode(image[None]).float() 228 | image = self.image_processor.postprocess(image)["pixel_values"][0] 229 | multimodal_output.append(image) 230 | else: 231 | multimodal_output.append(c) 232 | 233 | return multimodal_output if len(multimodal_output) > 1 else multimodal_output[0] 234 | 235 | @property 236 | def model_input_names(self): 237 | tokenizer_input_names = self.tokenizer.model_input_names 238 | image_processor_input_names = self.image_processor.model_input_names 239 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 240 | 241 | def to_imgstr(self, image_tokens): 242 | image_tokens = image_tokens.cpu().numpy().tolist() 243 | image_token_str = [ 244 | [ 245 | self.visual_template[0].format(token_id=token_id) 246 | for token_id in token_row 247 | ] 248 | for token_row in image_tokens 249 | ] 250 | image_row_str = ["".join(token_row) for token_row in image_token_str] 251 | imgstr = self.tokenizer.eol_token.join(image_row_str) 252 | return imgstr 253 | 254 | def calculate_generate_size(self, ratio, image_area, spatial_scale_factor): 255 | w, h = map(int, ratio.split(":")) 256 | current_area = h * w 257 | target_ratio = (image_area / current_area) ** 0.5 258 | 259 | th = int(round(h * target_ratio / spatial_scale_factor)) 260 | tw = int(round(w * target_ratio / spatial_scale_factor)) 261 | return th, tw 262 | 263 | def tokenize_image(self, image: List[Image.Image], *, padding_image: bool = False): 264 | is_all_same_size, prev_size = True, None 265 | for im in image: 266 | if prev_size is not None: 267 | is_all_same_size &= (prev_size == im.size) 268 | prev_size = im.size 269 | 270 | if is_all_same_size: 271 | image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"] 272 | image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) 273 | image_tokens = self.vision_tokenizer.encode(image_inputs) 274 | elif padding_image: 275 | image_inputs = [self.image_processor(im, return_tensors="pt")["pixel_values"] for im in image] 276 | image_shapes = [im.shape[2:] for im in image_inputs] 277 | max_shape = ( 278 | max([im_shape[0] for im_shape in image_shapes]), 279 | max([im_shape[1] for im_shape in image_shapes]), 280 | ) 281 | image_inputs = [ 282 | F.pad(im_inp, (0, max_shape[1] - im_shape[1], 0, max_shape[0] - im_shape[0])) 283 | for im_inp, im_shape in zip(image_inputs, image_shapes) 284 | ] 285 | image_inputs = torch.cat(image_inputs, dim=0).to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) 286 | image_tokens = self.vision_tokenizer.encode(image_inputs) 287 | image_tokens = [ 288 | im_tok[:ceil(im_shape[0] / self.vis_tok_spatial_factor), :ceil(im_shape[1] / self.vis_tok_spatial_factor)] 289 | for im_tok, im_shape in zip(image_tokens, image_shapes) 290 | ] 291 | else: 292 | image_tokens = [] 293 | for im in image: 294 | image_input = self.image_processor(im, return_tensors="pt")["pixel_values"] 295 | image_input = image_input.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) 296 | image_tokens.append(self.vision_tokenizer.encode(image_input).squeeze(0)) 297 | 298 | return image_tokens 299 | 300 | def build_const_helper(self): 301 | ( 302 | img_token, 303 | eoi_token, 304 | eos_token, 305 | eol_token, 306 | eof_token, 307 | pad_token, 308 | vis_start, 309 | vis_end, 310 | ) = self.tokenizer.encode([ 311 | self.tokenizer.img_token, 312 | self.tokenizer.eoi_token, 313 | self.tokenizer.eos_token, 314 | self.tokenizer.eol_token, 315 | self.tokenizer.eof_token, 316 | self.tokenizer.pad_token, 317 | self.visual_template[0].format(token_id=0), 318 | self.visual_template[0].format(token_id=self.vision_tokenizer.config.codebook_size - 1), 319 | ]) 320 | 321 | const_helper = partial( 322 | Emu3PrefixConstrainedLogitsHelper, 323 | img_token=img_token, 324 | eoi_token=eoi_token, 325 | eos_token=eos_token, 326 | eol_token=eol_token, 327 | eof_token=eof_token, 328 | pad_token=pad_token, 329 | visual_tokens=list(range(vis_start, vis_end + 1)), 330 | ) 331 | return const_helper 332 | 333 | def build_prefix_constrained_fn(self, height, width): 334 | helper = self.const_helper(height=height, width=width) 335 | return helper 336 | -------------------------------------------------------------------------------- /emu3/mllm/tokenization_emu3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for Emu3.""" 16 | 17 | import base64 18 | import logging 19 | import os 20 | import unicodedata 21 | from typing import Collection, Dict, List, Optional, Set, Tuple, Union 22 | 23 | import tiktoken 24 | from transformers import PreTrainedTokenizer, AddedToken 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | VOCAB_FILES_NAMES = { 30 | "vocab_file": "emu3.tiktoken", 31 | "special_tokens_file": "emu3_vision_tokens.txt", 32 | } 33 | 34 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 35 | ENDOFTEXT = "<|endoftext|>" 36 | IMSTART = "<|im_start|>" 37 | IMEND = "<|im_end|>" 38 | # as the default behavior is changed to allow special tokens in 39 | # regular texts, the surface forms of special tokens need to be 40 | # as different as possible to minimize the impact 41 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) 42 | # changed to use actual index to avoid misconfiguration with vocabulary expansion 43 | SPECIAL_START_ID = 151643 44 | 45 | 46 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: 47 | with open(tiktoken_bpe_file, "rb") as f: 48 | contents = f.read() 49 | return { 50 | base64.b64decode(token): int(rank) 51 | for token, rank in (line.split() for line in contents.splitlines() if line) 52 | } 53 | 54 | 55 | class Emu3Tokenizer(PreTrainedTokenizer): 56 | """Emu3 tokenizer.""" 57 | 58 | vocab_files_names = VOCAB_FILES_NAMES 59 | 60 | def __init__( 61 | self, 62 | vocab_file, 63 | special_tokens_file, 64 | errors="replace", 65 | bos_token = "<|extra_203|>", 66 | eos_token = "<|extra_204|>", 67 | pad_token = "<|endoftext|>", 68 | img_token = "<|image token|>", 69 | boi_token = "<|image start|>", 70 | eoi_token = "<|image end|>", 71 | eol_token = "<|extra_200|>", 72 | eof_token = "<|extra_201|>", 73 | **kwargs, 74 | ): 75 | super().__init__(**kwargs) 76 | 77 | # how to handle errors in decoding UTF-8 byte sequences 78 | # use ignore if you are in streaming inference 79 | self.errors = errors 80 | 81 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) 82 | 83 | vision_tokens = [t.strip() for t in open(special_tokens_file).readlines() if len(t.strip()) > 0] 84 | SPECIAL_TOKENS = tuple( 85 | enumerate( 86 | ( 87 | ( 88 | ENDOFTEXT, 89 | IMSTART, 90 | IMEND, 91 | ) 92 | + EXTRAS 93 | + tuple(vision_tokens) 94 | ), 95 | start=SPECIAL_START_ID, 96 | ) 97 | ) 98 | self.special_tokens = {token: index for index, token in SPECIAL_TOKENS} 99 | self.special_tokens_set = set(t for _, t in SPECIAL_TOKENS) 100 | 101 | enc = tiktoken.Encoding( 102 | "Emu3", 103 | pat_str=PAT_STR, 104 | mergeable_ranks=self.mergeable_ranks, 105 | special_tokens=self.special_tokens, 106 | ) 107 | 108 | assert ( 109 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab 110 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" 111 | 112 | self.decoder = { 113 | v: k for k, v in self.mergeable_ranks.items() 114 | } 115 | self.decoder.update({v: k for k, v in self.special_tokens.items()}) 116 | 117 | self.tokenizer = enc 118 | 119 | self.eod_id = self.tokenizer.eot_token 120 | self.bos_token = bos_token 121 | self.eos_token = eos_token 122 | self.pad_token = pad_token 123 | self.img_token = img_token 124 | self.boi_token = boi_token 125 | self.eoi_token = eoi_token 126 | self.eol_token = eol_token 127 | self.eof_token = eof_token 128 | 129 | def __getstate__(self): 130 | # for pickle lovers 131 | state = self.__dict__.copy() 132 | del state["tokenizer"] 133 | return state 134 | 135 | def __setstate__(self, state): 136 | # tokenizer is not python native; don't pass it; rebuild it 137 | self.__dict__.update(state) 138 | enc = tiktoken.Encoding( 139 | "Emu3", 140 | pat_str=PAT_STR, 141 | mergeable_ranks=self.mergeable_ranks, 142 | special_tokens=self.special_tokens, 143 | ) 144 | self.tokenizer = enc 145 | 146 | def __len__(self) -> int: 147 | return self.tokenizer.n_vocab 148 | 149 | def get_vocab(self) -> Dict[bytes, int]: 150 | return self.mergeable_ranks 151 | 152 | def convert_tokens_to_ids( 153 | self, tokens: Union[bytes, str, List[Union[bytes, str]]] 154 | ) -> List[int]: 155 | if isinstance(tokens, (str, bytes)): 156 | if tokens in self.special_tokens: 157 | return self.special_tokens[tokens] 158 | else: 159 | return self.mergeable_ranks.get(tokens) 160 | 161 | ids = [] 162 | for token in tokens: 163 | if token in self.special_tokens: 164 | ids.append(self.special_tokens[token]) 165 | else: 166 | ids.append(self.mergeable_ranks.get(token)) 167 | return ids 168 | 169 | def _add_tokens( 170 | self, 171 | new_tokens: Union[List[str], List[AddedToken]], 172 | special_tokens: bool = False, 173 | ) -> int: 174 | if not special_tokens and new_tokens: 175 | raise ValueError("Adding regular tokens is not supported") 176 | 177 | for token in new_tokens: 178 | surface_form = token.content if isinstance(token, AddedToken) else token 179 | if surface_form not in self.special_tokens_set: 180 | raise ValueError("Adding unknown special tokens is not supported") 181 | 182 | return 0 183 | 184 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: 185 | """ 186 | Save only the vocabulary of the tokenizer (vocabulary). 187 | 188 | Returns: 189 | `Tuple(str)`: Paths to the files saved. 190 | """ 191 | regular_file_path = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) 192 | with open(regular_file_path,'w', encoding="utf8") as w: 193 | for k, v in self.mergeable_ranks.items(): 194 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" 195 | w.write(line) 196 | 197 | excluded_special_tokens = set((ENDOFTEXT, IMSTART, IMEND,) + EXTRAS) 198 | special_file_path = os.path.join(save_directory, self.vocab_files_names["special_tokens_file"]) 199 | with open(special_file_path, 'w', encoding="utf8") as w: 200 | for k in self.special_tokens: 201 | if k not in excluded_special_tokens: 202 | print(k, file=w) 203 | 204 | return (regular_file_path, special_file_path) 205 | 206 | def tokenize( 207 | self, 208 | text: str, 209 | allowed_special: Union[Set, str] = "all", 210 | disallowed_special: Union[Collection, str] = (), 211 | **kwargs, 212 | ) -> List[Union[bytes, str]]: 213 | """ 214 | Converts a string in a sequence of tokens. 215 | 216 | Args: 217 | text (`str`): 218 | The sequence to be encoded. 219 | allowed_special (`Literal["all"]` or `set`): 220 | The surface forms of the tokens to be encoded as special tokens in regular texts. 221 | Default to "all". 222 | disallowed_special (`Literal["all"]` or `Collection`): 223 | The surface forms of the tokens that should not be in regular texts and trigger errors. 224 | Default to an empty tuple. 225 | 226 | kwargs (additional keyword arguments, *optional*): 227 | Will be passed to the underlying model specific encode method. 228 | 229 | Returns: 230 | `List[bytes|str]`: The list of tokens. 231 | """ 232 | tokens = [] 233 | text = unicodedata.normalize("NFC", text) 234 | 235 | # this implementation takes a detour: text -> token id -> token surface forms 236 | for t in self.tokenizer.encode( 237 | text, allowed_special=allowed_special, disallowed_special=disallowed_special 238 | ): 239 | tokens.append(self.decoder[t]) 240 | 241 | return tokens 242 | 243 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: 244 | """ 245 | Converts a sequence of tokens in a single string. 246 | """ 247 | text = "" 248 | temp = b"" 249 | for t in tokens: 250 | if isinstance(t, str): 251 | if temp: 252 | text += temp.decode("utf-8", errors=self.errors) 253 | temp = b"" 254 | text += t 255 | elif isinstance(t, bytes): 256 | temp += t 257 | else: 258 | raise TypeError("token should only be of type types or str") 259 | if temp: 260 | text += temp.decode("utf-8", errors=self.errors) 261 | return text 262 | 263 | @property 264 | def vocab_size(self): 265 | return self.tokenizer.n_vocab 266 | 267 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]: 268 | """Converts an id to a token, special tokens included""" 269 | if index in self.decoder: 270 | return self.decoder[index] 271 | raise ValueError("unknown ids") 272 | 273 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int: 274 | """Converts a token to an id using the vocab, special tokens included""" 275 | if token in self.special_tokens: 276 | return self.special_tokens[token] 277 | if token in self.mergeable_ranks: 278 | return self.mergeable_ranks[token] 279 | raise ValueError("unknown token") 280 | 281 | def _decode( 282 | self, 283 | token_ids: Union[int, List[int]], 284 | skip_special_tokens: bool = False, 285 | errors: Optional[str] = None, 286 | **kwargs, 287 | ) -> str: 288 | if isinstance(token_ids, int): 289 | token_ids = [token_ids] 290 | 291 | if skip_special_tokens: 292 | token_ids = [i for i in token_ids if i < self.eod_id] 293 | 294 | return self.tokenizer.decode(token_ids, errors=errors or self.errors) 295 | -------------------------------------------------------------------------------- /emu3/mllm/utils_emu3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logits Processor Helper class for Emu3. """ 16 | 17 | import torch 18 | 19 | class Emu3PrefixConstrainedLogitsHelper: 20 | 21 | def __init__( 22 | self, 23 | height, 24 | width, 25 | img_token, 26 | eoi_token, 27 | eos_token, 28 | eol_token, 29 | eof_token, 30 | pad_token, 31 | visual_tokens, 32 | ): 33 | self.height = height 34 | self.width = width 35 | self.img_token = img_token 36 | self.eoi_token = eoi_token 37 | self.eos_token = eos_token 38 | self.eol_token = eol_token 39 | self.eof_token = eof_token 40 | self.pad_token = pad_token 41 | self.visual_tokens = visual_tokens 42 | 43 | self.offset_cache = {} 44 | 45 | def __call__(self, batch_id, input_ids): 46 | if batch_id not in self.offset_cache: 47 | position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0] 48 | self.offset_cache[batch_id] = position 49 | 50 | height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0] 51 | width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0] 52 | 53 | offset = input_ids.shape[0] - self.offset_cache[batch_id] 54 | height = height.to(offset.device) 55 | width = width.to(offset.device) 56 | 57 | if offset % (width + 1) == 0: 58 | return (self.eol_token, ) 59 | elif offset == (width + 1) * height + 1: 60 | return (self.eof_token, ) 61 | elif offset == (width + 1) * height + 2: 62 | return (self.eoi_token, ) 63 | elif offset == (width + 1) * height + 3: 64 | return (self.eos_token, ) 65 | elif offset > (width + 1) * height + 3: 66 | return (self.pad_token, ) 67 | else: 68 | return self.visual_tokens 69 | -------------------------------------------------------------------------------- /emu3/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 BAAI and the HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_torch_available, 20 | is_vision_available, 21 | ) 22 | 23 | 24 | _import_structure = {"configuration_emu3visionvq": ["Emu3VisionVQConfig"]} 25 | 26 | try: 27 | if not is_torch_available(): 28 | raise OptionalDependencyNotAvailable() 29 | except OptionalDependencyNotAvailable: 30 | pass 31 | else: 32 | _import_structure["modeling_emu3visionvq"] = [ 33 | "Emu3VisionVQModel", 34 | "Emu3VisionVQPretrainedModel", 35 | ] 36 | 37 | try: 38 | if not is_vision_available(): 39 | raise OptionalDependencyNotAvailable() 40 | except OptionalDependencyNotAvailable: 41 | pass 42 | else: 43 | _import_structure["image_processing_emu3visionvq"] = ["Emu3VisionVQImageProcessor"] 44 | 45 | if TYPE_CHECKING: 46 | from .configuration_emu3visionvq import Emu3VisionVQConfig 47 | 48 | try: 49 | if not is_torch_available(): 50 | raise OptionalDependencyNotAvailable() 51 | except OptionalDependencyNotAvailable: 52 | pass 53 | else: 54 | from .modeling_emu3visionvq import ( 55 | Emu3VisionVQModel, 56 | Emu3VisionVQPretrainedModel, 57 | ) 58 | 59 | try: 60 | if not is_vision_available(): 61 | raise OptionalDependencyNotAvailable() 62 | except OptionalDependencyNotAvailable: 63 | pass 64 | else: 65 | from .image_processing_emu3visionvq import Emu3VisionVQImageProcessor 66 | 67 | else: 68 | import sys 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) 71 | -------------------------------------------------------------------------------- /emu3/tokenizer/configuration_emu3visionvq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Emu3VisionVQ model configuration """ 16 | 17 | from typing import List 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | class Emu3VisionVQConfig(PretrainedConfig): 27 | r""" 28 | This is the configuration class to store the configuration of a [`Emu3VisionVQ`]. It is used to instantiate an video movq 29 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 30 | defaults will yield a configuration to the VQ model presented in Emu3 paper. 31 | 32 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 33 | documentation from [`PretrainedConfig`] for more information. 34 | 35 | 36 | Args: 37 | codebook_size (`int`, *optional*, defaults to 32768): 38 | Codebook size of the VQ model. 39 | embed_dim (`int`, *optional*, defaults to 4): 40 | Dimension of the quantized vector in codebook. 41 | z_channels (`int`, *optional*, defaults to 4): 42 | Dimension of the output channel of encoder and the input channel of decoder 43 | double_z (`bool`, *optional*, defaults to False): 44 | Whether double the output dim of the encoder. 45 | in_channels (`int`, *optional*, defaults to 3): 46 | Input channel of encoder. 47 | out_channels (`int`, *optional*, defaults to 3): 48 | Output channel of decoder. 49 | temporal_downsample_factor (`int`, *optional*, defaults to 4): 50 | Temporal downsample factor. 51 | ch (`int`, *optional*, defaults to 256): 52 | Basic channel number of the intermediate blocks. 53 | ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`): 54 | Channel scaling factor of the intermediate blocks. 55 | num_res_blocks (`int`, *optional*, defaults to 2): 56 | Residual block number in each stage. 57 | attn_resolutions (`List[int]`, *optional*, defaults to 3): 58 | Stage indices to apply attention. 59 | dropout (`float`, *optional*, defaults to 0.0): 60 | Dropout probability. 61 | 62 | ```python 63 | >>> from transformers import Emu3VisionVQ, Emu3VisionVQConfig 64 | 65 | >>> # Initializing a video VQ model of Emu3 configuration 66 | >>> configuration = Emu3VisionVQConfig() 67 | 68 | >>> # Initializing a model from the Emu3 VQ model style configuration 69 | >>> model = Emu3VisionVQModel(configuration) 70 | 71 | >>> # Accessing the model configuration 72 | >>> configuration = model.config 73 | ```""" 74 | 75 | model_type = "Emu3VisionVQ" 76 | 77 | def __init__( 78 | self, 79 | codebook_size: int = 32768, 80 | embed_dim: int = 4, 81 | z_channels: int = 4, 82 | double_z: bool = False, 83 | in_channels: int = 3, 84 | out_channels: int = 3, 85 | temporal_downsample_factor: int = 4, 86 | ch: int = 256, 87 | ch_mult: List[int] = [1, 2, 2, 4], 88 | num_res_blocks: int = 2, 89 | attn_resolutions: List[int] = [3], 90 | dropout: float = 0.0, 91 | **kwargs, 92 | ): 93 | super().__init__(**kwargs) 94 | 95 | self.codebook_size = codebook_size 96 | self.embed_dim = embed_dim 97 | self.z_channels = z_channels 98 | self.double_z = double_z 99 | self.in_channels = in_channels 100 | self.out_channels = out_channels 101 | self.temporal_downsample_factor = temporal_downsample_factor 102 | self.ch = ch 103 | self.ch_mult = ch_mult 104 | self.num_res_blocks = num_res_blocks 105 | self.attn_resolutions = attn_resolutions 106 | self.dropout = dropout 107 | -------------------------------------------------------------------------------- /emu3/tokenizer/image_processing_emu3visionvq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Image processor class for Emu3VisionVQ.""" 16 | 17 | 18 | import math 19 | from typing import Dict, List, Optional, Union 20 | 21 | import numpy as np 22 | 23 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 24 | from transformers.image_transforms import ( 25 | convert_to_rgb, 26 | resize, 27 | to_channel_dimension_format, 28 | ) 29 | from transformers.image_utils import ( 30 | IMAGENET_STANDARD_MEAN, 31 | IMAGENET_STANDARD_STD, 32 | ChannelDimension, 33 | ImageInput, 34 | PILImageResampling, 35 | get_image_size, 36 | infer_channel_dimension_format, 37 | is_scaled_image, 38 | make_list_of_images, 39 | to_numpy_array, 40 | valid_images, 41 | validate_preprocess_arguments, 42 | ) 43 | from transformers.utils import TensorType, is_vision_available, logging 44 | 45 | 46 | logger = logging.get_logger(__name__) 47 | 48 | 49 | if is_vision_available(): 50 | from PIL import Image 51 | 52 | 53 | def smart_resize( 54 | height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024 55 | ): 56 | """Rescales the image so that the following conditions are met: 57 | 58 | 1. Both dimensions (height and width) are divisible by 'factor'. 59 | 60 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 61 | 62 | 3. The aspect ratio of the image is maintained as closely as possible. 63 | 64 | """ 65 | if height < factor or width < factor: 66 | raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") 67 | elif max(height, width) / min(height, width) > 5: 68 | raise ValueError( 69 | f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}" 70 | ) 71 | 72 | h_bar = round(height / factor) * factor 73 | w_bar = round(width / factor) * factor 74 | if h_bar * w_bar > max_pixels: 75 | beta = math.sqrt((height * width) / max_pixels) 76 | h_bar = math.floor(height / beta / factor) * factor 77 | w_bar = math.floor(width / beta / factor) * factor 78 | elif h_bar * w_bar < min_pixels: 79 | beta = math.sqrt(min_pixels / (height * width)) 80 | h_bar = math.ceil(height * beta / factor) * factor 81 | w_bar = math.ceil(width * beta / factor) * factor 82 | 83 | return h_bar, w_bar 84 | 85 | 86 | class Emu3VisionVQImageProcessor(BaseImageProcessor): 87 | r""" 88 | Constructs a Emu3VisionVQ image processor that dynamically resizes images based on the original images. 89 | 90 | Args: 91 | do_resize (`bool`, *optional*, defaults to `True`): 92 | Whether to resize the image's (height, width) dimensions. 93 | resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): 94 | Resampling filter to use when resizing the image. 95 | do_rescale (`bool`, *optional*, defaults to `True`): 96 | Whether to rescale the image by the specified scale `rescale_factor`. 97 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): 98 | Scale factor to use if rescaling the image. 99 | do_normalize (`bool`, *optional*, defaults to `True`): 100 | Whether to normalize the image. 101 | image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): 102 | Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. 103 | image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): 104 | Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. 105 | do_convert_rgb (`bool`, *optional*, defaults to `True`): 106 | Whether to convert the image to RGB. 107 | min_pixels (`int`, *optional*, defaults to `512 * 512`): 108 | The min pixels of the image to resize the image. 109 | max_pixels (`int`, *optional*, defaults to `1024 * 1024`): 110 | The max pixels of the image to resize the image. 111 | spatial_factor (`int`, *optional*, defautls to 8): 112 | The spatial downsample factor the image will be downsampled in feature extracting phase 113 | """ 114 | 115 | model_input_names = ["pixel_values"] 116 | 117 | def __init__( 118 | self, 119 | do_resize: bool = True, 120 | resample: PILImageResampling = PILImageResampling.BICUBIC, 121 | do_rescale: bool = True, 122 | rescale_factor: Union[int, float] = 1 / 255, 123 | do_normalize: bool = True, 124 | image_mean: Optional[Union[float, List[float]]] = None, 125 | image_std: Optional[Union[float, List[float]]] = None, 126 | do_convert_rgb: bool = True, 127 | min_pixels: int = 512 * 512, 128 | max_pixels: int = 1024 * 1024, 129 | spatial_factor: int = 8, 130 | **kwargs, 131 | ) -> None: 132 | super().__init__(**kwargs) 133 | self.do_resize = do_resize 134 | self.resample = resample 135 | self.do_rescale = do_rescale 136 | self.rescale_factor = rescale_factor 137 | self.do_normalize = do_normalize 138 | self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN 139 | self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD 140 | self.min_pixels = min_pixels 141 | self.max_pixels = max_pixels 142 | self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} 143 | self.do_convert_rgb = do_convert_rgb 144 | self.spatial_factor = spatial_factor 145 | 146 | def _preprocess( 147 | self, 148 | images: ImageInput, 149 | do_resize: Optional[bool] = None, 150 | resample: PILImageResampling = None, 151 | do_rescale: Optional[bool] = None, 152 | rescale_factor: Optional[float] = None, 153 | do_normalize: Optional[bool] = None, 154 | image_mean: Optional[Union[float, List[float]]] = None, 155 | image_std: Optional[Union[float, List[float]]] = None, 156 | do_convert_rgb: Optional[bool] = None, 157 | spatial_factor: Optional[int] = None, 158 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 159 | output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST, 160 | ): 161 | """ 162 | Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. 163 | 164 | Args: 165 | images (`ImageInput`): 166 | Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. 167 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 168 | Whether to resize the image. 169 | resample (`PILImageResampling`, *optional*, defaults to `self.resample`): 170 | Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. 171 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 172 | Whether to rescale the image. 173 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 174 | Scale factor to use if rescaling the image. 175 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 176 | Whether to normalize the image. 177 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 178 | Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. 179 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 180 | Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. 181 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 182 | Whether to convert the image to RGB. 183 | spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`): 184 | The spatial downsample factor the image will be downsampled in feature extracting phase 185 | input_data_format (`ChannelDimension` or `str`, *optional*): 186 | The channel dimension format for the input image. Can be one of: 187 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 188 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 189 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 190 | output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): 191 | The channel dimension format for the output image. Can be one of: 192 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 193 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 194 | - Unset: Use the channel dimension format of the input image. 195 | """ 196 | spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor 197 | 198 | images = make_list_of_images(images) 199 | if do_convert_rgb: 200 | images = [convert_to_rgb(image) for image in images] 201 | 202 | # All transformations expect numpy arrays. 203 | images = [to_numpy_array(image) for image in images] 204 | 205 | if is_scaled_image(images[0]) and do_rescale: 206 | logger.warning_once( 207 | "It looks like you are trying to rescale already rescaled images. If the input" 208 | "pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." 209 | ) 210 | 211 | if input_data_format is None: 212 | # We assume that all images have the same channel dimension format. 213 | input_data_format = infer_channel_dimension_format(images[0]) 214 | 215 | height, width = get_image_size(images[0], channel_dim=input_data_format) 216 | resized_height, resized_width = height, width 217 | processed_images = [] 218 | for image in images: 219 | if do_resize: 220 | resized_height, resized_width = smart_resize( 221 | height, 222 | width, 223 | factor=spatial_factor, 224 | min_pixels=self.min_pixels, 225 | max_pixels=self.max_pixels, 226 | ) 227 | image = resize( 228 | image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format 229 | ) 230 | 231 | if do_rescale: 232 | image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) 233 | 234 | if do_normalize: 235 | image = self.normalize( 236 | image=image, mean=image_mean, std=image_std, input_data_format=input_data_format 237 | ) 238 | 239 | image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format) 240 | processed_images.append(image) 241 | 242 | image = np.array(processed_images) 243 | return image 244 | 245 | def preprocess( 246 | self, 247 | images: ImageInput, 248 | do_resize: Optional[bool] = None, 249 | resample: PILImageResampling = None, 250 | do_rescale: Optional[bool] = None, 251 | rescale_factor: Optional[float] = None, 252 | do_normalize: Optional[bool] = None, 253 | image_mean: Optional[Union[float, List[float]]] = None, 254 | image_std: Optional[Union[float, List[float]]] = None, 255 | do_convert_rgb: Optional[bool] = None, 256 | spatial_factor: Optional[int] = None, 257 | return_tensors: Optional[Union[str, TensorType]] = None, 258 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 259 | output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST, 260 | ): 261 | """ 262 | Args: 263 | images (`ImageInput`): 264 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 265 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 266 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 267 | Whether to resize the image. 268 | resample (`int`, *optional*, defaults to `self.resample`): 269 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only 270 | has an effect if `do_resize` is set to `True`. 271 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 272 | Whether to rescale the image. 273 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 274 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 275 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 276 | Whether to normalize the image. 277 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 278 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. 279 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 280 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. 281 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 282 | Whether to convert the image to RGB. 283 | spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`): 284 | The spatial downsample factor the image will be downsampled in feature extracting phase 285 | return_tensors (`str` or `TensorType`, *optional*): 286 | The type of tensors to return. Can be one of: 287 | - Unset: Return a list of `np.ndarray`. 288 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 289 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 290 | input_data_format (`ChannelDimension` or `str`, *optional*): 291 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 292 | from the input image. Can be one of: 293 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 294 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 295 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 296 | output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 297 | The channel dimension format for the output image. Can be one of: 298 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 299 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 300 | - Unset: Use the channel dimension format of the input image. 301 | """ 302 | do_resize = do_resize if do_resize is not None else self.do_resize 303 | resample = resample if resample is not None else self.resample 304 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 305 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 306 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 307 | image_mean = image_mean if image_mean is not None else self.image_mean 308 | image_std = image_std if image_std is not None else self.image_std 309 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 310 | spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor 311 | 312 | images = make_list_of_images(images) 313 | if images is None or not valid_images(images): 314 | raise ValueError( 315 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 316 | "torch.Tensor, tf.Tensor or jax.ndarray." 317 | ) 318 | 319 | validate_preprocess_arguments( 320 | rescale_factor=rescale_factor, 321 | do_normalize=do_normalize, 322 | image_mean=image_mean, 323 | image_std=image_std, 324 | do_resize=do_resize, 325 | size=self.size, 326 | resample=resample, 327 | ) 328 | 329 | pixel_values = [] 330 | for image in images: 331 | norm_image = self._preprocess( 332 | image, 333 | do_resize=do_resize, 334 | resample=resample, 335 | do_rescale=do_rescale, 336 | rescale_factor=rescale_factor, 337 | do_normalize=do_normalize, 338 | image_mean=image_mean, 339 | image_std=image_std, 340 | do_convert_rgb=do_convert_rgb, 341 | spatial_factor=spatial_factor, 342 | input_data_format=input_data_format, 343 | output_data_format=output_data_format, 344 | ) 345 | pixel_values.extend(norm_image) 346 | pixel_values = np.array(pixel_values) 347 | data = {"pixel_values": pixel_values} 348 | 349 | return BatchFeature(data=data, tensor_type=return_tensors) 350 | 351 | def postprocess( 352 | self, 353 | images: ImageInput, 354 | do_rescale: Optional[bool] = None, 355 | rescale_factor: Optional[float] = None, 356 | do_normalize: Optional[bool] = None, 357 | image_mean: Optional[Union[float, List[float]]] = None, 358 | image_std: Optional[Union[float, List[float]]] = None, 359 | return_tensors: str | TensorType = "PIL.Image.Image", 360 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 361 | ): 362 | """ 363 | Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess. 364 | The parameters should be same as in preprocess. 365 | 366 | Args: 367 | images (`ImageInput`): 368 | Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1. 369 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 370 | Whether to rescale the image. 371 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 372 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 373 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 374 | Whether to normalize the image. 375 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 376 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. 377 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 378 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. 379 | return_tensors (`str` or `TensorType`, *optional*): 380 | The type of tensors to return. Can be one of: 381 | - Unset: Return a list of `np.ndarray`. 382 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 383 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 384 | input_data_format (`ChannelDimension` or `str`, *optional*): 385 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 386 | from the input image. Can be one of: 387 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 388 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 389 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 390 | """ 391 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 392 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 393 | rescale_factor = 1 / rescale_factor 394 | 395 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 396 | image_mean = image_mean if image_mean is not None else self.image_mean 397 | image_std = image_std if image_std is not None else self.image_std 398 | image_mean, image_std = self.inverse_meanstd(image_mean, image_std) 399 | 400 | images = make_list_of_images(images) 401 | if isinstance(images[0], Image.Image): 402 | return images if len(images) > 1 else images[0] 403 | 404 | if input_data_format is None: 405 | # We assume that all images have the same channel dimension format. 406 | input_data_format = infer_channel_dimension_format(images[0]) 407 | 408 | pixel_values = [] 409 | for image in images: 410 | image = to_numpy_array(image) 411 | if do_normalize: 412 | image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 413 | 414 | if do_rescale: 415 | image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) 416 | image = image.clip(0, 255).astype(np.uint8) 417 | 418 | if do_normalize and do_rescale and return_tensors == "PIL.Image.Image": 419 | image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format) 420 | pixel_values.append(Image.fromarray(image)) 421 | else: 422 | pixel_values.extend(image) 423 | 424 | data = {"pixel_values": pixel_values} 425 | return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None 426 | 427 | return BatchFeature(data=data, tensor_type=return_tensors) 428 | 429 | def inverse_meanstd(self, image_mean, image_std): 430 | image_mean = self.to_tuple(image_mean) 431 | image_std = self.to_tuple(image_std) 432 | 433 | rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std)) 434 | rev_image_std = tuple(1 / s for s in image_std) 435 | 436 | return rev_image_mean, rev_image_std 437 | 438 | def to_tuple(self, value, dim=3): 439 | if isinstance(value, int | float): 440 | return (value,) * dim 441 | 442 | return tuple(value) 443 | -------------------------------------------------------------------------------- /emu3/tokenizer/modeling_emu3visionvq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Emu3VisionVQ model """ 16 | 17 | import math 18 | from typing import Optional, Tuple, Union 19 | 20 | import torch 21 | from torch import nn 22 | from torch.nn import functional as F 23 | from transformers.modeling_utils import PreTrainedModel 24 | 25 | from .configuration_emu3visionvq import Emu3VisionVQConfig 26 | 27 | 28 | class Emu3VisionVQActivation(nn.Module): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def __call__(self, x: torch.Tensor): 34 | return x * torch.sigmoid(x) 35 | 36 | 37 | class Emu3VisionVQUpsample(nn.Module): 38 | 39 | def __init__(self, in_channels: int): 40 | super().__init__() 41 | self.conv = nn.Conv2d( 42 | in_channels, 43 | in_channels, 44 | kernel_size=3, 45 | stride=1, 46 | padding=1, 47 | ) 48 | 49 | def forward(self, x: torch.Tensor): 50 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class Emu3VisionVQDownsample(nn.Module): 56 | 57 | def __init__(self, in_channels: int): 58 | super().__init__() 59 | self.conv = nn.Conv2d( 60 | in_channels, 61 | in_channels, 62 | kernel_size=3, 63 | stride=2, 64 | padding=0, 65 | ) 66 | 67 | def forward(self, x: torch.Tensor): 68 | pad = (0, 1, 0, 1) 69 | x = F.pad(x, pad, mode="constant", value=0) 70 | x = self.conv(x) 71 | return x 72 | 73 | 74 | class Emu3VisionVQCausalConv3d(nn.Module): 75 | 76 | def __init__( 77 | self, 78 | in_channel: int, 79 | out_channel: int, 80 | kernel_size: Union[int, Tuple[int, ...]] = (3, 1, 1), 81 | stride: Union[int, Tuple[int, ...]] = (1, 1, 1), 82 | ): 83 | super().__init__() 84 | 85 | if isinstance(kernel_size, int): 86 | kernel_size = (kernel_size,) * 3 87 | if isinstance(stride, int): 88 | stride = (stride,) * 3 89 | 90 | hw_pad = [k - s for k, s in zip(kernel_size[1:], stride[1:])] 91 | self.padding = tuple() 92 | for p in hw_pad[::-1]: 93 | self.padding += (p // 2 + p % 2, p // 2) 94 | self.padding += (2, 0) 95 | 96 | self.conv = nn.Conv3d( 97 | in_channel, 98 | out_channel, 99 | kernel_size, 100 | stride=stride, 101 | ) 102 | 103 | def forward(self, x: torch.Tensor): 104 | x = F.pad(x, self.padding) 105 | x = self.conv(x) 106 | return x 107 | 108 | 109 | class Emu3VisionVQResnetTemporalBlock(nn.Module): 110 | 111 | def __init__( 112 | self, 113 | in_channels: int, 114 | out_channels: Optional[int] = None, 115 | conv_shortcut: bool = False, 116 | dropout: float = 0.0, 117 | ): 118 | super().__init__() 119 | self.in_channels = in_channels 120 | out_channels = in_channels if out_channels is None else out_channels 121 | self.out_channels = out_channels 122 | self.use_conv_shortcut = conv_shortcut 123 | 124 | stride = (1, 1, 1) 125 | kernel_size = (3, 3, 3) 126 | 127 | self.norm1 = nn.BatchNorm3d(in_channels) 128 | self.conv1 = Emu3VisionVQCausalConv3d( 129 | in_channels, 130 | out_channels, 131 | kernel_size=kernel_size, 132 | stride=stride, 133 | ) 134 | self.norm2 = nn.BatchNorm3d(out_channels) 135 | self.dropout = nn.Dropout(dropout) 136 | self.conv2 = Emu3VisionVQCausalConv3d( 137 | out_channels, 138 | out_channels, 139 | kernel_size=kernel_size, 140 | stride=stride, 141 | ) 142 | self.act = Emu3VisionVQActivation() 143 | 144 | if self.in_channels != self.out_channels: 145 | if self.use_conv_shortcut: 146 | self.conv_shortcut = Emu3VisionVQCausalConv3d( 147 | in_channels, 148 | out_channels, 149 | kernel_size=kernel_size, 150 | stride=stride, 151 | ) 152 | else: 153 | self.nin_shortcut = nn.Conv3d( 154 | in_channels, 155 | out_channels, 156 | kernel_size=1, 157 | stride=1, 158 | padding=0, 159 | ) 160 | 161 | def forward(self, x: torch.Tensor): 162 | h = self.norm1(x) 163 | h = self.act(h) 164 | h = self.conv1(h) 165 | 166 | h = self.norm2(h) 167 | h = self.act(h) 168 | h = self.dropout(h) 169 | h = self.conv2(h) 170 | 171 | if self.in_channels != self.out_channels: 172 | if self.use_conv_shortcut: 173 | x = self.conv_shortcut(x) 174 | else: 175 | x = self.nin_shortcut(x) 176 | 177 | return x + h 178 | 179 | 180 | class Emu3VisionVQSpatialNorm(nn.Module): 181 | 182 | def __init__( 183 | self, 184 | f_channels: int, 185 | zq_channels: int, 186 | norm_layer: nn.Module = nn.GroupNorm, 187 | add_conv: bool = False, 188 | num_groups: int = 32, 189 | eps: float = 1e-6, 190 | affine: bool = True, 191 | ): 192 | super().__init__() 193 | self.norm_layer = norm_layer( 194 | num_channels=f_channels, 195 | num_groups=num_groups, 196 | eps=eps, 197 | affine=affine, 198 | ) 199 | 200 | self.add_conv = add_conv 201 | if self.add_conv: 202 | self.conv = nn.Conv2d( 203 | zq_channels, 204 | zq_channels, 205 | kernel_size=3, 206 | stride=1, 207 | padding=1, 208 | ) 209 | 210 | self.conv_y = nn.Conv2d( 211 | zq_channels, 212 | f_channels, 213 | kernel_size=1, 214 | stride=1, 215 | padding=0, 216 | ) 217 | self.conv_b = nn.Conv2d( 218 | zq_channels, 219 | f_channels, 220 | kernel_size=1, 221 | stride=1, 222 | padding=0, 223 | ) 224 | 225 | def forward(self, x: torch.Tensor, zq: torch.Tensor): 226 | zq = F.interpolate(zq, size=x.shape[-2:], mode="nearest") 227 | 228 | if self.add_conv: 229 | zq = self.conv(zq) 230 | 231 | x = self.norm_layer(x) 232 | x = x * self.conv_y(zq) + self.conv_b(zq) 233 | return x 234 | 235 | 236 | class Emu3VisionVQResnetBlock(nn.Module): 237 | 238 | def __init__( 239 | self, 240 | in_channels: int, 241 | out_channels: Optional[int] = None, 242 | conv_shortcut: bool = False, 243 | dropout: float = 0.0, 244 | zq_ch: Optional[int] = None, 245 | add_conv: bool = False, 246 | ): 247 | super().__init__() 248 | self.in_channels = in_channels 249 | out_channels = in_channels if out_channels is None else out_channels 250 | self.out_channels = out_channels 251 | self.use_conv_shortcut = conv_shortcut 252 | self.zq_ch = zq_ch 253 | 254 | if zq_ch is None: 255 | norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True) 256 | self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs) 257 | self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs) 258 | else: 259 | self.norm1 = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv) 260 | self.norm2 = Emu3VisionVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv) 261 | 262 | self.conv1 = nn.Conv2d( 263 | in_channels, 264 | out_channels, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | ) 269 | 270 | self.dropout = nn.Dropout(dropout) 271 | self.conv2 = nn.Conv2d( 272 | out_channels, 273 | out_channels, 274 | kernel_size=3, 275 | stride=1, 276 | padding=1, 277 | ) 278 | 279 | self.act = Emu3VisionVQActivation() 280 | 281 | if self.in_channels != self.out_channels: 282 | if self.use_conv_shortcut: 283 | self.conv_shortcut = nn.Conv2d( 284 | in_channels, 285 | out_channels, 286 | kernel_size=3, 287 | stride=1, 288 | padding=1, 289 | ) 290 | else: 291 | self.nin_shortcut = nn.Conv2d( 292 | in_channels, 293 | out_channels, 294 | kernel_size=1, 295 | stride=1, 296 | padding=0, 297 | ) 298 | 299 | def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None): 300 | norm_args = tuple() if self.zq_ch is None else (zq, ) 301 | 302 | h = self.norm1(x, *norm_args) 303 | h = self.act(h) 304 | h = self.conv1(h) 305 | 306 | h = self.norm2(h, *norm_args) 307 | h = self.act(h) 308 | h = self.dropout(h) 309 | h = self.conv2(h) 310 | 311 | if self.in_channels != self.out_channels: 312 | if self.use_conv_shortcut: 313 | x = self.conv_shortcut(x) 314 | else: 315 | x = self.nin_shortcut(x) 316 | 317 | return x + h 318 | 319 | 320 | class Emu3VisionVQAttnBlock(nn.Module): 321 | 322 | def __init__( 323 | self, 324 | in_channels: int, 325 | zq_ch: Optional[int] = None, 326 | add_conv: bool = False 327 | ): 328 | super().__init__() 329 | self.in_channels = in_channels 330 | self.zq_ch = zq_ch 331 | 332 | if zq_ch is None: 333 | norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True) 334 | self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs) 335 | else: 336 | self.norm = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv) 337 | 338 | self.q = nn.Conv2d( 339 | in_channels, 340 | in_channels, 341 | kernel_size=1, 342 | stride=1, 343 | padding=0, 344 | ) 345 | self.k = nn.Conv2d( 346 | in_channels, 347 | in_channels, 348 | kernel_size=1, 349 | stride=1, 350 | padding=0, 351 | ) 352 | self.v = nn.Conv2d( 353 | in_channels, 354 | in_channels, 355 | kernel_size=1, 356 | stride=1, 357 | padding=0, 358 | ) 359 | self.proj_out = nn.Conv2d( 360 | in_channels, 361 | in_channels, 362 | kernel_size=1, 363 | stride=1, 364 | padding=0, 365 | ) 366 | 367 | def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None): 368 | norm_args = tuple() if self.zq_ch is None else (zq, ) 369 | 370 | nx = self.norm(x, *norm_args) 371 | q = self.q(nx) 372 | k = self.k(nx) 373 | v = self.v(nx) 374 | 375 | # compute attention 376 | b, c, h, w = q.shape 377 | q = q.reshape(b, c, h * w) 378 | k = k.reshape(b, c, h * w) 379 | score = torch.bmm(q.permute(0, 2, 1), k) 380 | score = score / (c ** 0.5) 381 | score = F.softmax(score, dim=2) 382 | 383 | # attend to values 384 | v = v.reshape(b, c, h * w) 385 | v = torch.bmm(v, score.permute(0, 2, 1)) 386 | v = v.reshape(b, c, h, w) 387 | 388 | v = self.proj_out(v) 389 | 390 | return x + v 391 | 392 | 393 | class Emu3VisionVQTemporalUpsample(nn.Module): 394 | 395 | def __init__( 396 | self, 397 | in_channel: int, 398 | out_channel: int, 399 | kernel_size: Tuple[int, ...] = (3, 3, 3), 400 | stride: Tuple[int, ...] = (1, 1, 1) 401 | ): 402 | super().__init__() 403 | self.in_channel = in_channel 404 | self.out_channel = out_channel 405 | self.conv = Emu3VisionVQCausalConv3d( 406 | in_channel, 407 | out_channel, 408 | kernel_size, 409 | stride=stride, 410 | ) 411 | 412 | def forward(self, x: torch.Tensor): 413 | b, c, t, h, w = x.shape 414 | x = x.permute(0, 1, 3, 4, 2).contiguous().view(b, -1, t) 415 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 416 | x = x.view(b, c, h, w, -1).permute(0, 1, 4, 2, 3).contiguous() 417 | x = self.conv(x) 418 | return x 419 | 420 | 421 | class Emu3VisionVQTemporalDownsample(nn.Module): 422 | 423 | def __init__( 424 | self, 425 | in_channel: int, 426 | out_channel: int, 427 | kernel_size: Tuple[int, ...] = (4, 3, 3), 428 | stride: Tuple[int, ...] = (2, 1, 1), 429 | ): 430 | super().__init__() 431 | self.in_channel = in_channel 432 | self.out_channel = out_channel 433 | self.kernel_size = kernel_size 434 | 435 | self.conv = Emu3VisionVQCausalConv3d( 436 | in_channel, 437 | out_channel, 438 | kernel_size=kernel_size, 439 | stride=stride, 440 | ) 441 | 442 | def forward(self, x: torch.Tensor): 443 | x = self.conv(x) 444 | return x 445 | 446 | 447 | class Emu3VisionVQVectorQuantizer(nn.Module): 448 | 449 | def __init__(self, config: Emu3VisionVQConfig): 450 | super().__init__() 451 | self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) 452 | self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) 453 | 454 | def forward(self, x: torch.Tensor): 455 | # b t c h w -> b t h w c 456 | b, t, c, h, w = x.shape 457 | x = x.permute(0, 1, 3, 4, 2).contiguous() 458 | x_flattened = x.view(-1, c) 459 | 460 | codebook = self.embedding.weight 461 | 462 | d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \ 463 | torch.sum(codebook ** 2, dim=1) - 2 * \ 464 | torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0)) 465 | 466 | indices = torch.argmin(d, dim=1) 467 | indices = indices.view(b, t, h, w) 468 | return indices 469 | 470 | 471 | class Emu3VisionVQEncoder(nn.Module): 472 | 473 | def __init__(self, config: Emu3VisionVQConfig): 474 | super().__init__() 475 | self.ch = config.ch 476 | self.num_resolutions = len(config.ch_mult) 477 | self.num_res_blocks = config.num_res_blocks 478 | self.in_channels = config.in_channels 479 | 480 | # downsampling 481 | self.conv_in = nn.Conv2d( 482 | self.in_channels, 483 | self.ch, 484 | kernel_size=3, 485 | stride=1, 486 | padding=1 487 | ) 488 | 489 | in_ch_mult = (1,) + tuple(config.ch_mult) 490 | self.down = nn.ModuleList() 491 | for i_level in range(self.num_resolutions): 492 | block = nn.ModuleList() 493 | attn = nn.ModuleList() 494 | block_in = config.ch * in_ch_mult[i_level] 495 | block_out = config.ch * config.ch_mult[i_level] 496 | for i_block in range(self.num_res_blocks): 497 | block.append( 498 | Emu3VisionVQResnetBlock( 499 | in_channels=block_in, 500 | out_channels=block_out, 501 | dropout=config.dropout, 502 | ) 503 | ) 504 | block_in = block_out 505 | if i_level in config.attn_resolutions: 506 | attn.append(Emu3VisionVQAttnBlock(block_in)) 507 | 508 | down = nn.Module() 509 | down.block = block 510 | down.attn = attn 511 | if i_level != self.num_resolutions - 1: 512 | down.downsample = Emu3VisionVQDownsample(block_in) 513 | 514 | self.down.append(down) 515 | 516 | # middle 517 | self.mid = nn.Module() 518 | self.mid.block_1 = Emu3VisionVQResnetBlock( 519 | in_channels=block_in, 520 | out_channels=block_in, 521 | dropout=config.dropout, 522 | ) 523 | self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in) 524 | self.mid.block_2 = Emu3VisionVQResnetBlock( 525 | in_channels=block_in, 526 | out_channels=block_in, 527 | dropout=config.dropout, 528 | ) 529 | 530 | # end 531 | self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True) 532 | 533 | out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels 534 | self.conv_out = nn.Conv2d( 535 | block_in, 536 | out_z_channels, 537 | kernel_size=3, 538 | stride=1, 539 | padding=1, 540 | ) 541 | 542 | temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) 543 | self.time_conv = nn.ModuleList() 544 | 545 | for i in range(temporal_down_blocks): 546 | conv = Emu3VisionVQTemporalDownsample(out_z_channels, out_z_channels) 547 | self.time_conv.append(conv) 548 | 549 | self.time_res_stack = nn.Sequential(*[ 550 | Emu3VisionVQResnetTemporalBlock( 551 | in_channels=out_z_channels, 552 | out_channels=out_z_channels, 553 | dropout=config.dropout, 554 | ) for _ in range(self.num_res_blocks) 555 | ]) 556 | 557 | self.act = Emu3VisionVQActivation() 558 | 559 | def forward(self, x: torch.Tensor): 560 | t = x.shape[1] 561 | x = x.reshape(-1, *x.shape[2:]) 562 | 563 | # downsampling 564 | h = self.conv_in(x) 565 | for i_level in range(self.num_resolutions): 566 | for i_block in range(self.num_res_blocks): 567 | h = self.down[i_level].block[i_block](h) 568 | if len(self.down[i_level].attn) > 0: 569 | h = self.down[i_level].attn[i_block](h) 570 | 571 | if i_level != self.num_resolutions - 1: 572 | h = self.down[i_level].downsample(h) 573 | 574 | h = self.mid.block_1(h) 575 | h = self.mid.attn_1(h) 576 | h = self.mid.block_2(h) 577 | 578 | # end 579 | h = self.norm_out(h) 580 | h = self.act(h) 581 | 582 | h = self.conv_out(h) 583 | 584 | h = h.reshape(-1, t, *h.shape[1:]) 585 | h = h.permute(0, 2, 1, 3, 4) 586 | 587 | for conv in self.time_conv: 588 | h = self.act(conv(h)) 589 | 590 | h = self.time_res_stack(h) 591 | h = h.permute(0, 2, 1, 3, 4) 592 | 593 | return h 594 | 595 | 596 | class Emu3VisionVQDecoder(nn.Module): 597 | 598 | def __init__(self, config: Emu3VisionVQConfig): 599 | super().__init__() 600 | self.ch = config.ch 601 | self.num_resolutions = len(config.ch_mult) 602 | self.num_res_blocks = config.num_res_blocks 603 | 604 | in_ch_mult = (1,) + tuple(config.ch_mult) 605 | zq_ch = config.embed_dim 606 | 607 | block_in = config.ch * config.ch_mult[-1] 608 | self.time_res_stack = nn.Sequential(*[ 609 | Emu3VisionVQResnetTemporalBlock( 610 | in_channels=config.z_channels, 611 | out_channels=config.z_channels, 612 | dropout=config.dropout, 613 | ) for _ in range(config.num_res_blocks) 614 | ]) 615 | 616 | tempo_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) 617 | self.time_conv = nn.ModuleList() 618 | for i in range(tempo_upsample_block_num): 619 | conv = Emu3VisionVQTemporalUpsample(config.z_channels, config.z_channels) 620 | self.time_conv.append(conv) 621 | 622 | self.conv_in = nn.Conv2d( 623 | config.z_channels, 624 | block_in, 625 | kernel_size=3, 626 | stride=1, 627 | padding=1, 628 | ) 629 | 630 | # middle 631 | self.mid = nn.Module() 632 | self.mid.block_1 = Emu3VisionVQResnetBlock( 633 | in_channels=block_in, 634 | out_channels=block_in, 635 | dropout=config.dropout, 636 | zq_ch=zq_ch, 637 | ) 638 | self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in, zq_ch) 639 | self.mid.block_2 = Emu3VisionVQResnetBlock( 640 | in_channels=block_in, 641 | out_channels=block_in, 642 | dropout=config.dropout, 643 | zq_ch=zq_ch, 644 | ) 645 | 646 | # upsampling 647 | self.up = nn.ModuleList() 648 | for i_level in reversed(range(self.num_resolutions)): 649 | block = nn.ModuleList() 650 | attn = nn.ModuleList() 651 | block_out = config.ch * config.ch_mult[i_level] 652 | for i_block in range(self.num_res_blocks + 1): 653 | block.append( 654 | Emu3VisionVQResnetBlock( 655 | in_channels=block_in, 656 | out_channels=block_out, 657 | dropout=config.dropout, 658 | zq_ch=zq_ch, 659 | ) 660 | ) 661 | block_in = block_out 662 | if i_level in config.attn_resolutions: 663 | attn.append(Emu3VisionVQAttnBlock(block_in, zq_ch)) 664 | 665 | up = nn.Module() 666 | up.block = block 667 | up.attn = attn 668 | if i_level != 0: 669 | up.upsample = Emu3VisionVQUpsample(block_in) 670 | 671 | self.up.insert(0, up) 672 | 673 | self.act = Emu3VisionVQActivation() 674 | 675 | self.norm_out = Emu3VisionVQSpatialNorm(block_in, zq_ch) 676 | self.conv_out = nn.Conv2d( 677 | block_in, 678 | config.out_channels, 679 | kernel_size=3, 680 | stride=1, 681 | padding=1, 682 | ) 683 | 684 | def forward(self, z: torch.Tensor, zq: torch.Tensor): 685 | z_zq = torch.cat((z, zq), dim=0) 686 | z_zq = z_zq.permute(0, 2, 1, 3, 4) 687 | z_zq = self.time_res_stack(z_zq) 688 | 689 | for conv in self.time_conv: 690 | z_zq = self.act(conv(z_zq)) 691 | 692 | z_zq = z_zq.permute(0, 2, 1, 3, 4) 693 | 694 | h, zq = torch.chunk(z_zq, 2, dim=0) 695 | 696 | h = h.reshape(-1, *h.shape[2:]) 697 | zq = zq.reshape(-1, *zq.shape[2:]) 698 | 699 | h = self.conv_in(h) 700 | 701 | # middle 702 | h = self.mid.block_1(h, zq) 703 | h = self.mid.attn_1(h, zq) 704 | h = self.mid.block_2(h, zq) 705 | 706 | # upsampling 707 | for i_level in reversed(range(self.num_resolutions)): 708 | for i_block in range(self.num_res_blocks+1): 709 | h = self.up[i_level].block[i_block](h, zq) 710 | if len(self.up[i_level].attn) > 0: 711 | h = self.up[i_level].attn[i_block](h, zq) 712 | 713 | if i_level != 0: 714 | h = self.up[i_level].upsample(h) 715 | 716 | h = self.norm_out(h, zq) 717 | h = self.act(h) 718 | h = self.conv_out(h) 719 | 720 | return h 721 | 722 | 723 | class Emu3VisionVQPretrainedModel(PreTrainedModel): 724 | """ 725 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 726 | models. 727 | """ 728 | 729 | config_class = Emu3VisionVQConfig 730 | base_model_prefix = "emuvideovq" 731 | main_input_name = "pixel_values" 732 | _no_split_modules = ["Emu3VisionVQResnetBlock", "Emu3VisionVQAttnBlock", "Emu3VisionVQResnetTemporalBlock"] 733 | 734 | def _init_weights(self, module): 735 | if isinstance(module, (nn.Conv2d, nn.Conv3d)): 736 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 737 | # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. 738 | elif isinstance(module, nn.Linear): 739 | nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) 740 | if module.bias is not None: 741 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) 742 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 743 | nn.init.uniform_(module.bias, -bound, bound) 744 | elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): 745 | nn.init.constant_(module.weight, 1) 746 | nn.init.constant_(module.bias, 0) 747 | 748 | 749 | class Emu3VisionVQModel(Emu3VisionVQPretrainedModel): 750 | 751 | def __init__(self, config): 752 | super().__init__(config) 753 | self.config = config 754 | 755 | self.encoder = Emu3VisionVQEncoder(config) 756 | self.decoder = Emu3VisionVQDecoder(config) 757 | self.quantize = Emu3VisionVQVectorQuantizer(config) 758 | 759 | self.quant_conv = Emu3VisionVQCausalConv3d(config.z_channels, config.embed_dim) 760 | self.post_quant_conv = Emu3VisionVQCausalConv3d(config.embed_dim, config.z_channels) 761 | 762 | self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1) 763 | 764 | self.post_init() 765 | 766 | def encode(self, x: torch.Tensor): 767 | ndim = x.ndim 768 | if ndim == 4: 769 | t = self.config.temporal_downsample_factor 770 | b, c, h, w = x.shape 771 | x = x.unsqueeze(1).repeat(1, t, 1, 1, 1) 772 | elif ndim == 5: 773 | b, t, c, h, w = x.shape 774 | 775 | h = self.encoder(x) 776 | 777 | # b t c h w -> b c t h w 778 | h = h.permute(0, 2, 1, 3, 4) 779 | h = self.quant_conv(h) 780 | # b c t h w -> b t c h w 781 | h = h.permute(0, 2, 1, 3, 4) 782 | 783 | codes = self.quantize(h) 784 | 785 | if ndim == 4: 786 | codes = codes.squeeze(1) 787 | 788 | return codes 789 | 790 | def decode(self, x: torch.Tensor): 791 | ndim = x.ndim 792 | if ndim == 3: 793 | x = x.unsqueeze(1) 794 | 795 | b, t, h, w = x.shape 796 | quant = self.quantize.embedding(x.flatten()) 797 | c = quant.shape[-1] 798 | quant = quant.view(b, t, h, w, c).permute(0, 4, 1, 2, 3).contiguous() 799 | quant2 = self.post_quant_conv(quant) 800 | 801 | quant = quant.permute(0, 2, 1, 3, 4) 802 | quant2 = quant2.permute(0, 2, 1, 3, 4) 803 | 804 | video = self.decoder(quant2, quant) 805 | video = video.reshape( 806 | b, 807 | t * self.config.temporal_downsample_factor, 808 | self.config.out_channels, 809 | h * self.spatial_scale_factor, 810 | w * self.spatial_scale_factor, 811 | ) 812 | if ndim == 3: 813 | return video[:, 0] 814 | return video 815 | 816 | @property 817 | def device(self): 818 | return next(self.parameters()).device 819 | 820 | @property 821 | def dtype(self): 822 | return next(self.parameters()).dtype 823 | -------------------------------------------------------------------------------- /emu3/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/emu3/train/__init__.py -------------------------------------------------------------------------------- /emu3/train/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import os.path as osp 5 | import random 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class Emu3FeatureDataset(Dataset): 12 | 13 | def __init__(self, args: "DataArguments", tokenizer: "Emu3Tokenizer"): 14 | super().__init__() 15 | 16 | self.args = args 17 | with open(args.data_path) as f: 18 | d = json.load(f) 19 | 20 | self.path_prefix = d["prefix"] 21 | self.filelist = d["path_list"] 22 | 23 | self.tokenizer = tokenizer 24 | self.bov = tokenizer.encode(args.visual_token_pattern.format(token_id=0))[0] 25 | self.eov = tokenizer.encode(args.visual_token_pattern.format(token_id=args.codebook_size - 1))[0] 26 | 27 | def __len__(self): 28 | return len(self.filelist) 29 | 30 | def __getitem__(self, index: int): 31 | path = osp.join(self.path_prefix, self.filelist[index]) 32 | data = torch.load(path) 33 | 34 | image_tokens = data["images"] 35 | image_prompt = self.format_image_prompt(image_tokens) 36 | 37 | p_prob = random.random() 38 | if p_prob < self.args.null_prompt_prob: 39 | prompt = "" 40 | else: 41 | prompt = data["texts"] 42 | 43 | input = self.tokenizer.bos_token + prompt + image_prompt 44 | sample = self.tokenizer( 45 | input, 46 | padding="max_length", 47 | return_token_type_ids=False, 48 | return_tensors="pt", 49 | ) 50 | 51 | labels = sample["input_ids"] 52 | if self.args.apply_loss_on_only_vision: 53 | labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, self.args.ignore_index) 54 | 55 | sample["labels"] = labels 56 | for k, v in sample.items(): 57 | sample[k] = v.squeeze(0) 58 | 59 | return sample 60 | 61 | def format_image_prompt(self, image_tokens): 62 | h, w = image_tokens.shape 63 | imgstr = self.to_imgstr(image_tokens) 64 | 65 | image_prompt = ( 66 | self.tokenizer.boi_token + 67 | f"{h}*{w}" + 68 | self.tokenizer.img_token + 69 | imgstr + 70 | self.tokenizer.eol_token + 71 | self.tokenizer.eof_token + 72 | self.tokenizer.eoi_token 73 | ) 74 | 75 | return image_prompt 76 | 77 | def to_imgstr(self, image_tokens): 78 | image_token_str = [ 79 | [ 80 | self.args.visual_token_pattern.format(token_id=token_id) 81 | for token_id in token_row 82 | ] 83 | for token_row in image_tokens 84 | ] 85 | image_row_str = ["".join(token_row) for token_row in image_token_str] 86 | imgstr = self.tokenizer.eol_token.join(image_row_str) 87 | return imgstr 88 | 89 | -------------------------------------------------------------------------------- /emu3/train/prepare_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import json 5 | import os 6 | 7 | from PIL import Image 8 | import torch 9 | 10 | from emu3.tokenizer import Emu3VisionVQModel, Emu3VisionVQImageProcessor 11 | 12 | 13 | def prepare_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--model-path', type=str, help='vision tokenizer path') 16 | parser.add_argument('--data-path', type=str, help='data path') 17 | parser.add_argument('--output-path', type=str, help='tokenized data save path') 18 | parser.add_argument('--image-area', type=int, default=720 * 720) 19 | 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def smart_resize(image, image_area: int = 720 * 720): 25 | w, h = image.size 26 | current_area = h * w 27 | target_ratio = (image_area / current_area) ** 0.5 28 | 29 | th = int(round(h * target_ratio)) 30 | tw = int(round(w * target_ratio)) 31 | 32 | image = image.resize((tw, th)) 33 | return image 34 | 35 | 36 | def main(): 37 | args = prepare_args() 38 | 39 | image_processor = Emu3VisionVQImageProcessor.from_pretrained(args.model_path) 40 | image_tokenizer = Emu3VisionVQModel.from_pretrained(args.model_path, device_map="cuda:0") 41 | image_tokenizer.eval() 42 | 43 | os.makedirs(f"{args.output_path}/feature", exist_ok=True) 44 | os.makedirs(f"{args.output_path}/list", exist_ok=True) 45 | 46 | datalist = { 47 | "prefix": f"{args.output_path}/feature", 48 | "path_list": [] 49 | } 50 | 51 | with open(args.data_path) as f: 52 | input_data = json.load(f) 53 | 54 | for inp in input_data: 55 | name = inp["name"] 56 | prompt = inp["text"] 57 | 58 | image = Image.open(inp["image"]).convert("RGB") 59 | image = smart_resize(image, args.image_area) 60 | 61 | image = image_processor(image, return_tensors="pt")["pixel_values"] 62 | with torch.no_grad(): 63 | image = image.cuda() 64 | token_ids = image_tokenizer.encode(image) 65 | 66 | token_ids = token_ids.squeeze(0).cpu().numpy() 67 | data = { 68 | "name": name, 69 | "images": token_ids, 70 | "texts": prompt 71 | } 72 | 73 | torch.save(data, f"{args.output_path}/feature/{name}.pth") 74 | datalist["path_list"].append(f"{name}.pth") 75 | 76 | with open(f"{args.output_path}/list/train.json", 'w') as f: 77 | json.dump(datalist, f) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /emu3/train/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from dataclasses import dataclass, field 4 | import os 5 | import os.path as osp 6 | import pathlib 7 | from typing import Optional, List 8 | 9 | import transformers as tf 10 | import torch 11 | 12 | from emu3.mllm import Emu3Config, Emu3Tokenizer, Emu3ForCausalLM 13 | from emu3.train.datasets import Emu3FeatureDataset 14 | 15 | 16 | @dataclass 17 | class ModelArguments: 18 | model_name_or_path: Optional[str] = field(default="BAAI/Emu3-Gen") 19 | 20 | 21 | @dataclass 22 | class DataArguments: 23 | data_path: Optional[str] = field(default=None) 24 | null_prompt_prob: float = field(default=0.05) 25 | apply_loss_on_only_vision: bool = field(default=True) 26 | apply_loss_on_only_text: bool = field(default=False) 27 | ignore_index: int = field(default=-100) 28 | visual_token_pattern: str = field(default="<|visual token {token_id:0>6d}|>") 29 | codebook_size: Optional[int] = field(default=32768) 30 | 31 | 32 | @dataclass 33 | class TrainingArguments(tf.TrainingArguments): 34 | report_to: List[str] = field(default_factory=list) 35 | remove_unused_columns: bool = field(default=False) 36 | min_learning_rate: Optional[float] = field(default=None) 37 | attn_type: Optional[str] = field(default="fa2") 38 | image_area: Optional[int] = field(default=None) 39 | max_position_embeddings: Optional[int] = field(default=None) 40 | 41 | 42 | def update_configs(model_config, args, fields): 43 | cross_update = lambda a, b, field_name: ( 44 | setattr(b, field_name, getattr(a, field_name)) 45 | if getattr(b, field_name, None) is None else 46 | setattr(a, field_name, getattr(b, field_name)) 47 | ) 48 | 49 | for f in fields: 50 | cross_update(model_config, args, f) 51 | 52 | 53 | def train(): 54 | parser = tf.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 55 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 56 | 57 | model_config = Emu3Config.from_pretrained(model_args.model_name_or_path) 58 | update_configs(model_config, training_args, ["image_area", "max_position_embeddings"]) 59 | if training_args.min_learning_rate is not None: 60 | training_args.lr_scheduler_kwargs["min_lr"] = training_args.min_learning_rate 61 | 62 | os.environ["WANDB_DIR"] = osp.join(training_args.output_dir, "wandb") 63 | 64 | model = Emu3ForCausalLM.from_pretrained( 65 | model_args.model_name_or_path, 66 | config=model_config, 67 | attn_implementation="flash_attention_2" if training_args.attn_type == "fa2" else None, 68 | torch_dtype=torch.bfloat16 if training_args.bf16 else None, 69 | ) 70 | 71 | tokenizer = Emu3Tokenizer.from_pretrained( 72 | model_args.model_name_or_path, 73 | model_max_length=training_args.max_position_embeddings, 74 | padding_side="right", 75 | use_fast=False, 76 | ) 77 | 78 | train_dataset = Emu3FeatureDataset(data_args, tokenizer=tokenizer) 79 | 80 | trainer = tf.Trainer( 81 | model=model, 82 | args=training_args, 83 | train_dataset=train_dataset, 84 | ) 85 | 86 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 87 | trainer.train(resume_from_checkpoint=True) 88 | else: 89 | trainer.train() 90 | trainer.save_state() 91 | 92 | torch.cuda.synchronize() 93 | trainer.save_model(training_args.output_dir) 94 | 95 | 96 | if __name__ == "__main__": 97 | train() 98 | -------------------------------------------------------------------------------- /gradio_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import base64 4 | import io 5 | from PIL import Image 6 | 7 | import gradio as gr 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModelForCausalLM, 11 | AutoImageProcessor, 12 | AutoModel, 13 | ) 14 | from transformers.generation.configuration_utils import GenerationConfig 15 | from transformers.generation import ( 16 | LogitsProcessorList, 17 | PrefixConstrainedLogitsProcessor, 18 | UnbatchedClassifierFreeGuidanceLogitsProcessor, 19 | ) 20 | import torch 21 | 22 | from emu3.mllm.processing_emu3 import Emu3Processor 23 | 24 | def image2str(image): 25 | buf = io.BytesIO() 26 | image.save(buf, format="PNG") 27 | i_str = base64.b64encode(buf.getvalue()).decode() 28 | return f'
' 29 | 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | 32 | # Model paths 33 | EMU_GEN_HUB = "BAAI/Emu3-Gen" 34 | EMU_CHAT_HUB = "BAAI/Emu3-Chat" 35 | VQ_HUB = "BAAI/Emu3-VisionTokenizer" 36 | 37 | # Prepare models and processors 38 | gen_model = AutoModelForCausalLM.from_pretrained( 39 | EMU_GEN_HUB, 40 | device_map="cpu", 41 | torch_dtype=torch.bfloat16, 42 | attn_implementation="flash_attention_2", 43 | trust_remote_code=True, 44 | ).eval() 45 | 46 | chat_model = AutoModelForCausalLM.from_pretrained( 47 | EMU_CHAT_HUB, 48 | device_map="cpu", 49 | torch_dtype=torch.bfloat16, 50 | attn_implementation="flash_attention_2", 51 | trust_remote_code=True, 52 | ).eval() 53 | 54 | tokenizer = AutoTokenizer.from_pretrained( 55 | EMU_CHAT_HUB, trust_remote_code=True, padding_side="left", 56 | ) 57 | image_processor = AutoImageProcessor.from_pretrained( 58 | VQ_HUB, trust_remote_code=True, 59 | ) 60 | image_tokenizer = AutoModel.from_pretrained( 61 | VQ_HUB, device_map="cpu", trust_remote_code=True, 62 | ).eval() 63 | 64 | image_tokenizer.to(device) 65 | 66 | processor = Emu3Processor( 67 | image_processor, image_tokenizer, tokenizer 68 | ) 69 | 70 | def generate_image(prompt): 71 | POSITIVE_PROMPT = " masterpiece, film grained, best quality." 72 | NEGATIVE_PROMPT = ( 73 | "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, " 74 | "fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, " 75 | "signature, watermark, username, blurry." 76 | ) 77 | 78 | classifier_free_guidance = 3.0 79 | full_prompt = prompt + POSITIVE_PROMPT 80 | 81 | kwargs = dict( 82 | mode="G", 83 | ratio="1:1", 84 | image_area=gen_model.config.image_area, 85 | return_tensors="pt", 86 | ) 87 | pos_inputs = processor(text=full_prompt, **kwargs) 88 | neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs) 89 | 90 | # Prepare hyperparameters 91 | GENERATION_CONFIG = GenerationConfig( 92 | use_cache=True, 93 | eos_token_id=gen_model.config.eos_token_id, 94 | pad_token_id=gen_model.config.pad_token_id, 95 | max_new_tokens=40960, 96 | do_sample=True, 97 | top_k=2048, 98 | ) 99 | 100 | torch.cuda.empty_cache() 101 | gen_model.to(device) 102 | 103 | h = pos_inputs.image_size[:, 0] 104 | w = pos_inputs.image_size[:, 1] 105 | constrained_fn = processor.build_prefix_constrained_fn(h, w) 106 | logits_processor = LogitsProcessorList([ 107 | UnbatchedClassifierFreeGuidanceLogitsProcessor( 108 | classifier_free_guidance, 109 | gen_model, 110 | unconditional_ids=neg_inputs.input_ids.to(device), 111 | ), 112 | PrefixConstrainedLogitsProcessor( 113 | constrained_fn, 114 | num_beams=1, 115 | ), 116 | ]) 117 | 118 | # Generate 119 | outputs = gen_model.generate( 120 | pos_inputs.input_ids.to(device), 121 | generation_config=GENERATION_CONFIG, 122 | logits_processor=logits_processor, 123 | attention_mask=pos_inputs.attention_mask.to(device), 124 | ) 125 | 126 | mm_list = processor.decode(outputs[0]) 127 | result = None 128 | for idx, im in enumerate(mm_list): 129 | if isinstance(im, Image.Image): 130 | result = im 131 | break 132 | 133 | gen_model.cpu() 134 | torch.cuda.empty_cache() 135 | 136 | return result 137 | 138 | def vision_language_understanding(image, text): 139 | inputs = processor( 140 | text=text, 141 | image=image, 142 | mode="U", 143 | padding="longest", 144 | return_tensors="pt", 145 | ) 146 | 147 | # Prepare hyperparameters 148 | GENERATION_CONFIG = GenerationConfig( 149 | pad_token_id=tokenizer.pad_token_id, 150 | bos_token_id=tokenizer.bos_token_id, 151 | eos_token_id=tokenizer.eos_token_id, 152 | max_new_tokens=1024, 153 | ) 154 | 155 | torch.cuda.empty_cache() 156 | chat_model.to(device) 157 | 158 | # Generate 159 | outputs = chat_model.generate( 160 | inputs.input_ids.to(device), 161 | generation_config=GENERATION_CONFIG, 162 | attention_mask=inputs.attention_mask.to(device), 163 | ) 164 | 165 | outputs = outputs[:, inputs.input_ids.shape[-1] :] 166 | response = processor.batch_decode(outputs, skip_special_tokens=True)[0] 167 | 168 | chat_model.cpu() 169 | torch.cuda.empty_cache() 170 | 171 | return response 172 | 173 | 174 | def chat(history, user_input, user_image): 175 | if user_image is not None: 176 | # Use Emu3-Chat for vision-language understanding 177 | response = vision_language_understanding(user_image, user_input) 178 | # Append the user input and response to the history 179 | history = history + [(image2str(user_image) + "
" + user_input, response)] 180 | else: 181 | # Use Emu3-Gen for image generation 182 | generated_image = generate_image(user_input) 183 | if generated_image is not None: 184 | # Append the user input and generated image to the history 185 | history = history + [(user_input, image2str(generated_image))] 186 | else: 187 | # If image generation failed, respond with an error message 188 | history = history + [ 189 | (user_input, "Sorry, I could not generate an image.") 190 | ] 191 | 192 | return history, history, gr.update(value=None) 193 | 194 | 195 | def clear_input(): 196 | return gr.update(value="") 197 | 198 | 199 | with gr.Blocks() as demo: 200 | gr.Markdown("# Emu3 Chatbot Demo") 201 | gr.Markdown( 202 | "This is a chatbot demo for image generation and vision-language understanding using Emu3 models." 203 | ) 204 | gr.Markdown( 205 | "Please provide only text input for image generation (\~600s) and both image and text for vision-language understanding (\~20s)" 206 | ) 207 | 208 | state = gr.State([]) 209 | with gr.Row(): 210 | with gr.Column(scale=0.2): 211 | user_input = gr.Textbox( 212 | show_label=False, placeholder="Type your message here...", lines=10, container=False, 213 | ) 214 | user_image = gr.Image( 215 | sources="upload", type="pil", label="Upload an image (optional)" 216 | ) 217 | submit_btn = gr.Button("Send") 218 | 219 | with gr.Column(scale=0.8): 220 | chatbot = gr.Chatbot(height=800) 221 | 222 | submit_btn.click( 223 | chat, 224 | inputs=[state, user_input, user_image], 225 | outputs=[chatbot, state, user_image], 226 | ).then(fn=clear_input, inputs=[], outputs=user_input, queue=False) 227 | user_input.submit( 228 | chat, 229 | inputs=[state, user_input, user_image], 230 | outputs=[chatbot, state, user_image], 231 | ).then(fn=clear_input, inputs=[], outputs=user_input, queue=False) 232 | 233 | demo.launch(max_threads=1).queue() 234 | -------------------------------------------------------------------------------- /image_generation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from PIL import Image 3 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM 4 | from transformers.generation.configuration_utils import GenerationConfig 5 | from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor 6 | import torch 7 | 8 | from emu3.mllm.processing_emu3 import Emu3Processor 9 | 10 | 11 | # model path 12 | EMU_HUB = "BAAI/Emu3-Gen" 13 | VQ_HUB = "BAAI/Emu3-VisionTokenizer" 14 | 15 | # prepare model and processor 16 | model = AutoModelForCausalLM.from_pretrained( 17 | EMU_HUB, 18 | device_map="cuda:0", 19 | torch_dtype=torch.bfloat16, 20 | attn_implementation="flash_attention_2", 21 | trust_remote_code=True, 22 | ) 23 | model.eval() 24 | 25 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left") 26 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True) 27 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval() 28 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) 29 | 30 | # prepare input 31 | POSITIVE_PROMPT = " masterpiece, film grained, best quality." 32 | NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." 33 | 34 | classifier_free_guidance = 3.0 35 | prompt = ["a portrait of young girl.", "a shiba inu"] 36 | prompt = [p + POSITIVE_PROMPT for p in prompt] 37 | 38 | kwargs = dict( 39 | mode='G', 40 | ratio=["1:1", "16:9"], 41 | image_area=model.config.image_area, 42 | return_tensors="pt", 43 | padding="longest", 44 | ) 45 | pos_inputs = processor(text=prompt, **kwargs) 46 | neg_inputs = processor(text=[NEGATIVE_PROMPT] * len(prompt), **kwargs) 47 | 48 | # prepare hyper parameters 49 | GENERATION_CONFIG = GenerationConfig( 50 | use_cache=True, 51 | eos_token_id=model.config.eos_token_id, 52 | pad_token_id=model.config.pad_token_id, 53 | max_new_tokens=40960, 54 | do_sample=True, 55 | top_k=2048, 56 | ) 57 | 58 | h = pos_inputs.image_size[:, 0] 59 | w = pos_inputs.image_size[:, 1] 60 | constrained_fn = processor.build_prefix_constrained_fn(h, w) 61 | logits_processor = LogitsProcessorList([ 62 | UnbatchedClassifierFreeGuidanceLogitsProcessor( 63 | classifier_free_guidance, 64 | model, 65 | unconditional_ids=neg_inputs.input_ids.to("cuda:0"), 66 | ), 67 | PrefixConstrainedLogitsProcessor( 68 | constrained_fn , 69 | num_beams=1, 70 | ), 71 | ]) 72 | 73 | # generate 74 | outputs = model.generate( 75 | pos_inputs.input_ids.to("cuda:0"), 76 | GENERATION_CONFIG, 77 | logits_processor=logits_processor, 78 | attention_mask=pos_inputs.attention_mask.to("cuda:0"), 79 | ) 80 | 81 | for idx_i, out in enumerate(outputs): 82 | mm_list = processor.decode(out) 83 | for idx_j, im in enumerate(mm_list): 84 | if not isinstance(im, Image.Image): 85 | continue 86 | im.save(f"result_{idx_i}_{idx_j}.png") 87 | -------------------------------------------------------------------------------- /multimodal_understanding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from PIL import Image 3 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM 4 | from transformers.generation.configuration_utils import GenerationConfig 5 | import torch 6 | 7 | from emu3.mllm.processing_emu3 import Emu3Processor 8 | 9 | 10 | # model path 11 | EMU_HUB = "BAAI/Emu3-Chat" 12 | VQ_HUB = "BAAI/Emu3-VisionTokenizer" 13 | 14 | # prepare model and processor 15 | model = AutoModelForCausalLM.from_pretrained( 16 | EMU_HUB, 17 | device_map="cuda:0", 18 | torch_dtype=torch.bfloat16, 19 | attn_implementation="flash_attention_2", 20 | trust_remote_code=True, 21 | ) 22 | model.eval() 23 | 24 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left") 25 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True) 26 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval() 27 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) 28 | 29 | # prepare input 30 | text = ["Please describe the image", "Please describe the image"] 31 | image = Image.open("assets/demo.png") 32 | image = [image, image] 33 | 34 | inputs = processor( 35 | text=text, 36 | image=image, 37 | mode='U', 38 | padding_image=True, 39 | padding="longest", 40 | return_tensors="pt", 41 | ) 42 | 43 | # prepare hyper parameters 44 | GENERATION_CONFIG = GenerationConfig(pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id) 45 | 46 | # generate 47 | outputs = model.generate( 48 | inputs.input_ids.to("cuda:0"), 49 | GENERATION_CONFIG, 50 | max_new_tokens=1024, 51 | attention_mask=inputs.attention_mask.to("cuda:0"), 52 | ) 53 | 54 | outputs = outputs[:, inputs.input_ids.shape[-1]:] 55 | answers = processor.batch_decode(outputs, skip_special_tokens=True) 56 | for ans in answers: 57 | print(ans) 58 | -------------------------------------------------------------------------------- /replicate_demo/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | 8 | # a list of ubuntu apt packages to install 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.11" 15 | 16 | # a list of packages in the format == 17 | python_packages: 18 | # - packaging 19 | - torch==2.2.1 20 | - transformers==4.44.0 21 | - tiktoken==0.6.0 22 | - accelerate 23 | - numpy<2 24 | run: 25 | - pip install flash-attn==2.5.7 26 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 27 | 28 | # predict.py defines how predictions are run on your model 29 | predict: "predict_chat.py:Predictor" 30 | # predict: "predict_gen.py:Predictor" 31 | -------------------------------------------------------------------------------- /replicate_demo/predict_chat.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://cog.run/python 3 | 4 | import os 5 | import time 6 | import subprocess 7 | from PIL import Image 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModel, 11 | AutoImageProcessor, 12 | AutoModelForCausalLM, 13 | ) 14 | from transformers.generation.configuration_utils import GenerationConfig 15 | import torch 16 | from cog import BasePredictor, Input, Path 17 | 18 | from emu3.mllm.processing_emu3 import Emu3Processor 19 | 20 | 21 | MODEL_CACHE = "model_cache_chat" 22 | MODEL_URL = ( 23 | f"https://weights.replicate.delivery/default/baaivision/Emu3/{MODEL_CACHE}.tar" 24 | ) 25 | os.environ["HF_DATASETS_OFFLINE"] = "1" 26 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 27 | os.environ["HF_HOME"] = MODEL_CACHE 28 | os.environ["TORCH_HOME"] = MODEL_CACHE 29 | os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE 30 | os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE 31 | os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE 32 | 33 | TORCH_TYPE = torch.bfloat16 34 | DEVICE = "cuda:0" 35 | 36 | 37 | def download_weights(url, dest): 38 | start = time.time() 39 | print("downloading url: ", url) 40 | print("downloading to: ", dest) 41 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 42 | print("downloading took: ", time.time() - start) 43 | 44 | 45 | class Predictor(BasePredictor): 46 | def setup(self) -> None: 47 | """Load the model into memory to make running multiple predictions efficient""" 48 | 49 | if not os.path.exists(MODEL_CACHE): 50 | download_weights(MODEL_URL, MODEL_CACHE) 51 | 52 | # prepare model and processor 53 | self.model = AutoModelForCausalLM.from_pretrained( 54 | f"{MODEL_CACHE}/Emu3-Chat", # "BAAI/Emu3-Chat" 55 | device_map="cuda:0", 56 | torch_dtype=torch.bfloat16, 57 | attn_implementation="flash_attention_2", 58 | trust_remote_code=True, 59 | ) 60 | 61 | tokenizer = AutoTokenizer.from_pretrained( 62 | f"{MODEL_CACHE}/Emu3-Chat", trust_remote_code=True 63 | ) # "BAAI/Emu3-Chat" 64 | image_processor = AutoImageProcessor.from_pretrained( 65 | f"{MODEL_CACHE}/Emu3-VisionTokenizer", trust_remote_code=True 66 | ) # "BAAI/Emu3-VisionTokenizer" 67 | image_tokenizer = AutoModel.from_pretrained( 68 | f"{MODEL_CACHE}/Emu3-VisionTokenizer", 69 | device_map="cuda:0", 70 | trust_remote_code=True, 71 | ).eval() # "BAAI/Emu3-VisionTokenizer" 72 | self.processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) 73 | # prepare hyper parameters 74 | self.generation_config = GenerationConfig( 75 | pad_token_id=tokenizer.pad_token_id, 76 | bos_token_id=tokenizer.bos_token_id, 77 | eos_token_id=tokenizer.eos_token_id, 78 | ) 79 | 80 | def predict( 81 | self, 82 | text: str = Input( 83 | description="Input prompt", 84 | default="Please describe the image.", 85 | ), 86 | image: Path = Input( 87 | default="Input image", 88 | ), 89 | temperature: float = Input( 90 | description="Controls randomness. Lower values make the model more deterministic, higher values make it more random.", 91 | default=0.7, 92 | ge=0.0, 93 | le=1.0, 94 | ), 95 | top_p: float = Input( 96 | description="Controls diversity of the output. Valid when temperature > 0. Lower values make the output more focused, higher values make it more diverse.", 97 | default=0.9, 98 | ge=0.0, 99 | le=1.0, 100 | ), 101 | max_new_tokens: int = Input( 102 | description="Maximum number of tokens to generate", default=256, ge=1 103 | ), 104 | ) -> str: 105 | """Run a single prediction on the model""" 106 | 107 | img = Image.open(str(image)) 108 | 109 | inputs = self.processor( 110 | text=text, 111 | image=img, 112 | mode="U", 113 | padding_side="left", 114 | padding="longest", 115 | return_tensors="pt", 116 | ) 117 | 118 | outputs = self.model.generate( 119 | inputs.input_ids.to("cuda:0"), 120 | self.generation_config, 121 | max_new_tokens=max_new_tokens, 122 | temperature=temperature, 123 | top_p=top_p, 124 | ) 125 | 126 | outputs = outputs[:, inputs.input_ids.shape[-1] :] 127 | return self.processor.batch_decode(outputs, skip_special_tokens=True)[0] 128 | -------------------------------------------------------------------------------- /replicate_demo/predict_gen.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://cog.run/python 3 | 4 | import os 5 | import time 6 | import subprocess 7 | from PIL import Image 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModel, 11 | AutoImageProcessor, 12 | AutoModelForCausalLM, 13 | ) 14 | from transformers.generation.configuration_utils import GenerationConfig 15 | from transformers.generation import ( 16 | LogitsProcessorList, 17 | PrefixConstrainedLogitsProcessor, 18 | UnbatchedClassifierFreeGuidanceLogitsProcessor, 19 | ) 20 | import torch 21 | from cog import BasePredictor, Input, Path 22 | 23 | from emu3.mllm.processing_emu3 import Emu3Processor 24 | 25 | 26 | MODEL_CACHE = "model_cache" 27 | MODEL_URL = ( 28 | f"https://weights.replicate.delivery/default/baaivision/Emu3/{MODEL_CACHE}.tar" 29 | ) 30 | os.environ["HF_DATASETS_OFFLINE"] = "1" 31 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 32 | os.environ["HF_HOME"] = MODEL_CACHE 33 | os.environ["TORCH_HOME"] = MODEL_CACHE 34 | os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE 35 | os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE 36 | os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE 37 | 38 | TORCH_TYPE = torch.bfloat16 39 | DEVICE = "cuda:0" 40 | 41 | 42 | def download_weights(url, dest): 43 | start = time.time() 44 | print("downloading url: ", url) 45 | print("downloading to: ", dest) 46 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 47 | print("downloading took: ", time.time() - start) 48 | 49 | 50 | class Predictor(BasePredictor): 51 | def setup(self) -> None: 52 | """Load the model into memory to make running multiple predictions efficient""" 53 | 54 | if not os.path.exists(MODEL_CACHE): 55 | download_weights(MODEL_URL, MODEL_CACHE) 56 | 57 | # prepare model and processor 58 | self.model = AutoModelForCausalLM.from_pretrained( 59 | f"{MODEL_CACHE}/Emu3-Gen", # "BAAI/Emu3-Gen" 60 | device_map="cuda:0", 61 | torch_dtype=torch.bfloat16, 62 | attn_implementation="flash_attention_2", 63 | trust_remote_code=True, 64 | ) 65 | 66 | tokenizer = AutoTokenizer.from_pretrained( 67 | f"{MODEL_CACHE}/Emu3-Gen", trust_remote_code=True 68 | ) # "BAAI/Emu3-Gen" 69 | image_processor = AutoImageProcessor.from_pretrained( 70 | f"{MODEL_CACHE}/Emu3-VisionTokenizer", trust_remote_code=True 71 | ) # "BAAI/Emu3-VisionTokenizer" 72 | image_tokenizer = AutoModel.from_pretrained( 73 | f"{MODEL_CACHE}/Emu3-VisionTokenizer", 74 | device_map="cuda:0", 75 | trust_remote_code=True, 76 | ).eval() # "BAAI/Emu3-VisionTokenizer" 77 | self.processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) 78 | 79 | self.kwargs = dict( 80 | mode="G", 81 | ratio="1:1", 82 | image_area=self.model.config.image_area, 83 | return_tensors="pt", 84 | ) 85 | 86 | # prepare hyper parameters 87 | self.generation_config = GenerationConfig( 88 | use_cache=True, 89 | eos_token_id=self.model.config.eos_token_id, 90 | pad_token_id=self.model.config.pad_token_id, 91 | max_new_tokens=40960, 92 | do_sample=True, 93 | top_k=2048, 94 | ) 95 | 96 | def predict( 97 | self, 98 | prompt: str = Input( 99 | description="Input prompt", 100 | default="a portrait of young girl.", 101 | ), 102 | positive_prompt: str = Input( 103 | default="masterpiece, film grained, best quality.", 104 | ), 105 | negative_prompt: str = Input( 106 | description="Specify things to not see in the output", 107 | default="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.", 108 | ), 109 | guidance_scale: float = Input( 110 | description="Scale for classifier-free guidance", ge=1, le=20, default=3 111 | ), 112 | ) -> Path: 113 | """Run a single prediction on the model""" 114 | 115 | pos_inputs = self.processor( 116 | text=prompt + " " + positive_prompt.strip(), **self.kwargs 117 | ) 118 | neg_inputs = self.processor(text=negative_prompt, **self.kwargs) 119 | 120 | h, w = pos_inputs.image_size[0] 121 | constrained_fn = self.processor.build_prefix_constrained_fn(h, w) 122 | logits_processor = LogitsProcessorList( 123 | [ 124 | UnbatchedClassifierFreeGuidanceLogitsProcessor( 125 | guidance_scale, 126 | self.model, 127 | unconditional_ids=neg_inputs.input_ids.to("cuda:0"), 128 | ), 129 | PrefixConstrainedLogitsProcessor( 130 | constrained_fn, 131 | num_beams=1, 132 | ), 133 | ] 134 | ) 135 | 136 | # generate 137 | outputs = self.model.generate( 138 | pos_inputs.input_ids.to("cuda:0"), 139 | self.generation_config, 140 | logits_processor=logits_processor, 141 | ) 142 | 143 | out_path = "/tmp/out.png" 144 | 145 | mm_list = self.processor.decode(outputs[0]) 146 | print(len(mm_list)) 147 | print(mm_list) 148 | for idx, im in enumerate(mm_list): 149 | if not isinstance(im, Image.Image): 150 | continue 151 | im.save(out_path) 152 | return Path(out_path) 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.1 2 | transformers==4.44.0 3 | tiktoken==0.6.0 4 | flash-attn==2.5.7 5 | pillow 6 | gradio==4.44.0 7 | -------------------------------------------------------------------------------- /scripts/t2i_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WORLD_SIZE=${WORLD_SIZE:-1} 4 | RANK=${RANK:-0} 5 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 6 | MASTER_PORT=${MASTER_PORT:-23456} 7 | NGPUS=$(python -c "import torch; print(torch.cuda.device_count())") 8 | 9 | export PYTHONPATH=$(pwd) 10 | 11 | DATAPATH="your data path (json file)" 12 | EXP_NAME="Emu3-T2I-SFT-Trial" 13 | torchrun \ 14 | --nproc_per_node=${NGPUS} \ 15 | --nnodes=${WORLD_SIZE} \ 16 | --node_rank=${RANK} \ 17 | --master_addr=${MASTER_ADDR} \ 18 | --master_port=${MASTER_PORT} \ 19 | emu3/train/train.py \ 20 | --model_name_or_path BAAI/Emu3-Gen \ 21 | --deepspeed scripts/zero3.json \ 22 | --data_path ${DATAPATH} \ 23 | --null_prompt_prob 0.05 \ 24 | --apply_loss_on_only_vision True \ 25 | --apply_loss_on_only_text False \ 26 | --image_area 518400 \ 27 | --max_position_embeddings 10240 \ 28 | --output_dir "logs/"${EXP_NAME} \ 29 | --bf16 True \ 30 | --tf32 True \ 31 | --num_train_epochs 4 \ 32 | --per_device_train_batch_size 2 \ 33 | --gradient_accumulation_steps 4 \ 34 | --eval_strategy no \ 35 | --save_strategy steps \ 36 | --save_steps 500 \ 37 | --save_total_limit 10 \ 38 | --learning_rate 1e-5 \ 39 | --min_learning_rate 1e-6 \ 40 | --weight_decay 0.1 \ 41 | --max_grad_norm 5.0 \ 42 | --adam_beta1 0.9 \ 43 | --adam_beta2 0.95 \ 44 | --adam_epsilon 1e-6 \ 45 | --warmup_steps 30 \ 46 | --lr_scheduler_type "cosine_with_min_lr" \ 47 | --logging_steps 1 \ 48 | --gradient_checkpointing True \ 49 | --dataloader_num_workers 4 \ 50 | --report_to wandb tensorboard \ 51 | --run_name ${EXP_NAME} 52 | -------------------------------------------------------------------------------- /scripts/t2i_sft_offload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WORLD_SIZE=${WORLD_SIZE:-1} 4 | RANK=${RANK:-0} 5 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 6 | MASTER_PORT=${MASTER_PORT:-23456} 7 | NGPUS=$(python -c "import torch; print(torch.cuda.device_count())") 8 | 9 | export PYTHONPATH=$(pwd) 10 | 11 | DATAPATH="your data path" 12 | EXP_NAME="Emu3-T2I-SFT-Trial" 13 | torchrun \ 14 | --nproc_per_node=${NGPUS} \ 15 | --nnodes=${WORLD_SIZE} \ 16 | --node_rank=${RANK} \ 17 | --master_addr=${MASTER_ADDR} \ 18 | --master_port=${MASTER_PORT} \ 19 | emu3/train/train.py \ 20 | --model_name_or_path BAAI/Emu3-Gen \ 21 | --deepspeed scripts/zero3_offload.json \ 22 | --data_path ${DATAPATH} \ 23 | --null_prompt_prob 0.05 \ 24 | --apply_loss_on_only_vision True \ 25 | --apply_loss_on_only_text False \ 26 | --image_area 518400 \ 27 | --max_position_embeddings 10240 \ 28 | --output_dir "logs/"${EXP_NAME} \ 29 | --bf16 True \ 30 | --tf32 True \ 31 | --num_train_epochs 4 \ 32 | --per_device_train_batch_size 2 \ 33 | --gradient_accumulation_steps 4 \ 34 | --eval_strategy no \ 35 | --save_strategy steps \ 36 | --save_steps 500 \ 37 | --save_total_limit 10 \ 38 | --learning_rate 1e-5 \ 39 | --min_learning_rate 1e-6 \ 40 | --weight_decay 0.1 \ 41 | --max_grad_norm 5.0 \ 42 | --adam_beta1 0.9 \ 43 | --adam_beta2 0.95 \ 44 | --adam_epsilon 1e-6 \ 45 | --warmup_steps 30 \ 46 | --lr_scheduler_type "cosine_with_min_lr" \ 47 | --logging_steps 1 \ 48 | --gradient_checkpointing True \ 49 | --dataloader_num_workers 4 \ 50 | --report_to wandb tensorboard \ 51 | --run_name ${EXP_NAME} 52 | -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true, 27 | "offload_optimizer": { 28 | "device": "cpu", 29 | "pin_memory": true 30 | }, 31 | "offload_param": { 32 | "device": "cpu", 33 | "pin_memory": true 34 | } 35 | } 36 | } 37 | --------------------------------------------------------------------------------