├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── assets
├── meme.jpg
├── octupusy.jpg
├── teaser.webp
└── woman.jpg
├── data
├── __init__.py
├── configs
│ └── example.yaml
├── data_utils.py
├── dataset_base.py
├── dataset_info.py
├── distributed_iterable_dataset.py
├── interleave_datasets
│ ├── __init__.py
│ ├── edit_dataset.py
│ └── interleave_t2i_dataset.py
├── parquet_utils.py
├── t2i_dataset.py
├── transforms.py
├── video_utils.py
└── vlm_dataset.py
├── example_workflows
├── bagel_image_edit.json
├── bagel_image_edit.png
├── bagel_image_understanding.json
├── bagel_image_understanding.png
├── bagel_text_to_image.json
└── bagel_text_to_image.png
├── inferencer.py
├── modeling
├── __init__.py
├── autoencoder.py
├── bagel
│ ├── __init__.py
│ ├── bagel.py
│ ├── modeling_utils.py
│ ├── qwen2_navit.py
│ └── siglip_navit.py
├── qwen2
│ ├── __init__.py
│ ├── configuration_qwen2.py
│ ├── modeling_qwen2.py
│ ├── tokenization_qwen2.py
│ └── tokenization_qwen2_fast.py
└── siglip
│ ├── __init__.py
│ ├── configuration_siglip.py
│ ├── convert_siglip_to_hf.py
│ ├── image_processing_siglip.py
│ ├── modeling_siglip.py
│ ├── processing_siglip.py
│ └── tokenization_siglip.py
├── node.py
├── pyproject.toml
└── requirements.txt
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | permissions:
12 | issues: write
13 |
14 | jobs:
15 | publish-node:
16 | name: Publish Custom Node to registry
17 | runs-on: ubuntu-latest
18 | if: ${{ github.repository_owner == 'neverbiasu' }}
19 | steps:
20 | - name: Check out code
21 | uses: actions/checkout@v4
22 | - name: Publish Custom Node
23 | uses: Comfy-Org/publish-node-action@v1
24 | with:
25 | ## Add your own personal access token to your Github Repository secrets and reference it here.
26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-Bagel
2 |
3 | A ComfyUI custom node package based on the BAGEL-7B-MoT multimodal model.
4 |
5 | ## About BAGEL
6 |
7 |
8 |
9 |
10 |
11 | BAGEL is an open-source multimodal foundation model with 7B active parameters (14B total) that adopts a Mixture-of-Transformer-Experts (MoT) architecture. It is designed for multimodal understanding and generation tasks, outperforming top-tier open-source VLMs like Qwen2.5-VL and InternVL-2.5 on standard multimodal understanding leaderboards, and delivering text-to-image quality competitive with specialist generators such as SD3.
12 |
13 | ## Features
14 |
15 | - **Text-to-Image Generation**: Generate high-quality images using natural language prompts
16 | - **Image Editing**: Edit existing images based on textual descriptions
17 | - **Image Understanding**: Perform Q&A and analysis on images
18 | - **Reasoning Process Display**: Optionally display the model's reasoning process
19 | - **DFloat11 Quantized Model Support**: Support for DFloat11 quantized version that requires only ~22GB VRAM
20 |
21 | ## Installation
22 |
23 | ### 1. Model Selection and Download
24 | The ComfyUI-Bagel node now supports automatic model selection via dropdown:
25 | - **ByteDance-Seed/BAGEL-7B-MoT**: Original standard model (~80GB VRAM recommended)
26 | - **DFloat11/BAGEL-7B-MoT-DF11**: Quantized model (~22GB VRAM, single 24GB GPU compatible)
27 |
28 | Models will be automatically downloaded to `models/bagel/` when first selected. You can also manually download them:
29 |
30 | #### Standard Model
31 | ```bash
32 | # Clone model using git lfs (recommended)
33 | git lfs install
34 | git clone https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT models/bagel/BAGEL-7B-MoT
35 |
36 | # Or use huggingface_hub
37 | pip install huggingface_hub
38 | python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='ByteDance-Seed/BAGEL-7B-MoT', local_dir='models/bagel/BAGEL-7B-MoT')"
39 | ```
40 |
41 | #### DFloat11 Quantized Model (Recommended for single GPU)
42 | ```bash
43 | # Clone DFloat11 quantized model
44 | git clone https://huggingface.co/DFloat11/BAGEL-7B-MoT-DF11 models/bagel/BAGEL-7B-MoT-DF11
45 |
46 | # Or use huggingface_hub
47 | python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='DFloat11/BAGEL-7B-MoT-DF11', local_dir='models/bagel/BAGEL-7B-MoT-DF11')"
48 | ```
49 |
50 | ### 2. Install Dependencies
51 | Install the required dependencies:
52 | ```bash
53 | pip install -r requirements.txt
54 | ```
55 |
56 | For DFloat11 quantized model support, also install:
57 | ```bash
58 | pip install dfloat11
59 | ```
60 |
61 | ### 3. Restart ComfyUI
62 | Restart ComfyUI to load the new nodes.
63 |
64 | ## Workflows
65 |
66 | ### Text-to-Image Generation
67 | 
68 | Generate high-quality images from text descriptions. Suitable for creative design and content generation.
69 |
70 | ### Image Editing Workflow
71 | 
72 | Edit existing images based on textual descriptions, supporting local modifications and style adjustments.
73 |
74 | ### Image Understanding Workflow
75 | 
76 | Analyze and answer questions about image content, suitable for content understanding and information extraction.
77 |
78 | ## Performance Comparison
79 |
80 | | Metric | BAGEL-7B-MoT (Standard Model) | BAGEL-7B-MoT (DFloat11 Quantized Model) |
81 | |--------|-------------------------------|-----------------------------------------|
82 | | Model Size | 29.21 GB | 19.89 GB |
83 | | Peak GPU Memory (1024x1024 image generation) | 30.07 GB | 21.76 GB |
84 | | Generation Time (on an RTX4090 GPU) | 482.95 seconds | 154.39 seconds |
85 |
86 | DFloat11 Quantized Model significantly reduces VRAM requirements and speeds up generation time, making it ideal for single GPU setups.
87 |
88 | ## Related Links
89 |
90 | - [BAGEL Official Paper](https://arxiv.org/abs/2505.14683)
91 | - [BAGEL Model Homepage](https://bagel-ai.org/)
92 | - [Hugging Face Model](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT)
93 | - [Online Demo](https://demo.bagel-ai.org/)
94 | - [Discord Community](https://discord.gg/Z836xxzy)
95 |
96 | ## License
97 |
98 | This project is licensed under the Apache 2.0 License. Please refer to the official license terms for the use of the BAGEL model.
99 |
100 | ## Contribution
101 |
102 | Contributions are welcome! Please submit issue reports and feature requests. If you wish to contribute code, please create an issue to discuss your ideas first.
103 |
104 | ## FAQ
105 |
106 | ### 1. VRAM Requirements
107 | The official recommendation for generating a 1024×1024 image is over 80GB GPU memory. However, multi-GPU setups can distribute the memory load. For example:
108 | - **Single GPU**: A100 (40GB) takes approximately 340-380 seconds per image.
109 | - **Multi-GPU**: 3 RTX3090 GPUs (24GB each) complete the task in about 1 minute.
110 | - **Compressed Model**: Using the DFloat11 version requires only 22GB VRAM and can run on a single 24GB GPU, with peak memory usage around 21.76GB (A100) and generation time of approximately 58 seconds.
111 |
112 | For more details, visit the [GitHub issue](https://github.com/ByteDance-Seed/Bagel/issues/4).
113 |
114 | ### 2. Quantized Version
115 | A quantized version of BAGEL is currently under development, which aims to reduce VRAM requirements further.
116 |
117 | ### 3. NameError: 'Qwen2Config' is not defined
118 | This issue is likely related to environment or dependency problems. For more information, refer to [this GitHub issue](https://github.com/neverbiasu/ComfyUI-BAGEL/issues/7).
119 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ComfyUI-Bagel - ComfyUI custom node package for the BAGEL multimodal model
3 | """
4 |
5 | from .node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
6 |
7 | # Export node mappings for ComfyUI
8 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
9 |
--------------------------------------------------------------------------------
/assets/meme.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/meme.jpg
--------------------------------------------------------------------------------
/assets/octupusy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/octupusy.jpg
--------------------------------------------------------------------------------
/assets/teaser.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/teaser.webp
--------------------------------------------------------------------------------
/assets/woman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neverbiasu/ComfyUI-BAGEL/777a359273afd21a978ac67ae613c035f18a41a7/assets/woman.jpg
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
--------------------------------------------------------------------------------
/data/configs/example.yaml:
--------------------------------------------------------------------------------
1 | t2i_pretrain:
2 | dataset_names:
3 | - t2i
4 | image_transform_args:
5 | image_stride: 16
6 | max_image_size: 1024
7 | min_image_size: 512
8 | is_mandatory: true
9 | num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
10 | - 10
11 | weight: 1
12 |
13 | unified_edit:
14 | dataset_names:
15 | - seedxedit_multi
16 | image_transform_args:
17 | image_stride: 16
18 | max_image_size: 1024
19 | min_image_size: 512
20 | vit_image_transform_args:
21 | image_stride: 14
22 | max_image_size: 518
23 | min_image_size: 224
24 | is_mandatory: false
25 | num_used_data:
26 | - 10
27 | weight: 1
28 |
29 | vlm_sft:
30 | dataset_names:
31 | - llava_ov
32 | image_transform_args:
33 | image_stride: 14
34 | max_image_size: 980
35 | min_image_size: 378
36 | max_pixels: 2_007_040
37 | frame_sampler_args:
38 | max_num_frames: 12
39 | min_num_frames: 8
40 | is_mandatory: true
41 | shuffle_lines: True
42 | shuffle_seed: 0
43 | num_used_data:
44 | - 1000
45 | weight: 1
46 |
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 |
5 | import math
6 | import random
7 | from PIL import Image
8 |
9 | import torch
10 | from torch.nn.attention.flex_attention import or_masks, and_masks
11 |
12 |
13 | def create_sparse_mask(document_lens, split_lens, attn_modes, device):
14 | def causal_mask(b, h, q_idx, kv_idx):
15 | return q_idx >= kv_idx
16 |
17 | def full_and_noise_mask(b, h, q_idx, kv_idx):
18 | return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
19 |
20 | def remove_noise_mask(b, h, q_idx, kv_idx):
21 | return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
22 |
23 | def sample_mask(b, h, q_idx, kv_idx):
24 | return document_id[q_idx] == document_id[kv_idx]
25 |
26 | full_and_noise_tmp = []
27 | noise_tmp = []
28 |
29 | for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
30 | value = i if model in ['full', 'noise'] else -1
31 | full_and_noise_tmp.extend([value] * length)
32 | value_noise = i if model == 'noise' else -1
33 | noise_tmp.extend([value_noise] * length)
34 |
35 | full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
36 | noise_seq_id = torch.Tensor(noise_tmp).to(device)
37 |
38 | document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
39 |
40 | return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
41 |
42 |
43 | def patchify(image, patch_size):
44 | p = patch_size
45 | c, h, w = image.shape
46 | assert h % p == 0 and w % p == 0
47 | image = image.reshape(c, h // p, p, w // p, p)
48 | image = torch.einsum("chpwq->hwpqc", image)
49 | image = image.reshape(-1, p**2 * c)
50 | return image
51 |
52 |
53 | def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
54 | num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
55 | coords_h = torch.arange(0, num_patches_h)
56 | coords_w = torch.arange(0, num_patches_w)
57 | pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
58 | return pos_ids
59 |
60 |
61 | def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
62 | num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
63 | boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
64 | fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
65 | fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
66 | bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
67 | bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
68 | pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
69 | return pos_ids
70 |
71 |
72 | def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
73 | """
74 | nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
75 | a sample, where each sample contains multiple splits with different attn modes.
76 | nested_attn_modes: whether to use full attn in each split.
77 | """
78 | sample_len = sum(split_lens)
79 | attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
80 |
81 | csum = 0
82 | for s, attn_mode in zip(split_lens, attn_modes):
83 | assert attn_mode in ['causal', 'full', 'noise']
84 | if attn_mode == "causal":
85 | attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
86 | attention_mask[csum:csum + s, :csum] = 1
87 | else:
88 | attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
89 | attention_mask[csum:csum + s, :csum] = 1
90 | csum += s
91 |
92 | csum = 0
93 | for s, attn_mode in zip(split_lens, attn_modes):
94 | if attn_mode == "noise":
95 | attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
96 | attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
97 | csum += s
98 |
99 | attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
100 | ~attention_mask, float("-inf")
101 | )
102 |
103 | return attention_mask
104 |
105 |
106 | def split_integer_exp_decay(S, ng_sample_decay=1.0):
107 | if ng_sample_decay == 1.0:
108 | N = random.randint(1, S)
109 | else:
110 | base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
111 | p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
112 | N = random.choices(list(range(1, S + 1)), p, k=1)[0]
113 | cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
114 | result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
115 | return result, cumsum
116 |
117 |
118 | def pil_img2rgb(image):
119 | if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
120 | image = image.convert("RGBA")
121 | white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
122 | white.paste(image, mask=image.split()[3])
123 | image = white
124 | else:
125 | image = image.convert("RGB")
126 |
127 | return image
128 |
129 |
130 | def add_special_tokens(tokenizer):
131 | all_special_tokens = []
132 | for k, v in tokenizer.special_tokens_map.items():
133 | if isinstance(v, str):
134 | all_special_tokens.append(v)
135 | elif isinstance(v, list):
136 | all_special_tokens += v
137 |
138 | new_tokens = []
139 |
140 | if '<|im_start|>' not in all_special_tokens:
141 | new_tokens.append('<|im_start|>')
142 |
143 | if '<|im_end|>' not in all_special_tokens:
144 | new_tokens.append('<|im_end|>')
145 |
146 | if '<|vision_start|>' not in all_special_tokens:
147 | new_tokens.append('<|vision_start|>')
148 |
149 | if '<|vision_end|>' not in all_special_tokens:
150 | new_tokens.append('<|vision_end|>')
151 |
152 | num_new_tokens = tokenizer.add_tokens(new_tokens)
153 | bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
154 | eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
155 | start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
156 | end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
157 |
158 | new_token_ids = dict(
159 | bos_token_id=bos_token_id,
160 | eos_token_id=eos_token_id,
161 | start_of_image=start_of_image,
162 | end_of_image=end_of_image,
163 | )
164 |
165 | return tokenizer, new_token_ids, num_new_tokens
166 |
167 |
168 | def len2weight(x, loss_reduction='square'):
169 | if x == 0:
170 | return x
171 | if loss_reduction == 'token':
172 | return 1
173 | if loss_reduction == 'sample':
174 | return 1 / x
175 | if loss_reduction == 'square':
176 | return 1 / (x ** 0.5)
177 | raise NotImplementedError(loss_reduction)
178 |
--------------------------------------------------------------------------------
/data/dataset_info.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from .interleave_datasets import UnifiedEditIterableDataset
5 | from .t2i_dataset import T2IIterableDataset
6 | from .vlm_dataset import SftJSONLIterableDataset
7 |
8 |
9 | DATASET_REGISTRY = {
10 | 't2i_pretrain': T2IIterableDataset,
11 | 'vlm_sft': SftJSONLIterableDataset,
12 | 'unified_edit': UnifiedEditIterableDataset,
13 | }
14 |
15 |
16 | DATASET_INFO = {
17 | 't2i_pretrain': {
18 | 't2i': {
19 | 'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
20 | 'num_files': 10, # number of data units to be sharded across all ranks and workers
21 | 'num_total_samples': 1000, # number of total samples in the dataset
22 | },
23 | },
24 | 'unified_edit':{
25 | 'seedxedit_multi': {
26 | 'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
27 | 'num_files': 10,
28 | 'num_total_samples': 1000,
29 | "parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
30 | },
31 | },
32 | 'vlm_sft': {
33 | 'llava_ov': {
34 | 'data_dir': 'your_data_path/bagel_example/vlm/images',
35 | 'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
36 | 'num_total_samples': 1000
37 | },
38 | },
39 | }
--------------------------------------------------------------------------------
/data/distributed_iterable_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import random
5 | import torch
6 |
7 |
8 | class DistributedIterableDataset(torch.utils.data.IterableDataset):
9 | def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
10 | self.dataset_name = dataset_name
11 | self.local_rank = local_rank
12 | self.world_size = world_size
13 | self.num_workers = num_workers
14 | self.rng = random.Random()
15 | self.data_paths = None
16 |
17 | def get_data_paths(self, *args, **kwargs):
18 | raise NotImplementedError
19 |
20 | def set_epoch(self, seed=42):
21 | if self.data_paths is None:
22 | return
23 |
24 | if isinstance(self.data_paths[0], tuple):
25 | data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
26 | elif isinstance(self.data_paths[0], str):
27 | data_paths = sorted(self.data_paths)
28 | else:
29 | raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
30 |
31 | self.rng.seed(seed)
32 | self.rng.shuffle(data_paths)
33 |
34 | num_files_per_rank = len(data_paths) // self.world_size
35 | local_start = self.local_rank * num_files_per_rank
36 | local_end = (self.local_rank + 1) * num_files_per_rank
37 | self.num_files_per_rank = num_files_per_rank
38 | self.data_paths_per_rank = data_paths[local_start:local_end]
39 |
40 | def get_data_paths_per_worker(self):
41 | if self.data_paths is None:
42 | return None
43 |
44 | info = torch.utils.data.get_worker_info()
45 | if info is None:
46 | # Single worker: Use all files assigned to the rank
47 | return self.data_paths_per_rank, 0
48 |
49 | worker_id = info.id
50 | num_files_per_worker = self.num_files_per_rank // info.num_workers
51 | start = num_files_per_worker * worker_id
52 | end = num_files_per_worker * (worker_id + 1)
53 | data_paths_per_worker = self.data_paths_per_rank[start:end]
54 |
55 | return data_paths_per_worker[::-1], worker_id
56 |
57 | def __iter__(self):
58 | raise NotImplementedError
59 |
--------------------------------------------------------------------------------
/data/interleave_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from .edit_dataset import UnifiedEditIterableDataset
5 |
6 |
--------------------------------------------------------------------------------
/data/interleave_datasets/edit_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import io
5 | import random
6 | from PIL import Image, ImageFile, PngImagePlugin
7 |
8 | from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
9 | from ..data_utils import pil_img2rgb
10 |
11 |
12 | Image.MAX_IMAGE_PIXELS = 200000000
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | MaximumDecompressedSize = 1024
15 | MegaByte = 2 ** 20
16 | PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
17 |
18 |
19 | class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
20 |
21 | def parse_row(self, row):
22 | image_num = len(row["image_list"])
23 | # randomly choose start and end, return [0, 1] when only two images
24 | start_idx = random.choice(range(image_num - 1))
25 | max_end = min(start_idx + 3, image_num)
26 | end_idx = random.choice(range(start_idx + 1, max_end))
27 |
28 | data = self._init_data()
29 | data = self._add_image(
30 | data,
31 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
32 | need_loss=False,
33 | need_vae=True,
34 | need_vit=True,
35 | )
36 |
37 | if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
38 | if end_idx == image_num - 1:
39 | end_idx -= 1
40 |
41 | instruction = ""
42 | for idx in range(start_idx + 1, end_idx + 1):
43 | instruction += random.choice(row["instruction_list"][idx-1]) + ". "
44 | data = self._add_text(data, instruction.rstrip(), need_loss=False)
45 | data = self._add_image(
46 | data,
47 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
48 | need_loss=True,
49 | need_vae=False,
50 | need_vit=False,
51 | )
52 | else:
53 | for idx in range(start_idx + 1, end_idx + 1):
54 | instruction = random.choice(row["instruction_list"][idx-1])
55 | data = self._add_text(data, instruction, need_loss=False)
56 | if idx != end_idx:
57 | data = self._add_image(
58 | data,
59 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
60 | need_loss=True,
61 | need_vae=True,
62 | need_vit=True,
63 | )
64 | else:
65 | data = self._add_image(
66 | data,
67 | pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
68 | need_loss=True,
69 | need_vae=False,
70 | need_vit=False,
71 | )
72 | return data
73 |
--------------------------------------------------------------------------------
/data/interleave_datasets/interleave_t2i_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pyarrow.parquet as pq
5 |
6 | from ..distributed_iterable_dataset import DistributedIterableDataset
7 | from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
8 |
9 |
10 | class InterleavedBaseIterableDataset(DistributedIterableDataset):
11 |
12 | def _init_data(self):
13 | data = {
14 | 'sequence_plan': [],
15 | 'text_ids_list': [],
16 | 'image_tensor_list': [],
17 | 'num_tokens': 0,
18 | }
19 | return data
20 |
21 | def _add_text(self, data, text, need_loss, enable_cfg=True):
22 | text_ids = self.tokenizer.encode(text)
23 | data['num_tokens'] += len(text_ids)
24 | data['text_ids_list'].append(text_ids)
25 | data['sequence_plan'].append(
26 | {
27 | 'type': 'text',
28 | 'enable_cfg': int(enable_cfg),
29 | 'loss': int(need_loss),
30 | 'special_token_loss': 0,
31 | 'special_token_label': None,
32 | }
33 | )
34 | return data
35 |
36 | def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
37 | assert need_loss or need_vae or need_vit
38 |
39 | if need_loss:
40 | data['sequence_plan'].append(
41 | {
42 | 'type': 'vae_image',
43 | 'enable_cfg': 0,
44 | 'loss': 1,
45 | 'special_token_loss': 0,
46 | 'special_token_label': None,
47 | }
48 | )
49 |
50 | image_tensor = self.transform(image)
51 | height, width = image_tensor.shape[1:]
52 | data['num_tokens'] += width * height // self.transform.stride ** 2
53 | data['image_tensor_list'].append(image_tensor)
54 |
55 | if need_vae:
56 | data['sequence_plan'].append(
57 | {
58 | 'type': 'vae_image',
59 | 'enable_cfg': int(enable_cfg),
60 | 'loss': 0,
61 | 'special_token_loss': 0,
62 | 'special_token_label': None,
63 | }
64 | )
65 |
66 | image_tensor = self.transform(image)
67 | height, width = image_tensor.shape[1:]
68 | data['num_tokens'] += width * height // self.transform.stride ** 2
69 | data['image_tensor_list'].append(image_tensor.clone())
70 |
71 | if need_vit:
72 | data['sequence_plan'].append(
73 | {
74 | 'type': 'vit_image',
75 | 'enable_cfg': int(enable_cfg),
76 | 'loss': 0,
77 | 'special_token_loss': 0,
78 | 'special_token_label': None,
79 | },
80 | )
81 | vit_image_tensor = self.vit_transform(image)
82 | height, width = vit_image_tensor.shape[1:]
83 | data['num_tokens'] += width * height // self.vit_transform.stride ** 2
84 | data['image_tensor_list'].append(vit_image_tensor)
85 |
86 | return data
87 |
88 | def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
89 | assert int(need_loss) + int(need_vae) == 1
90 |
91 | if need_loss:
92 | for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
93 | current_sequence_plan = {
94 | 'type': 'vae_image',
95 | 'enable_cfg': 0,
96 | 'loss': 1,
97 | 'special_token_loss': 0,
98 | 'special_token_label': None,
99 | 'split_start': idx == 0,
100 | 'split_end': idx == len(frames) - 1,
101 | }
102 | if idx < len(frame_indexes) - 1:
103 | current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
104 | data['sequence_plan'].append(current_sequence_plan)
105 | image_tensor = self.transform(image)
106 | height, width = image_tensor.shape[1:]
107 | data['image_tensor_list'].append(image_tensor)
108 | data['num_tokens'] += width * height // self.transform.stride ** 2
109 |
110 | elif need_vae:
111 | for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
112 | current_sequence_plan = {
113 | 'type': 'vae_image',
114 | 'enable_cfg': int(enable_cfg),
115 | 'loss': 0,
116 | 'special_token_loss': 0,
117 | 'special_token_label': None,
118 | 'split_start': idx == 0,
119 | 'split_end': idx == len(frames) - 1,
120 | }
121 | if idx < len(frame_indexes) - 1:
122 | current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
123 | data['sequence_plan'].append(current_sequence_plan)
124 | image_tensor = self.transform(image)
125 | height, width = image_tensor.shape[1:]
126 | data['image_tensor_list'].append(image_tensor)
127 | data['num_tokens'] += width * height // self.transform.stride ** 2
128 |
129 | return data
130 |
131 |
132 | class ParquetStandardIterableDataset(DistributedIterableDataset):
133 |
134 | def __init__(
135 | self, dataset_name, transform, tokenizer, vit_transform,
136 | data_dir_list, num_used_data, parquet_info,
137 | local_rank=0, world_size=1, num_workers=8, data_status=None,
138 | ):
139 | """
140 | data_dir_list: list of data directories contains parquet files
141 | num_used_data: list of number of sampled data paths for each data directory
142 | vit_transform: input transform for vit model.
143 | """
144 | super().__init__(dataset_name, local_rank, world_size, num_workers)
145 | self.transform = transform
146 | self.vit_transform = vit_transform
147 | self.tokenizer = tokenizer
148 | self.data_status = data_status
149 | self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
150 | self.set_epoch()
151 |
152 | def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
153 | row_groups = []
154 | for data_dir, num_data_path in zip(data_dir_list, num_used_data):
155 | data_paths = get_parquet_data_paths([data_dir], [num_data_path])
156 | for data_path in data_paths:
157 | if data_path in parquet_info.keys():
158 | num_row_groups = parquet_info[data_path]['num_row_groups']
159 | for rg_idx in range(num_row_groups):
160 | row_groups.append((data_path, rg_idx))
161 | return row_groups
162 |
163 | def parse_row(self, row):
164 | raise NotImplementedError
165 |
166 | def __iter__(self):
167 | file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
168 | if self.data_status is not None:
169 | global_row_group_start_id = self.data_status[worker_id][0]
170 | row_start_id = self.data_status[worker_id][1] + 1
171 | else:
172 | global_row_group_start_id = 0
173 | row_start_id = 0
174 |
175 | print(
176 | f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
177 | f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
178 | )
179 |
180 | while True:
181 | file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
182 | for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
183 | file_paths_per_worker_, start=global_row_group_start_id
184 | ):
185 | fs = init_arrow_pf_fs(parquet_file_path)
186 | with fs.open_input_file(parquet_file_path) as f:
187 | try:
188 | fr = pq.ParquetFile(f)
189 | df = fr.read_row_group(row_group_id).to_pandas()
190 | df = df.iloc[row_start_id:]
191 | except Exception as e:
192 | print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
193 | continue
194 |
195 | for row_idx, row in df.iterrows():
196 | try:
197 | data = self.parse_row(row)
198 | if len(data) == 0:
199 | continue
200 | data['data_indexes'] = {
201 | "data_indexes": [global_row_group_idx, row_idx],
202 | "worker_id": worker_id,
203 | "dataset_name": self.dataset_name,
204 | }
205 | except Exception as e:
206 | print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
207 | continue
208 | yield data
209 |
210 | row_start_id = 0
211 | global_row_group_start_id = 0
212 | print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
213 |
--------------------------------------------------------------------------------
/data/parquet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 |
5 | import os
6 | import xml.etree.ElementTree as ET
7 | import subprocess
8 | import logging
9 |
10 | import pyarrow.fs as pf
11 | import torch.distributed as dist
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
17 | num_data_dirs = len(data_dir_list)
18 | if world_size > 1:
19 | chunk_size = (num_data_dirs + world_size - 1) // world_size
20 | start_idx = rank * chunk_size
21 | end_idx = min(start_idx + chunk_size, num_data_dirs)
22 | local_data_dir_list = data_dir_list[start_idx:end_idx]
23 | local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
24 | else:
25 | local_data_dir_list = data_dir_list
26 | local_num_sampled_data_paths = num_sampled_data_paths
27 |
28 | local_data_paths = []
29 | for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
30 | if data_dir.startswith("hdfs://"):
31 | files = hdfs_ls_cmd(data_dir)
32 | data_paths_per_dir = [
33 | file for file in files if file.endswith(".parquet")
34 | ]
35 | else:
36 | files = os.listdir(data_dir)
37 | data_paths_per_dir = [
38 | os.path.join(data_dir, name)
39 | for name in files
40 | if name.endswith(".parquet")
41 | ]
42 | repeat = num_data_path // len(data_paths_per_dir)
43 | data_paths_per_dir = data_paths_per_dir * (repeat + 1)
44 | local_data_paths.extend(data_paths_per_dir[:num_data_path])
45 |
46 | if world_size > 1:
47 | gather_list = [None] * world_size
48 | dist.all_gather_object(gather_list, local_data_paths)
49 |
50 | combined_chunks = []
51 | for chunk_list in gather_list:
52 | if chunk_list is not None:
53 | combined_chunks.extend(chunk_list)
54 | else:
55 | combined_chunks = local_data_paths
56 |
57 | return combined_chunks
58 |
59 |
60 | # NOTE: cumtomize this function for your cluster
61 | def get_hdfs_host():
62 | return "hdfs://xxx"
63 |
64 |
65 | # NOTE: cumtomize this function for your cluster
66 | def get_hdfs_block_size():
67 | return 134217728
68 |
69 |
70 | # NOTE: cumtomize this function for your cluster
71 | def get_hdfs_extra_conf():
72 | return None
73 |
74 |
75 | def init_arrow_pf_fs(parquet_file_path):
76 | if parquet_file_path.startswith("hdfs://"):
77 | fs = pf.HadoopFileSystem(
78 | host=get_hdfs_host(),
79 | port=0,
80 | buffer_size=get_hdfs_block_size(),
81 | extra_conf=get_hdfs_extra_conf(),
82 | )
83 | else:
84 | fs = pf.LocalFileSystem()
85 | return fs
86 |
87 |
88 | def hdfs_ls_cmd(dir):
89 | result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
90 | return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
91 |
--------------------------------------------------------------------------------
/data/t2i_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import io
5 | import json
6 | import pyarrow.parquet as pq
7 | import random
8 | from PIL import Image
9 |
10 | from .data_utils import pil_img2rgb
11 | from .distributed_iterable_dataset import DistributedIterableDataset
12 | from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
13 |
14 | Image.MAX_IMAGE_PIXELS = 20_000_000
15 |
16 |
17 | class T2IIterableDataset(DistributedIterableDataset):
18 | def __init__(
19 | self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
20 | local_rank=0, world_size=1, num_workers=8, data_status=None,
21 | ):
22 | """
23 | data_dir_list: list of data directories contains parquet files
24 | num_used_data: list of number of sampled data paths for each data directory
25 | """
26 | super().__init__(dataset_name, local_rank, world_size, num_workers)
27 | self.transform = transform
28 | self.tokenizer = tokenizer
29 | self.data_status = data_status
30 | self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
31 | self.set_epoch()
32 |
33 | def get_data_paths(self, data_dir_list, num_used_data):
34 | return get_parquet_data_paths(data_dir_list, num_used_data)
35 |
36 | def __iter__(self):
37 | data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
38 | if self.data_status is not None:
39 | parquet_start_id = self.data_status[worker_id][0]
40 | row_group_start_id = self.data_status[worker_id][1]
41 | row_start_id = self.data_status[worker_id][2] + 1
42 | else:
43 | parquet_start_id = 0
44 | row_group_start_id = 0
45 | row_start_id = 0
46 | transform_stride = self.transform.stride
47 |
48 | print(
49 | f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
50 | f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
51 | )
52 |
53 | while True:
54 | data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
55 | for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
56 | fs = init_arrow_pf_fs(parquet_file_path)
57 | with fs.open_input_file(parquet_file_path) as f:
58 | fr = pq.ParquetFile(f)
59 | row_group_ids = list(range(fr.num_row_groups))
60 | row_group_ids_ = row_group_ids[row_group_start_id:]
61 |
62 | for row_group_id in row_group_ids_:
63 | df = fr.read_row_group(row_group_id).to_pandas()
64 | df = df.iloc[row_start_id:]
65 |
66 | for row_idx, row in df.iterrows():
67 | num_tokens = 0
68 | try:
69 | image_byte = row['image']
70 | image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
71 | except Exception as e:
72 | print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
73 | continue
74 | image_tensor = self.transform(image)
75 | height, width = image_tensor.shape[1:]
76 | num_tokens += width * height // transform_stride ** 2
77 |
78 | try:
79 | caption_dict = row['captions']
80 | caption_dict = json.loads(caption_dict)
81 | except Exception as e:
82 | print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
83 | continue
84 |
85 | caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
86 | if len(caps_token) == 0:
87 | print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
88 | caption_token = self.tokenizer.encode(' ')
89 | else:
90 | caption_token = random.choice(caps_token)
91 |
92 | sequence_plan, text_ids_list = [], []
93 | text_ids = caption_token
94 | num_tokens += len(caption_token)
95 | text_ids_list.append(text_ids)
96 | sequence_plan.append({
97 | 'type': 'text',
98 | 'enable_cfg': 1,
99 | 'loss': 0,
100 | 'special_token_loss': 0,
101 | 'special_token_label': None,
102 | })
103 |
104 | sequence_plan.append({
105 | 'type': 'vae_image',
106 | 'enable_cfg': 0,
107 | 'loss': 1,
108 | 'special_token_loss': 0,
109 | 'special_token_label': None,
110 | })
111 |
112 | sample = dict(
113 | image_tensor_list=[image_tensor],
114 | text_ids_list=text_ids_list,
115 | num_tokens=num_tokens,
116 | sequence_plan=sequence_plan,
117 | data_indexes={
118 | "data_indexes": [parquet_idx, row_group_id, row_idx],
119 | "worker_id": worker_id,
120 | "dataset_name": self.dataset_name,
121 | }
122 | )
123 | yield sample
124 |
125 | row_start_id = 0
126 | row_group_start_id = 0
127 | parquet_start_id = 0
128 | print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
129 |
--------------------------------------------------------------------------------
/data/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import random
5 | from PIL import Image
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | from torchvision import transforms
11 | from torchvision.transforms import functional as F
12 | from torchvision.transforms import InterpolationMode
13 |
14 |
15 | class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
16 | """Resize the input image so that its longest side and shortest side are within a specified range,
17 | ensuring that both sides are divisible by a specified stride.
18 |
19 | Args:
20 | max_size (int): Maximum size for the longest edge of the image.
21 | min_size (int): Minimum size for the shortest edge of the image.
22 | stride (int): Value by which the height and width of the image must be divisible.
23 | max_pixels (int): Maximum pixels for the full image.
24 | interpolation (InterpolationMode): Desired interpolation enum defined by
25 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
26 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
27 | ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
28 | The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
29 | antialias (bool, optional): Whether to apply antialiasing (default is True).
30 | """
31 |
32 | def __init__(
33 | self,
34 | max_size: int,
35 | min_size: int,
36 | stride: int,
37 | max_pixels: int,
38 | interpolation=InterpolationMode.BICUBIC,
39 | antialias=True
40 | ):
41 | super().__init__()
42 | self.max_size = max_size
43 | self.min_size = min_size
44 | self.stride = stride
45 | self.max_pixels = max_pixels
46 | self.interpolation = interpolation
47 | self.antialias = antialias
48 |
49 | def _make_divisible(self, value, stride):
50 | """Ensure the value is divisible by the stride."""
51 | return max(stride, int(round(value / stride) * stride))
52 |
53 | def _apply_scale(self, width, height, scale):
54 | new_width = round(width * scale)
55 | new_height = round(height * scale)
56 | new_width = self._make_divisible(new_width, self.stride)
57 | new_height = self._make_divisible(new_height, self.stride)
58 | return new_width, new_height
59 |
60 | def forward(self, img, img_num=1):
61 | """
62 | Args:
63 | img (PIL Image): Image to be resized.
64 | img_num (int): Number of images, used to change max_tokens.
65 | Returns:
66 | PIL Image or Tensor: Rescaled image with divisible dimensions.
67 | """
68 | if isinstance(img, torch.Tensor):
69 | height, width = img.shape[-2:]
70 | else:
71 | width, height = img.size
72 |
73 | scale = min(self.max_size / max(width, height), 1.0)
74 | scale = max(scale, self.min_size / min(width, height))
75 | new_width, new_height = self._apply_scale(width, height, scale)
76 |
77 | # Ensure the number of pixels does not exceed max_pixels
78 | if new_width * new_height > self.max_pixels / img_num:
79 | scale = self.max_pixels / img_num / (new_width * new_height)
80 | new_width, new_height = self._apply_scale(new_width, new_height, scale)
81 |
82 | # Ensure longest edge does not exceed max_size
83 | if max(new_width, new_height) > self.max_size:
84 | scale = self.max_size / max(new_width, new_height)
85 | new_width, new_height = self._apply_scale(new_width, new_height, scale)
86 |
87 | return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
88 |
89 |
90 | class ImageTransform:
91 | def __init__(
92 | self,
93 | max_image_size,
94 | min_image_size,
95 | image_stride,
96 | max_pixels=14*14*9*1024,
97 | image_mean=[0.5, 0.5, 0.5],
98 | image_std=[0.5, 0.5, 0.5]
99 | ):
100 | self.stride = image_stride
101 |
102 | self.resize_transform = MaxLongEdgeMinShortEdgeResize(
103 | max_size=max_image_size,
104 | min_size=min_image_size,
105 | stride=image_stride,
106 | max_pixels=max_pixels,
107 | )
108 | self.to_tensor_transform = transforms.ToTensor()
109 | self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
110 |
111 | def __call__(self, img, img_num=1):
112 | img = self.resize_transform(img, img_num=img_num)
113 | img = self.to_tensor_transform(img)
114 | img = self.normalize_transform(img)
115 | return img
116 |
117 |
118 | def decolorization(image):
119 | gray_image = image.convert('L')
120 | return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
121 |
122 |
123 | def downscale(image, scale_factor):
124 | new_width = int(round(image.width * scale_factor))
125 | new_height = int(round(image.height * scale_factor))
126 | new_width = max(1, new_width)
127 | new_height = max(1, new_height)
128 | return image.resize((new_width, new_height), resample=Image.BICUBIC)
129 |
130 |
131 | def crop(image, crop_factors):
132 | target_h, target_w = crop_factors
133 | img_w, img_h = image.size
134 |
135 | if target_h > img_h or target_w > img_w:
136 | raise ValueError("Crop size exceeds image dimensions")
137 |
138 | x = random.randint(0, img_w - target_w)
139 | y = random.randint(0, img_h - target_h)
140 |
141 | return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
142 |
143 |
144 | def motion_blur_opencv(image, kernel_size=15, angle=0):
145 | # 线性核
146 | kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
147 | kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
148 |
149 | # 旋转核
150 | center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
151 | M = cv2.getRotationMatrix2D(center, angle, 1)
152 | rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
153 |
154 | # 归一化核
155 | rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
156 |
157 | img = np.array(image)
158 | if img.ndim == 2:
159 | blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
160 | else:
161 | # 对于彩色图像,各通道独立卷积
162 | blurred = np.zeros_like(img)
163 | for c in range(img.shape[2]):
164 | blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
165 |
166 | return Image.fromarray(blurred.astype(np.uint8))
167 |
168 |
169 | def shuffle_patch(image, num_splits, gap_size=2):
170 | """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
171 | h_splits, w_splits = num_splits
172 | img_w, img_h = image.size
173 |
174 | base_patch_h = img_h // h_splits
175 | patch_heights = [base_patch_h] * (h_splits - 1)
176 | patch_heights.append(img_h - sum(patch_heights))
177 |
178 | base_patch_w = img_w // w_splits
179 | patch_widths = [base_patch_w] * (w_splits - 1)
180 | patch_widths.append(img_w - sum(patch_widths))
181 |
182 | patches = []
183 | current_y = 0
184 | for i in range(h_splits):
185 | current_x = 0
186 | patch_h = patch_heights[i]
187 | for j in range(w_splits):
188 | patch_w = patch_widths[j]
189 | patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
190 | patches.append(patch)
191 | current_x += patch_w
192 | current_y += patch_h
193 |
194 | random.shuffle(patches)
195 |
196 | total_width = sum(patch_widths) + (w_splits - 1) * gap_size
197 | total_height = sum(patch_heights) + (h_splits - 1) * gap_size
198 | new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
199 |
200 | current_y = 0 # 当前行的起始 Y 坐标
201 | patch_idx = 0 # 当前处理的块索引
202 | for i in range(h_splits):
203 | current_x = 0 # 当前列的起始 X 坐标
204 | patch_h = patch_heights[i] # 当前行块的高度
205 | for j in range(w_splits):
206 | # 取出打乱后的块
207 | patch = patches[patch_idx]
208 | patch_w = patch_widths[j] # 当前列块的宽度
209 | # 粘贴块(左上角坐标为 (current_x, current_y))
210 | new_image.paste(patch, (current_x, current_y))
211 | # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
212 | current_x += patch_w + gap_size
213 | patch_idx += 1
214 | # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
215 | current_y += patch_h + gap_size
216 |
217 | return new_image
218 |
219 |
220 | def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
221 | """
222 | 图像分割后随机空白部分patch,用于inpainting任务
223 |
224 | 参数:
225 | image: PIL.Image 输入图像(RGB模式)
226 | h_splits: int 行分割数(垂直方向分割块数)
227 | w_splits: int 列分割数(水平方向分割块数)
228 | blank_ratio: float 空白patch的比例(0~1)
229 | blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
230 |
231 | 返回:
232 | PIL.Image 处理后拼接的图像
233 | """
234 | h_splits, w_splits = num_splits
235 | img_w, img_h = image.size
236 |
237 | base_patch_h = img_h // h_splits
238 | patch_heights = [base_patch_h] * (h_splits - 1)
239 | patch_heights.append(img_h - sum(patch_heights))
240 |
241 | base_patch_w = img_w // w_splits
242 | patch_widths = [base_patch_w] * (w_splits - 1)
243 | patch_widths.append(img_w - sum(patch_widths))
244 |
245 | patches = []
246 | current_y = 0
247 | for i in range(h_splits):
248 | current_x = 0
249 | patch_h = patch_heights[i]
250 | for j in range(w_splits):
251 | patch_w = patch_widths[j]
252 | patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
253 | patches.append(patch)
254 | current_x += patch_w
255 | current_y += patch_h
256 |
257 | total_patches = h_splits * w_splits
258 | num_blank = int(total_patches * blank_ratio)
259 | num_blank = max(0, min(num_blank, total_patches))
260 | blank_indices = random.sample(range(total_patches), num_blank)
261 |
262 | processed_patches = []
263 | for idx, patch in enumerate(patches):
264 | if idx in blank_indices:
265 | blank_patch = Image.new("RGB", patch.size, color=blank_color)
266 | processed_patches.append(blank_patch)
267 | else:
268 | processed_patches.append(patch)
269 |
270 | # 创建结果图像(尺寸与原图一致)
271 | result_image = Image.new("RGB", (img_w, img_h))
272 | current_y = 0
273 | patch_idx = 0
274 | for i in range(h_splits):
275 | current_x = 0
276 | patch_h = patch_heights[i]
277 | for j in range(w_splits):
278 | # 取出处理后的patch
279 | patch = processed_patches[patch_idx]
280 | patch_w = patch_widths[j]
281 | # 粘贴到原位置
282 | result_image.paste(patch, (current_x, current_y))
283 | current_x += patch_w
284 | patch_idx += 1
285 | current_y += patch_h
286 |
287 | return result_image
288 |
--------------------------------------------------------------------------------
/data/video_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 OpenGVLab
2 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3 | # SPDX-License-Identifier: MIT
4 | #
5 | # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6 | #
7 | # Original file was released under MIT, with the full license text
8 | # available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
9 | #
10 | # This modified file is released under the same license.
11 |
12 |
13 | import io
14 | import os
15 | import random
16 | import re
17 |
18 | import numpy as np
19 | import decord
20 | from PIL import Image
21 |
22 |
23 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
24 | if sample in ['rand', 'middle']: # uniform sampling
25 | acc_samples = min(num_frames, vlen)
26 | # split the video into `acc_samples` intervals, and sample from each interval.
27 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
28 | ranges = []
29 | for idx, interv in enumerate(intervals[:-1]):
30 | ranges.append((interv, intervals[idx + 1] - 1))
31 | if sample == 'rand':
32 | try:
33 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
34 | except:
35 | frame_indices = np.random.permutation(vlen)[:acc_samples]
36 | frame_indices.sort()
37 | frame_indices = list(frame_indices)
38 | elif fix_start is not None:
39 | frame_indices = [x[0] + fix_start for x in ranges]
40 | elif sample == 'middle':
41 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
42 | else:
43 | raise NotImplementedError
44 |
45 | if len(frame_indices) < num_frames: # padded with last frame
46 | padded_frame_indices = [frame_indices[-1]] * num_frames
47 | padded_frame_indices[:len(frame_indices)] = frame_indices
48 | frame_indices = padded_frame_indices
49 | elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
50 | output_fps = float(sample[3:])
51 | duration = float(vlen) / input_fps
52 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
53 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
54 | frame_indices = np.around(frame_seconds * input_fps).astype(int)
55 | frame_indices = [e for e in frame_indices if e < vlen]
56 | if max_num_frames > 0 and len(frame_indices) > max_num_frames:
57 | frame_indices = frame_indices[:max_num_frames]
58 | else:
59 | raise ValueError
60 | return frame_indices
61 |
62 |
63 | def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
64 | video_reader = decord.VideoReader(video_path, num_threads=1)
65 | vlen = len(video_reader)
66 | fps = video_reader.get_avg_fps()
67 | duration = vlen / float(fps)
68 | if clip:
69 | start, end = clip
70 | duration = end - start
71 | vlen = int(duration * fps)
72 | start_index = int(start * fps)
73 |
74 | t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
75 |
76 | frame_indices = get_frame_indices(
77 | t_num_frames, vlen, sample=sample, fix_start=fix_start,
78 | input_fps=fps
79 | )
80 | if clip:
81 | frame_indices = [f + start_index for f in frame_indices]
82 | frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
83 | frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
84 | return frames
85 |
86 |
87 | def extract_frame_number(filename):
88 | # Extract the numeric part from the filename using regular expressions
89 | match = re.search(r'_(\d+).jpg$', filename)
90 | return int(match.group(1)) if match else -1
91 |
92 |
93 | def sort_frames(frame_paths):
94 | # Extract filenames from each path and sort by their numeric part
95 | return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
96 |
97 |
98 | def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
99 | image_list = sort_frames(list(os.listdir(video_path)))
100 | frames = []
101 | for image in image_list:
102 | fp = os.path.join(video_path, image)
103 | frame = Image.open(fp).convert('RGB')
104 | frames.append(frame)
105 | vlen = len(frames)
106 |
107 | t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
108 |
109 | if vlen > t_num_frames:
110 | frame_indices = get_frame_indices(
111 | t_num_frames, vlen, sample=sample, fix_start=fix_start
112 | )
113 | frames = [frames[i] for i in frame_indices]
114 | return frames
115 |
116 |
117 | class FrameSampler:
118 | def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
119 | self.max_num_frames = max_num_frames
120 | self.min_num_frames = min_num_frames
121 | self.sample = sample
122 |
123 | def __call__(self, file_name):
124 | fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
125 | frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
126 | return frames
127 |
128 |
129 | def decode_video_byte(video_bytes):
130 | video_stream = io.BytesIO(video_bytes)
131 | vr = decord.VideoReader(video_stream)
132 | return vr
133 |
134 |
135 | def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
136 | if isinstance(mp4_p, str):
137 | vr = decord.VideoReader(mp4_p, num_threads=1)
138 | elif isinstance(mp4_p, decord.video_reader.VideoReader):
139 | vr = mp4_p
140 | video_fps = vr.get_avg_fps() # 获取视频的帧率
141 | video_duration = len(vr) / video_fps
142 | if n_frames is not None:
143 | if random_sample:
144 | frame_indices = sorted(random.sample(range(len(vr)), n_frames))
145 | else:
146 | frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
147 | else:
148 | frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
149 | frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
150 | frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
151 | if not return_frame_indices:
152 | return frames, video_duration
153 | else:
154 | return frames, video_duration, frame_indices
155 |
156 |
157 | def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
158 | if isinstance(mp4_p, str):
159 | vr = decord.VideoReader(mp4_p, num_threads=1)
160 | elif isinstance(mp4_p, decord.video_reader.VideoReader):
161 | vr = mp4_p
162 | # sample the frames in frame_indices
163 | frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
164 | frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
165 | return frames
--------------------------------------------------------------------------------
/data/vlm_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import json
5 | import os
6 | import traceback
7 | from PIL import Image, ImageFile, PngImagePlugin
8 |
9 | from .data_utils import pil_img2rgb
10 | from .distributed_iterable_dataset import DistributedIterableDataset
11 |
12 |
13 | Image.MAX_IMAGE_PIXELS = 200000000
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 | MaximumDecompressedSize = 1024
16 | MegaByte = 2 ** 20
17 | PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
18 |
19 |
20 | class SftJSONLIterableDataset(DistributedIterableDataset):
21 | def __init__(
22 | self, dataset_name, transform, tokenizer, frame_sampler,
23 | jsonl_path_list, data_dir_list, num_used_data,
24 | local_rank=0, world_size=1, num_workers=8, data_status=None,
25 | shuffle_lines=False, shuffle_seed=0,
26 | ):
27 | """
28 | jsonl_path_list: list of jsonl file paths
29 | data_dir_list: list of image directories containing the images of each jsonl file
30 | num_used_data: list of number of sampled data points for each jsonl
31 | """
32 | super().__init__(dataset_name, local_rank, world_size, num_workers)
33 | self.transform = transform
34 | self.tokenizer = tokenizer
35 | self.frame_sampler = frame_sampler
36 | self.data_status = data_status
37 | self.data_paths = self.get_data_paths(
38 | jsonl_path_list,
39 | data_dir_list,
40 | num_used_data,
41 | shuffle_lines,
42 | shuffle_seed,
43 | )
44 | self.set_epoch()
45 |
46 | def get_data_paths(
47 | self,
48 | jsonl_path_list,
49 | data_dir_list,
50 | num_used_data,
51 | shuffle_lines,
52 | shuffle_seed,
53 | ):
54 | data_paths = []
55 | for jsonl_path, image_dir, num_data_point in zip(
56 | jsonl_path_list, data_dir_list, num_used_data
57 | ):
58 | with open(jsonl_path, 'r') as f:
59 | raw_data = f.readlines()
60 | if shuffle_lines:
61 | self.rng.seed(shuffle_seed)
62 | self.rng.shuffle(raw_data)
63 | raw_data = raw_data[:num_data_point]
64 | data_paths.extend([(json_data, image_dir) for json_data in raw_data])
65 | return data_paths
66 |
67 | def change_format(self, data, num_images):
68 | elements = []
69 | for conversation in data['conversations']:
70 | if conversation['from'] == 'human':
71 | if '' not in conversation['value']:
72 | elements.append({
73 | 'type': 'text',
74 | 'has_loss': 0,
75 | 'text': conversation['value'],
76 | })
77 | else:
78 | text_list = conversation['value'].split('')
79 | for idx, text in enumerate(text_list):
80 | if text.strip() != '':
81 | elements.append({
82 | 'type': 'text',
83 | 'has_loss': 0,
84 | 'text': text.strip(),
85 | })
86 | if (idx != len(text_list) - 1) and (idx < num_images):
87 | elements.append({'type': 'image',})
88 | elif conversation['from'] == 'gpt':
89 | elements.append({
90 | 'type': 'text',
91 | 'has_loss': 1,
92 | 'text': conversation['value'],
93 | })
94 | return elements
95 |
96 | def __iter__(self):
97 | data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
98 | if self.data_status is not None:
99 | row_start_id = self.data_status[worker_id] + 1
100 | else:
101 | row_start_id = 0
102 | transform_stride = self.transform.stride
103 |
104 | print(
105 | f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
106 | f"resuming data at row#{row_start_id}"
107 | )
108 |
109 | while True:
110 | data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
111 | for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
112 | num_tokens = 0
113 | image_tensor_list = []
114 | text_ids_list = []
115 | sequence_plan = []
116 |
117 | try:
118 | data_item = json.loads(data)
119 | raw_images = None
120 | if 'image' in data_item:
121 | if type(data_item['image']) == list:
122 | raw_images = [
123 | pil_img2rgb(Image.open(os.path.join(image_dir, image)))
124 | for image in data_item['image']
125 | ]
126 | else:
127 | raw_images = [
128 | pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
129 | ]
130 | elif 'video' in data_item:
131 | raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
132 | special_tokens = '' * len(raw_images)
133 | for item in data_item['conversations']:
134 | if '