├── .gitignore
├── LICENSE
├── README.md
├── assets
├── arch.png
├── comparison.png
└── demo.png
├── autoencode.py
├── emu3
├── mllm
│ ├── __init__.py
│ ├── configuration_emu3.py
│ ├── modeling_emu3.py
│ ├── processing_emu3.py
│ ├── tokenization_emu3.py
│ └── utils_emu3.py
├── tokenizer
│ ├── __init__.py
│ ├── configuration_emu3visionvq.py
│ ├── image_processing_emu3visionvq.py
│ └── modeling_emu3visionvq.py
└── train
│ ├── __init__.py
│ ├── datasets.py
│ ├── prepare_data.py
│ └── train.py
├── gradio_demo.py
├── image_generation.py
├── multimodal_understanding.py
├── replicate_demo
├── cog.yaml
├── predict_chat.py
└── predict_gen.py
├── requirements.txt
└── scripts
├── t2i_sft.sh
├── t2i_sft_offload.sh
├── zero3.json
└── zero3_offload.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Emu3: Next-Token Prediction is All You Need
3 |
4 |
5 | [Emu3 Team, BAAI](https://www.baai.ac.cn/english.html)
6 |
7 | | [Project Page](https://emu.baai.ac.cn) | [Paper](https://arxiv.org/pdf/2409.18869) | [🤗HF Models](https://huggingface.co/collections/BAAI/emu3-66f4e64f70850ff358a2e60f) | [Modelscope](https://modelscope.cn/collections/Emu3-9eacc8668b1043) | [Demo](https://huggingface.co/spaces/BAAI/Emu3) |
8 |
9 |
10 |
11 |
12 |
13 |

14 |
15 |
16 | We introduce **Emu3**, a new suite of state-of-the-art multimodal models trained solely with **next-token prediction**! By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences.
17 |
18 | ### Emu3 excels in both generation and perception
19 | **Emu3** outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship open models such as SDXL, LLaVA-1.6 and OpenSora-1.2, while eliminating the need for diffusion or compositional architectures.
20 |
21 |
22 |

23 |
24 |
25 | ### Highlights
26 |
27 | - **Emu3** is capable of generating high-quality images following the text input, by simply predicting the next vision token. The model naturally supports flexible resolutions and styles.
28 | - **Emu3** shows strong vision-language understanding capabilities to see the physical world and provides coherent text responses. Notably, this capability is achieved without depending on a CLIP and a pretrained LLM.
29 | - **Emu3** simply generates a video causally by predicting the next token in a video sequence, unlike the video diffusion model as in Sora. With a video in context, Emu3 can also naturally extend the video and predict what will happen next.
30 |
31 | ## News
32 | - 2024.10 We release the image pretrained model **[Emu3-Stage1](https://huggingface.co/BAAI/Emu3-Stage1)** and the sft scripts. The model supports image captioning and can generate images at a resolution of 512x512. You can use our training scripts for further instruction tuning for more image generation and perception tasks. 🔥🔥🔥
33 | - 2024.09 We relase **[Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)** and **[Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)** which are post training models separately for vision-language understanding and vision generation.
34 | - 2024.09 We introduce Emu3, a new suite of state-of-the-art multimodal models trained solely with next-token prediction.
35 |
36 |
37 | ### TODO
38 |
39 | - [X] Release model weights of tokenizer, Emu3-Chat and Emu3-Gen
40 | - [X] Release the inference code.
41 | - [ ] Release the evaluation code.
42 | - [X] Release training scripts for sft.
43 | - [ ] Release training scripts for pretrain and dpo.
44 |
45 |
46 | ### Setup
47 |
48 | Clone this repository and install required packages:
49 |
50 | ```shell
51 | git clone https://github.com/baaivision/Emu3
52 | cd Emu3
53 |
54 | pip install -r requirements.txt
55 | ```
56 |
57 | ### Model Weights
58 |
59 | | Model name | HF Weight | Modelscope | Wisemodel |
60 | | ------------------------ | -------------------------------------------------------------- | ------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
61 | | **Emu3-Stage1** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Stage1) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-Stage1) | |
62 | | **Emu3-Chat** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Chat) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-Chat) | [Wisemodel link](https://wisemodel.cn/models/BAAI/Emu3-Chat) |
63 | | **Emu3-Gen** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Gen) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-Gen) | [Wisemodel link](https://wisemodel.cn/models/BAAI/Emu3-Gen) |
64 | | **Emu3-VisionTokenizer** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-VisionTokenizer) | [Modelscope link](https://modelscope.cn/models/BAAI/Emu3-VisionTokenizer) | [Wisemodel link](https://wisemodel.cn/models/BAAI/Emu3-VisionTokenizer) |
65 |
66 | ### Quickstart
67 |
68 | #### Use 🤗Transformers to run Emu3-Gen/Stage1 for image generation
69 | ```python
70 | from PIL import Image
71 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
72 | from transformers.generation.configuration_utils import GenerationConfig
73 | from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
74 | import torch
75 |
76 | from emu3.mllm.processing_emu3 import Emu3Processor
77 |
78 |
79 | # model path
80 | EMU_HUB = "BAAI/Emu3-Gen"
81 | VQ_HUB = "BAAI/Emu3-VisionTokenizer"
82 |
83 | # prepare model and processor
84 | model = AutoModelForCausalLM.from_pretrained(
85 | EMU_HUB,
86 | device_map="cuda:0",
87 | torch_dtype=torch.bfloat16,
88 | attn_implementation="flash_attention_2",
89 | trust_remote_code=True,
90 | )
91 |
92 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left")
93 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
94 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
95 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
96 |
97 | # prepare input
98 | POSITIVE_PROMPT = " masterpiece, film grained, best quality."
99 | NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
100 |
101 | classifier_free_guidance = 3.0
102 | prompt = "a portrait of young girl."
103 | prompt += POSITIVE_PROMPT
104 |
105 | kwargs = dict(
106 | mode='G',
107 | ratio="1:1",
108 | image_area=model.config.image_area,
109 | return_tensors="pt",
110 | padding="longest",
111 | )
112 | pos_inputs = processor(text=prompt, **kwargs)
113 | neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
114 |
115 | # prepare hyper parameters
116 | GENERATION_CONFIG = GenerationConfig(
117 | use_cache=True,
118 | eos_token_id=model.config.eos_token_id,
119 | pad_token_id=model.config.pad_token_id,
120 | max_new_tokens=40960,
121 | do_sample=True,
122 | top_k=2048,
123 | )
124 |
125 | h = pos_inputs.image_size[:, 0]
126 | w = pos_inputs.image_size[:, 1]
127 | constrained_fn = processor.build_prefix_constrained_fn(h, w)
128 | logits_processor = LogitsProcessorList([
129 | UnbatchedClassifierFreeGuidanceLogitsProcessor(
130 | classifier_free_guidance,
131 | model,
132 | unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
133 | ),
134 | PrefixConstrainedLogitsProcessor(
135 | constrained_fn ,
136 | num_beams=1,
137 | ),
138 | ])
139 |
140 | # generate
141 | outputs = model.generate(
142 | pos_inputs.input_ids.to("cuda:0"),
143 | GENERATION_CONFIG,
144 | logits_processor=logits_processor,
145 | attention_mask=pos_inputs.attention_mask.to("cuda:0"),
146 | )
147 |
148 | mm_list = processor.decode(outputs[0])
149 | for idx, im in enumerate(mm_list):
150 | if not isinstance(im, Image.Image):
151 | continue
152 | im.save(f"result_{idx}.png")
153 | ```
154 |
155 | #### Use 🤗Transformers to run Emu3-Chat/Stage1 for vision-language understanding
156 |
157 | ```python
158 | from PIL import Image
159 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
160 | from transformers.generation.configuration_utils import GenerationConfig
161 | import torch
162 |
163 | from emu3.mllm.processing_emu3 import Emu3Processor
164 |
165 |
166 | # model path
167 | EMU_HUB = "BAAI/Emu3-Chat"
168 | VQ_HUB = "BAAI/Emu3-VisionTokenizer"
169 |
170 | # prepare model and processor
171 | model = AutoModelForCausalLM.from_pretrained(
172 | EMU_HUB,
173 | device_map="cuda:0",
174 | torch_dtype=torch.bfloat16,
175 | attn_implementation="flash_attention_2",
176 | trust_remote_code=True,
177 | )
178 |
179 | # used for Emu3-Chat
180 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left")
181 | # used for Emu3-Stage1
182 | # tokenizer = AutoTokenizer.from_pretrained(
183 | # EMU_HUB,
184 | # trust_remote_code=True,
185 | # chat_template="{image_prompt}{text_prompt}",
186 | # padding_side="left",
187 | # )
188 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
189 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
190 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
191 |
192 | # prepare input
193 | text = "Please describe the image"
194 | image = Image.open("assets/demo.png")
195 |
196 | inputs = processor(
197 | text=text,
198 | image=image,
199 | mode='U',
200 | return_tensors="pt",
201 | padding="longest",
202 | )
203 |
204 | # prepare hyper parameters
205 | GENERATION_CONFIG = GenerationConfig(
206 | pad_token_id=tokenizer.pad_token_id,
207 | bos_token_id=tokenizer.bos_token_id,
208 | eos_token_id=tokenizer.eos_token_id,
209 | max_new_tokens=1024,
210 | )
211 |
212 | # generate
213 | outputs = model.generate(
214 | inputs.input_ids.to("cuda:0"),
215 | GENERATION_CONFIG,
216 | attention_mask=inputs.attention_mask.to("cuda:0"),
217 | )
218 |
219 | outputs = outputs[:, inputs.input_ids.shape[-1]:]
220 | print(processor.batch_decode(outputs, skip_special_tokens=True)[0])
221 | ```
222 |
223 | #### Use 🤗Transformers to run Emu3-VisionTokenzier for vision encoding and decoding
224 | ```python
225 | import os
226 | import os.path as osp
227 |
228 | from PIL import Image
229 | import torch
230 | from transformers import AutoModel, AutoImageProcessor
231 |
232 | MODEL_HUB = "BAAI/Emu3-VisionTokenizer"
233 |
234 | model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda()
235 | processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True)
236 |
237 | # TODO: you need to modify the path here
238 | VIDEO_FRAMES_PATH = "YOUR_VIDEO_FRAMES_PATH"
239 |
240 | video = os.listdir(VIDEO_FRAMES_PATH)
241 | video.sort()
242 | video = [Image.open(osp.join(VIDEO_FRAMES_PATH, v)) for v in video]
243 |
244 | images = processor(video, return_tensors="pt")["pixel_values"]
245 | images = images.unsqueeze(0).cuda()
246 |
247 | # image autoencode
248 | image = images[:, 0]
249 | print(image.shape)
250 | with torch.no_grad():
251 | # encode
252 | codes = model.encode(image)
253 | # decode
254 | recon = model.decode(codes)
255 |
256 | recon = recon.view(-1, *recon.shape[2:])
257 | recon_image = processor.postprocess(recon)["pixel_values"][0]
258 | recon_image.save("recon_image.png")
259 |
260 | # video autoencode
261 | images = images.view(
262 | -1,
263 | model.config.temporal_downsample_factor,
264 | *images.shape[2:],
265 | )
266 |
267 | print(images.shape)
268 | with torch.no_grad():
269 | # encode
270 | codes = model.encode(images)
271 | # decode
272 | recon = model.decode(codes)
273 |
274 | recon = recon.view(-1, *recon.shape[2:])
275 | recon_images = processor.postprocess(recon)["pixel_values"]
276 | for idx, im in enumerate(recon_images):
277 | im.save(f"recon_video_{idx}.png")
278 | ```
279 |
280 | ## Acknowledgement
281 |
282 | We thank the great work from [Emu Series](https://github.com/baaivision/Emu), [QWen2-VL](https://github.com/QwenLM/Qwen2-VL) and [MoVQGAN](https://github.com/ai-forever/MoVQGAN)
283 |
284 | ## Citation
285 |
286 | If you find Emu3 useful for your research and applications, please consider starring this repository and citing:
287 |
288 | ```
289 | @article{wang2024emu3,
290 | title={Emu3: Next-Token Prediction is All You Need},
291 | author={Wang, Xinlong and Zhang, Xiaosong and Luo, Zhengxiong and Sun, Quan and Cui, Yufeng and Wang, Jinsheng and Zhang, Fan and Wang, Yueze and Li, Zhen and Yu, Qiying and others},
292 | journal={arXiv preprint arXiv:2409.18869},
293 | year={2024}
294 | }
295 | ```
296 |
297 |
298 |
--------------------------------------------------------------------------------
/assets/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/assets/arch.png
--------------------------------------------------------------------------------
/assets/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/assets/comparison.png
--------------------------------------------------------------------------------
/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/assets/demo.png
--------------------------------------------------------------------------------
/autoencode.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import os.path as osp
5 |
6 | from PIL import Image
7 | import torch
8 | from transformers import AutoModel, AutoImageProcessor
9 |
10 | MODEL_HUB = "BAAI/Emu3-VisionTokenizer"
11 |
12 | model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda()
13 | processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True)
14 |
15 | # TODO: you need to modify the path here
16 | VIDEO_FRAMES_PATH = "YOUR_VIDEO_FRAMES_PATH"
17 |
18 | video = os.listdir(VIDEO_FRAMES_PATH)
19 | video.sort()
20 | video = [Image.open(osp.join(VIDEO_FRAMES_PATH, v)) for v in video]
21 |
22 | images = processor(video, return_tensors="pt")["pixel_values"]
23 | images = images.unsqueeze(0).cuda()
24 |
25 | # image autoencode
26 | image = images[:, 0]
27 | print(image.shape)
28 | with torch.no_grad():
29 | # encode
30 | codes = model.encode(image)
31 | # decode
32 | recon = model.decode(codes)
33 |
34 | recon = recon.view(-1, *recon.shape[2:])
35 | recon_image = processor.postprocess(recon)["pixel_values"][0]
36 | recon_image.save("recon_image.png")
37 |
38 | # video autoencode
39 | images = images.view(
40 | -1,
41 | model.config.temporal_downsample_factor,
42 | *images.shape[2:],
43 | )
44 |
45 | print(images.shape)
46 | with torch.no_grad():
47 | # encode
48 | codes = model.encode(images)
49 | # decode
50 | recon = model.decode(codes)
51 |
52 | recon = recon.view(-1, *recon.shape[2:])
53 | recon_images = processor.postprocess(recon)["pixel_values"]
54 | for idx, im in enumerate(recon_images):
55 | im.save(f"recon_video_{idx}.png")
56 |
--------------------------------------------------------------------------------
/emu3/mllm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 BAAI and the HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import (
17 | OptionalDependencyNotAvailable,
18 | _LazyModule,
19 | is_torch_available,
20 | )
21 |
22 |
23 | _import_structure = {
24 | "configuration_emu3": ["Emu3Config"],
25 | "tokenization_emu3": ["Emu3Tokenizer"],
26 | "processing_emu3": ["Emu3Processor"],
27 | }
28 |
29 | try:
30 | if not is_torch_available():
31 | raise OptionalDependencyNotAvailable()
32 | except OptionalDependencyNotAvailable:
33 | pass
34 | else:
35 | _import_structure["modeling_emu3"] = [
36 | "Emu3Model",
37 | "Emu3PretrainedModel",
38 | "Emu3ForCausalLM",
39 | ]
40 |
41 | if TYPE_CHECKING:
42 | from .configuration_emu3 import Emu3Config
43 | from .tokenization_emu3 import Emu3Tokenizer
44 | from .processing_emu3 import Emu3Processor
45 |
46 | try:
47 | if not is_torch_available():
48 | raise OptionalDependencyNotAvailable()
49 | except OptionalDependencyNotAvailable:
50 | pass
51 | else:
52 | from .modeling_emu3 import (
53 | Emu3Model,
54 | Emu3PretrainedModel,
55 | Emu3ForCausalLM,
56 | )
57 |
58 | else:
59 | import sys
60 |
61 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
62 |
--------------------------------------------------------------------------------
/emu3/mllm/configuration_emu3.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ Emu3 model configuration"""
21 |
22 | from typing import Optional
23 |
24 | from transformers.configuration_utils import PretrainedConfig
25 | from transformers.utils import logging
26 |
27 |
28 | logger = logging.get_logger(__name__)
29 |
30 | EMU3_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
31 |
32 |
33 | class Emu3Config(PretrainedConfig):
34 | r"""
35 | This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate an Emu3
36 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37 | defaults will yield a similar configuration to that of the Emu3-8B.
38 |
39 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40 | documentation from [`PretrainedConfig`] for more information.
41 |
42 |
43 | Args:
44 | vocab_size (`int`, *optional*, defaults to 184622):
45 | Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the
46 | `inputs_ids` passed when calling [`Emu3Model`]
47 | hidden_size (`int`, *optional*, defaults to 4096):
48 | Dimension of the hidden representations.
49 | intermediate_size (`int`, *optional*, defaults to 14336):
50 | Dimension of the MLP representations.
51 | num_hidden_layers (`int`, *optional*, defaults to 32):
52 | Number of hidden layers in the Transformer decoder.
53 | num_attention_heads (`int`, *optional*, defaults to 32):
54 | Number of attention heads for each attention layer in the Transformer decoder.
55 | num_key_value_heads (`int`, *optional*, defaults to 8):
56 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
57 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
58 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
59 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
60 | by meanpooling all the original heads within that group. For more details checkout [this
61 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
62 | `num_attention_heads`.
63 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
64 | The non-linear activation function (function or string) in the decoder.
65 | max_position_embeddings (`int`, *optional*, defaults to 9216):
66 | The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens,
67 | initializer_range (`float`, *optional*, defaults to 0.02):
68 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69 | rms_norm_eps (`float`, *optional*, defaults to 1e-05):
70 | The epsilon used by the rms normalization layers.
71 | use_cache (`bool`, *optional*, defaults to `True`):
72 | Whether or not the model should return the last key/values attentions (not used by all models). Only
73 | relevant if `config.is_decoder=True`.
74 | pad_token_id (`int`, *optional*, 151643):
75 | Padding token id.
76 | bos_token_id (`int`, *optional*, defaults to 151849):
77 | Beginning of stream token id.
78 | eos_token_id (`int`, *optional*, defaults to 151850):
79 | End of stream token id.
80 | img_token_id (`int`, *optional*, defaults to 151851):
81 | image token id.
82 | boi_token_id (`int`, *optional*, defaults to 151852):
83 | Beginning of image token id.
84 | eoi_token_id (`int`, *optional*, defaults to 151853):
85 | End of image token id.
86 | eol_token_id (`int`, *optional*, defaults to 151846):
87 | End of line token id.
88 | eof_token_id (`int`, *optional*, defaults to 151847):
89 | End of line token id.
90 | image_area (`int`, *optional*, defaults to 720 * 720)
91 | generated image area (image area used in training)
92 | pretraining_tp (`int`, *optional*, defaults to 1):
93 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
94 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
95 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
96 | issue](https://github.com/pytorch/pytorch/issues/76232).
97 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
98 | Whether to tie weight embeddings
99 | rope_theta (`float`, *optional*, defaults to 1_000_000.0):
100 | The base period of the RoPE embeddings.
101 | rope_scaling (`Dict`, *optional*):
102 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
103 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
104 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
105 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
106 | these scaling strategies behave:
107 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
108 | experimental feature, subject to breaking API changes in future versions.
109 | attention_dropout (`float`, *optional*, defaults to 0.1):
110 | The dropout ratio for the attention probabilities.
111 |
112 | ```python
113 | >>> from transformers import Emu3Model, Emu3Config
114 |
115 | >>> # Initializing a Emu3-8b style configuration
116 | >>> configuration = Emu3Config()
117 |
118 | >>> # Initializing a model from the Emu3-8b style configuration
119 | >>> model = Emu3Model(configuration)
120 |
121 | >>> # Accessing the model configuration
122 | >>> configuration = model.config
123 | ```"""
124 |
125 | model_type = "Emu3"
126 | keys_to_ignore_at_inference = ["past_key_values"]
127 |
128 | def __init__(
129 | self,
130 | vocab_size: int = 184622,
131 | hidden_size: int = 4096,
132 | intermediate_size: int = 14336,
133 | num_hidden_layers: int = 32,
134 | num_attention_heads: int = 32,
135 | num_key_value_heads: Optional[int] = 8,
136 | hidden_act: str = "silu",
137 | max_position_embeddings: int = 9216,
138 | initializer_range: float = 0.02,
139 | rms_norm_eps: float = 1e-5,
140 | use_cache: bool = True,
141 | pad_token_id: int = 151643,
142 | bos_token_id: int = 151849,
143 | eos_token_id: int = 151850,
144 | img_token_id: int = 151851,
145 | boi_token_id: int = 151852,
146 | eoi_token_id: int = 151853,
147 | eol_token_id: int = 151846,
148 | eof_token_id: int = 151847,
149 | image_area: int = 720 * 720,
150 | pretraining_tp: int = 1,
151 | tie_word_embeddings: bool = False,
152 | rope_theta: float = 1000000.0,
153 | rope_scaling: Optional = None,
154 | attention_dropout: float = 0.1,
155 | **kwargs,
156 | ):
157 | self.vocab_size = vocab_size
158 | self.max_position_embeddings = max_position_embeddings
159 | self.hidden_size = hidden_size
160 | self.intermediate_size = intermediate_size
161 | self.num_hidden_layers = num_hidden_layers
162 | self.num_attention_heads = num_attention_heads
163 |
164 | # for backward compatibility
165 | if num_key_value_heads is None:
166 | num_key_value_heads = num_attention_heads
167 |
168 | self.num_key_value_heads = num_key_value_heads
169 | self.hidden_act = hidden_act
170 | self.initializer_range = initializer_range
171 | self.rms_norm_eps = rms_norm_eps
172 | self.pretraining_tp = pretraining_tp
173 | self.use_cache = use_cache
174 | self.rope_theta = rope_theta
175 | self.rope_scaling = rope_scaling
176 | self._rope_scaling_validation()
177 | self.attention_dropout = attention_dropout
178 |
179 | self.img_token_id = img_token_id
180 | self.boi_token_id = boi_token_id
181 | self.eoi_token_id = eoi_token_id
182 | self.eol_token_id = eol_token_id
183 | self.eof_token_id = eof_token_id
184 | self.image_area = image_area
185 |
186 | super().__init__(
187 | pad_token_id=pad_token_id,
188 | bos_token_id=bos_token_id,
189 | eos_token_id=eos_token_id,
190 | tie_word_embeddings=tie_word_embeddings,
191 | **kwargs,
192 | )
193 |
194 | def _rope_scaling_validation(self):
195 | """
196 | Validate the `rope_scaling` configuration.
197 | """
198 | if self.rope_scaling is None:
199 | return
200 |
201 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
202 | raise ValueError(
203 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
204 | f"got {self.rope_scaling}"
205 | )
206 | rope_scaling_type = self.rope_scaling.get("type", None)
207 | rope_scaling_factor = self.rope_scaling.get("factor", None)
208 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
209 | raise ValueError(
210 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
211 | )
212 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
213 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
214 |
--------------------------------------------------------------------------------
/emu3/mllm/processing_emu3.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Processor class for Emu3. """
16 |
17 | from math import ceil
18 | import re
19 | from typing import List, Optional, Sequence, Union
20 | from functools import partial
21 |
22 | from PIL import Image
23 | import torch
24 | from torch.nn import functional as F
25 | from transformers.feature_extraction_utils import BatchFeature
26 | from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
27 | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
28 | from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
29 | from transformers.utils import logging
30 |
31 | from .utils_emu3 import Emu3PrefixConstrainedLogitsHelper
32 |
33 |
34 | logger = logging.get_logger(__name__)
35 |
36 |
37 | class Emu3Processor(ProcessorMixin):
38 | r"""
39 | Constructs an Emu3 processor which wraps an Emu3 image processor and an Emu3 vision vq model and an Emu3 tokenizer into a single processor.
40 |
41 | [`Emu3Processor`] offers all the functionalities of [`Emu3VisionVQModel`] and [`Emu3Tokenizer`]. See the
42 | [`~Emu3Processor.__call__`], [`~Emu3Processor.decode`], [`~Emu3Processor.vision_encode`], [`~Emu3Processor.vision_decode`]
43 | for more information.
44 |
45 | Args:
46 | image_processor ([`Emu3VisionVQImageProcessor`]):
47 | The image processor is a required input.
48 | vision_tokenizer ([`Emu3VisionVQModel`]):
49 | The vision tokenizer is a required input.
50 | tokenizer ([`Emu3Tokenizer`]):
51 | The tokenizer is a required input.
52 | prefix_template(`str`, *optional*):
53 | The prefix template for image tokens
54 | visual_template(`Tuple[str, ...]`, *optional*):
55 | The visual token template for image tokens
56 | """
57 |
58 | attributes = ["image_processor", "tokenizer"]
59 | valid_kwargs = ["vision_tokenizer", "prefix_template", "visual_template"]
60 | image_processor_class = "AutoImageProcessor"
61 | tokenizer_class = "AutoTokenizer"
62 |
63 | def __init__(
64 | self,
65 | image_processor=None,
66 | vision_tokenizer=None,
67 | tokenizer=None,
68 | chat_template="You are a helpful assistant. USER: {image_prompt}{text_prompt}. ASSISTANT:",
69 | prefix_template="{H}*{W}",
70 | visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>"),
71 | **kwargs,
72 | ):
73 | assert vision_tokenizer is not None, "image tokenizer can not be None"
74 |
75 | self.vision_tokenizer = vision_tokenizer
76 | self.prefix_template = prefix_template
77 | self.visual_template = visual_template
78 | self.vis_tok_spatial_factor = 2 ** (len(self.vision_tokenizer.config.ch_mult) - 1)
79 |
80 | super().__init__(image_processor, tokenizer, chat_template=chat_template)
81 | self.const_helper = self.build_const_helper()
82 |
83 | @torch.no_grad()
84 | def __call__(
85 | self,
86 | text: Optional[TextInput | PreTokenizedInput] = None,
87 | image: Optional[Image.Image | List[Image.Image]] = None,
88 | *,
89 | mode: str = "G",
90 | ratio: str | List[str] = "1:1",
91 | image_area: int = 518400,
92 | padding_image: bool = False,
93 | **kwargs,
94 | ) -> BatchFeature:
95 | """
96 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
97 | and `kwargs` arguments to Emu3Tokenizer's [`~Emu3Tokenizer.__call__`] to encode the text.
98 | To prepare the image(s), this method forwards the `image` argument to
99 | Emu3VisionVQImageProcessor's [`~Emu3VisionVQImageProcessor.__call__`] and Emu3VisionVQModel's [`~EmuVideoVQModel.encode`]
100 | if `image` is not `None`. Please refer to the doctsring of the above two methods for more information.
101 |
102 | Args:
103 | text (`str` or `List[str]`):
104 | The sequence or a batch of sequence to be encoded. A sequence is a string.
105 | image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
106 | The image or a batch of images to be prepared. An image is a PIL image.
107 | mode (`str`, *optional*, in `G` or `U`):
108 | task mode, `G` for generation and `U` for understanding
109 | ratio (`str`, *optional*):
110 | the image width-height ratio for generation
111 | image_area (`int`, *optional*):
112 | image area used to calcualte the generated image height and width
113 | padding_image (`bool`, *optional*):
114 | whether pad images to same size for fast preprocessing if they have different sizes
115 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
116 | If set, will return tensors of a particular framework. Acceptable values are:
117 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
118 | - `'np'`: Return NumPy `np.ndarray` objects.
119 |
120 | Returns:
121 | [`BatchFeature`]: A [`BatchFeature`] with the following fields:
122 |
123 | - **input_ids** -- List of token ids to be fed to a model.
124 | - **image_size** -- List of image size of input images or generated images.
125 | """
126 | assert mode in ('G', 'U'), "mode must be 'G' or 'U'."
127 | if isinstance(text, str):
128 | text = [text]
129 |
130 | if isinstance(image, Image.Image):
131 | image = [image]
132 |
133 | if not isinstance(text[0], str):
134 | raise ValueError("`text` must be string or list of string")
135 |
136 | image_tokens = None
137 | if mode == 'G':
138 | if image is not None:
139 | raise ValueError("You have to specify only `text` in generation mode")
140 |
141 | if isinstance(ratio, str):
142 | ratio = [ratio] * len(text)
143 |
144 | if len(ratio) != len(text):
145 | raise ValueError("ratio number must match text number")
146 | else:
147 | if image is None:
148 | raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.")
149 |
150 | if not isinstance(image, Sequence) and not isinstance(image, Image.Image):
151 | raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
152 |
153 | if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
154 | raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
155 |
156 | image_tokens = self.tokenize_image(image, padding_image=padding_image)
157 | if len(text) != len(image_tokens):
158 | raise ValueError("number of image must match number of text prompt")
159 |
160 | prompt_list, size_list = [], []
161 | for idx, text_prompt in enumerate(text):
162 | prompt = self.tokenizer.bos_token
163 | if mode == 'U':
164 | h, w = image_tokens[idx].shape
165 | imgstr = self.to_imgstr(image_tokens[idx])
166 | image_prompt = (
167 | self.tokenizer.boi_token +
168 | self.prefix_template.format(H=h, W=w) +
169 | self.tokenizer.img_token +
170 | imgstr +
171 | self.tokenizer.eol_token +
172 | self.tokenizer.eof_token +
173 | self.tokenizer.eoi_token
174 | )
175 | prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt)
176 | else:
177 | h, w = self.calculate_generate_size(ratio[idx], image_area, self.vision_tokenizer.spatial_scale_factor)
178 | image_prompt = (
179 | self.tokenizer.boi_token +
180 | self.prefix_template.format(H=h, W=w) +
181 | self.tokenizer.img_token
182 | )
183 | prompt += (text_prompt + image_prompt)
184 |
185 | prompt_list.append(prompt)
186 | size_list.append([h, w])
187 |
188 | text_inputs = self.tokenizer(prompt_list, **kwargs)
189 | return BatchFeature(data={**text_inputs, "image_size": size_list}, tensor_type=kwargs.get("return_tensors"))
190 |
191 | @torch.no_grad()
192 | def batch_decode(self, *args, **kwargs):
193 | docs = self.tokenizer.batch_decode(*args, **kwargs)
194 | return [self.multimodal_decode(d) for d in docs]
195 |
196 | @torch.no_grad()
197 | def decode(self, *args, **kwargs):
198 | doc = self.tokenizer.decode(*args, **kwargs)
199 | return self.multimodal_decode(doc)
200 |
201 | @torch.no_grad()
202 | def vision_encode(self, *args, **kwargs):
203 | return self.vision_tokenizer.encode(*args, **kwargs)
204 |
205 | @torch.no_grad()
206 | def vision_decode(self, *args, **kwargs):
207 | return self.vision_tokenizer.decode(*args, **kwargs)
208 |
209 | @torch.no_grad()
210 | def multimodal_decode(self, doc):
211 | multimodal_output = []
212 | pattern = rf'({re.escape(self.tokenizer.boi_token)}.*?{re.escape(self.tokenizer.eoi_token)})'
213 | chunks = re.split(pattern, doc)
214 | for c in chunks:
215 | if len(c) == 0:
216 | continue
217 |
218 | if self.tokenizer.boi_token in c:
219 | image = []
220 | image_rows = re.split(re.escape(self.tokenizer.eol_token), c)
221 | for r in image_rows:
222 | token_ids = re.findall(self.visual_template[1], r)
223 | if len(token_ids) > 0:
224 | row_token = [int(m) for m in token_ids]
225 | image.append(row_token)
226 | image = torch.tensor(image, dtype=torch.long, device=self.vision_tokenizer.device)
227 | image = self.vision_tokenizer.decode(image[None]).float()
228 | image = self.image_processor.postprocess(image)["pixel_values"][0]
229 | multimodal_output.append(image)
230 | else:
231 | multimodal_output.append(c)
232 |
233 | return multimodal_output if len(multimodal_output) > 1 else multimodal_output[0]
234 |
235 | @property
236 | def model_input_names(self):
237 | tokenizer_input_names = self.tokenizer.model_input_names
238 | image_processor_input_names = self.image_processor.model_input_names
239 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
240 |
241 | def to_imgstr(self, image_tokens):
242 | image_tokens = image_tokens.cpu().numpy().tolist()
243 | image_token_str = [
244 | [
245 | self.visual_template[0].format(token_id=token_id)
246 | for token_id in token_row
247 | ]
248 | for token_row in image_tokens
249 | ]
250 | image_row_str = ["".join(token_row) for token_row in image_token_str]
251 | imgstr = self.tokenizer.eol_token.join(image_row_str)
252 | return imgstr
253 |
254 | def calculate_generate_size(self, ratio, image_area, spatial_scale_factor):
255 | w, h = map(int, ratio.split(":"))
256 | current_area = h * w
257 | target_ratio = (image_area / current_area) ** 0.5
258 |
259 | th = int(round(h * target_ratio / spatial_scale_factor))
260 | tw = int(round(w * target_ratio / spatial_scale_factor))
261 | return th, tw
262 |
263 | def tokenize_image(self, image: List[Image.Image], *, padding_image: bool = False):
264 | is_all_same_size, prev_size = True, None
265 | for im in image:
266 | if prev_size is not None:
267 | is_all_same_size &= (prev_size == im.size)
268 | prev_size = im.size
269 |
270 | if is_all_same_size:
271 | image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"]
272 | image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
273 | image_tokens = self.vision_tokenizer.encode(image_inputs)
274 | elif padding_image:
275 | image_inputs = [self.image_processor(im, return_tensors="pt")["pixel_values"] for im in image]
276 | image_shapes = [im.shape[2:] for im in image_inputs]
277 | max_shape = (
278 | max([im_shape[0] for im_shape in image_shapes]),
279 | max([im_shape[1] for im_shape in image_shapes]),
280 | )
281 | image_inputs = [
282 | F.pad(im_inp, (0, max_shape[1] - im_shape[1], 0, max_shape[0] - im_shape[0]))
283 | for im_inp, im_shape in zip(image_inputs, image_shapes)
284 | ]
285 | image_inputs = torch.cat(image_inputs, dim=0).to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
286 | image_tokens = self.vision_tokenizer.encode(image_inputs)
287 | image_tokens = [
288 | im_tok[:ceil(im_shape[0] / self.vis_tok_spatial_factor), :ceil(im_shape[1] / self.vis_tok_spatial_factor)]
289 | for im_tok, im_shape in zip(image_tokens, image_shapes)
290 | ]
291 | else:
292 | image_tokens = []
293 | for im in image:
294 | image_input = self.image_processor(im, return_tensors="pt")["pixel_values"]
295 | image_input = image_input.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
296 | image_tokens.append(self.vision_tokenizer.encode(image_input).squeeze(0))
297 |
298 | return image_tokens
299 |
300 | def build_const_helper(self):
301 | (
302 | img_token,
303 | eoi_token,
304 | eos_token,
305 | eol_token,
306 | eof_token,
307 | pad_token,
308 | vis_start,
309 | vis_end,
310 | ) = self.tokenizer.encode([
311 | self.tokenizer.img_token,
312 | self.tokenizer.eoi_token,
313 | self.tokenizer.eos_token,
314 | self.tokenizer.eol_token,
315 | self.tokenizer.eof_token,
316 | self.tokenizer.pad_token,
317 | self.visual_template[0].format(token_id=0),
318 | self.visual_template[0].format(token_id=self.vision_tokenizer.config.codebook_size - 1),
319 | ])
320 |
321 | const_helper = partial(
322 | Emu3PrefixConstrainedLogitsHelper,
323 | img_token=img_token,
324 | eoi_token=eoi_token,
325 | eos_token=eos_token,
326 | eol_token=eol_token,
327 | eof_token=eof_token,
328 | pad_token=pad_token,
329 | visual_tokens=list(range(vis_start, vis_end + 1)),
330 | )
331 | return const_helper
332 |
333 | def build_prefix_constrained_fn(self, height, width):
334 | helper = self.const_helper(height=height, width=width)
335 | return helper
336 |
--------------------------------------------------------------------------------
/emu3/mllm/tokenization_emu3.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for Emu3."""
16 |
17 | import base64
18 | import logging
19 | import os
20 | import unicodedata
21 | from typing import Collection, Dict, List, Optional, Set, Tuple, Union
22 |
23 | import tiktoken
24 | from transformers import PreTrainedTokenizer, AddedToken
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | VOCAB_FILES_NAMES = {
30 | "vocab_file": "emu3.tiktoken",
31 | "special_tokens_file": "emu3_vision_tokens.txt",
32 | }
33 |
34 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
35 | ENDOFTEXT = "<|endoftext|>"
36 | IMSTART = "<|im_start|>"
37 | IMEND = "<|im_end|>"
38 | # as the default behavior is changed to allow special tokens in
39 | # regular texts, the surface forms of special tokens need to be
40 | # as different as possible to minimize the impact
41 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
42 | # changed to use actual index to avoid misconfiguration with vocabulary expansion
43 | SPECIAL_START_ID = 151643
44 |
45 |
46 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
47 | with open(tiktoken_bpe_file, "rb") as f:
48 | contents = f.read()
49 | return {
50 | base64.b64decode(token): int(rank)
51 | for token, rank in (line.split() for line in contents.splitlines() if line)
52 | }
53 |
54 |
55 | class Emu3Tokenizer(PreTrainedTokenizer):
56 | """Emu3 tokenizer."""
57 |
58 | vocab_files_names = VOCAB_FILES_NAMES
59 |
60 | def __init__(
61 | self,
62 | vocab_file,
63 | special_tokens_file,
64 | errors="replace",
65 | bos_token = "<|extra_203|>",
66 | eos_token = "<|extra_204|>",
67 | pad_token = "<|endoftext|>",
68 | img_token = "<|image token|>",
69 | boi_token = "<|image start|>",
70 | eoi_token = "<|image end|>",
71 | eol_token = "<|extra_200|>",
72 | eof_token = "<|extra_201|>",
73 | **kwargs,
74 | ):
75 | super().__init__(**kwargs)
76 |
77 | # how to handle errors in decoding UTF-8 byte sequences
78 | # use ignore if you are in streaming inference
79 | self.errors = errors
80 |
81 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
82 |
83 | vision_tokens = [t.strip() for t in open(special_tokens_file).readlines() if len(t.strip()) > 0]
84 | SPECIAL_TOKENS = tuple(
85 | enumerate(
86 | (
87 | (
88 | ENDOFTEXT,
89 | IMSTART,
90 | IMEND,
91 | )
92 | + EXTRAS
93 | + tuple(vision_tokens)
94 | ),
95 | start=SPECIAL_START_ID,
96 | )
97 | )
98 | self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
99 | self.special_tokens_set = set(t for _, t in SPECIAL_TOKENS)
100 |
101 | enc = tiktoken.Encoding(
102 | "Emu3",
103 | pat_str=PAT_STR,
104 | mergeable_ranks=self.mergeable_ranks,
105 | special_tokens=self.special_tokens,
106 | )
107 |
108 | assert (
109 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
110 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
111 |
112 | self.decoder = {
113 | v: k for k, v in self.mergeable_ranks.items()
114 | }
115 | self.decoder.update({v: k for k, v in self.special_tokens.items()})
116 |
117 | self.tokenizer = enc
118 |
119 | self.eod_id = self.tokenizer.eot_token
120 | self.bos_token = bos_token
121 | self.eos_token = eos_token
122 | self.pad_token = pad_token
123 | self.img_token = img_token
124 | self.boi_token = boi_token
125 | self.eoi_token = eoi_token
126 | self.eol_token = eol_token
127 | self.eof_token = eof_token
128 |
129 | def __getstate__(self):
130 | # for pickle lovers
131 | state = self.__dict__.copy()
132 | del state["tokenizer"]
133 | return state
134 |
135 | def __setstate__(self, state):
136 | # tokenizer is not python native; don't pass it; rebuild it
137 | self.__dict__.update(state)
138 | enc = tiktoken.Encoding(
139 | "Emu3",
140 | pat_str=PAT_STR,
141 | mergeable_ranks=self.mergeable_ranks,
142 | special_tokens=self.special_tokens,
143 | )
144 | self.tokenizer = enc
145 |
146 | def __len__(self) -> int:
147 | return self.tokenizer.n_vocab
148 |
149 | def get_vocab(self) -> Dict[bytes, int]:
150 | return self.mergeable_ranks
151 |
152 | def convert_tokens_to_ids(
153 | self, tokens: Union[bytes, str, List[Union[bytes, str]]]
154 | ) -> List[int]:
155 | if isinstance(tokens, (str, bytes)):
156 | if tokens in self.special_tokens:
157 | return self.special_tokens[tokens]
158 | else:
159 | return self.mergeable_ranks.get(tokens)
160 |
161 | ids = []
162 | for token in tokens:
163 | if token in self.special_tokens:
164 | ids.append(self.special_tokens[token])
165 | else:
166 | ids.append(self.mergeable_ranks.get(token))
167 | return ids
168 |
169 | def _add_tokens(
170 | self,
171 | new_tokens: Union[List[str], List[AddedToken]],
172 | special_tokens: bool = False,
173 | ) -> int:
174 | if not special_tokens and new_tokens:
175 | raise ValueError("Adding regular tokens is not supported")
176 |
177 | for token in new_tokens:
178 | surface_form = token.content if isinstance(token, AddedToken) else token
179 | if surface_form not in self.special_tokens_set:
180 | raise ValueError("Adding unknown special tokens is not supported")
181 |
182 | return 0
183 |
184 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
185 | """
186 | Save only the vocabulary of the tokenizer (vocabulary).
187 |
188 | Returns:
189 | `Tuple(str)`: Paths to the files saved.
190 | """
191 | regular_file_path = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
192 | with open(regular_file_path,'w', encoding="utf8") as w:
193 | for k, v in self.mergeable_ranks.items():
194 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
195 | w.write(line)
196 |
197 | excluded_special_tokens = set((ENDOFTEXT, IMSTART, IMEND,) + EXTRAS)
198 | special_file_path = os.path.join(save_directory, self.vocab_files_names["special_tokens_file"])
199 | with open(special_file_path, 'w', encoding="utf8") as w:
200 | for k in self.special_tokens:
201 | if k not in excluded_special_tokens:
202 | print(k, file=w)
203 |
204 | return (regular_file_path, special_file_path)
205 |
206 | def tokenize(
207 | self,
208 | text: str,
209 | allowed_special: Union[Set, str] = "all",
210 | disallowed_special: Union[Collection, str] = (),
211 | **kwargs,
212 | ) -> List[Union[bytes, str]]:
213 | """
214 | Converts a string in a sequence of tokens.
215 |
216 | Args:
217 | text (`str`):
218 | The sequence to be encoded.
219 | allowed_special (`Literal["all"]` or `set`):
220 | The surface forms of the tokens to be encoded as special tokens in regular texts.
221 | Default to "all".
222 | disallowed_special (`Literal["all"]` or `Collection`):
223 | The surface forms of the tokens that should not be in regular texts and trigger errors.
224 | Default to an empty tuple.
225 |
226 | kwargs (additional keyword arguments, *optional*):
227 | Will be passed to the underlying model specific encode method.
228 |
229 | Returns:
230 | `List[bytes|str]`: The list of tokens.
231 | """
232 | tokens = []
233 | text = unicodedata.normalize("NFC", text)
234 |
235 | # this implementation takes a detour: text -> token id -> token surface forms
236 | for t in self.tokenizer.encode(
237 | text, allowed_special=allowed_special, disallowed_special=disallowed_special
238 | ):
239 | tokens.append(self.decoder[t])
240 |
241 | return tokens
242 |
243 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
244 | """
245 | Converts a sequence of tokens in a single string.
246 | """
247 | text = ""
248 | temp = b""
249 | for t in tokens:
250 | if isinstance(t, str):
251 | if temp:
252 | text += temp.decode("utf-8", errors=self.errors)
253 | temp = b""
254 | text += t
255 | elif isinstance(t, bytes):
256 | temp += t
257 | else:
258 | raise TypeError("token should only be of type types or str")
259 | if temp:
260 | text += temp.decode("utf-8", errors=self.errors)
261 | return text
262 |
263 | @property
264 | def vocab_size(self):
265 | return self.tokenizer.n_vocab
266 |
267 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
268 | """Converts an id to a token, special tokens included"""
269 | if index in self.decoder:
270 | return self.decoder[index]
271 | raise ValueError("unknown ids")
272 |
273 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
274 | """Converts a token to an id using the vocab, special tokens included"""
275 | if token in self.special_tokens:
276 | return self.special_tokens[token]
277 | if token in self.mergeable_ranks:
278 | return self.mergeable_ranks[token]
279 | raise ValueError("unknown token")
280 |
281 | def _decode(
282 | self,
283 | token_ids: Union[int, List[int]],
284 | skip_special_tokens: bool = False,
285 | errors: Optional[str] = None,
286 | **kwargs,
287 | ) -> str:
288 | if isinstance(token_ids, int):
289 | token_ids = [token_ids]
290 |
291 | if skip_special_tokens:
292 | token_ids = [i for i in token_ids if i < self.eod_id]
293 |
294 | return self.tokenizer.decode(token_ids, errors=errors or self.errors)
295 |
--------------------------------------------------------------------------------
/emu3/mllm/utils_emu3.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Logits Processor Helper class for Emu3. """
16 |
17 | import torch
18 |
19 | class Emu3PrefixConstrainedLogitsHelper:
20 |
21 | def __init__(
22 | self,
23 | height,
24 | width,
25 | img_token,
26 | eoi_token,
27 | eos_token,
28 | eol_token,
29 | eof_token,
30 | pad_token,
31 | visual_tokens,
32 | ):
33 | self.height = height
34 | self.width = width
35 | self.img_token = img_token
36 | self.eoi_token = eoi_token
37 | self.eos_token = eos_token
38 | self.eol_token = eol_token
39 | self.eof_token = eof_token
40 | self.pad_token = pad_token
41 | self.visual_tokens = visual_tokens
42 |
43 | self.offset_cache = {}
44 |
45 | def __call__(self, batch_id, input_ids):
46 | if batch_id not in self.offset_cache:
47 | position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
48 | self.offset_cache[batch_id] = position
49 |
50 | height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0]
51 | width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]
52 |
53 | offset = input_ids.shape[0] - self.offset_cache[batch_id]
54 | height = height.to(offset.device)
55 | width = width.to(offset.device)
56 |
57 | if offset % (width + 1) == 0:
58 | return (self.eol_token, )
59 | elif offset == (width + 1) * height + 1:
60 | return (self.eof_token, )
61 | elif offset == (width + 1) * height + 2:
62 | return (self.eoi_token, )
63 | elif offset == (width + 1) * height + 3:
64 | return (self.eos_token, )
65 | elif offset > (width + 1) * height + 3:
66 | return (self.pad_token, )
67 | else:
68 | return self.visual_tokens
69 |
--------------------------------------------------------------------------------
/emu3/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 BAAI and the HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import (
17 | OptionalDependencyNotAvailable,
18 | _LazyModule,
19 | is_torch_available,
20 | is_vision_available,
21 | )
22 |
23 |
24 | _import_structure = {"configuration_emu3visionvq": ["Emu3VisionVQConfig"]}
25 |
26 | try:
27 | if not is_torch_available():
28 | raise OptionalDependencyNotAvailable()
29 | except OptionalDependencyNotAvailable:
30 | pass
31 | else:
32 | _import_structure["modeling_emu3visionvq"] = [
33 | "Emu3VisionVQModel",
34 | "Emu3VisionVQPretrainedModel",
35 | ]
36 |
37 | try:
38 | if not is_vision_available():
39 | raise OptionalDependencyNotAvailable()
40 | except OptionalDependencyNotAvailable:
41 | pass
42 | else:
43 | _import_structure["image_processing_emu3visionvq"] = ["Emu3VisionVQImageProcessor"]
44 |
45 | if TYPE_CHECKING:
46 | from .configuration_emu3visionvq import Emu3VisionVQConfig
47 |
48 | try:
49 | if not is_torch_available():
50 | raise OptionalDependencyNotAvailable()
51 | except OptionalDependencyNotAvailable:
52 | pass
53 | else:
54 | from .modeling_emu3visionvq import (
55 | Emu3VisionVQModel,
56 | Emu3VisionVQPretrainedModel,
57 | )
58 |
59 | try:
60 | if not is_vision_available():
61 | raise OptionalDependencyNotAvailable()
62 | except OptionalDependencyNotAvailable:
63 | pass
64 | else:
65 | from .image_processing_emu3visionvq import Emu3VisionVQImageProcessor
66 |
67 | else:
68 | import sys
69 |
70 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
71 |
--------------------------------------------------------------------------------
/emu3/tokenizer/configuration_emu3visionvq.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Emu3VisionVQ model configuration """
16 |
17 | from typing import List
18 |
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 |
22 |
23 | logger = logging.get_logger(__name__)
24 |
25 |
26 | class Emu3VisionVQConfig(PretrainedConfig):
27 | r"""
28 | This is the configuration class to store the configuration of a [`Emu3VisionVQ`]. It is used to instantiate an video movq
29 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30 | defaults will yield a configuration to the VQ model presented in Emu3 paper.
31 |
32 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33 | documentation from [`PretrainedConfig`] for more information.
34 |
35 |
36 | Args:
37 | codebook_size (`int`, *optional*, defaults to 32768):
38 | Codebook size of the VQ model.
39 | embed_dim (`int`, *optional*, defaults to 4):
40 | Dimension of the quantized vector in codebook.
41 | z_channels (`int`, *optional*, defaults to 4):
42 | Dimension of the output channel of encoder and the input channel of decoder
43 | double_z (`bool`, *optional*, defaults to False):
44 | Whether double the output dim of the encoder.
45 | in_channels (`int`, *optional*, defaults to 3):
46 | Input channel of encoder.
47 | out_channels (`int`, *optional*, defaults to 3):
48 | Output channel of decoder.
49 | temporal_downsample_factor (`int`, *optional*, defaults to 4):
50 | Temporal downsample factor.
51 | ch (`int`, *optional*, defaults to 256):
52 | Basic channel number of the intermediate blocks.
53 | ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
54 | Channel scaling factor of the intermediate blocks.
55 | num_res_blocks (`int`, *optional*, defaults to 2):
56 | Residual block number in each stage.
57 | attn_resolutions (`List[int]`, *optional*, defaults to 3):
58 | Stage indices to apply attention.
59 | dropout (`float`, *optional*, defaults to 0.0):
60 | Dropout probability.
61 |
62 | ```python
63 | >>> from transformers import Emu3VisionVQ, Emu3VisionVQConfig
64 |
65 | >>> # Initializing a video VQ model of Emu3 configuration
66 | >>> configuration = Emu3VisionVQConfig()
67 |
68 | >>> # Initializing a model from the Emu3 VQ model style configuration
69 | >>> model = Emu3VisionVQModel(configuration)
70 |
71 | >>> # Accessing the model configuration
72 | >>> configuration = model.config
73 | ```"""
74 |
75 | model_type = "Emu3VisionVQ"
76 |
77 | def __init__(
78 | self,
79 | codebook_size: int = 32768,
80 | embed_dim: int = 4,
81 | z_channels: int = 4,
82 | double_z: bool = False,
83 | in_channels: int = 3,
84 | out_channels: int = 3,
85 | temporal_downsample_factor: int = 4,
86 | ch: int = 256,
87 | ch_mult: List[int] = [1, 2, 2, 4],
88 | num_res_blocks: int = 2,
89 | attn_resolutions: List[int] = [3],
90 | dropout: float = 0.0,
91 | **kwargs,
92 | ):
93 | super().__init__(**kwargs)
94 |
95 | self.codebook_size = codebook_size
96 | self.embed_dim = embed_dim
97 | self.z_channels = z_channels
98 | self.double_z = double_z
99 | self.in_channels = in_channels
100 | self.out_channels = out_channels
101 | self.temporal_downsample_factor = temporal_downsample_factor
102 | self.ch = ch
103 | self.ch_mult = ch_mult
104 | self.num_res_blocks = num_res_blocks
105 | self.attn_resolutions = attn_resolutions
106 | self.dropout = dropout
107 |
--------------------------------------------------------------------------------
/emu3/tokenizer/image_processing_emu3visionvq.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Image processor class for Emu3VisionVQ."""
16 |
17 |
18 | import math
19 | from typing import Dict, List, Optional, Union
20 |
21 | import numpy as np
22 |
23 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
24 | from transformers.image_transforms import (
25 | convert_to_rgb,
26 | resize,
27 | to_channel_dimension_format,
28 | )
29 | from transformers.image_utils import (
30 | IMAGENET_STANDARD_MEAN,
31 | IMAGENET_STANDARD_STD,
32 | ChannelDimension,
33 | ImageInput,
34 | PILImageResampling,
35 | get_image_size,
36 | infer_channel_dimension_format,
37 | is_scaled_image,
38 | make_list_of_images,
39 | to_numpy_array,
40 | valid_images,
41 | validate_preprocess_arguments,
42 | )
43 | from transformers.utils import TensorType, is_vision_available, logging
44 |
45 |
46 | logger = logging.get_logger(__name__)
47 |
48 |
49 | if is_vision_available():
50 | from PIL import Image
51 |
52 |
53 | def smart_resize(
54 | height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
55 | ):
56 | """Rescales the image so that the following conditions are met:
57 |
58 | 1. Both dimensions (height and width) are divisible by 'factor'.
59 |
60 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
61 |
62 | 3. The aspect ratio of the image is maintained as closely as possible.
63 |
64 | """
65 | if height < factor or width < factor:
66 | raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
67 | elif max(height, width) / min(height, width) > 5:
68 | raise ValueError(
69 | f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
70 | )
71 |
72 | h_bar = round(height / factor) * factor
73 | w_bar = round(width / factor) * factor
74 | if h_bar * w_bar > max_pixels:
75 | beta = math.sqrt((height * width) / max_pixels)
76 | h_bar = math.floor(height / beta / factor) * factor
77 | w_bar = math.floor(width / beta / factor) * factor
78 | elif h_bar * w_bar < min_pixels:
79 | beta = math.sqrt(min_pixels / (height * width))
80 | h_bar = math.ceil(height * beta / factor) * factor
81 | w_bar = math.ceil(width * beta / factor) * factor
82 |
83 | return h_bar, w_bar
84 |
85 |
86 | class Emu3VisionVQImageProcessor(BaseImageProcessor):
87 | r"""
88 | Constructs a Emu3VisionVQ image processor that dynamically resizes images based on the original images.
89 |
90 | Args:
91 | do_resize (`bool`, *optional*, defaults to `True`):
92 | Whether to resize the image's (height, width) dimensions.
93 | resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
94 | Resampling filter to use when resizing the image.
95 | do_rescale (`bool`, *optional*, defaults to `True`):
96 | Whether to rescale the image by the specified scale `rescale_factor`.
97 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
98 | Scale factor to use if rescaling the image.
99 | do_normalize (`bool`, *optional*, defaults to `True`):
100 | Whether to normalize the image.
101 | image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
102 | Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
103 | image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
104 | Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
105 | do_convert_rgb (`bool`, *optional*, defaults to `True`):
106 | Whether to convert the image to RGB.
107 | min_pixels (`int`, *optional*, defaults to `512 * 512`):
108 | The min pixels of the image to resize the image.
109 | max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
110 | The max pixels of the image to resize the image.
111 | spatial_factor (`int`, *optional*, defautls to 8):
112 | The spatial downsample factor the image will be downsampled in feature extracting phase
113 | """
114 |
115 | model_input_names = ["pixel_values"]
116 |
117 | def __init__(
118 | self,
119 | do_resize: bool = True,
120 | resample: PILImageResampling = PILImageResampling.BICUBIC,
121 | do_rescale: bool = True,
122 | rescale_factor: Union[int, float] = 1 / 255,
123 | do_normalize: bool = True,
124 | image_mean: Optional[Union[float, List[float]]] = None,
125 | image_std: Optional[Union[float, List[float]]] = None,
126 | do_convert_rgb: bool = True,
127 | min_pixels: int = 512 * 512,
128 | max_pixels: int = 1024 * 1024,
129 | spatial_factor: int = 8,
130 | **kwargs,
131 | ) -> None:
132 | super().__init__(**kwargs)
133 | self.do_resize = do_resize
134 | self.resample = resample
135 | self.do_rescale = do_rescale
136 | self.rescale_factor = rescale_factor
137 | self.do_normalize = do_normalize
138 | self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
139 | self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
140 | self.min_pixels = min_pixels
141 | self.max_pixels = max_pixels
142 | self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
143 | self.do_convert_rgb = do_convert_rgb
144 | self.spatial_factor = spatial_factor
145 |
146 | def _preprocess(
147 | self,
148 | images: ImageInput,
149 | do_resize: Optional[bool] = None,
150 | resample: PILImageResampling = None,
151 | do_rescale: Optional[bool] = None,
152 | rescale_factor: Optional[float] = None,
153 | do_normalize: Optional[bool] = None,
154 | image_mean: Optional[Union[float, List[float]]] = None,
155 | image_std: Optional[Union[float, List[float]]] = None,
156 | do_convert_rgb: Optional[bool] = None,
157 | spatial_factor: Optional[int] = None,
158 | input_data_format: Optional[Union[str, ChannelDimension]] = None,
159 | output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
160 | ):
161 | """
162 | Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
163 |
164 | Args:
165 | images (`ImageInput`):
166 | Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
167 | do_resize (`bool`, *optional*, defaults to `self.do_resize`):
168 | Whether to resize the image.
169 | resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
170 | Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
171 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
172 | Whether to rescale the image.
173 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
174 | Scale factor to use if rescaling the image.
175 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
176 | Whether to normalize the image.
177 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
178 | Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
179 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
180 | Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
181 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
182 | Whether to convert the image to RGB.
183 | spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
184 | The spatial downsample factor the image will be downsampled in feature extracting phase
185 | input_data_format (`ChannelDimension` or `str`, *optional*):
186 | The channel dimension format for the input image. Can be one of:
187 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
188 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
189 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
190 | output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
191 | The channel dimension format for the output image. Can be one of:
192 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
193 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
194 | - Unset: Use the channel dimension format of the input image.
195 | """
196 | spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
197 |
198 | images = make_list_of_images(images)
199 | if do_convert_rgb:
200 | images = [convert_to_rgb(image) for image in images]
201 |
202 | # All transformations expect numpy arrays.
203 | images = [to_numpy_array(image) for image in images]
204 |
205 | if is_scaled_image(images[0]) and do_rescale:
206 | logger.warning_once(
207 | "It looks like you are trying to rescale already rescaled images. If the input"
208 | "pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
209 | )
210 |
211 | if input_data_format is None:
212 | # We assume that all images have the same channel dimension format.
213 | input_data_format = infer_channel_dimension_format(images[0])
214 |
215 | height, width = get_image_size(images[0], channel_dim=input_data_format)
216 | resized_height, resized_width = height, width
217 | processed_images = []
218 | for image in images:
219 | if do_resize:
220 | resized_height, resized_width = smart_resize(
221 | height,
222 | width,
223 | factor=spatial_factor,
224 | min_pixels=self.min_pixels,
225 | max_pixels=self.max_pixels,
226 | )
227 | image = resize(
228 | image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
229 | )
230 |
231 | if do_rescale:
232 | image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
233 |
234 | if do_normalize:
235 | image = self.normalize(
236 | image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
237 | )
238 |
239 | image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
240 | processed_images.append(image)
241 |
242 | image = np.array(processed_images)
243 | return image
244 |
245 | def preprocess(
246 | self,
247 | images: ImageInput,
248 | do_resize: Optional[bool] = None,
249 | resample: PILImageResampling = None,
250 | do_rescale: Optional[bool] = None,
251 | rescale_factor: Optional[float] = None,
252 | do_normalize: Optional[bool] = None,
253 | image_mean: Optional[Union[float, List[float]]] = None,
254 | image_std: Optional[Union[float, List[float]]] = None,
255 | do_convert_rgb: Optional[bool] = None,
256 | spatial_factor: Optional[int] = None,
257 | return_tensors: Optional[Union[str, TensorType]] = None,
258 | input_data_format: Optional[Union[str, ChannelDimension]] = None,
259 | output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
260 | ):
261 | """
262 | Args:
263 | images (`ImageInput`):
264 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
265 | passing in images with pixel values between 0 and 1, set `do_rescale=False`.
266 | do_resize (`bool`, *optional*, defaults to `self.do_resize`):
267 | Whether to resize the image.
268 | resample (`int`, *optional*, defaults to `self.resample`):
269 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
270 | has an effect if `do_resize` is set to `True`.
271 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
272 | Whether to rescale the image.
273 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
274 | Rescale factor to rescale the image by if `do_rescale` is set to `True`.
275 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
276 | Whether to normalize the image.
277 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
278 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
279 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
280 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
281 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
282 | Whether to convert the image to RGB.
283 | spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
284 | The spatial downsample factor the image will be downsampled in feature extracting phase
285 | return_tensors (`str` or `TensorType`, *optional*):
286 | The type of tensors to return. Can be one of:
287 | - Unset: Return a list of `np.ndarray`.
288 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
289 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
290 | input_data_format (`ChannelDimension` or `str`, *optional*):
291 | The channel dimension format for the input image. If unset, the channel dimension format is inferred
292 | from the input image. Can be one of:
293 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
294 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
295 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
296 | output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
297 | The channel dimension format for the output image. Can be one of:
298 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
299 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
300 | - Unset: Use the channel dimension format of the input image.
301 | """
302 | do_resize = do_resize if do_resize is not None else self.do_resize
303 | resample = resample if resample is not None else self.resample
304 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale
305 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
306 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize
307 | image_mean = image_mean if image_mean is not None else self.image_mean
308 | image_std = image_std if image_std is not None else self.image_std
309 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
310 | spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
311 |
312 | images = make_list_of_images(images)
313 | if images is None or not valid_images(images):
314 | raise ValueError(
315 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
316 | "torch.Tensor, tf.Tensor or jax.ndarray."
317 | )
318 |
319 | validate_preprocess_arguments(
320 | rescale_factor=rescale_factor,
321 | do_normalize=do_normalize,
322 | image_mean=image_mean,
323 | image_std=image_std,
324 | do_resize=do_resize,
325 | size=self.size,
326 | resample=resample,
327 | )
328 |
329 | pixel_values = []
330 | for image in images:
331 | norm_image = self._preprocess(
332 | image,
333 | do_resize=do_resize,
334 | resample=resample,
335 | do_rescale=do_rescale,
336 | rescale_factor=rescale_factor,
337 | do_normalize=do_normalize,
338 | image_mean=image_mean,
339 | image_std=image_std,
340 | do_convert_rgb=do_convert_rgb,
341 | spatial_factor=spatial_factor,
342 | input_data_format=input_data_format,
343 | output_data_format=output_data_format,
344 | )
345 | pixel_values.extend(norm_image)
346 | pixel_values = np.array(pixel_values)
347 | data = {"pixel_values": pixel_values}
348 |
349 | return BatchFeature(data=data, tensor_type=return_tensors)
350 |
351 | def postprocess(
352 | self,
353 | images: ImageInput,
354 | do_rescale: Optional[bool] = None,
355 | rescale_factor: Optional[float] = None,
356 | do_normalize: Optional[bool] = None,
357 | image_mean: Optional[Union[float, List[float]]] = None,
358 | image_std: Optional[Union[float, List[float]]] = None,
359 | return_tensors: str | TensorType = "PIL.Image.Image",
360 | input_data_format: Optional[Union[str, ChannelDimension]] = None,
361 | ):
362 | """
363 | Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
364 | The parameters should be same as in preprocess.
365 |
366 | Args:
367 | images (`ImageInput`):
368 | Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
369 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
370 | Whether to rescale the image.
371 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
372 | Rescale factor to rescale the image by if `do_rescale` is set to `True`.
373 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
374 | Whether to normalize the image.
375 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
376 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
377 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
378 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
379 | return_tensors (`str` or `TensorType`, *optional*):
380 | The type of tensors to return. Can be one of:
381 | - Unset: Return a list of `np.ndarray`.
382 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
383 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
384 | input_data_format (`ChannelDimension` or `str`, *optional*):
385 | The channel dimension format for the input image. If unset, the channel dimension format is inferred
386 | from the input image. Can be one of:
387 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
388 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
389 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
390 | """
391 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale
392 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
393 | rescale_factor = 1 / rescale_factor
394 |
395 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize
396 | image_mean = image_mean if image_mean is not None else self.image_mean
397 | image_std = image_std if image_std is not None else self.image_std
398 | image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
399 |
400 | images = make_list_of_images(images)
401 | if isinstance(images[0], Image.Image):
402 | return images if len(images) > 1 else images[0]
403 |
404 | if input_data_format is None:
405 | # We assume that all images have the same channel dimension format.
406 | input_data_format = infer_channel_dimension_format(images[0])
407 |
408 | pixel_values = []
409 | for image in images:
410 | image = to_numpy_array(image)
411 | if do_normalize:
412 | image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
413 |
414 | if do_rescale:
415 | image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
416 | image = image.clip(0, 255).astype(np.uint8)
417 |
418 | if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
419 | image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
420 | pixel_values.append(Image.fromarray(image))
421 | else:
422 | pixel_values.extend(image)
423 |
424 | data = {"pixel_values": pixel_values}
425 | return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
426 |
427 | return BatchFeature(data=data, tensor_type=return_tensors)
428 |
429 | def inverse_meanstd(self, image_mean, image_std):
430 | image_mean = self.to_tuple(image_mean)
431 | image_std = self.to_tuple(image_std)
432 |
433 | rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
434 | rev_image_std = tuple(1 / s for s in image_std)
435 |
436 | return rev_image_mean, rev_image_std
437 |
438 | def to_tuple(self, value, dim=3):
439 | if isinstance(value, int | float):
440 | return (value,) * dim
441 |
442 | return tuple(value)
443 |
--------------------------------------------------------------------------------
/emu3/tokenizer/modeling_emu3visionvq.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Emu3VisionVQ model """
16 |
17 | import math
18 | from typing import Optional, Tuple, Union
19 |
20 | import torch
21 | from torch import nn
22 | from torch.nn import functional as F
23 | from transformers.modeling_utils import PreTrainedModel
24 |
25 | from .configuration_emu3visionvq import Emu3VisionVQConfig
26 |
27 |
28 | class Emu3VisionVQActivation(nn.Module):
29 |
30 | def __init__(self):
31 | super().__init__()
32 |
33 | def __call__(self, x: torch.Tensor):
34 | return x * torch.sigmoid(x)
35 |
36 |
37 | class Emu3VisionVQUpsample(nn.Module):
38 |
39 | def __init__(self, in_channels: int):
40 | super().__init__()
41 | self.conv = nn.Conv2d(
42 | in_channels,
43 | in_channels,
44 | kernel_size=3,
45 | stride=1,
46 | padding=1,
47 | )
48 |
49 | def forward(self, x: torch.Tensor):
50 | x = F.interpolate(x, scale_factor=2.0, mode="nearest")
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class Emu3VisionVQDownsample(nn.Module):
56 |
57 | def __init__(self, in_channels: int):
58 | super().__init__()
59 | self.conv = nn.Conv2d(
60 | in_channels,
61 | in_channels,
62 | kernel_size=3,
63 | stride=2,
64 | padding=0,
65 | )
66 |
67 | def forward(self, x: torch.Tensor):
68 | pad = (0, 1, 0, 1)
69 | x = F.pad(x, pad, mode="constant", value=0)
70 | x = self.conv(x)
71 | return x
72 |
73 |
74 | class Emu3VisionVQCausalConv3d(nn.Module):
75 |
76 | def __init__(
77 | self,
78 | in_channel: int,
79 | out_channel: int,
80 | kernel_size: Union[int, Tuple[int, ...]] = (3, 1, 1),
81 | stride: Union[int, Tuple[int, ...]] = (1, 1, 1),
82 | ):
83 | super().__init__()
84 |
85 | if isinstance(kernel_size, int):
86 | kernel_size = (kernel_size,) * 3
87 | if isinstance(stride, int):
88 | stride = (stride,) * 3
89 |
90 | hw_pad = [k - s for k, s in zip(kernel_size[1:], stride[1:])]
91 | self.padding = tuple()
92 | for p in hw_pad[::-1]:
93 | self.padding += (p // 2 + p % 2, p // 2)
94 | self.padding += (2, 0)
95 |
96 | self.conv = nn.Conv3d(
97 | in_channel,
98 | out_channel,
99 | kernel_size,
100 | stride=stride,
101 | )
102 |
103 | def forward(self, x: torch.Tensor):
104 | x = F.pad(x, self.padding)
105 | x = self.conv(x)
106 | return x
107 |
108 |
109 | class Emu3VisionVQResnetTemporalBlock(nn.Module):
110 |
111 | def __init__(
112 | self,
113 | in_channels: int,
114 | out_channels: Optional[int] = None,
115 | conv_shortcut: bool = False,
116 | dropout: float = 0.0,
117 | ):
118 | super().__init__()
119 | self.in_channels = in_channels
120 | out_channels = in_channels if out_channels is None else out_channels
121 | self.out_channels = out_channels
122 | self.use_conv_shortcut = conv_shortcut
123 |
124 | stride = (1, 1, 1)
125 | kernel_size = (3, 3, 3)
126 |
127 | self.norm1 = nn.BatchNorm3d(in_channels)
128 | self.conv1 = Emu3VisionVQCausalConv3d(
129 | in_channels,
130 | out_channels,
131 | kernel_size=kernel_size,
132 | stride=stride,
133 | )
134 | self.norm2 = nn.BatchNorm3d(out_channels)
135 | self.dropout = nn.Dropout(dropout)
136 | self.conv2 = Emu3VisionVQCausalConv3d(
137 | out_channels,
138 | out_channels,
139 | kernel_size=kernel_size,
140 | stride=stride,
141 | )
142 | self.act = Emu3VisionVQActivation()
143 |
144 | if self.in_channels != self.out_channels:
145 | if self.use_conv_shortcut:
146 | self.conv_shortcut = Emu3VisionVQCausalConv3d(
147 | in_channels,
148 | out_channels,
149 | kernel_size=kernel_size,
150 | stride=stride,
151 | )
152 | else:
153 | self.nin_shortcut = nn.Conv3d(
154 | in_channels,
155 | out_channels,
156 | kernel_size=1,
157 | stride=1,
158 | padding=0,
159 | )
160 |
161 | def forward(self, x: torch.Tensor):
162 | h = self.norm1(x)
163 | h = self.act(h)
164 | h = self.conv1(h)
165 |
166 | h = self.norm2(h)
167 | h = self.act(h)
168 | h = self.dropout(h)
169 | h = self.conv2(h)
170 |
171 | if self.in_channels != self.out_channels:
172 | if self.use_conv_shortcut:
173 | x = self.conv_shortcut(x)
174 | else:
175 | x = self.nin_shortcut(x)
176 |
177 | return x + h
178 |
179 |
180 | class Emu3VisionVQSpatialNorm(nn.Module):
181 |
182 | def __init__(
183 | self,
184 | f_channels: int,
185 | zq_channels: int,
186 | norm_layer: nn.Module = nn.GroupNorm,
187 | add_conv: bool = False,
188 | num_groups: int = 32,
189 | eps: float = 1e-6,
190 | affine: bool = True,
191 | ):
192 | super().__init__()
193 | self.norm_layer = norm_layer(
194 | num_channels=f_channels,
195 | num_groups=num_groups,
196 | eps=eps,
197 | affine=affine,
198 | )
199 |
200 | self.add_conv = add_conv
201 | if self.add_conv:
202 | self.conv = nn.Conv2d(
203 | zq_channels,
204 | zq_channels,
205 | kernel_size=3,
206 | stride=1,
207 | padding=1,
208 | )
209 |
210 | self.conv_y = nn.Conv2d(
211 | zq_channels,
212 | f_channels,
213 | kernel_size=1,
214 | stride=1,
215 | padding=0,
216 | )
217 | self.conv_b = nn.Conv2d(
218 | zq_channels,
219 | f_channels,
220 | kernel_size=1,
221 | stride=1,
222 | padding=0,
223 | )
224 |
225 | def forward(self, x: torch.Tensor, zq: torch.Tensor):
226 | zq = F.interpolate(zq, size=x.shape[-2:], mode="nearest")
227 |
228 | if self.add_conv:
229 | zq = self.conv(zq)
230 |
231 | x = self.norm_layer(x)
232 | x = x * self.conv_y(zq) + self.conv_b(zq)
233 | return x
234 |
235 |
236 | class Emu3VisionVQResnetBlock(nn.Module):
237 |
238 | def __init__(
239 | self,
240 | in_channels: int,
241 | out_channels: Optional[int] = None,
242 | conv_shortcut: bool = False,
243 | dropout: float = 0.0,
244 | zq_ch: Optional[int] = None,
245 | add_conv: bool = False,
246 | ):
247 | super().__init__()
248 | self.in_channels = in_channels
249 | out_channels = in_channels if out_channels is None else out_channels
250 | self.out_channels = out_channels
251 | self.use_conv_shortcut = conv_shortcut
252 | self.zq_ch = zq_ch
253 |
254 | if zq_ch is None:
255 | norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
256 | self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
257 | self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
258 | else:
259 | self.norm1 = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
260 | self.norm2 = Emu3VisionVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
261 |
262 | self.conv1 = nn.Conv2d(
263 | in_channels,
264 | out_channels,
265 | kernel_size=3,
266 | stride=1,
267 | padding=1,
268 | )
269 |
270 | self.dropout = nn.Dropout(dropout)
271 | self.conv2 = nn.Conv2d(
272 | out_channels,
273 | out_channels,
274 | kernel_size=3,
275 | stride=1,
276 | padding=1,
277 | )
278 |
279 | self.act = Emu3VisionVQActivation()
280 |
281 | if self.in_channels != self.out_channels:
282 | if self.use_conv_shortcut:
283 | self.conv_shortcut = nn.Conv2d(
284 | in_channels,
285 | out_channels,
286 | kernel_size=3,
287 | stride=1,
288 | padding=1,
289 | )
290 | else:
291 | self.nin_shortcut = nn.Conv2d(
292 | in_channels,
293 | out_channels,
294 | kernel_size=1,
295 | stride=1,
296 | padding=0,
297 | )
298 |
299 | def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
300 | norm_args = tuple() if self.zq_ch is None else (zq, )
301 |
302 | h = self.norm1(x, *norm_args)
303 | h = self.act(h)
304 | h = self.conv1(h)
305 |
306 | h = self.norm2(h, *norm_args)
307 | h = self.act(h)
308 | h = self.dropout(h)
309 | h = self.conv2(h)
310 |
311 | if self.in_channels != self.out_channels:
312 | if self.use_conv_shortcut:
313 | x = self.conv_shortcut(x)
314 | else:
315 | x = self.nin_shortcut(x)
316 |
317 | return x + h
318 |
319 |
320 | class Emu3VisionVQAttnBlock(nn.Module):
321 |
322 | def __init__(
323 | self,
324 | in_channels: int,
325 | zq_ch: Optional[int] = None,
326 | add_conv: bool = False
327 | ):
328 | super().__init__()
329 | self.in_channels = in_channels
330 | self.zq_ch = zq_ch
331 |
332 | if zq_ch is None:
333 | norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
334 | self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
335 | else:
336 | self.norm = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
337 |
338 | self.q = nn.Conv2d(
339 | in_channels,
340 | in_channels,
341 | kernel_size=1,
342 | stride=1,
343 | padding=0,
344 | )
345 | self.k = nn.Conv2d(
346 | in_channels,
347 | in_channels,
348 | kernel_size=1,
349 | stride=1,
350 | padding=0,
351 | )
352 | self.v = nn.Conv2d(
353 | in_channels,
354 | in_channels,
355 | kernel_size=1,
356 | stride=1,
357 | padding=0,
358 | )
359 | self.proj_out = nn.Conv2d(
360 | in_channels,
361 | in_channels,
362 | kernel_size=1,
363 | stride=1,
364 | padding=0,
365 | )
366 |
367 | def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
368 | norm_args = tuple() if self.zq_ch is None else (zq, )
369 |
370 | nx = self.norm(x, *norm_args)
371 | q = self.q(nx)
372 | k = self.k(nx)
373 | v = self.v(nx)
374 |
375 | # compute attention
376 | b, c, h, w = q.shape
377 | q = q.reshape(b, c, h * w)
378 | k = k.reshape(b, c, h * w)
379 | score = torch.bmm(q.permute(0, 2, 1), k)
380 | score = score / (c ** 0.5)
381 | score = F.softmax(score, dim=2)
382 |
383 | # attend to values
384 | v = v.reshape(b, c, h * w)
385 | v = torch.bmm(v, score.permute(0, 2, 1))
386 | v = v.reshape(b, c, h, w)
387 |
388 | v = self.proj_out(v)
389 |
390 | return x + v
391 |
392 |
393 | class Emu3VisionVQTemporalUpsample(nn.Module):
394 |
395 | def __init__(
396 | self,
397 | in_channel: int,
398 | out_channel: int,
399 | kernel_size: Tuple[int, ...] = (3, 3, 3),
400 | stride: Tuple[int, ...] = (1, 1, 1)
401 | ):
402 | super().__init__()
403 | self.in_channel = in_channel
404 | self.out_channel = out_channel
405 | self.conv = Emu3VisionVQCausalConv3d(
406 | in_channel,
407 | out_channel,
408 | kernel_size,
409 | stride=stride,
410 | )
411 |
412 | def forward(self, x: torch.Tensor):
413 | b, c, t, h, w = x.shape
414 | x = x.permute(0, 1, 3, 4, 2).contiguous().view(b, -1, t)
415 | x = F.interpolate(x, scale_factor=2.0, mode="nearest")
416 | x = x.view(b, c, h, w, -1).permute(0, 1, 4, 2, 3).contiguous()
417 | x = self.conv(x)
418 | return x
419 |
420 |
421 | class Emu3VisionVQTemporalDownsample(nn.Module):
422 |
423 | def __init__(
424 | self,
425 | in_channel: int,
426 | out_channel: int,
427 | kernel_size: Tuple[int, ...] = (4, 3, 3),
428 | stride: Tuple[int, ...] = (2, 1, 1),
429 | ):
430 | super().__init__()
431 | self.in_channel = in_channel
432 | self.out_channel = out_channel
433 | self.kernel_size = kernel_size
434 |
435 | self.conv = Emu3VisionVQCausalConv3d(
436 | in_channel,
437 | out_channel,
438 | kernel_size=kernel_size,
439 | stride=stride,
440 | )
441 |
442 | def forward(self, x: torch.Tensor):
443 | x = self.conv(x)
444 | return x
445 |
446 |
447 | class Emu3VisionVQVectorQuantizer(nn.Module):
448 |
449 | def __init__(self, config: Emu3VisionVQConfig):
450 | super().__init__()
451 | self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
452 | self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
453 |
454 | def forward(self, x: torch.Tensor):
455 | # b t c h w -> b t h w c
456 | b, t, c, h, w = x.shape
457 | x = x.permute(0, 1, 3, 4, 2).contiguous()
458 | x_flattened = x.view(-1, c)
459 |
460 | codebook = self.embedding.weight
461 |
462 | d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
463 | torch.sum(codebook ** 2, dim=1) - 2 * \
464 | torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
465 |
466 | indices = torch.argmin(d, dim=1)
467 | indices = indices.view(b, t, h, w)
468 | return indices
469 |
470 |
471 | class Emu3VisionVQEncoder(nn.Module):
472 |
473 | def __init__(self, config: Emu3VisionVQConfig):
474 | super().__init__()
475 | self.ch = config.ch
476 | self.num_resolutions = len(config.ch_mult)
477 | self.num_res_blocks = config.num_res_blocks
478 | self.in_channels = config.in_channels
479 |
480 | # downsampling
481 | self.conv_in = nn.Conv2d(
482 | self.in_channels,
483 | self.ch,
484 | kernel_size=3,
485 | stride=1,
486 | padding=1
487 | )
488 |
489 | in_ch_mult = (1,) + tuple(config.ch_mult)
490 | self.down = nn.ModuleList()
491 | for i_level in range(self.num_resolutions):
492 | block = nn.ModuleList()
493 | attn = nn.ModuleList()
494 | block_in = config.ch * in_ch_mult[i_level]
495 | block_out = config.ch * config.ch_mult[i_level]
496 | for i_block in range(self.num_res_blocks):
497 | block.append(
498 | Emu3VisionVQResnetBlock(
499 | in_channels=block_in,
500 | out_channels=block_out,
501 | dropout=config.dropout,
502 | )
503 | )
504 | block_in = block_out
505 | if i_level in config.attn_resolutions:
506 | attn.append(Emu3VisionVQAttnBlock(block_in))
507 |
508 | down = nn.Module()
509 | down.block = block
510 | down.attn = attn
511 | if i_level != self.num_resolutions - 1:
512 | down.downsample = Emu3VisionVQDownsample(block_in)
513 |
514 | self.down.append(down)
515 |
516 | # middle
517 | self.mid = nn.Module()
518 | self.mid.block_1 = Emu3VisionVQResnetBlock(
519 | in_channels=block_in,
520 | out_channels=block_in,
521 | dropout=config.dropout,
522 | )
523 | self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in)
524 | self.mid.block_2 = Emu3VisionVQResnetBlock(
525 | in_channels=block_in,
526 | out_channels=block_in,
527 | dropout=config.dropout,
528 | )
529 |
530 | # end
531 | self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
532 |
533 | out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
534 | self.conv_out = nn.Conv2d(
535 | block_in,
536 | out_z_channels,
537 | kernel_size=3,
538 | stride=1,
539 | padding=1,
540 | )
541 |
542 | temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
543 | self.time_conv = nn.ModuleList()
544 |
545 | for i in range(temporal_down_blocks):
546 | conv = Emu3VisionVQTemporalDownsample(out_z_channels, out_z_channels)
547 | self.time_conv.append(conv)
548 |
549 | self.time_res_stack = nn.Sequential(*[
550 | Emu3VisionVQResnetTemporalBlock(
551 | in_channels=out_z_channels,
552 | out_channels=out_z_channels,
553 | dropout=config.dropout,
554 | ) for _ in range(self.num_res_blocks)
555 | ])
556 |
557 | self.act = Emu3VisionVQActivation()
558 |
559 | def forward(self, x: torch.Tensor):
560 | t = x.shape[1]
561 | x = x.reshape(-1, *x.shape[2:])
562 |
563 | # downsampling
564 | h = self.conv_in(x)
565 | for i_level in range(self.num_resolutions):
566 | for i_block in range(self.num_res_blocks):
567 | h = self.down[i_level].block[i_block](h)
568 | if len(self.down[i_level].attn) > 0:
569 | h = self.down[i_level].attn[i_block](h)
570 |
571 | if i_level != self.num_resolutions - 1:
572 | h = self.down[i_level].downsample(h)
573 |
574 | h = self.mid.block_1(h)
575 | h = self.mid.attn_1(h)
576 | h = self.mid.block_2(h)
577 |
578 | # end
579 | h = self.norm_out(h)
580 | h = self.act(h)
581 |
582 | h = self.conv_out(h)
583 |
584 | h = h.reshape(-1, t, *h.shape[1:])
585 | h = h.permute(0, 2, 1, 3, 4)
586 |
587 | for conv in self.time_conv:
588 | h = self.act(conv(h))
589 |
590 | h = self.time_res_stack(h)
591 | h = h.permute(0, 2, 1, 3, 4)
592 |
593 | return h
594 |
595 |
596 | class Emu3VisionVQDecoder(nn.Module):
597 |
598 | def __init__(self, config: Emu3VisionVQConfig):
599 | super().__init__()
600 | self.ch = config.ch
601 | self.num_resolutions = len(config.ch_mult)
602 | self.num_res_blocks = config.num_res_blocks
603 |
604 | in_ch_mult = (1,) + tuple(config.ch_mult)
605 | zq_ch = config.embed_dim
606 |
607 | block_in = config.ch * config.ch_mult[-1]
608 | self.time_res_stack = nn.Sequential(*[
609 | Emu3VisionVQResnetTemporalBlock(
610 | in_channels=config.z_channels,
611 | out_channels=config.z_channels,
612 | dropout=config.dropout,
613 | ) for _ in range(config.num_res_blocks)
614 | ])
615 |
616 | tempo_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
617 | self.time_conv = nn.ModuleList()
618 | for i in range(tempo_upsample_block_num):
619 | conv = Emu3VisionVQTemporalUpsample(config.z_channels, config.z_channels)
620 | self.time_conv.append(conv)
621 |
622 | self.conv_in = nn.Conv2d(
623 | config.z_channels,
624 | block_in,
625 | kernel_size=3,
626 | stride=1,
627 | padding=1,
628 | )
629 |
630 | # middle
631 | self.mid = nn.Module()
632 | self.mid.block_1 = Emu3VisionVQResnetBlock(
633 | in_channels=block_in,
634 | out_channels=block_in,
635 | dropout=config.dropout,
636 | zq_ch=zq_ch,
637 | )
638 | self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in, zq_ch)
639 | self.mid.block_2 = Emu3VisionVQResnetBlock(
640 | in_channels=block_in,
641 | out_channels=block_in,
642 | dropout=config.dropout,
643 | zq_ch=zq_ch,
644 | )
645 |
646 | # upsampling
647 | self.up = nn.ModuleList()
648 | for i_level in reversed(range(self.num_resolutions)):
649 | block = nn.ModuleList()
650 | attn = nn.ModuleList()
651 | block_out = config.ch * config.ch_mult[i_level]
652 | for i_block in range(self.num_res_blocks + 1):
653 | block.append(
654 | Emu3VisionVQResnetBlock(
655 | in_channels=block_in,
656 | out_channels=block_out,
657 | dropout=config.dropout,
658 | zq_ch=zq_ch,
659 | )
660 | )
661 | block_in = block_out
662 | if i_level in config.attn_resolutions:
663 | attn.append(Emu3VisionVQAttnBlock(block_in, zq_ch))
664 |
665 | up = nn.Module()
666 | up.block = block
667 | up.attn = attn
668 | if i_level != 0:
669 | up.upsample = Emu3VisionVQUpsample(block_in)
670 |
671 | self.up.insert(0, up)
672 |
673 | self.act = Emu3VisionVQActivation()
674 |
675 | self.norm_out = Emu3VisionVQSpatialNorm(block_in, zq_ch)
676 | self.conv_out = nn.Conv2d(
677 | block_in,
678 | config.out_channels,
679 | kernel_size=3,
680 | stride=1,
681 | padding=1,
682 | )
683 |
684 | def forward(self, z: torch.Tensor, zq: torch.Tensor):
685 | z_zq = torch.cat((z, zq), dim=0)
686 | z_zq = z_zq.permute(0, 2, 1, 3, 4)
687 | z_zq = self.time_res_stack(z_zq)
688 |
689 | for conv in self.time_conv:
690 | z_zq = self.act(conv(z_zq))
691 |
692 | z_zq = z_zq.permute(0, 2, 1, 3, 4)
693 |
694 | h, zq = torch.chunk(z_zq, 2, dim=0)
695 |
696 | h = h.reshape(-1, *h.shape[2:])
697 | zq = zq.reshape(-1, *zq.shape[2:])
698 |
699 | h = self.conv_in(h)
700 |
701 | # middle
702 | h = self.mid.block_1(h, zq)
703 | h = self.mid.attn_1(h, zq)
704 | h = self.mid.block_2(h, zq)
705 |
706 | # upsampling
707 | for i_level in reversed(range(self.num_resolutions)):
708 | for i_block in range(self.num_res_blocks+1):
709 | h = self.up[i_level].block[i_block](h, zq)
710 | if len(self.up[i_level].attn) > 0:
711 | h = self.up[i_level].attn[i_block](h, zq)
712 |
713 | if i_level != 0:
714 | h = self.up[i_level].upsample(h)
715 |
716 | h = self.norm_out(h, zq)
717 | h = self.act(h)
718 | h = self.conv_out(h)
719 |
720 | return h
721 |
722 |
723 | class Emu3VisionVQPretrainedModel(PreTrainedModel):
724 | """
725 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
726 | models.
727 | """
728 |
729 | config_class = Emu3VisionVQConfig
730 | base_model_prefix = "emuvideovq"
731 | main_input_name = "pixel_values"
732 | _no_split_modules = ["Emu3VisionVQResnetBlock", "Emu3VisionVQAttnBlock", "Emu3VisionVQResnetTemporalBlock"]
733 |
734 | def _init_weights(self, module):
735 | if isinstance(module, (nn.Conv2d, nn.Conv3d)):
736 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
737 | # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
738 | elif isinstance(module, nn.Linear):
739 | nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
740 | if module.bias is not None:
741 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
742 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
743 | nn.init.uniform_(module.bias, -bound, bound)
744 | elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
745 | nn.init.constant_(module.weight, 1)
746 | nn.init.constant_(module.bias, 0)
747 |
748 |
749 | class Emu3VisionVQModel(Emu3VisionVQPretrainedModel):
750 |
751 | def __init__(self, config):
752 | super().__init__(config)
753 | self.config = config
754 |
755 | self.encoder = Emu3VisionVQEncoder(config)
756 | self.decoder = Emu3VisionVQDecoder(config)
757 | self.quantize = Emu3VisionVQVectorQuantizer(config)
758 |
759 | self.quant_conv = Emu3VisionVQCausalConv3d(config.z_channels, config.embed_dim)
760 | self.post_quant_conv = Emu3VisionVQCausalConv3d(config.embed_dim, config.z_channels)
761 |
762 | self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
763 |
764 | self.post_init()
765 |
766 | def encode(self, x: torch.Tensor):
767 | ndim = x.ndim
768 | if ndim == 4:
769 | t = self.config.temporal_downsample_factor
770 | b, c, h, w = x.shape
771 | x = x.unsqueeze(1).repeat(1, t, 1, 1, 1)
772 | elif ndim == 5:
773 | b, t, c, h, w = x.shape
774 |
775 | h = self.encoder(x)
776 |
777 | # b t c h w -> b c t h w
778 | h = h.permute(0, 2, 1, 3, 4)
779 | h = self.quant_conv(h)
780 | # b c t h w -> b t c h w
781 | h = h.permute(0, 2, 1, 3, 4)
782 |
783 | codes = self.quantize(h)
784 |
785 | if ndim == 4:
786 | codes = codes.squeeze(1)
787 |
788 | return codes
789 |
790 | def decode(self, x: torch.Tensor):
791 | ndim = x.ndim
792 | if ndim == 3:
793 | x = x.unsqueeze(1)
794 |
795 | b, t, h, w = x.shape
796 | quant = self.quantize.embedding(x.flatten())
797 | c = quant.shape[-1]
798 | quant = quant.view(b, t, h, w, c).permute(0, 4, 1, 2, 3).contiguous()
799 | quant2 = self.post_quant_conv(quant)
800 |
801 | quant = quant.permute(0, 2, 1, 3, 4)
802 | quant2 = quant2.permute(0, 2, 1, 3, 4)
803 |
804 | video = self.decoder(quant2, quant)
805 | video = video.reshape(
806 | b,
807 | t * self.config.temporal_downsample_factor,
808 | self.config.out_channels,
809 | h * self.spatial_scale_factor,
810 | w * self.spatial_scale_factor,
811 | )
812 | if ndim == 3:
813 | return video[:, 0]
814 | return video
815 |
816 | @property
817 | def device(self):
818 | return next(self.parameters()).device
819 |
820 | @property
821 | def dtype(self):
822 | return next(self.parameters()).dtype
823 |
--------------------------------------------------------------------------------
/emu3/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/Emu3/e93b6b4bb4944ef3f13e58397ba11f98c877df91/emu3/train/__init__.py
--------------------------------------------------------------------------------
/emu3/train/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import json
4 | import os.path as osp
5 | import random
6 |
7 | import torch
8 | from torch.utils.data import Dataset
9 |
10 |
11 | class Emu3FeatureDataset(Dataset):
12 |
13 | def __init__(self, args: "DataArguments", tokenizer: "Emu3Tokenizer"):
14 | super().__init__()
15 |
16 | self.args = args
17 | with open(args.data_path) as f:
18 | d = json.load(f)
19 |
20 | self.path_prefix = d["prefix"]
21 | self.filelist = d["path_list"]
22 |
23 | self.tokenizer = tokenizer
24 | self.bov = tokenizer.encode(args.visual_token_pattern.format(token_id=0))[0]
25 | self.eov = tokenizer.encode(args.visual_token_pattern.format(token_id=args.codebook_size - 1))[0]
26 |
27 | def __len__(self):
28 | return len(self.filelist)
29 |
30 | def __getitem__(self, index: int):
31 | path = osp.join(self.path_prefix, self.filelist[index])
32 | data = torch.load(path)
33 |
34 | image_tokens = data["images"]
35 | image_prompt = self.format_image_prompt(image_tokens)
36 |
37 | p_prob = random.random()
38 | if p_prob < self.args.null_prompt_prob:
39 | prompt = ""
40 | else:
41 | prompt = data["texts"]
42 |
43 | input = self.tokenizer.bos_token + prompt + image_prompt
44 | sample = self.tokenizer(
45 | input,
46 | padding="max_length",
47 | return_token_type_ids=False,
48 | return_tensors="pt",
49 | )
50 |
51 | labels = sample["input_ids"]
52 | if self.args.apply_loss_on_only_vision:
53 | labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, self.args.ignore_index)
54 |
55 | sample["labels"] = labels
56 | for k, v in sample.items():
57 | sample[k] = v.squeeze(0)
58 |
59 | return sample
60 |
61 | def format_image_prompt(self, image_tokens):
62 | h, w = image_tokens.shape
63 | imgstr = self.to_imgstr(image_tokens)
64 |
65 | image_prompt = (
66 | self.tokenizer.boi_token +
67 | f"{h}*{w}" +
68 | self.tokenizer.img_token +
69 | imgstr +
70 | self.tokenizer.eol_token +
71 | self.tokenizer.eof_token +
72 | self.tokenizer.eoi_token
73 | )
74 |
75 | return image_prompt
76 |
77 | def to_imgstr(self, image_tokens):
78 | image_token_str = [
79 | [
80 | self.args.visual_token_pattern.format(token_id=token_id)
81 | for token_id in token_row
82 | ]
83 | for token_row in image_tokens
84 | ]
85 | image_row_str = ["".join(token_row) for token_row in image_token_str]
86 | imgstr = self.tokenizer.eol_token.join(image_row_str)
87 | return imgstr
88 |
89 |
--------------------------------------------------------------------------------
/emu3/train/prepare_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import json
5 | import os
6 |
7 | from PIL import Image
8 | import torch
9 |
10 | from emu3.tokenizer import Emu3VisionVQModel, Emu3VisionVQImageProcessor
11 |
12 |
13 | def prepare_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--model-path', type=str, help='vision tokenizer path')
16 | parser.add_argument('--data-path', type=str, help='data path')
17 | parser.add_argument('--output-path', type=str, help='tokenized data save path')
18 | parser.add_argument('--image-area', type=int, default=720 * 720)
19 |
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 | def smart_resize(image, image_area: int = 720 * 720):
25 | w, h = image.size
26 | current_area = h * w
27 | target_ratio = (image_area / current_area) ** 0.5
28 |
29 | th = int(round(h * target_ratio))
30 | tw = int(round(w * target_ratio))
31 |
32 | image = image.resize((tw, th))
33 | return image
34 |
35 |
36 | def main():
37 | args = prepare_args()
38 |
39 | image_processor = Emu3VisionVQImageProcessor.from_pretrained(args.model_path)
40 | image_tokenizer = Emu3VisionVQModel.from_pretrained(args.model_path, device_map="cuda:0")
41 | image_tokenizer.eval()
42 |
43 | os.makedirs(f"{args.output_path}/feature", exist_ok=True)
44 | os.makedirs(f"{args.output_path}/list", exist_ok=True)
45 |
46 | datalist = {
47 | "prefix": f"{args.output_path}/feature",
48 | "path_list": []
49 | }
50 |
51 | with open(args.data_path) as f:
52 | input_data = json.load(f)
53 |
54 | for inp in input_data:
55 | name = inp["name"]
56 | prompt = inp["text"]
57 |
58 | image = Image.open(inp["image"]).convert("RGB")
59 | image = smart_resize(image, args.image_area)
60 |
61 | image = image_processor(image, return_tensors="pt")["pixel_values"]
62 | with torch.no_grad():
63 | image = image.cuda()
64 | token_ids = image_tokenizer.encode(image)
65 |
66 | token_ids = token_ids.squeeze(0).cpu().numpy()
67 | data = {
68 | "name": name,
69 | "images": token_ids,
70 | "texts": prompt
71 | }
72 |
73 | torch.save(data, f"{args.output_path}/feature/{name}.pth")
74 | datalist["path_list"].append(f"{name}.pth")
75 |
76 | with open(f"{args.output_path}/list/train.json", 'w') as f:
77 | json.dump(datalist, f)
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/emu3/train/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from dataclasses import dataclass, field
4 | import os
5 | import os.path as osp
6 | import pathlib
7 | from typing import Optional, List
8 |
9 | import transformers as tf
10 | import torch
11 |
12 | from emu3.mllm import Emu3Config, Emu3Tokenizer, Emu3ForCausalLM
13 | from emu3.train.datasets import Emu3FeatureDataset
14 |
15 |
16 | @dataclass
17 | class ModelArguments:
18 | model_name_or_path: Optional[str] = field(default="BAAI/Emu3-Gen")
19 |
20 |
21 | @dataclass
22 | class DataArguments:
23 | data_path: Optional[str] = field(default=None)
24 | null_prompt_prob: float = field(default=0.05)
25 | apply_loss_on_only_vision: bool = field(default=True)
26 | apply_loss_on_only_text: bool = field(default=False)
27 | ignore_index: int = field(default=-100)
28 | visual_token_pattern: str = field(default="<|visual token {token_id:0>6d}|>")
29 | codebook_size: Optional[int] = field(default=32768)
30 |
31 |
32 | @dataclass
33 | class TrainingArguments(tf.TrainingArguments):
34 | report_to: List[str] = field(default_factory=list)
35 | remove_unused_columns: bool = field(default=False)
36 | min_learning_rate: Optional[float] = field(default=None)
37 | attn_type: Optional[str] = field(default="fa2")
38 | image_area: Optional[int] = field(default=None)
39 | max_position_embeddings: Optional[int] = field(default=None)
40 |
41 |
42 | def update_configs(model_config, args, fields):
43 | cross_update = lambda a, b, field_name: (
44 | setattr(b, field_name, getattr(a, field_name))
45 | if getattr(b, field_name, None) is None else
46 | setattr(a, field_name, getattr(b, field_name))
47 | )
48 |
49 | for f in fields:
50 | cross_update(model_config, args, f)
51 |
52 |
53 | def train():
54 | parser = tf.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
55 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
56 |
57 | model_config = Emu3Config.from_pretrained(model_args.model_name_or_path)
58 | update_configs(model_config, training_args, ["image_area", "max_position_embeddings"])
59 | if training_args.min_learning_rate is not None:
60 | training_args.lr_scheduler_kwargs["min_lr"] = training_args.min_learning_rate
61 |
62 | os.environ["WANDB_DIR"] = osp.join(training_args.output_dir, "wandb")
63 |
64 | model = Emu3ForCausalLM.from_pretrained(
65 | model_args.model_name_or_path,
66 | config=model_config,
67 | attn_implementation="flash_attention_2" if training_args.attn_type == "fa2" else None,
68 | torch_dtype=torch.bfloat16 if training_args.bf16 else None,
69 | )
70 |
71 | tokenizer = Emu3Tokenizer.from_pretrained(
72 | model_args.model_name_or_path,
73 | model_max_length=training_args.max_position_embeddings,
74 | padding_side="right",
75 | use_fast=False,
76 | )
77 |
78 | train_dataset = Emu3FeatureDataset(data_args, tokenizer=tokenizer)
79 |
80 | trainer = tf.Trainer(
81 | model=model,
82 | args=training_args,
83 | train_dataset=train_dataset,
84 | )
85 |
86 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
87 | trainer.train(resume_from_checkpoint=True)
88 | else:
89 | trainer.train()
90 | trainer.save_state()
91 |
92 | torch.cuda.synchronize()
93 | trainer.save_model(training_args.output_dir)
94 |
95 |
96 | if __name__ == "__main__":
97 | train()
98 |
--------------------------------------------------------------------------------
/gradio_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import base64
4 | import io
5 | from PIL import Image
6 |
7 | import gradio as gr
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoModelForCausalLM,
11 | AutoImageProcessor,
12 | AutoModel,
13 | )
14 | from transformers.generation.configuration_utils import GenerationConfig
15 | from transformers.generation import (
16 | LogitsProcessorList,
17 | PrefixConstrainedLogitsProcessor,
18 | UnbatchedClassifierFreeGuidanceLogitsProcessor,
19 | )
20 | import torch
21 |
22 | from emu3.mllm.processing_emu3 import Emu3Processor
23 |
24 | def image2str(image):
25 | buf = io.BytesIO()
26 | image.save(buf, format="PNG")
27 | i_str = base64.b64encode(buf.getvalue()).decode()
28 | return f''
29 |
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 |
32 | # Model paths
33 | EMU_GEN_HUB = "BAAI/Emu3-Gen"
34 | EMU_CHAT_HUB = "BAAI/Emu3-Chat"
35 | VQ_HUB = "BAAI/Emu3-VisionTokenizer"
36 |
37 | # Prepare models and processors
38 | gen_model = AutoModelForCausalLM.from_pretrained(
39 | EMU_GEN_HUB,
40 | device_map="cpu",
41 | torch_dtype=torch.bfloat16,
42 | attn_implementation="flash_attention_2",
43 | trust_remote_code=True,
44 | ).eval()
45 |
46 | chat_model = AutoModelForCausalLM.from_pretrained(
47 | EMU_CHAT_HUB,
48 | device_map="cpu",
49 | torch_dtype=torch.bfloat16,
50 | attn_implementation="flash_attention_2",
51 | trust_remote_code=True,
52 | ).eval()
53 |
54 | tokenizer = AutoTokenizer.from_pretrained(
55 | EMU_CHAT_HUB, trust_remote_code=True, padding_side="left",
56 | )
57 | image_processor = AutoImageProcessor.from_pretrained(
58 | VQ_HUB, trust_remote_code=True,
59 | )
60 | image_tokenizer = AutoModel.from_pretrained(
61 | VQ_HUB, device_map="cpu", trust_remote_code=True,
62 | ).eval()
63 |
64 | image_tokenizer.to(device)
65 |
66 | processor = Emu3Processor(
67 | image_processor, image_tokenizer, tokenizer
68 | )
69 |
70 | def generate_image(prompt):
71 | POSITIVE_PROMPT = " masterpiece, film grained, best quality."
72 | NEGATIVE_PROMPT = (
73 | "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, "
74 | "fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, "
75 | "signature, watermark, username, blurry."
76 | )
77 |
78 | classifier_free_guidance = 3.0
79 | full_prompt = prompt + POSITIVE_PROMPT
80 |
81 | kwargs = dict(
82 | mode="G",
83 | ratio="1:1",
84 | image_area=gen_model.config.image_area,
85 | return_tensors="pt",
86 | )
87 | pos_inputs = processor(text=full_prompt, **kwargs)
88 | neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
89 |
90 | # Prepare hyperparameters
91 | GENERATION_CONFIG = GenerationConfig(
92 | use_cache=True,
93 | eos_token_id=gen_model.config.eos_token_id,
94 | pad_token_id=gen_model.config.pad_token_id,
95 | max_new_tokens=40960,
96 | do_sample=True,
97 | top_k=2048,
98 | )
99 |
100 | torch.cuda.empty_cache()
101 | gen_model.to(device)
102 |
103 | h = pos_inputs.image_size[:, 0]
104 | w = pos_inputs.image_size[:, 1]
105 | constrained_fn = processor.build_prefix_constrained_fn(h, w)
106 | logits_processor = LogitsProcessorList([
107 | UnbatchedClassifierFreeGuidanceLogitsProcessor(
108 | classifier_free_guidance,
109 | gen_model,
110 | unconditional_ids=neg_inputs.input_ids.to(device),
111 | ),
112 | PrefixConstrainedLogitsProcessor(
113 | constrained_fn,
114 | num_beams=1,
115 | ),
116 | ])
117 |
118 | # Generate
119 | outputs = gen_model.generate(
120 | pos_inputs.input_ids.to(device),
121 | generation_config=GENERATION_CONFIG,
122 | logits_processor=logits_processor,
123 | attention_mask=pos_inputs.attention_mask.to(device),
124 | )
125 |
126 | mm_list = processor.decode(outputs[0])
127 | result = None
128 | for idx, im in enumerate(mm_list):
129 | if isinstance(im, Image.Image):
130 | result = im
131 | break
132 |
133 | gen_model.cpu()
134 | torch.cuda.empty_cache()
135 |
136 | return result
137 |
138 | def vision_language_understanding(image, text):
139 | inputs = processor(
140 | text=text,
141 | image=image,
142 | mode="U",
143 | padding="longest",
144 | return_tensors="pt",
145 | )
146 |
147 | # Prepare hyperparameters
148 | GENERATION_CONFIG = GenerationConfig(
149 | pad_token_id=tokenizer.pad_token_id,
150 | bos_token_id=tokenizer.bos_token_id,
151 | eos_token_id=tokenizer.eos_token_id,
152 | max_new_tokens=1024,
153 | )
154 |
155 | torch.cuda.empty_cache()
156 | chat_model.to(device)
157 |
158 | # Generate
159 | outputs = chat_model.generate(
160 | inputs.input_ids.to(device),
161 | generation_config=GENERATION_CONFIG,
162 | attention_mask=inputs.attention_mask.to(device),
163 | )
164 |
165 | outputs = outputs[:, inputs.input_ids.shape[-1] :]
166 | response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
167 |
168 | chat_model.cpu()
169 | torch.cuda.empty_cache()
170 |
171 | return response
172 |
173 |
174 | def chat(history, user_input, user_image):
175 | if user_image is not None:
176 | # Use Emu3-Chat for vision-language understanding
177 | response = vision_language_understanding(user_image, user_input)
178 | # Append the user input and response to the history
179 | history = history + [(image2str(user_image) + "
" + user_input, response)]
180 | else:
181 | # Use Emu3-Gen for image generation
182 | generated_image = generate_image(user_input)
183 | if generated_image is not None:
184 | # Append the user input and generated image to the history
185 | history = history + [(user_input, image2str(generated_image))]
186 | else:
187 | # If image generation failed, respond with an error message
188 | history = history + [
189 | (user_input, "Sorry, I could not generate an image.")
190 | ]
191 |
192 | return history, history, gr.update(value=None)
193 |
194 |
195 | def clear_input():
196 | return gr.update(value="")
197 |
198 |
199 | with gr.Blocks() as demo:
200 | gr.Markdown("# Emu3 Chatbot Demo")
201 | gr.Markdown(
202 | "This is a chatbot demo for image generation and vision-language understanding using Emu3 models."
203 | )
204 | gr.Markdown(
205 | "Please provide only text input for image generation (\~600s) and both image and text for vision-language understanding (\~20s)"
206 | )
207 |
208 | state = gr.State([])
209 | with gr.Row():
210 | with gr.Column(scale=0.2):
211 | user_input = gr.Textbox(
212 | show_label=False, placeholder="Type your message here...", lines=10, container=False,
213 | )
214 | user_image = gr.Image(
215 | sources="upload", type="pil", label="Upload an image (optional)"
216 | )
217 | submit_btn = gr.Button("Send")
218 |
219 | with gr.Column(scale=0.8):
220 | chatbot = gr.Chatbot(height=800)
221 |
222 | submit_btn.click(
223 | chat,
224 | inputs=[state, user_input, user_image],
225 | outputs=[chatbot, state, user_image],
226 | ).then(fn=clear_input, inputs=[], outputs=user_input, queue=False)
227 | user_input.submit(
228 | chat,
229 | inputs=[state, user_input, user_image],
230 | outputs=[chatbot, state, user_image],
231 | ).then(fn=clear_input, inputs=[], outputs=user_input, queue=False)
232 |
233 | demo.launch(max_threads=1).queue()
234 |
--------------------------------------------------------------------------------
/image_generation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from PIL import Image
3 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
4 | from transformers.generation.configuration_utils import GenerationConfig
5 | from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
6 | import torch
7 |
8 | from emu3.mllm.processing_emu3 import Emu3Processor
9 |
10 |
11 | # model path
12 | EMU_HUB = "BAAI/Emu3-Gen"
13 | VQ_HUB = "BAAI/Emu3-VisionTokenizer"
14 |
15 | # prepare model and processor
16 | model = AutoModelForCausalLM.from_pretrained(
17 | EMU_HUB,
18 | device_map="cuda:0",
19 | torch_dtype=torch.bfloat16,
20 | attn_implementation="flash_attention_2",
21 | trust_remote_code=True,
22 | )
23 | model.eval()
24 |
25 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left")
26 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
27 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
28 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
29 |
30 | # prepare input
31 | POSITIVE_PROMPT = " masterpiece, film grained, best quality."
32 | NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
33 |
34 | classifier_free_guidance = 3.0
35 | prompt = ["a portrait of young girl.", "a shiba inu"]
36 | prompt = [p + POSITIVE_PROMPT for p in prompt]
37 |
38 | kwargs = dict(
39 | mode='G',
40 | ratio=["1:1", "16:9"],
41 | image_area=model.config.image_area,
42 | return_tensors="pt",
43 | padding="longest",
44 | )
45 | pos_inputs = processor(text=prompt, **kwargs)
46 | neg_inputs = processor(text=[NEGATIVE_PROMPT] * len(prompt), **kwargs)
47 |
48 | # prepare hyper parameters
49 | GENERATION_CONFIG = GenerationConfig(
50 | use_cache=True,
51 | eos_token_id=model.config.eos_token_id,
52 | pad_token_id=model.config.pad_token_id,
53 | max_new_tokens=40960,
54 | do_sample=True,
55 | top_k=2048,
56 | )
57 |
58 | h = pos_inputs.image_size[:, 0]
59 | w = pos_inputs.image_size[:, 1]
60 | constrained_fn = processor.build_prefix_constrained_fn(h, w)
61 | logits_processor = LogitsProcessorList([
62 | UnbatchedClassifierFreeGuidanceLogitsProcessor(
63 | classifier_free_guidance,
64 | model,
65 | unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
66 | ),
67 | PrefixConstrainedLogitsProcessor(
68 | constrained_fn ,
69 | num_beams=1,
70 | ),
71 | ])
72 |
73 | # generate
74 | outputs = model.generate(
75 | pos_inputs.input_ids.to("cuda:0"),
76 | GENERATION_CONFIG,
77 | logits_processor=logits_processor,
78 | attention_mask=pos_inputs.attention_mask.to("cuda:0"),
79 | )
80 |
81 | for idx_i, out in enumerate(outputs):
82 | mm_list = processor.decode(out)
83 | for idx_j, im in enumerate(mm_list):
84 | if not isinstance(im, Image.Image):
85 | continue
86 | im.save(f"result_{idx_i}_{idx_j}.png")
87 |
--------------------------------------------------------------------------------
/multimodal_understanding.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from PIL import Image
3 | from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
4 | from transformers.generation.configuration_utils import GenerationConfig
5 | import torch
6 |
7 | from emu3.mllm.processing_emu3 import Emu3Processor
8 |
9 |
10 | # model path
11 | EMU_HUB = "BAAI/Emu3-Chat"
12 | VQ_HUB = "BAAI/Emu3-VisionTokenizer"
13 |
14 | # prepare model and processor
15 | model = AutoModelForCausalLM.from_pretrained(
16 | EMU_HUB,
17 | device_map="cuda:0",
18 | torch_dtype=torch.bfloat16,
19 | attn_implementation="flash_attention_2",
20 | trust_remote_code=True,
21 | )
22 | model.eval()
23 |
24 | tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left")
25 | image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
26 | image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
27 | processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
28 |
29 | # prepare input
30 | text = ["Please describe the image", "Please describe the image"]
31 | image = Image.open("assets/demo.png")
32 | image = [image, image]
33 |
34 | inputs = processor(
35 | text=text,
36 | image=image,
37 | mode='U',
38 | padding_image=True,
39 | padding="longest",
40 | return_tensors="pt",
41 | )
42 |
43 | # prepare hyper parameters
44 | GENERATION_CONFIG = GenerationConfig(pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
45 |
46 | # generate
47 | outputs = model.generate(
48 | inputs.input_ids.to("cuda:0"),
49 | GENERATION_CONFIG,
50 | max_new_tokens=1024,
51 | attention_mask=inputs.attention_mask.to("cuda:0"),
52 | )
53 |
54 | outputs = outputs[:, inputs.input_ids.shape[-1]:]
55 | answers = processor.batch_decode(outputs, skip_special_tokens=True)
56 | for ans in answers:
57 | print(ans)
58 |
--------------------------------------------------------------------------------
/replicate_demo/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://cog.run/yaml
3 |
4 | build:
5 | # set to true if your model requires a GPU
6 | gpu: true
7 |
8 | # a list of ubuntu apt packages to install
9 | system_packages:
10 | - "libgl1-mesa-glx"
11 | - "libglib2.0-0"
12 |
13 | # python version in the form '3.11' or '3.11.4'
14 | python_version: "3.11"
15 |
16 | # a list of packages in the format ==
17 | python_packages:
18 | # - packaging
19 | - torch==2.2.1
20 | - transformers==4.44.0
21 | - tiktoken==0.6.0
22 | - accelerate
23 | - numpy<2
24 | run:
25 | - pip install flash-attn==2.5.7
26 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
27 |
28 | # predict.py defines how predictions are run on your model
29 | predict: "predict_chat.py:Predictor"
30 | # predict: "predict_gen.py:Predictor"
31 |
--------------------------------------------------------------------------------
/replicate_demo/predict_chat.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://cog.run/python
3 |
4 | import os
5 | import time
6 | import subprocess
7 | from PIL import Image
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoModel,
11 | AutoImageProcessor,
12 | AutoModelForCausalLM,
13 | )
14 | from transformers.generation.configuration_utils import GenerationConfig
15 | import torch
16 | from cog import BasePredictor, Input, Path
17 |
18 | from emu3.mllm.processing_emu3 import Emu3Processor
19 |
20 |
21 | MODEL_CACHE = "model_cache_chat"
22 | MODEL_URL = (
23 | f"https://weights.replicate.delivery/default/baaivision/Emu3/{MODEL_CACHE}.tar"
24 | )
25 | os.environ["HF_DATASETS_OFFLINE"] = "1"
26 | os.environ["TRANSFORMERS_OFFLINE"] = "1"
27 | os.environ["HF_HOME"] = MODEL_CACHE
28 | os.environ["TORCH_HOME"] = MODEL_CACHE
29 | os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
30 | os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
31 | os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
32 |
33 | TORCH_TYPE = torch.bfloat16
34 | DEVICE = "cuda:0"
35 |
36 |
37 | def download_weights(url, dest):
38 | start = time.time()
39 | print("downloading url: ", url)
40 | print("downloading to: ", dest)
41 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
42 | print("downloading took: ", time.time() - start)
43 |
44 |
45 | class Predictor(BasePredictor):
46 | def setup(self) -> None:
47 | """Load the model into memory to make running multiple predictions efficient"""
48 |
49 | if not os.path.exists(MODEL_CACHE):
50 | download_weights(MODEL_URL, MODEL_CACHE)
51 |
52 | # prepare model and processor
53 | self.model = AutoModelForCausalLM.from_pretrained(
54 | f"{MODEL_CACHE}/Emu3-Chat", # "BAAI/Emu3-Chat"
55 | device_map="cuda:0",
56 | torch_dtype=torch.bfloat16,
57 | attn_implementation="flash_attention_2",
58 | trust_remote_code=True,
59 | )
60 |
61 | tokenizer = AutoTokenizer.from_pretrained(
62 | f"{MODEL_CACHE}/Emu3-Chat", trust_remote_code=True
63 | ) # "BAAI/Emu3-Chat"
64 | image_processor = AutoImageProcessor.from_pretrained(
65 | f"{MODEL_CACHE}/Emu3-VisionTokenizer", trust_remote_code=True
66 | ) # "BAAI/Emu3-VisionTokenizer"
67 | image_tokenizer = AutoModel.from_pretrained(
68 | f"{MODEL_CACHE}/Emu3-VisionTokenizer",
69 | device_map="cuda:0",
70 | trust_remote_code=True,
71 | ).eval() # "BAAI/Emu3-VisionTokenizer"
72 | self.processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
73 | # prepare hyper parameters
74 | self.generation_config = GenerationConfig(
75 | pad_token_id=tokenizer.pad_token_id,
76 | bos_token_id=tokenizer.bos_token_id,
77 | eos_token_id=tokenizer.eos_token_id,
78 | )
79 |
80 | def predict(
81 | self,
82 | text: str = Input(
83 | description="Input prompt",
84 | default="Please describe the image.",
85 | ),
86 | image: Path = Input(
87 | default="Input image",
88 | ),
89 | temperature: float = Input(
90 | description="Controls randomness. Lower values make the model more deterministic, higher values make it more random.",
91 | default=0.7,
92 | ge=0.0,
93 | le=1.0,
94 | ),
95 | top_p: float = Input(
96 | description="Controls diversity of the output. Valid when temperature > 0. Lower values make the output more focused, higher values make it more diverse.",
97 | default=0.9,
98 | ge=0.0,
99 | le=1.0,
100 | ),
101 | max_new_tokens: int = Input(
102 | description="Maximum number of tokens to generate", default=256, ge=1
103 | ),
104 | ) -> str:
105 | """Run a single prediction on the model"""
106 |
107 | img = Image.open(str(image))
108 |
109 | inputs = self.processor(
110 | text=text,
111 | image=img,
112 | mode="U",
113 | padding_side="left",
114 | padding="longest",
115 | return_tensors="pt",
116 | )
117 |
118 | outputs = self.model.generate(
119 | inputs.input_ids.to("cuda:0"),
120 | self.generation_config,
121 | max_new_tokens=max_new_tokens,
122 | temperature=temperature,
123 | top_p=top_p,
124 | )
125 |
126 | outputs = outputs[:, inputs.input_ids.shape[-1] :]
127 | return self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
128 |
--------------------------------------------------------------------------------
/replicate_demo/predict_gen.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://cog.run/python
3 |
4 | import os
5 | import time
6 | import subprocess
7 | from PIL import Image
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoModel,
11 | AutoImageProcessor,
12 | AutoModelForCausalLM,
13 | )
14 | from transformers.generation.configuration_utils import GenerationConfig
15 | from transformers.generation import (
16 | LogitsProcessorList,
17 | PrefixConstrainedLogitsProcessor,
18 | UnbatchedClassifierFreeGuidanceLogitsProcessor,
19 | )
20 | import torch
21 | from cog import BasePredictor, Input, Path
22 |
23 | from emu3.mllm.processing_emu3 import Emu3Processor
24 |
25 |
26 | MODEL_CACHE = "model_cache"
27 | MODEL_URL = (
28 | f"https://weights.replicate.delivery/default/baaivision/Emu3/{MODEL_CACHE}.tar"
29 | )
30 | os.environ["HF_DATASETS_OFFLINE"] = "1"
31 | os.environ["TRANSFORMERS_OFFLINE"] = "1"
32 | os.environ["HF_HOME"] = MODEL_CACHE
33 | os.environ["TORCH_HOME"] = MODEL_CACHE
34 | os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
35 | os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
36 | os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
37 |
38 | TORCH_TYPE = torch.bfloat16
39 | DEVICE = "cuda:0"
40 |
41 |
42 | def download_weights(url, dest):
43 | start = time.time()
44 | print("downloading url: ", url)
45 | print("downloading to: ", dest)
46 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
47 | print("downloading took: ", time.time() - start)
48 |
49 |
50 | class Predictor(BasePredictor):
51 | def setup(self) -> None:
52 | """Load the model into memory to make running multiple predictions efficient"""
53 |
54 | if not os.path.exists(MODEL_CACHE):
55 | download_weights(MODEL_URL, MODEL_CACHE)
56 |
57 | # prepare model and processor
58 | self.model = AutoModelForCausalLM.from_pretrained(
59 | f"{MODEL_CACHE}/Emu3-Gen", # "BAAI/Emu3-Gen"
60 | device_map="cuda:0",
61 | torch_dtype=torch.bfloat16,
62 | attn_implementation="flash_attention_2",
63 | trust_remote_code=True,
64 | )
65 |
66 | tokenizer = AutoTokenizer.from_pretrained(
67 | f"{MODEL_CACHE}/Emu3-Gen", trust_remote_code=True
68 | ) # "BAAI/Emu3-Gen"
69 | image_processor = AutoImageProcessor.from_pretrained(
70 | f"{MODEL_CACHE}/Emu3-VisionTokenizer", trust_remote_code=True
71 | ) # "BAAI/Emu3-VisionTokenizer"
72 | image_tokenizer = AutoModel.from_pretrained(
73 | f"{MODEL_CACHE}/Emu3-VisionTokenizer",
74 | device_map="cuda:0",
75 | trust_remote_code=True,
76 | ).eval() # "BAAI/Emu3-VisionTokenizer"
77 | self.processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
78 |
79 | self.kwargs = dict(
80 | mode="G",
81 | ratio="1:1",
82 | image_area=self.model.config.image_area,
83 | return_tensors="pt",
84 | )
85 |
86 | # prepare hyper parameters
87 | self.generation_config = GenerationConfig(
88 | use_cache=True,
89 | eos_token_id=self.model.config.eos_token_id,
90 | pad_token_id=self.model.config.pad_token_id,
91 | max_new_tokens=40960,
92 | do_sample=True,
93 | top_k=2048,
94 | )
95 |
96 | def predict(
97 | self,
98 | prompt: str = Input(
99 | description="Input prompt",
100 | default="a portrait of young girl.",
101 | ),
102 | positive_prompt: str = Input(
103 | default="masterpiece, film grained, best quality.",
104 | ),
105 | negative_prompt: str = Input(
106 | description="Specify things to not see in the output",
107 | default="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.",
108 | ),
109 | guidance_scale: float = Input(
110 | description="Scale for classifier-free guidance", ge=1, le=20, default=3
111 | ),
112 | ) -> Path:
113 | """Run a single prediction on the model"""
114 |
115 | pos_inputs = self.processor(
116 | text=prompt + " " + positive_prompt.strip(), **self.kwargs
117 | )
118 | neg_inputs = self.processor(text=negative_prompt, **self.kwargs)
119 |
120 | h, w = pos_inputs.image_size[0]
121 | constrained_fn = self.processor.build_prefix_constrained_fn(h, w)
122 | logits_processor = LogitsProcessorList(
123 | [
124 | UnbatchedClassifierFreeGuidanceLogitsProcessor(
125 | guidance_scale,
126 | self.model,
127 | unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
128 | ),
129 | PrefixConstrainedLogitsProcessor(
130 | constrained_fn,
131 | num_beams=1,
132 | ),
133 | ]
134 | )
135 |
136 | # generate
137 | outputs = self.model.generate(
138 | pos_inputs.input_ids.to("cuda:0"),
139 | self.generation_config,
140 | logits_processor=logits_processor,
141 | )
142 |
143 | out_path = "/tmp/out.png"
144 |
145 | mm_list = self.processor.decode(outputs[0])
146 | print(len(mm_list))
147 | print(mm_list)
148 | for idx, im in enumerate(mm_list):
149 | if not isinstance(im, Image.Image):
150 | continue
151 | im.save(out_path)
152 | return Path(out_path)
153 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.1
2 | transformers==4.44.0
3 | tiktoken==0.6.0
4 | flash-attn==2.5.7
5 | pillow
6 | gradio==4.44.0
7 |
--------------------------------------------------------------------------------
/scripts/t2i_sft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | WORLD_SIZE=${WORLD_SIZE:-1}
4 | RANK=${RANK:-0}
5 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
6 | MASTER_PORT=${MASTER_PORT:-23456}
7 | NGPUS=$(python -c "import torch; print(torch.cuda.device_count())")
8 |
9 | export PYTHONPATH=$(pwd)
10 |
11 | DATAPATH="your data path (json file)"
12 | EXP_NAME="Emu3-T2I-SFT-Trial"
13 | torchrun \
14 | --nproc_per_node=${NGPUS} \
15 | --nnodes=${WORLD_SIZE} \
16 | --node_rank=${RANK} \
17 | --master_addr=${MASTER_ADDR} \
18 | --master_port=${MASTER_PORT} \
19 | emu3/train/train.py \
20 | --model_name_or_path BAAI/Emu3-Gen \
21 | --deepspeed scripts/zero3.json \
22 | --data_path ${DATAPATH} \
23 | --null_prompt_prob 0.05 \
24 | --apply_loss_on_only_vision True \
25 | --apply_loss_on_only_text False \
26 | --image_area 518400 \
27 | --max_position_embeddings 10240 \
28 | --output_dir "logs/"${EXP_NAME} \
29 | --bf16 True \
30 | --tf32 True \
31 | --num_train_epochs 4 \
32 | --per_device_train_batch_size 2 \
33 | --gradient_accumulation_steps 4 \
34 | --eval_strategy no \
35 | --save_strategy steps \
36 | --save_steps 500 \
37 | --save_total_limit 10 \
38 | --learning_rate 1e-5 \
39 | --min_learning_rate 1e-6 \
40 | --weight_decay 0.1 \
41 | --max_grad_norm 5.0 \
42 | --adam_beta1 0.9 \
43 | --adam_beta2 0.95 \
44 | --adam_epsilon 1e-6 \
45 | --warmup_steps 30 \
46 | --lr_scheduler_type "cosine_with_min_lr" \
47 | --logging_steps 1 \
48 | --gradient_checkpointing True \
49 | --dataloader_num_workers 4 \
50 | --report_to wandb tensorboard \
51 | --run_name ${EXP_NAME}
52 |
--------------------------------------------------------------------------------
/scripts/t2i_sft_offload.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | WORLD_SIZE=${WORLD_SIZE:-1}
4 | RANK=${RANK:-0}
5 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
6 | MASTER_PORT=${MASTER_PORT:-23456}
7 | NGPUS=$(python -c "import torch; print(torch.cuda.device_count())")
8 |
9 | export PYTHONPATH=$(pwd)
10 |
11 | DATAPATH="your data path"
12 | EXP_NAME="Emu3-T2I-SFT-Trial"
13 | torchrun \
14 | --nproc_per_node=${NGPUS} \
15 | --nnodes=${WORLD_SIZE} \
16 | --node_rank=${RANK} \
17 | --master_addr=${MASTER_ADDR} \
18 | --master_port=${MASTER_PORT} \
19 | emu3/train/train.py \
20 | --model_name_or_path BAAI/Emu3-Gen \
21 | --deepspeed scripts/zero3_offload.json \
22 | --data_path ${DATAPATH} \
23 | --null_prompt_prob 0.05 \
24 | --apply_loss_on_only_vision True \
25 | --apply_loss_on_only_text False \
26 | --image_area 518400 \
27 | --max_position_embeddings 10240 \
28 | --output_dir "logs/"${EXP_NAME} \
29 | --bf16 True \
30 | --tf32 True \
31 | --num_train_epochs 4 \
32 | --per_device_train_batch_size 2 \
33 | --gradient_accumulation_steps 4 \
34 | --eval_strategy no \
35 | --save_strategy steps \
36 | --save_steps 500 \
37 | --save_total_limit 10 \
38 | --learning_rate 1e-5 \
39 | --min_learning_rate 1e-6 \
40 | --weight_decay 0.1 \
41 | --max_grad_norm 5.0 \
42 | --adam_beta1 0.9 \
43 | --adam_beta2 0.95 \
44 | --adam_epsilon 1e-6 \
45 | --warmup_steps 30 \
46 | --lr_scheduler_type "cosine_with_min_lr" \
47 | --logging_steps 1 \
48 | --gradient_checkpointing True \
49 | --dataloader_num_workers 4 \
50 | --report_to wandb tensorboard \
51 | --run_name ${EXP_NAME}
52 |
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true,
27 | "offload_optimizer": {
28 | "device": "cpu",
29 | "pin_memory": true
30 | },
31 | "offload_param": {
32 | "device": "cpu",
33 | "pin_memory": true
34 | }
35 | }
36 | }
37 |
--------------------------------------------------------------------------------