├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── meme.jpg ├── octupusy.jpg ├── teaser.webp └── woman.jpg ├── data ├── __init__.py ├── configs │ └── example.yaml ├── data_utils.py ├── dataset_base.py ├── dataset_info.py ├── distributed_iterable_dataset.py ├── interleave_datasets │ ├── __init__.py │ ├── edit_dataset.py │ └── interleave_t2i_dataset.py ├── parquet_utils.py ├── t2i_dataset.py ├── transforms.py ├── video_utils.py └── vlm_dataset.py ├── example_workflows ├── bagel_image_edit.json ├── bagel_image_edit.png ├── bagel_image_understanding.json ├── bagel_image_understanding.png ├── bagel_text_to_image.json └── bagel_text_to_image.png ├── inferencer.py ├── modeling ├── __init__.py ├── autoencoder.py ├── bagel │ ├── __init__.py │ ├── bagel.py │ ├── modeling_utils.py │ ├── qwen2_navit.py │ └── siglip_navit.py ├── qwen2 │ ├── __init__.py │ ├── configuration_qwen2.py │ ├── modeling_qwen2.py │ ├── tokenization_qwen2.py │ └── tokenization_qwen2_fast.py └── siglip │ ├── __init__.py │ ├── configuration_siglip.py │ ├── convert_siglip_to_hf.py │ ├── image_processing_siglip.py │ ├── modeling_siglip.py │ ├── processing_siglip.py │ └── tokenization_siglip.py ├── node.py ├── pyproject.toml └── requirements.txt /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'neverbiasu' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-Bagel 2 | 3 | A ComfyUI custom node package based on the BAGEL-7B-MoT multimodal model. 4 | 5 | ## About BAGEL 6 | 7 |

8 | BAGEL 9 |

10 | 11 | BAGEL is an open-source multimodal foundation model with 7B active parameters (14B total) that adopts a Mixture-of-Transformer-Experts (MoT) architecture. It is designed for multimodal understanding and generation tasks, outperforming top-tier open-source VLMs like Qwen2.5-VL and InternVL-2.5 on standard multimodal understanding leaderboards, and delivering text-to-image quality competitive with specialist generators such as SD3. 12 | 13 | ## Features 14 | 15 | - **Text-to-Image Generation**: Generate high-quality images using natural language prompts 16 | - **Image Editing**: Edit existing images based on textual descriptions 17 | - **Image Understanding**: Perform Q&A and analysis on images 18 | - **Reasoning Process Display**: Optionally display the model's reasoning process 19 | - **DFloat11 Quantized Model Support**: Support for DFloat11 quantized version that requires only ~22GB VRAM 20 | 21 | ## Installation 22 | 23 | ### 1. Model Selection and Download 24 | The ComfyUI-Bagel node now supports automatic model selection via dropdown: 25 | - **ByteDance-Seed/BAGEL-7B-MoT**: Original standard model (~80GB VRAM recommended) 26 | - **DFloat11/BAGEL-7B-MoT-DF11**: Quantized model (~22GB VRAM, single 24GB GPU compatible) 27 | 28 | Models will be automatically downloaded to `models/bagel/` when first selected. You can also manually download them: 29 | 30 | #### Standard Model 31 | ```bash 32 | # Clone model using git lfs (recommended) 33 | git lfs install 34 | git clone https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT models/bagel/BAGEL-7B-MoT 35 | 36 | # Or use huggingface_hub 37 | pip install huggingface_hub 38 | python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='ByteDance-Seed/BAGEL-7B-MoT', local_dir='models/bagel/BAGEL-7B-MoT')" 39 | ``` 40 | 41 | #### DFloat11 Quantized Model (Recommended for single GPU) 42 | ```bash 43 | # Clone DFloat11 quantized model 44 | git clone https://huggingface.co/DFloat11/BAGEL-7B-MoT-DF11 models/bagel/BAGEL-7B-MoT-DF11 45 | 46 | # Or use huggingface_hub 47 | python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='DFloat11/BAGEL-7B-MoT-DF11', local_dir='models/bagel/BAGEL-7B-MoT-DF11')" 48 | ``` 49 | 50 | ### 2. Install Dependencies 51 | Install the required dependencies: 52 | ```bash 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | For DFloat11 quantized model support, also install: 57 | ```bash 58 | pip install dfloat11 59 | ``` 60 | 61 | ### 3. Restart ComfyUI 62 | Restart ComfyUI to load the new nodes. 63 | 64 | ## Workflows 65 | 66 | ### Text-to-Image Generation 67 | ![text to image workflow](example_workflows/bagel_text_to_image.png) 68 | Generate high-quality images from text descriptions. Suitable for creative design and content generation. 69 | 70 | ### Image Editing Workflow 71 | ![image editing workflow](example_workflows/bagel_image_edit.png) 72 | Edit existing images based on textual descriptions, supporting local modifications and style adjustments. 73 | 74 | ### Image Understanding Workflow 75 | ![image understanding workflow](example_workflows/bagel_image_understanding.png) 76 | Analyze and answer questions about image content, suitable for content understanding and information extraction. 77 | 78 | ## Performance Comparison 79 | 80 | | Metric | BAGEL-7B-MoT (Standard Model) | BAGEL-7B-MoT (DFloat11 Quantized Model) | 81 | |--------|-------------------------------|-----------------------------------------| 82 | | Model Size | 29.21 GB | 19.89 GB | 83 | | Peak GPU Memory (1024x1024 image generation) | 30.07 GB | 21.76 GB | 84 | | Generation Time (on an RTX4090 GPU) | 482.95 seconds | 154.39 seconds | 85 | 86 | DFloat11 Quantized Model significantly reduces VRAM requirements and speeds up generation time, making it ideal for single GPU setups. 87 | 88 | ## Related Links 89 | 90 | - [BAGEL Official Paper](https://arxiv.org/abs/2505.14683) 91 | - [BAGEL Model Homepage](https://bagel-ai.org/) 92 | - [Hugging Face Model](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT) 93 | - [Online Demo](https://demo.bagel-ai.org/) 94 | - [Discord Community](https://discord.gg/Z836xxzy) 95 | 96 | ## License 97 | 98 | This project is licensed under the Apache 2.0 License. Please refer to the official license terms for the use of the BAGEL model. 99 | 100 | ## Contribution 101 | 102 | Contributions are welcome! Please submit issue reports and feature requests. If you wish to contribute code, please create an issue to discuss your ideas first. 103 | 104 | ## FAQ 105 | 106 | ### 1. VRAM Requirements 107 | The official recommendation for generating a 1024×1024 image is over 80GB GPU memory. However, multi-GPU setups can distribute the memory load. For example: 108 | - **Single GPU**: A100 (40GB) takes approximately 340-380 seconds per image. 109 | - **Multi-GPU**: 3 RTX3090 GPUs (24GB each) complete the task in about 1 minute. 110 | - **Compressed Model**: Using the DFloat11 version requires only 22GB VRAM and can run on a single 24GB GPU, with peak memory usage around 21.76GB (A100) and generation time of approximately 58 seconds. 111 | 112 | For more details, visit the [GitHub issue](https://github.com/ByteDance-Seed/Bagel/issues/4). 113 | 114 | ### 2. Quantized Version 115 | A quantized version of BAGEL is currently under development, which aims to reduce VRAM requirements further. 116 | 117 | ### 3. NameError: 'Qwen2Config' is not defined 118 | This issue is likely related to environment or dependency problems. For more information, refer to [this GitHub issue](https://github.com/neverbiasu/ComfyUI-BAGEL/issues/7). 119 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ComfyUI-Bagel - ComfyUI custom node package for the BAGEL multimodal model 3 | """ 4 | 5 | from .node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 6 | 7 | # Export node mappings for ComfyUI 8 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 9 | -------------------------------------------------------------------------------- /assets/meme.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/meme.jpg -------------------------------------------------------------------------------- /assets/octupusy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/octupusy.jpg -------------------------------------------------------------------------------- /assets/teaser.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/teaser.webp -------------------------------------------------------------------------------- /assets/woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/woman.jpg -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /data/configs/example.yaml: -------------------------------------------------------------------------------- 1 | t2i_pretrain: 2 | dataset_names: 3 | - t2i 4 | image_transform_args: 5 | image_stride: 16 6 | max_image_size: 1024 7 | min_image_size: 512 8 | is_mandatory: true 9 | num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS 10 | - 10 11 | weight: 1 12 | 13 | unified_edit: 14 | dataset_names: 15 | - seedxedit_multi 16 | image_transform_args: 17 | image_stride: 16 18 | max_image_size: 1024 19 | min_image_size: 512 20 | vit_image_transform_args: 21 | image_stride: 14 22 | max_image_size: 518 23 | min_image_size: 224 24 | is_mandatory: false 25 | num_used_data: 26 | - 10 27 | weight: 1 28 | 29 | vlm_sft: 30 | dataset_names: 31 | - llava_ov 32 | image_transform_args: 33 | image_stride: 14 34 | max_image_size: 980 35 | min_image_size: 378 36 | max_pixels: 2_007_040 37 | frame_sampler_args: 38 | max_num_frames: 12 39 | min_num_frames: 8 40 | is_mandatory: true 41 | shuffle_lines: True 42 | shuffle_seed: 0 43 | num_used_data: 44 | - 1000 45 | weight: 1 46 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | import math 6 | import random 7 | from PIL import Image 8 | 9 | import torch 10 | from torch.nn.attention.flex_attention import or_masks, and_masks 11 | 12 | 13 | def create_sparse_mask(document_lens, split_lens, attn_modes, device): 14 | def causal_mask(b, h, q_idx, kv_idx): 15 | return q_idx >= kv_idx 16 | 17 | def full_and_noise_mask(b, h, q_idx, kv_idx): 18 | return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) 19 | 20 | def remove_noise_mask(b, h, q_idx, kv_idx): 21 | return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))) 22 | 23 | def sample_mask(b, h, q_idx, kv_idx): 24 | return document_id[q_idx] == document_id[kv_idx] 25 | 26 | full_and_noise_tmp = [] 27 | noise_tmp = [] 28 | 29 | for i, (length, model) in enumerate(zip(split_lens, attn_modes)): 30 | value = i if model in ['full', 'noise'] else -1 31 | full_and_noise_tmp.extend([value] * length) 32 | value_noise = i if model == 'noise' else -1 33 | noise_tmp.extend([value_noise] * length) 34 | 35 | full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) 36 | noise_seq_id = torch.Tensor(noise_tmp).to(device) 37 | 38 | document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) 39 | 40 | return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) 41 | 42 | 43 | def patchify(image, patch_size): 44 | p = patch_size 45 | c, h, w = image.shape 46 | assert h % p == 0 and w % p == 0 47 | image = image.reshape(c, h // p, p, w // p, p) 48 | image = torch.einsum("chpwq->hwpqc", image) 49 | image = image.reshape(-1, p**2 * c) 50 | return image 51 | 52 | 53 | def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): 54 | num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size 55 | coords_h = torch.arange(0, num_patches_h) 56 | coords_w = torch.arange(0, num_patches_w) 57 | pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() 58 | return pos_ids 59 | 60 | 61 | def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): 62 | num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size 63 | boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) 64 | fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) 65 | fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) 66 | bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) 67 | bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) 68 | pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() 69 | return pos_ids 70 | 71 | 72 | def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): 73 | """ 74 | nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within 75 | a sample, where each sample contains multiple splits with different attn modes. 76 | nested_attn_modes: whether to use full attn in each split. 77 | """ 78 | sample_len = sum(split_lens) 79 | attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) 80 | 81 | csum = 0 82 | for s, attn_mode in zip(split_lens, attn_modes): 83 | assert attn_mode in ['causal', 'full', 'noise'] 84 | if attn_mode == "causal": 85 | attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril() 86 | attention_mask[csum:csum + s, :csum] = 1 87 | else: 88 | attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s)) 89 | attention_mask[csum:csum + s, :csum] = 1 90 | csum += s 91 | 92 | csum = 0 93 | for s, attn_mode in zip(split_lens, attn_modes): 94 | if attn_mode == "noise": 95 | attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s)) 96 | attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s)) 97 | csum += s 98 | 99 | attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( 100 | ~attention_mask, float("-inf") 101 | ) 102 | 103 | return attention_mask 104 | 105 | 106 | def split_integer_exp_decay(S, ng_sample_decay=1.0): 107 | if ng_sample_decay == 1.0: 108 | N = random.randint(1, S) 109 | else: 110 | base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S)) 111 | p = [base * math.pow(ng_sample_decay, i) for i in range(S)] 112 | N = random.choices(list(range(1, S + 1)), p, k=1)[0] 113 | cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S] 114 | result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)] 115 | return result, cumsum 116 | 117 | 118 | def pil_img2rgb(image): 119 | if image.mode == "RGBA" or image.info.get("transparency", None) is not None: 120 | image = image.convert("RGBA") 121 | white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255)) 122 | white.paste(image, mask=image.split()[3]) 123 | image = white 124 | else: 125 | image = image.convert("RGB") 126 | 127 | return image 128 | 129 | 130 | def add_special_tokens(tokenizer): 131 | all_special_tokens = [] 132 | for k, v in tokenizer.special_tokens_map.items(): 133 | if isinstance(v, str): 134 | all_special_tokens.append(v) 135 | elif isinstance(v, list): 136 | all_special_tokens += v 137 | 138 | new_tokens = [] 139 | 140 | if '<|im_start|>' not in all_special_tokens: 141 | new_tokens.append('<|im_start|>') 142 | 143 | if '<|im_end|>' not in all_special_tokens: 144 | new_tokens.append('<|im_end|>') 145 | 146 | if '<|vision_start|>' not in all_special_tokens: 147 | new_tokens.append('<|vision_start|>') 148 | 149 | if '<|vision_end|>' not in all_special_tokens: 150 | new_tokens.append('<|vision_end|>') 151 | 152 | num_new_tokens = tokenizer.add_tokens(new_tokens) 153 | bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>') 154 | eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>') 155 | start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') 156 | end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') 157 | 158 | new_token_ids = dict( 159 | bos_token_id=bos_token_id, 160 | eos_token_id=eos_token_id, 161 | start_of_image=start_of_image, 162 | end_of_image=end_of_image, 163 | ) 164 | 165 | return tokenizer, new_token_ids, num_new_tokens 166 | 167 | 168 | def len2weight(x, loss_reduction='square'): 169 | if x == 0: 170 | return x 171 | if loss_reduction == 'token': 172 | return 1 173 | if loss_reduction == 'sample': 174 | return 1 / x 175 | if loss_reduction == 'square': 176 | return 1 / (x ** 0.5) 177 | raise NotImplementedError(loss_reduction) 178 | -------------------------------------------------------------------------------- /data/dataset_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .interleave_datasets import UnifiedEditIterableDataset 5 | from .t2i_dataset import T2IIterableDataset 6 | from .vlm_dataset import SftJSONLIterableDataset 7 | 8 | 9 | DATASET_REGISTRY = { 10 | 't2i_pretrain': T2IIterableDataset, 11 | 'vlm_sft': SftJSONLIterableDataset, 12 | 'unified_edit': UnifiedEditIterableDataset, 13 | } 14 | 15 | 16 | DATASET_INFO = { 17 | 't2i_pretrain': { 18 | 't2i': { 19 | 'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files 20 | 'num_files': 10, # number of data units to be sharded across all ranks and workers 21 | 'num_total_samples': 1000, # number of total samples in the dataset 22 | }, 23 | }, 24 | 'unified_edit':{ 25 | 'seedxedit_multi': { 26 | 'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi', 27 | 'num_files': 10, 28 | 'num_total_samples': 1000, 29 | "parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files 30 | }, 31 | }, 32 | 'vlm_sft': { 33 | 'llava_ov': { 34 | 'data_dir': 'your_data_path/bagel_example/vlm/images', 35 | 'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl', 36 | 'num_total_samples': 1000 37 | }, 38 | }, 39 | } -------------------------------------------------------------------------------- /data/distributed_iterable_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import random 5 | import torch 6 | 7 | 8 | class DistributedIterableDataset(torch.utils.data.IterableDataset): 9 | def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8): 10 | self.dataset_name = dataset_name 11 | self.local_rank = local_rank 12 | self.world_size = world_size 13 | self.num_workers = num_workers 14 | self.rng = random.Random() 15 | self.data_paths = None 16 | 17 | def get_data_paths(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | def set_epoch(self, seed=42): 21 | if self.data_paths is None: 22 | return 23 | 24 | if isinstance(self.data_paths[0], tuple): 25 | data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1])) 26 | elif isinstance(self.data_paths[0], str): 27 | data_paths = sorted(self.data_paths) 28 | else: 29 | raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}") 30 | 31 | self.rng.seed(seed) 32 | self.rng.shuffle(data_paths) 33 | 34 | num_files_per_rank = len(data_paths) // self.world_size 35 | local_start = self.local_rank * num_files_per_rank 36 | local_end = (self.local_rank + 1) * num_files_per_rank 37 | self.num_files_per_rank = num_files_per_rank 38 | self.data_paths_per_rank = data_paths[local_start:local_end] 39 | 40 | def get_data_paths_per_worker(self): 41 | if self.data_paths is None: 42 | return None 43 | 44 | info = torch.utils.data.get_worker_info() 45 | if info is None: 46 | # Single worker: Use all files assigned to the rank 47 | return self.data_paths_per_rank, 0 48 | 49 | worker_id = info.id 50 | num_files_per_worker = self.num_files_per_rank // info.num_workers 51 | start = num_files_per_worker * worker_id 52 | end = num_files_per_worker * (worker_id + 1) 53 | data_paths_per_worker = self.data_paths_per_rank[start:end] 54 | 55 | return data_paths_per_worker[::-1], worker_id 56 | 57 | def __iter__(self): 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /data/interleave_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .edit_dataset import UnifiedEditIterableDataset 5 | 6 | -------------------------------------------------------------------------------- /data/interleave_datasets/edit_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import io 5 | import random 6 | from PIL import Image, ImageFile, PngImagePlugin 7 | 8 | from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset 9 | from ..data_utils import pil_img2rgb 10 | 11 | 12 | Image.MAX_IMAGE_PIXELS = 200000000 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | MaximumDecompressedSize = 1024 15 | MegaByte = 2 ** 20 16 | PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte 17 | 18 | 19 | class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset): 20 | 21 | def parse_row(self, row): 22 | image_num = len(row["image_list"]) 23 | # randomly choose start and end, return [0, 1] when only two images 24 | start_idx = random.choice(range(image_num - 1)) 25 | max_end = min(start_idx + 3, image_num) 26 | end_idx = random.choice(range(start_idx + 1, max_end)) 27 | 28 | data = self._init_data() 29 | data = self._add_image( 30 | data, 31 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))), 32 | need_loss=False, 33 | need_vae=True, 34 | need_vit=True, 35 | ) 36 | 37 | if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction 38 | if end_idx == image_num - 1: 39 | end_idx -= 1 40 | 41 | instruction = "" 42 | for idx in range(start_idx + 1, end_idx + 1): 43 | instruction += random.choice(row["instruction_list"][idx-1]) + ". " 44 | data = self._add_text(data, instruction.rstrip(), need_loss=False) 45 | data = self._add_image( 46 | data, 47 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))), 48 | need_loss=True, 49 | need_vae=False, 50 | need_vit=False, 51 | ) 52 | else: 53 | for idx in range(start_idx + 1, end_idx + 1): 54 | instruction = random.choice(row["instruction_list"][idx-1]) 55 | data = self._add_text(data, instruction, need_loss=False) 56 | if idx != end_idx: 57 | data = self._add_image( 58 | data, 59 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))), 60 | need_loss=True, 61 | need_vae=True, 62 | need_vit=True, 63 | ) 64 | else: 65 | data = self._add_image( 66 | data, 67 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))), 68 | need_loss=True, 69 | need_vae=False, 70 | need_vit=False, 71 | ) 72 | return data 73 | -------------------------------------------------------------------------------- /data/interleave_datasets/interleave_t2i_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pyarrow.parquet as pq 5 | 6 | from ..distributed_iterable_dataset import DistributedIterableDataset 7 | from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs 8 | 9 | 10 | class InterleavedBaseIterableDataset(DistributedIterableDataset): 11 | 12 | def _init_data(self): 13 | data = { 14 | 'sequence_plan': [], 15 | 'text_ids_list': [], 16 | 'image_tensor_list': [], 17 | 'num_tokens': 0, 18 | } 19 | return data 20 | 21 | def _add_text(self, data, text, need_loss, enable_cfg=True): 22 | text_ids = self.tokenizer.encode(text) 23 | data['num_tokens'] += len(text_ids) 24 | data['text_ids_list'].append(text_ids) 25 | data['sequence_plan'].append( 26 | { 27 | 'type': 'text', 28 | 'enable_cfg': int(enable_cfg), 29 | 'loss': int(need_loss), 30 | 'special_token_loss': 0, 31 | 'special_token_label': None, 32 | } 33 | ) 34 | return data 35 | 36 | def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True): 37 | assert need_loss or need_vae or need_vit 38 | 39 | if need_loss: 40 | data['sequence_plan'].append( 41 | { 42 | 'type': 'vae_image', 43 | 'enable_cfg': 0, 44 | 'loss': 1, 45 | 'special_token_loss': 0, 46 | 'special_token_label': None, 47 | } 48 | ) 49 | 50 | image_tensor = self.transform(image) 51 | height, width = image_tensor.shape[1:] 52 | data['num_tokens'] += width * height // self.transform.stride ** 2 53 | data['image_tensor_list'].append(image_tensor) 54 | 55 | if need_vae: 56 | data['sequence_plan'].append( 57 | { 58 | 'type': 'vae_image', 59 | 'enable_cfg': int(enable_cfg), 60 | 'loss': 0, 61 | 'special_token_loss': 0, 62 | 'special_token_label': None, 63 | } 64 | ) 65 | 66 | image_tensor = self.transform(image) 67 | height, width = image_tensor.shape[1:] 68 | data['num_tokens'] += width * height // self.transform.stride ** 2 69 | data['image_tensor_list'].append(image_tensor.clone()) 70 | 71 | if need_vit: 72 | data['sequence_plan'].append( 73 | { 74 | 'type': 'vit_image', 75 | 'enable_cfg': int(enable_cfg), 76 | 'loss': 0, 77 | 'special_token_loss': 0, 78 | 'special_token_label': None, 79 | }, 80 | ) 81 | vit_image_tensor = self.vit_transform(image) 82 | height, width = vit_image_tensor.shape[1:] 83 | data['num_tokens'] += width * height // self.vit_transform.stride ** 2 84 | data['image_tensor_list'].append(vit_image_tensor) 85 | 86 | return data 87 | 88 | def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True): 89 | assert int(need_loss) + int(need_vae) == 1 90 | 91 | if need_loss: 92 | for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)): 93 | current_sequence_plan = { 94 | 'type': 'vae_image', 95 | 'enable_cfg': 0, 96 | 'loss': 1, 97 | 'special_token_loss': 0, 98 | 'special_token_label': None, 99 | 'split_start': idx == 0, 100 | 'split_end': idx == len(frames) - 1, 101 | } 102 | if idx < len(frame_indexes) - 1: 103 | current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx 104 | data['sequence_plan'].append(current_sequence_plan) 105 | image_tensor = self.transform(image) 106 | height, width = image_tensor.shape[1:] 107 | data['image_tensor_list'].append(image_tensor) 108 | data['num_tokens'] += width * height // self.transform.stride ** 2 109 | 110 | elif need_vae: 111 | for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)): 112 | current_sequence_plan = { 113 | 'type': 'vae_image', 114 | 'enable_cfg': int(enable_cfg), 115 | 'loss': 0, 116 | 'special_token_loss': 0, 117 | 'special_token_label': None, 118 | 'split_start': idx == 0, 119 | 'split_end': idx == len(frames) - 1, 120 | } 121 | if idx < len(frame_indexes) - 1: 122 | current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx 123 | data['sequence_plan'].append(current_sequence_plan) 124 | image_tensor = self.transform(image) 125 | height, width = image_tensor.shape[1:] 126 | data['image_tensor_list'].append(image_tensor) 127 | data['num_tokens'] += width * height // self.transform.stride ** 2 128 | 129 | return data 130 | 131 | 132 | class ParquetStandardIterableDataset(DistributedIterableDataset): 133 | 134 | def __init__( 135 | self, dataset_name, transform, tokenizer, vit_transform, 136 | data_dir_list, num_used_data, parquet_info, 137 | local_rank=0, world_size=1, num_workers=8, data_status=None, 138 | ): 139 | """ 140 | data_dir_list: list of data directories contains parquet files 141 | num_used_data: list of number of sampled data paths for each data directory 142 | vit_transform: input transform for vit model. 143 | """ 144 | super().__init__(dataset_name, local_rank, world_size, num_workers) 145 | self.transform = transform 146 | self.vit_transform = vit_transform 147 | self.tokenizer = tokenizer 148 | self.data_status = data_status 149 | self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info) 150 | self.set_epoch() 151 | 152 | def get_data_paths(self, data_dir_list, num_used_data, parquet_info): 153 | row_groups = [] 154 | for data_dir, num_data_path in zip(data_dir_list, num_used_data): 155 | data_paths = get_parquet_data_paths([data_dir], [num_data_path]) 156 | for data_path in data_paths: 157 | if data_path in parquet_info.keys(): 158 | num_row_groups = parquet_info[data_path]['num_row_groups'] 159 | for rg_idx in range(num_row_groups): 160 | row_groups.append((data_path, rg_idx)) 161 | return row_groups 162 | 163 | def parse_row(self, row): 164 | raise NotImplementedError 165 | 166 | def __iter__(self): 167 | file_paths_per_worker, worker_id = self.get_data_paths_per_worker() 168 | if self.data_status is not None: 169 | global_row_group_start_id = self.data_status[worker_id][0] 170 | row_start_id = self.data_status[worker_id][1] + 1 171 | else: 172 | global_row_group_start_id = 0 173 | row_start_id = 0 174 | 175 | print( 176 | f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " 177 | f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}" 178 | ) 179 | 180 | while True: 181 | file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:] 182 | for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate( 183 | file_paths_per_worker_, start=global_row_group_start_id 184 | ): 185 | fs = init_arrow_pf_fs(parquet_file_path) 186 | with fs.open_input_file(parquet_file_path) as f: 187 | try: 188 | fr = pq.ParquetFile(f) 189 | df = fr.read_row_group(row_group_id).to_pandas() 190 | df = df.iloc[row_start_id:] 191 | except Exception as e: 192 | print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}') 193 | continue 194 | 195 | for row_idx, row in df.iterrows(): 196 | try: 197 | data = self.parse_row(row) 198 | if len(data) == 0: 199 | continue 200 | data['data_indexes'] = { 201 | "data_indexes": [global_row_group_idx, row_idx], 202 | "worker_id": worker_id, 203 | "dataset_name": self.dataset_name, 204 | } 205 | except Exception as e: 206 | print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}') 207 | continue 208 | yield data 209 | 210 | row_start_id = 0 211 | global_row_group_start_id = 0 212 | print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}") 213 | -------------------------------------------------------------------------------- /data/parquet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | import os 6 | import xml.etree.ElementTree as ET 7 | import subprocess 8 | import logging 9 | 10 | import pyarrow.fs as pf 11 | import torch.distributed as dist 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1): 17 | num_data_dirs = len(data_dir_list) 18 | if world_size > 1: 19 | chunk_size = (num_data_dirs + world_size - 1) // world_size 20 | start_idx = rank * chunk_size 21 | end_idx = min(start_idx + chunk_size, num_data_dirs) 22 | local_data_dir_list = data_dir_list[start_idx:end_idx] 23 | local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx] 24 | else: 25 | local_data_dir_list = data_dir_list 26 | local_num_sampled_data_paths = num_sampled_data_paths 27 | 28 | local_data_paths = [] 29 | for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths): 30 | if data_dir.startswith("hdfs://"): 31 | files = hdfs_ls_cmd(data_dir) 32 | data_paths_per_dir = [ 33 | file for file in files if file.endswith(".parquet") 34 | ] 35 | else: 36 | files = os.listdir(data_dir) 37 | data_paths_per_dir = [ 38 | os.path.join(data_dir, name) 39 | for name in files 40 | if name.endswith(".parquet") 41 | ] 42 | repeat = num_data_path // len(data_paths_per_dir) 43 | data_paths_per_dir = data_paths_per_dir * (repeat + 1) 44 | local_data_paths.extend(data_paths_per_dir[:num_data_path]) 45 | 46 | if world_size > 1: 47 | gather_list = [None] * world_size 48 | dist.all_gather_object(gather_list, local_data_paths) 49 | 50 | combined_chunks = [] 51 | for chunk_list in gather_list: 52 | if chunk_list is not None: 53 | combined_chunks.extend(chunk_list) 54 | else: 55 | combined_chunks = local_data_paths 56 | 57 | return combined_chunks 58 | 59 | 60 | # NOTE: cumtomize this function for your cluster 61 | def get_hdfs_host(): 62 | return "hdfs://xxx" 63 | 64 | 65 | # NOTE: cumtomize this function for your cluster 66 | def get_hdfs_block_size(): 67 | return 134217728 68 | 69 | 70 | # NOTE: cumtomize this function for your cluster 71 | def get_hdfs_extra_conf(): 72 | return None 73 | 74 | 75 | def init_arrow_pf_fs(parquet_file_path): 76 | if parquet_file_path.startswith("hdfs://"): 77 | fs = pf.HadoopFileSystem( 78 | host=get_hdfs_host(), 79 | port=0, 80 | buffer_size=get_hdfs_block_size(), 81 | extra_conf=get_hdfs_extra_conf(), 82 | ) 83 | else: 84 | fs = pf.LocalFileSystem() 85 | return fs 86 | 87 | 88 | def hdfs_ls_cmd(dir): 89 | result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout 90 | return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i] 91 | -------------------------------------------------------------------------------- /data/t2i_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import io 5 | import json 6 | import pyarrow.parquet as pq 7 | import random 8 | from PIL import Image 9 | 10 | from .data_utils import pil_img2rgb 11 | from .distributed_iterable_dataset import DistributedIterableDataset 12 | from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs 13 | 14 | Image.MAX_IMAGE_PIXELS = 20_000_000 15 | 16 | 17 | class T2IIterableDataset(DistributedIterableDataset): 18 | def __init__( 19 | self, dataset_name, transform, tokenizer, data_dir_list, num_used_data, 20 | local_rank=0, world_size=1, num_workers=8, data_status=None, 21 | ): 22 | """ 23 | data_dir_list: list of data directories contains parquet files 24 | num_used_data: list of number of sampled data paths for each data directory 25 | """ 26 | super().__init__(dataset_name, local_rank, world_size, num_workers) 27 | self.transform = transform 28 | self.tokenizer = tokenizer 29 | self.data_status = data_status 30 | self.data_paths = self.get_data_paths(data_dir_list, num_used_data) 31 | self.set_epoch() 32 | 33 | def get_data_paths(self, data_dir_list, num_used_data): 34 | return get_parquet_data_paths(data_dir_list, num_used_data) 35 | 36 | def __iter__(self): 37 | data_paths_per_worker, worker_id = self.get_data_paths_per_worker() 38 | if self.data_status is not None: 39 | parquet_start_id = self.data_status[worker_id][0] 40 | row_group_start_id = self.data_status[worker_id][1] 41 | row_start_id = self.data_status[worker_id][2] + 1 42 | else: 43 | parquet_start_id = 0 44 | row_group_start_id = 0 45 | row_start_id = 0 46 | transform_stride = self.transform.stride 47 | 48 | print( 49 | f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " 50 | f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}" 51 | ) 52 | 53 | while True: 54 | data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:] 55 | for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id): 56 | fs = init_arrow_pf_fs(parquet_file_path) 57 | with fs.open_input_file(parquet_file_path) as f: 58 | fr = pq.ParquetFile(f) 59 | row_group_ids = list(range(fr.num_row_groups)) 60 | row_group_ids_ = row_group_ids[row_group_start_id:] 61 | 62 | for row_group_id in row_group_ids_: 63 | df = fr.read_row_group(row_group_id).to_pandas() 64 | df = df.iloc[row_start_id:] 65 | 66 | for row_idx, row in df.iterrows(): 67 | num_tokens = 0 68 | try: 69 | image_byte = row['image'] 70 | image = pil_img2rgb(Image.open(io.BytesIO(image_byte))) 71 | except Exception as e: 72 | print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}') 73 | continue 74 | image_tensor = self.transform(image) 75 | height, width = image_tensor.shape[1:] 76 | num_tokens += width * height // transform_stride ** 2 77 | 78 | try: 79 | caption_dict = row['captions'] 80 | caption_dict = json.loads(caption_dict) 81 | except Exception as e: 82 | print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}') 83 | continue 84 | 85 | caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()] 86 | if len(caps_token) == 0: 87 | print(f'no caption in rg#{row_group_id}, {parquet_file_path}') 88 | caption_token = self.tokenizer.encode(' ') 89 | else: 90 | caption_token = random.choice(caps_token) 91 | 92 | sequence_plan, text_ids_list = [], [] 93 | text_ids = caption_token 94 | num_tokens += len(caption_token) 95 | text_ids_list.append(text_ids) 96 | sequence_plan.append({ 97 | 'type': 'text', 98 | 'enable_cfg': 1, 99 | 'loss': 0, 100 | 'special_token_loss': 0, 101 | 'special_token_label': None, 102 | }) 103 | 104 | sequence_plan.append({ 105 | 'type': 'vae_image', 106 | 'enable_cfg': 0, 107 | 'loss': 1, 108 | 'special_token_loss': 0, 109 | 'special_token_label': None, 110 | }) 111 | 112 | sample = dict( 113 | image_tensor_list=[image_tensor], 114 | text_ids_list=text_ids_list, 115 | num_tokens=num_tokens, 116 | sequence_plan=sequence_plan, 117 | data_indexes={ 118 | "data_indexes": [parquet_idx, row_group_id, row_idx], 119 | "worker_id": worker_id, 120 | "dataset_name": self.dataset_name, 121 | } 122 | ) 123 | yield sample 124 | 125 | row_start_id = 0 126 | row_group_start_id = 0 127 | parquet_start_id = 0 128 | print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}") 129 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import random 5 | from PIL import Image 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | from torchvision import transforms 11 | from torchvision.transforms import functional as F 12 | from torchvision.transforms import InterpolationMode 13 | 14 | 15 | class MaxLongEdgeMinShortEdgeResize(torch.nn.Module): 16 | """Resize the input image so that its longest side and shortest side are within a specified range, 17 | ensuring that both sides are divisible by a specified stride. 18 | 19 | Args: 20 | max_size (int): Maximum size for the longest edge of the image. 21 | min_size (int): Minimum size for the shortest edge of the image. 22 | stride (int): Value by which the height and width of the image must be divisible. 23 | max_pixels (int): Maximum pixels for the full image. 24 | interpolation (InterpolationMode): Desired interpolation enum defined by 25 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. 26 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, 27 | ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported. 28 | The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted. 29 | antialias (bool, optional): Whether to apply antialiasing (default is True). 30 | """ 31 | 32 | def __init__( 33 | self, 34 | max_size: int, 35 | min_size: int, 36 | stride: int, 37 | max_pixels: int, 38 | interpolation=InterpolationMode.BICUBIC, 39 | antialias=True 40 | ): 41 | super().__init__() 42 | self.max_size = max_size 43 | self.min_size = min_size 44 | self.stride = stride 45 | self.max_pixels = max_pixels 46 | self.interpolation = interpolation 47 | self.antialias = antialias 48 | 49 | def _make_divisible(self, value, stride): 50 | """Ensure the value is divisible by the stride.""" 51 | return max(stride, int(round(value / stride) * stride)) 52 | 53 | def _apply_scale(self, width, height, scale): 54 | new_width = round(width * scale) 55 | new_height = round(height * scale) 56 | new_width = self._make_divisible(new_width, self.stride) 57 | new_height = self._make_divisible(new_height, self.stride) 58 | return new_width, new_height 59 | 60 | def forward(self, img, img_num=1): 61 | """ 62 | Args: 63 | img (PIL Image): Image to be resized. 64 | img_num (int): Number of images, used to change max_tokens. 65 | Returns: 66 | PIL Image or Tensor: Rescaled image with divisible dimensions. 67 | """ 68 | if isinstance(img, torch.Tensor): 69 | height, width = img.shape[-2:] 70 | else: 71 | width, height = img.size 72 | 73 | scale = min(self.max_size / max(width, height), 1.0) 74 | scale = max(scale, self.min_size / min(width, height)) 75 | new_width, new_height = self._apply_scale(width, height, scale) 76 | 77 | # Ensure the number of pixels does not exceed max_pixels 78 | if new_width * new_height > self.max_pixels / img_num: 79 | scale = self.max_pixels / img_num / (new_width * new_height) 80 | new_width, new_height = self._apply_scale(new_width, new_height, scale) 81 | 82 | # Ensure longest edge does not exceed max_size 83 | if max(new_width, new_height) > self.max_size: 84 | scale = self.max_size / max(new_width, new_height) 85 | new_width, new_height = self._apply_scale(new_width, new_height, scale) 86 | 87 | return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias) 88 | 89 | 90 | class ImageTransform: 91 | def __init__( 92 | self, 93 | max_image_size, 94 | min_image_size, 95 | image_stride, 96 | max_pixels=14*14*9*1024, 97 | image_mean=[0.5, 0.5, 0.5], 98 | image_std=[0.5, 0.5, 0.5] 99 | ): 100 | self.stride = image_stride 101 | 102 | self.resize_transform = MaxLongEdgeMinShortEdgeResize( 103 | max_size=max_image_size, 104 | min_size=min_image_size, 105 | stride=image_stride, 106 | max_pixels=max_pixels, 107 | ) 108 | self.to_tensor_transform = transforms.ToTensor() 109 | self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True) 110 | 111 | def __call__(self, img, img_num=1): 112 | img = self.resize_transform(img, img_num=img_num) 113 | img = self.to_tensor_transform(img) 114 | img = self.normalize_transform(img) 115 | return img 116 | 117 | 118 | def decolorization(image): 119 | gray_image = image.convert('L') 120 | return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image 121 | 122 | 123 | def downscale(image, scale_factor): 124 | new_width = int(round(image.width * scale_factor)) 125 | new_height = int(round(image.height * scale_factor)) 126 | new_width = max(1, new_width) 127 | new_height = max(1, new_height) 128 | return image.resize((new_width, new_height), resample=Image.BICUBIC) 129 | 130 | 131 | def crop(image, crop_factors): 132 | target_h, target_w = crop_factors 133 | img_w, img_h = image.size 134 | 135 | if target_h > img_h or target_w > img_w: 136 | raise ValueError("Crop size exceeds image dimensions") 137 | 138 | x = random.randint(0, img_w - target_w) 139 | y = random.randint(0, img_h - target_h) 140 | 141 | return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]] 142 | 143 | 144 | def motion_blur_opencv(image, kernel_size=15, angle=0): 145 | # 线性核 146 | kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32) 147 | kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32) 148 | 149 | # 旋转核 150 | center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5) 151 | M = cv2.getRotationMatrix2D(center, angle, 1) 152 | rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size)) 153 | 154 | # 归一化核 155 | rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1 156 | 157 | img = np.array(image) 158 | if img.ndim == 2: 159 | blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT) 160 | else: 161 | # 对于彩色图像,各通道独立卷积 162 | blurred = np.zeros_like(img) 163 | for c in range(img.shape[2]): 164 | blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT) 165 | 166 | return Image.fromarray(blurred.astype(np.uint8)) 167 | 168 | 169 | def shuffle_patch(image, num_splits, gap_size=2): 170 | """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙""" 171 | h_splits, w_splits = num_splits 172 | img_w, img_h = image.size 173 | 174 | base_patch_h = img_h // h_splits 175 | patch_heights = [base_patch_h] * (h_splits - 1) 176 | patch_heights.append(img_h - sum(patch_heights)) 177 | 178 | base_patch_w = img_w // w_splits 179 | patch_widths = [base_patch_w] * (w_splits - 1) 180 | patch_widths.append(img_w - sum(patch_widths)) 181 | 182 | patches = [] 183 | current_y = 0 184 | for i in range(h_splits): 185 | current_x = 0 186 | patch_h = patch_heights[i] 187 | for j in range(w_splits): 188 | patch_w = patch_widths[j] 189 | patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h)) 190 | patches.append(patch) 191 | current_x += patch_w 192 | current_y += patch_h 193 | 194 | random.shuffle(patches) 195 | 196 | total_width = sum(patch_widths) + (w_splits - 1) * gap_size 197 | total_height = sum(patch_heights) + (h_splits - 1) * gap_size 198 | new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255)) 199 | 200 | current_y = 0 # 当前行的起始 Y 坐标 201 | patch_idx = 0 # 当前处理的块索引 202 | for i in range(h_splits): 203 | current_x = 0 # 当前列的起始 X 坐标 204 | patch_h = patch_heights[i] # 当前行块的高度 205 | for j in range(w_splits): 206 | # 取出打乱后的块 207 | patch = patches[patch_idx] 208 | patch_w = patch_widths[j] # 当前列块的宽度 209 | # 粘贴块(左上角坐标为 (current_x, current_y)) 210 | new_image.paste(patch, (current_x, current_y)) 211 | # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙) 212 | current_x += patch_w + gap_size 213 | patch_idx += 1 214 | # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙) 215 | current_y += patch_h + gap_size 216 | 217 | return new_image 218 | 219 | 220 | def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)): 221 | """ 222 | 图像分割后随机空白部分patch,用于inpainting任务 223 | 224 | 参数: 225 | image: PIL.Image 输入图像(RGB模式) 226 | h_splits: int 行分割数(垂直方向分割块数) 227 | w_splits: int 列分割数(水平方向分割块数) 228 | blank_ratio: float 空白patch的比例(0~1) 229 | blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255)) 230 | 231 | 返回: 232 | PIL.Image 处理后拼接的图像 233 | """ 234 | h_splits, w_splits = num_splits 235 | img_w, img_h = image.size 236 | 237 | base_patch_h = img_h // h_splits 238 | patch_heights = [base_patch_h] * (h_splits - 1) 239 | patch_heights.append(img_h - sum(patch_heights)) 240 | 241 | base_patch_w = img_w // w_splits 242 | patch_widths = [base_patch_w] * (w_splits - 1) 243 | patch_widths.append(img_w - sum(patch_widths)) 244 | 245 | patches = [] 246 | current_y = 0 247 | for i in range(h_splits): 248 | current_x = 0 249 | patch_h = patch_heights[i] 250 | for j in range(w_splits): 251 | patch_w = patch_widths[j] 252 | patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h)) 253 | patches.append(patch) 254 | current_x += patch_w 255 | current_y += patch_h 256 | 257 | total_patches = h_splits * w_splits 258 | num_blank = int(total_patches * blank_ratio) 259 | num_blank = max(0, min(num_blank, total_patches)) 260 | blank_indices = random.sample(range(total_patches), num_blank) 261 | 262 | processed_patches = [] 263 | for idx, patch in enumerate(patches): 264 | if idx in blank_indices: 265 | blank_patch = Image.new("RGB", patch.size, color=blank_color) 266 | processed_patches.append(blank_patch) 267 | else: 268 | processed_patches.append(patch) 269 | 270 | # 创建结果图像(尺寸与原图一致) 271 | result_image = Image.new("RGB", (img_w, img_h)) 272 | current_y = 0 273 | patch_idx = 0 274 | for i in range(h_splits): 275 | current_x = 0 276 | patch_h = patch_heights[i] 277 | for j in range(w_splits): 278 | # 取出处理后的patch 279 | patch = processed_patches[patch_idx] 280 | patch_w = patch_widths[j] 281 | # 粘贴到原位置 282 | result_image.paste(patch, (current_x, current_y)) 283 | current_x += patch_w 284 | patch_idx += 1 285 | current_y += patch_h 286 | 287 | return result_image 288 | -------------------------------------------------------------------------------- /data/video_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 OpenGVLab 2 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. 3 | # SPDX-License-Identifier: MIT 4 | # 5 | # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. 6 | # 7 | # Original file was released under MIT, with the full license text 8 | # available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE. 9 | # 10 | # This modified file is released under the same license. 11 | 12 | 13 | import io 14 | import os 15 | import random 16 | import re 17 | 18 | import numpy as np 19 | import decord 20 | from PIL import Image 21 | 22 | 23 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): 24 | if sample in ['rand', 'middle']: # uniform sampling 25 | acc_samples = min(num_frames, vlen) 26 | # split the video into `acc_samples` intervals, and sample from each interval. 27 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 28 | ranges = [] 29 | for idx, interv in enumerate(intervals[:-1]): 30 | ranges.append((interv, intervals[idx + 1] - 1)) 31 | if sample == 'rand': 32 | try: 33 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] 34 | except: 35 | frame_indices = np.random.permutation(vlen)[:acc_samples] 36 | frame_indices.sort() 37 | frame_indices = list(frame_indices) 38 | elif fix_start is not None: 39 | frame_indices = [x[0] + fix_start for x in ranges] 40 | elif sample == 'middle': 41 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] 42 | else: 43 | raise NotImplementedError 44 | 45 | if len(frame_indices) < num_frames: # padded with last frame 46 | padded_frame_indices = [frame_indices[-1]] * num_frames 47 | padded_frame_indices[:len(frame_indices)] = frame_indices 48 | frame_indices = padded_frame_indices 49 | elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps 50 | output_fps = float(sample[3:]) 51 | duration = float(vlen) / input_fps 52 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents 53 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) 54 | frame_indices = np.around(frame_seconds * input_fps).astype(int) 55 | frame_indices = [e for e in frame_indices if e < vlen] 56 | if max_num_frames > 0 and len(frame_indices) > max_num_frames: 57 | frame_indices = frame_indices[:max_num_frames] 58 | else: 59 | raise ValueError 60 | return frame_indices 61 | 62 | 63 | def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4): 64 | video_reader = decord.VideoReader(video_path, num_threads=1) 65 | vlen = len(video_reader) 66 | fps = video_reader.get_avg_fps() 67 | duration = vlen / float(fps) 68 | if clip: 69 | start, end = clip 70 | duration = end - start 71 | vlen = int(duration * fps) 72 | start_index = int(start * fps) 73 | 74 | t_num_frames = np.random.randint(min_num_frames, num_frames + 1) 75 | 76 | frame_indices = get_frame_indices( 77 | t_num_frames, vlen, sample=sample, fix_start=fix_start, 78 | input_fps=fps 79 | ) 80 | if clip: 81 | frame_indices = [f + start_index for f in frame_indices] 82 | frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8 83 | frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] 84 | return frames 85 | 86 | 87 | def extract_frame_number(filename): 88 | # Extract the numeric part from the filename using regular expressions 89 | match = re.search(r'_(\d+).jpg$', filename) 90 | return int(match.group(1)) if match else -1 91 | 92 | 93 | def sort_frames(frame_paths): 94 | # Extract filenames from each path and sort by their numeric part 95 | return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) 96 | 97 | 98 | def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4): 99 | image_list = sort_frames(list(os.listdir(video_path))) 100 | frames = [] 101 | for image in image_list: 102 | fp = os.path.join(video_path, image) 103 | frame = Image.open(fp).convert('RGB') 104 | frames.append(frame) 105 | vlen = len(frames) 106 | 107 | t_num_frames = np.random.randint(min_num_frames, num_frames + 1) 108 | 109 | if vlen > t_num_frames: 110 | frame_indices = get_frame_indices( 111 | t_num_frames, vlen, sample=sample, fix_start=fix_start 112 | ) 113 | frames = [frames[i] for i in frame_indices] 114 | return frames 115 | 116 | 117 | class FrameSampler: 118 | def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'): 119 | self.max_num_frames = max_num_frames 120 | self.min_num_frames = min_num_frames 121 | self.sample = sample 122 | 123 | def __call__(self, file_name): 124 | fn = read_frames_folder if file_name.endswith('/') else read_frames_decord 125 | frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample) 126 | return frames 127 | 128 | 129 | def decode_video_byte(video_bytes): 130 | video_stream = io.BytesIO(video_bytes) 131 | vr = decord.VideoReader(video_stream) 132 | return vr 133 | 134 | 135 | def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False): 136 | if isinstance(mp4_p, str): 137 | vr = decord.VideoReader(mp4_p, num_threads=1) 138 | elif isinstance(mp4_p, decord.video_reader.VideoReader): 139 | vr = mp4_p 140 | video_fps = vr.get_avg_fps() # 获取视频的帧率 141 | video_duration = len(vr) / video_fps 142 | if n_frames is not None: 143 | if random_sample: 144 | frame_indices = sorted(random.sample(range(len(vr)), n_frames)) 145 | else: 146 | frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist() 147 | else: 148 | frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)] 149 | frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组 150 | frames = [Image.fromarray(frame).convert("RGB") for frame in frames] 151 | if not return_frame_indices: 152 | return frames, video_duration 153 | else: 154 | return frames, video_duration, frame_indices 155 | 156 | 157 | def sample_mp4_frames_by_indices(mp4_p, frame_indices: list): 158 | if isinstance(mp4_p, str): 159 | vr = decord.VideoReader(mp4_p, num_threads=1) 160 | elif isinstance(mp4_p, decord.video_reader.VideoReader): 161 | vr = mp4_p 162 | # sample the frames in frame_indices 163 | frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组 164 | frames = [Image.fromarray(frame).convert("RGB") for frame in frames] 165 | return frames -------------------------------------------------------------------------------- /data/vlm_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import json 5 | import os 6 | import traceback 7 | from PIL import Image, ImageFile, PngImagePlugin 8 | 9 | from .data_utils import pil_img2rgb 10 | from .distributed_iterable_dataset import DistributedIterableDataset 11 | 12 | 13 | Image.MAX_IMAGE_PIXELS = 200000000 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | MaximumDecompressedSize = 1024 16 | MegaByte = 2 ** 20 17 | PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte 18 | 19 | 20 | class SftJSONLIterableDataset(DistributedIterableDataset): 21 | def __init__( 22 | self, dataset_name, transform, tokenizer, frame_sampler, 23 | jsonl_path_list, data_dir_list, num_used_data, 24 | local_rank=0, world_size=1, num_workers=8, data_status=None, 25 | shuffle_lines=False, shuffle_seed=0, 26 | ): 27 | """ 28 | jsonl_path_list: list of jsonl file paths 29 | data_dir_list: list of image directories containing the images of each jsonl file 30 | num_used_data: list of number of sampled data points for each jsonl 31 | """ 32 | super().__init__(dataset_name, local_rank, world_size, num_workers) 33 | self.transform = transform 34 | self.tokenizer = tokenizer 35 | self.frame_sampler = frame_sampler 36 | self.data_status = data_status 37 | self.data_paths = self.get_data_paths( 38 | jsonl_path_list, 39 | data_dir_list, 40 | num_used_data, 41 | shuffle_lines, 42 | shuffle_seed, 43 | ) 44 | self.set_epoch() 45 | 46 | def get_data_paths( 47 | self, 48 | jsonl_path_list, 49 | data_dir_list, 50 | num_used_data, 51 | shuffle_lines, 52 | shuffle_seed, 53 | ): 54 | data_paths = [] 55 | for jsonl_path, image_dir, num_data_point in zip( 56 | jsonl_path_list, data_dir_list, num_used_data 57 | ): 58 | with open(jsonl_path, 'r') as f: 59 | raw_data = f.readlines() 60 | if shuffle_lines: 61 | self.rng.seed(shuffle_seed) 62 | self.rng.shuffle(raw_data) 63 | raw_data = raw_data[:num_data_point] 64 | data_paths.extend([(json_data, image_dir) for json_data in raw_data]) 65 | return data_paths 66 | 67 | def change_format(self, data, num_images): 68 | elements = [] 69 | for conversation in data['conversations']: 70 | if conversation['from'] == 'human': 71 | if '' not in conversation['value']: 72 | elements.append({ 73 | 'type': 'text', 74 | 'has_loss': 0, 75 | 'text': conversation['value'], 76 | }) 77 | else: 78 | text_list = conversation['value'].split('') 79 | for idx, text in enumerate(text_list): 80 | if text.strip() != '': 81 | elements.append({ 82 | 'type': 'text', 83 | 'has_loss': 0, 84 | 'text': text.strip(), 85 | }) 86 | if (idx != len(text_list) - 1) and (idx < num_images): 87 | elements.append({'type': 'image',}) 88 | elif conversation['from'] == 'gpt': 89 | elements.append({ 90 | 'type': 'text', 91 | 'has_loss': 1, 92 | 'text': conversation['value'], 93 | }) 94 | return elements 95 | 96 | def __iter__(self): 97 | data_paths_per_worker, worker_id = self.get_data_paths_per_worker() 98 | if self.data_status is not None: 99 | row_start_id = self.data_status[worker_id] + 1 100 | else: 101 | row_start_id = 0 102 | transform_stride = self.transform.stride 103 | 104 | print( 105 | f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " 106 | f"resuming data at row#{row_start_id}" 107 | ) 108 | 109 | while True: 110 | data_paths_per_worker_ = data_paths_per_worker[row_start_id:] 111 | for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id): 112 | num_tokens = 0 113 | image_tensor_list = [] 114 | text_ids_list = [] 115 | sequence_plan = [] 116 | 117 | try: 118 | data_item = json.loads(data) 119 | raw_images = None 120 | if 'image' in data_item: 121 | if type(data_item['image']) == list: 122 | raw_images = [ 123 | pil_img2rgb(Image.open(os.path.join(image_dir, image))) 124 | for image in data_item['image'] 125 | ] 126 | else: 127 | raw_images = [ 128 | pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image']))) 129 | ] 130 | elif 'video' in data_item: 131 | raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video'])) 132 | special_tokens = '' * len(raw_images) 133 | for item in data_item['conversations']: 134 | if '