├── .gitignore
├── LICENSE
├── README.md
├── args.py
├── assets
├── car.jpg
├── green_creature.jpg
├── mushroom.jpg
└── norte_dam.jpg
├── dataset.py
├── misc
├── accelerate_ds2.yaml
├── convert_to_wds.py
└── launch_edit.slurm
├── requirements.txt
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FluxEdit
2 |
3 | This project tries to teach [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) the task of image editing from language with the [Flux Control framework](https://blackforestlabs.ai/flux-1-tools/). We use the high-quality [`TIGER-Lab/OmniEdit-Filtered-1.2M`](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M/) dataset for training.
4 |
5 | * Fine-tuned edit model here: [`sayakpaul/FLUX.1-dev-edit-v0`](https://huggingface.co/sayakpaul/FLUX.1-dev-edit-v0)
6 | * Interactive to try out on Hugging Face Spaces: [`ariG23498/flux-edit`](https://huggingface.co/spaces/ariG23498/flux-edit) (thanks to `ariG23498` for this!)
7 |
8 |
9 |

10 |
11 |
12 | >[!IMPORTANT]
13 | > Since we don't have the official Flux Control training details available, this project should be considered experimental and we welcome contributions from the community to make it better 🤗
14 |
15 | ## Setup
16 |
17 | Install the dependencies from [`requirements.txt`](./requirements.txt) and perform any other configuration that might be needed.
18 |
19 | The scripts were tested using PyTorch 2.5.1 and NVIDIA GPUs (H100).
20 |
21 | ## Training
22 |
23 | We first converted the original OmniEdit dataset into Webdataset shards using [this script](./misc/convert_to_wds.py) for efficiency. This script prepares the Webdataset shards and push them to an S3 bucket. But you can configure this as per your needs.
24 |
25 |
26 | Training Command
27 |
28 | ```bash
29 | export LR=1e-4
30 | export WEIGHT_DECAY=1e-4
31 | export GUIDANCE_SCALE=30.0
32 | export CAPTION_DROPOUT=0.0
33 | export LR_SCHEDULER="constant"
34 |
35 | srun --wait=60 --kill-on-bad-exit=1 accelerate launch --config_file=./misc/accelerate_ds2.yaml train.py \
36 | --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
37 | --output_dir="omniflux-lr_${LR}-wd_${WEIGHT_DECAY}-gs_${GUIDANCE_SCALE}-cd_${CAPTION_DROPOUT}-scheduler_${LR_SCHEDULER}-sim_flow-no8bitadam" \
38 | --mixed_precision="bf16" \
39 | --per_gpu_batch_size=4 \
40 | --dataloader_num_workers=4 \
41 | --gradient_accumulation_steps=4 \
42 | --quality_threshold=10.0 \
43 | --simplified_flow \
44 | --gradient_checkpointing \
45 | --proportion_empty_prompts=$CAPTION_DROPOUT \
46 | --learning_rate=$LR \
47 | --adam_weight_decay=$WEIGHT_DECAY \
48 | --guidance_scale=$GUIDANCE_SCALE \
49 | --validation_guidance_scale="10.,20.,30.,40." \
50 | --report_to="wandb" --log_dataset_samples \
51 | --lr_scheduler=$LR_SCHEDULER \
52 | --lr_warmup_steps=0 \
53 | --checkpointing_steps=4000 \
54 | --resume_from_checkpoint="latest" --checkpoints_total_limit=2 \
55 | --max_train_steps=20000 \
56 | --validation_steps=400 \
57 | --validation_image "assets/car.jpg" "assets/green_creature.jpg" "assets/norte_dam.jpg" "assets/mushroom.jpg" \
58 | --validation_prompt "Give this the look of a traditional Japanese woodblock print." "transform the setting to a winter scene" "Change it to look like it's in the style of an impasto painting." "turn the color of mushroom to gray" \
59 | --seed="0" \
60 | --push_to_hub
61 |
62 | echo "END TIME: $(date)"
63 | ```
64 |
65 |
66 |
67 | Training starts on 8 GPUs using DeepSpeed. You can configure the [`accelerate` config file](./misc/accelerate_ds2.yaml) to change that.
68 |
69 | Refer to the [`args.py`](./args.py) to know the different kinds of configurations supported. Training was conducted on a node of 8 H100s. If you prefer using Slurm, refer to [this Slurm](./misc/launch_edit.slurm) script for scheduling training.
70 |
71 | You can also use this [minimal version of the `train.py` script](https://github.com/huggingface/diffusers/blob/main/examples/flux-control/train_control_flux.py) with a [minimal version of the OmniEdit dataset](https://huggingface.co/datasets/sayakpaul/OmniEdit-mini) for quicker prototyping.
72 |
73 | ## Inference
74 |
75 | ```py
76 | from diffusers import FluxControlPipeline, FluxTransformer2DModel
77 | from diffusers.utils import load_image
78 | import torch
79 |
80 | path = "sayakpaul/FLUX.1-dev-edit-v0"
81 | edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16)
82 | pipeline = FluxControlPipeline.from_pretrained(
83 | "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16
84 | ).to("cuda")
85 |
86 | image = load_image("./assets/mushroom.jpg") # resize as needed.
87 | print(image.size)
88 |
89 | prompt = "turn the color of mushroom to gray"
90 | image = pipeline(
91 | control_image=image,
92 | prompt=prompt,
93 | guidance_scale=30., # change this as needed.
94 | num_inference_steps=50, # change this as needed.
95 | max_sequence_length=512,
96 | height=image.height,
97 | width=image.width,
98 | generator=torch.manual_seed(0)
99 | ).images[0]
100 | image.save("edited_image.png")
101 | ```
102 |
103 | ### Speeding inference with a turbo LoRA
104 |
105 | We can speed up the inference by reducing the `num_inference_steps` to produce a nice image by using turbo LoRA like [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
106 |
107 | Make sure to install `peft` before running the code below: `pip install -U peft`.
108 |
109 |
110 | Code
111 |
112 | ```py
113 | from diffusers import FluxControlPipeline, FluxTransformer2DModel
114 | from diffusers.utils import load_image
115 | from huggingface_hub import hf_hub_download
116 | import torch
117 |
118 | path = "sayakpaul/FLUX.1-dev-edit-v0"
119 | edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16)
120 | control_pipe = FluxControlPipeline.from_pretrained(
121 | "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16
122 | ).to("cuda")
123 |
124 | # load the turbo LoRA
125 | control_pipe.load_lora_weights(
126 | hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
127 | )
128 | control_pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
129 |
130 | image = load_image("./assets/mushroom.jpg") # resize as needed.
131 | print(image.size)
132 |
133 | prompt = "turn the color of mushroom to gray"
134 | image = pipeline(
135 | control_image=image,
136 | prompt=prompt,
137 | guidance_scale=30., # change this as needed.
138 | num_inference_steps=8, # change this as needed.
139 | max_sequence_length=512,
140 | height=image.height,
141 | width=image.width,
142 | generator=torch.manual_seed(0)
143 | ).images[0]
144 | image.save("edited_image.png")
145 | ```
146 |
147 |
148 |
149 |
150 | Comparison
151 |
152 |
153 |
154 | 50 steps |
155 | 8 steps |
156 |
157 |
158 |  |
159 |  |
160 |
161 |
162 |  |
163 |  |
164 |
165 |
166 |  |
167 |  |
168 |
169 |
170 |  |
171 |  |
172 |
173 |
174 |
175 |
176 |
177 |
178 | You can also choose to perform quantization if the memory requirements cannot be satisfied further w.r.t your hardware. Refer to the [Diffusers documentation](https://huggingface.co/docs/diffusers/main/en/quantization/overview) to learn more.
179 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 |
19 | def parse_args(input_args=None):
20 | parser = argparse.ArgumentParser(description="Simple example of a Flux Control training script.")
21 | parser.add_argument(
22 | "--pretrained_model_name_or_path",
23 | type=str,
24 | default=None,
25 | required=True,
26 | help="Path to pretrained model or model identifier from huggingface.co/models.",
27 | )
28 | parser.add_argument(
29 | "--variant",
30 | type=str,
31 | default=None,
32 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
33 | )
34 | parser.add_argument(
35 | "--revision",
36 | type=str,
37 | default=None,
38 | required=False,
39 | help="Revision of pretrained model identifier from huggingface.co/models.",
40 | )
41 | parser.add_argument(
42 | "--output_dir",
43 | type=str,
44 | default="flux-control",
45 | help="The output directory where the model predictions and checkpoints will be written.",
46 | )
47 | parser.add_argument(
48 | "--cache_dir",
49 | type=str,
50 | default=None,
51 | help="The directory where the downloaded models and datasets will be stored.",
52 | )
53 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
54 | parser.add_argument(
55 | "--resolution",
56 | type=int,
57 | default=1024,
58 | help=(
59 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60 | " resolution"
61 | ),
62 | )
63 | parser.add_argument(
64 | "--quality_threshold",
65 | type=float,
66 | default=None,
67 | help="Threshold to use to do some quality filtering."
68 | )
69 | parser.add_argument(
70 | "--per_gpu_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
71 | )
72 | # Original count is 1203497, extra 700 is for "dev".
73 | parser.add_argument("--num_train_examples", type=int, default=1202797, help="Number of samples in the training dataset.")
74 | parser.add_argument("--num_train_epochs", type=int, default=1)
75 | parser.add_argument(
76 | "--max_train_steps",
77 | type=int,
78 | default=None,
79 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
80 | )
81 | parser.add_argument(
82 | "--checkpointing_steps",
83 | type=int,
84 | default=500,
85 | help=(
86 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
87 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
88 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
89 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
90 | "instructions."
91 | ),
92 | )
93 | parser.add_argument(
94 | "--checkpoints_total_limit",
95 | type=int,
96 | default=None,
97 | help=("Max number of checkpoints to store."),
98 | )
99 | parser.add_argument(
100 | "--resume_from_checkpoint",
101 | type=str,
102 | default=None,
103 | help=(
104 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
105 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
106 | ),
107 | )
108 | parser.add_argument(
109 | "--proportion_empty_prompts",
110 | type=float,
111 | default=0,
112 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
113 | )
114 | parser.add_argument(
115 | "--gradient_accumulation_steps",
116 | type=int,
117 | default=1,
118 | help="Number of updates steps to accumulate before performing a backward/update pass.",
119 | )
120 | parser.add_argument(
121 | "--gradient_checkpointing",
122 | action="store_true",
123 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
124 | )
125 | parser.add_argument(
126 | "--learning_rate",
127 | type=float,
128 | default=5e-6,
129 | help="Initial learning rate (after the potential warmup period) to use.",
130 | )
131 | parser.add_argument(
132 | "--scale_lr",
133 | action="store_true",
134 | default=False,
135 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
136 | )
137 | parser.add_argument(
138 | "--lr_scheduler",
139 | type=str,
140 | default="constant",
141 | help=(
142 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
143 | ' "constant", "constant_with_warmup"]'
144 | ),
145 | )
146 | parser.add_argument(
147 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
148 | )
149 | parser.add_argument(
150 | "--lr_num_cycles",
151 | type=int,
152 | default=1,
153 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
154 | )
155 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
156 | parser.add_argument(
157 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
158 | )
159 |
160 | parser.add_argument(
161 | "--dataloader_num_workers",
162 | type=int,
163 | default=0,
164 | help=(
165 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
166 | ),
167 | )
168 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
169 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
170 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
171 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
172 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
173 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
174 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
175 | parser.add_argument(
176 | "--hub_model_id",
177 | type=str,
178 | default=None,
179 | help="The name of the repository to keep in sync with the local `output_dir`.",
180 | )
181 | parser.add_argument(
182 | "--logging_dir",
183 | type=str,
184 | default="logs",
185 | help=(
186 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
187 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
188 | ),
189 | )
190 | parser.add_argument(
191 | "--allow_tf32",
192 | action="store_true",
193 | help=(
194 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
195 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
196 | ),
197 | )
198 | parser.add_argument(
199 | "--report_to",
200 | type=str,
201 | default="tensorboard",
202 | help=(
203 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
204 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
205 | ),
206 | )
207 | parser.add_argument(
208 | "--mixed_precision",
209 | type=str,
210 | default=None,
211 | choices=["no", "fp16", "bf16"],
212 | help=(
213 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
214 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
215 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
216 | ),
217 | )
218 | parser.add_argument(
219 | "--dataset_path",
220 | type=str,
221 | default="pipe:aws s3 cp s3://omniedit-wds/train-{00000..00570}-of-00571.tar -",
222 | )
223 | parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
224 | parser.add_argument(
225 | "--validation_prompt",
226 | type=str,
227 | default=None,
228 | nargs="+",
229 | help=(
230 | "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
231 | " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
232 | " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
233 | ),
234 | )
235 | parser.add_argument(
236 | "--validation_image",
237 | type=str,
238 | default=None,
239 | nargs="+",
240 | help=(
241 | "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
242 | " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
243 | " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
244 | " `--validation_image` that will be used with all `--validation_prompt`s."
245 | ),
246 | )
247 | parser.add_argument(
248 | "--num_validation_images",
249 | type=int,
250 | default=1,
251 | help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
252 | )
253 | parser.add_argument(
254 | "--validation_steps",
255 | type=int,
256 | default=100,
257 | help=(
258 | "Run validation every X steps. Validation consists of running the prompt"
259 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
260 | " and logging the images."
261 | ),
262 | )
263 | parser.add_argument(
264 | "--tracker_project_name",
265 | type=str,
266 | default="flux_train_control_omniedit",
267 | help=(
268 | "The `project_name` argument passed to Accelerator.init_trackers for"
269 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
270 | ),
271 | )
272 | parser.add_argument(
273 | "--jsonl_for_train",
274 | type=str,
275 | default=None,
276 | help="Path to the jsonl file containing the training data.",
277 | )
278 | parser.add_argument(
279 | "--only_target_transformer_blocks",
280 | action="store_true",
281 | help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
282 | )
283 | parser.add_argument(
284 | "--guidance_scale",
285 | type=float,
286 | default=30.0,
287 | help="the guidance scale used for transformer.",
288 | )
289 | parser.add_argument(
290 | "--validation_guidance_scale",
291 | type=str,
292 | default=None,
293 | help="List of guidance scales to use in validation. Provide as - 10,20,30, ..."
294 | )
295 | parser.add_argument(
296 | "--upcast_before_saving",
297 | action="store_true",
298 | help=(
299 | "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
300 | "Defaults to precision dtype used for training to save memory"
301 | ),
302 | )
303 | parser.add_argument("--simplified_flow", action="store_true", help="Flow simplified.")
304 | parser.add_argument(
305 | "--weighting_scheme",
306 | type=str,
307 | default="none",
308 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
309 | help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
310 | )
311 | parser.add_argument(
312 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
313 | )
314 | parser.add_argument(
315 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
316 | )
317 | parser.add_argument(
318 | "--mode_scale",
319 | type=float,
320 | default=1.29,
321 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
322 | )
323 | parser.add_argument(
324 | "--offload",
325 | action="store_true",
326 | help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
327 | )
328 |
329 | if input_args is not None:
330 | args = parser.parse_args(input_args)
331 | else:
332 | args = parser.parse_args()
333 |
334 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
335 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
336 |
337 | if args.validation_prompt is not None and args.validation_image is None:
338 | raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
339 |
340 | if args.validation_prompt is None and args.validation_image is not None:
341 | raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
342 |
343 | if (
344 | args.validation_image is not None
345 | and args.validation_prompt is not None
346 | and len(args.validation_image) != 1
347 | and len(args.validation_prompt) != 1
348 | and len(args.validation_image) != len(args.validation_prompt)
349 | ):
350 | raise ValueError(
351 | "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
352 | " or the same number of `--validation_prompt`s and `--validation_image`s"
353 | )
354 |
355 | if args.resolution % 8 != 0:
356 | raise ValueError(
357 | "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
358 | )
359 |
360 | if args.validation_guidance_scale is None:
361 | args.validation_guidance_scale = args.guidance_scale
362 | else:
363 | args.validation_guidance_scale = [float(value.strip()) for value in args.validation_guidance_scale.split(",")]
364 |
365 | return args
--------------------------------------------------------------------------------
/assets/car.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/car.jpg
--------------------------------------------------------------------------------
/assets/green_creature.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/green_creature.jpg
--------------------------------------------------------------------------------
/assets/mushroom.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/mushroom.jpg
--------------------------------------------------------------------------------
/assets/norte_dam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/norte_dam.jpg
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import webdataset as wds
2 | import torch
3 | import numpy as np
4 | from torchvision import transforms
5 | import math
6 |
7 | # Taken from https://github.com/tmbdev-archive/webdataset-imagenet-2/blob/01a4ab54307b9156c527d45b6b171f88623d2dec/imagenet.py#L65.
8 | def nodesplitter(src, group=None):
9 | if torch.distributed.is_initialized():
10 | if group is None:
11 | group = torch.distributed.group.WORLD
12 | rank = torch.distributed.get_rank(group=group)
13 | size = torch.distributed.get_world_size(group=group)
14 | count = 0
15 | for i, item in enumerate(src):
16 | if i % size == rank:
17 | yield item
18 | count += 1
19 | else:
20 | yield from src
21 |
22 | def collate_fn(samples):
23 | source_pixel_values = torch.stack([example["source_pixel_values"] for example in samples])
24 | source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
25 | edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in samples])
26 | edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
27 | captions = [example["prompt"] for example in samples]
28 | return {"source_pixel_values": source_pixel_values, "edited_pixel_values": edited_pixel_values, "captions": captions}
29 |
30 |
31 | class QualityFilter:
32 | def __init__(self, threshold):
33 | self.threshold = threshold
34 | def __call__(self, x):
35 | avg_sc_score = (x["sc_score_1"] + x["sc_score_2"]) / 2
36 | select = avg_sc_score >= self.threshold and x["pq_score"] >= self.threshold and x["o_score"] >= self.threshold
37 | if select:
38 | return True
39 | else:
40 | return False
41 |
42 | class ControlFluxDataset:
43 | def __init__(self, args):
44 | self.args = args
45 | self.image_transforms = transforms.Compose(
46 | [
47 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
48 | transforms.ToTensor(),
49 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
50 | ]
51 | )
52 |
53 | def get_dataset(self):
54 | args = self.args
55 | dataset = (
56 | wds.WebDataset(
57 | args.dataset_path,
58 | handler=wds.warn_and_continue,
59 | nodesplitter=nodesplitter,
60 | shardshuffle=500,
61 | empty_check=False
62 | )
63 | .shuffle(2000, handler=wds.warn_and_continue)
64 | .decode("pil", handler=wds.warn_and_continue)
65 | .rename(
66 | src_img="src_img.jpg",
67 | edited_img="edited_img.jpg",
68 | prompt_list="edited_prompt_list.json",
69 | sc_score_1="sc_score_1.txt",
70 | sc_score_2="sc_score_2.txt",
71 | pq_score="pq_score.txt",
72 | o_score="o_score.txt",
73 | handler=wds.warn_and_continue,
74 | )
75 | )
76 | dataset = dataset.map(self.preprocess_fn, handler=wds.warn_and_continue)
77 | dataset = dataset.select(QualityFilter(args.quality_threshold)) if args.quality_threshold else dataset
78 | return dataset
79 |
80 | def preprocess_fn(self, sample):
81 | source_pixel_values = self.image_transforms(sample["src_img"])
82 | edited_pixel_values = self.image_transforms(sample["edited_img"])
83 | prompt = np.random.choice(sample["prompt_list"]) if isinstance(sample["prompt_list"], list) else sample["prompt_list"]
84 |
85 | return {
86 | "source_pixel_values": source_pixel_values,
87 | "edited_pixel_values": edited_pixel_values,
88 | "prompt": prompt,
89 | "sc_score_1": float(sample["sc_score_1"]),
90 | "sc_score_2": float(sample["sc_score_2"]),
91 | "pq_score": float(sample["pq_score"]),
92 | "o_score": float(sample["o_score"])
93 | }
94 |
95 | def prepare_dataloader(self, dataset):
96 | args = self.args
97 | # per dataloader worker
98 | num_worker_batches = math.ceil(args.num_train_examples / (args.global_batch_size * args.dataloader_num_workers))
99 | dataset = dataset.batched(
100 | args.per_gpu_batch_size, partial=False, collation_fn=collate_fn
101 | ).with_epoch(num_worker_batches)
102 | dataloader = wds.WebLoader(
103 | dataset,
104 | batch_size=None,
105 | shuffle=False,
106 | num_workers=args.dataloader_num_workers,
107 | pin_memory=True,
108 | persistent_workers=True,
109 | )
110 | return dataloader
111 |
112 |
113 | if __name__ == "__main__":
114 | from argparse import Namespace
115 |
116 | args = Namespace(
117 | dataset_path="pipe:aws s3 cp s3://omniedit-wds/train-{00000..00570}-of-00571.tar -",
118 | num_train_examples=1203497,
119 | per_gpu_batch_size=8,
120 | global_batch_size=64,
121 | num_workers=4,
122 | resolution=256,
123 | )
124 | dataset_obj = ControlFluxDataset(args)
125 | dataset = dataset_obj.get_dataset()
126 |
127 | sample_count = 0
128 | for sample in dataset:
129 | print(sample.keys())
130 | print(sample["prompt"])
131 | sample_count += 1
132 |
133 | with open("dataset_actual_count.txt", "w") as f:
134 | f.write(str(sample_count))
135 |
136 | dataloader = dataset_obj.prepare_dataloader(dataset)
137 | for batch in dataloader:
138 | print(batch.keys())
139 | print(batch["pixel_values"].shape)
140 | break
141 |
--------------------------------------------------------------------------------
/misc/accelerate_ds2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 1
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | gpu_ids: all
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/misc/convert_to_wds.py:
--------------------------------------------------------------------------------
1 | """
2 | Make sure to change `path` as needed.
3 | Install `smart_open`, `ray` before running the script.
4 | if you're serializing to an S3 bucket, make sure you're authenticated.
5 | """
6 |
7 | from datasets import Dataset
8 | import json
9 | import webdataset as wds
10 | from smart_open import smart_open
11 | import os
12 | import ray
13 | import glob
14 |
15 | ray.init(num_cpus=16)
16 |
17 |
18 | if __name__ == "__main__":
19 | path = "/fsx/sayak/.cache/datasets--TIGER-Lab--OmniEdit-Filtered-1.2M/snapshots/82455c6cd66db7f0e5bfce8d7a236441af59d6df/data/"
20 | all_parquets = sorted(glob.glob(f"{path}/train-*.parquet"))
21 |
22 | @ray.remote
23 | def convert_to_wds(parquet_path):
24 | dataset = Dataset.from_parquet(parquet_path, split="train", cache_dir=path)
25 | shard_path = os.path.basename(parquet_path).replace(".parquet", ".tar")
26 | shard_path = os.path.join("s3://omniedit-wds", shard_path)
27 |
28 | with smart_open(shard_path, "wb") as s3_file:
29 | with wds.TarWriter(s3_file) as shard_writer:
30 | for i, example in enumerate(dataset):
31 | json_data = json.dumps(example["edited_prompt_list"]).encode("utf-8")
32 | src_img = example["src_img"].convert("RGB")
33 | edited_img = example["edited_img"].convert("RGB")
34 |
35 | wds_example = {
36 | "__key__": str(i),
37 | "omni_edit_id.txt": example["omni_edit_id"],
38 | "task.txt": example["task"],
39 | "src_img.jpg": src_img,
40 | "edited_img.jpg": edited_img,
41 | "edited_prompt_list.json": json_data,
42 | "sc_reasoning.txt": example["sc_reasoning"],
43 | "pq_reasoning.txt": example["pq_reasoning"],
44 | "height.txt": str(example["height"]),
45 | "width.txt": str(example["width"]),
46 | "sc_score_1.txt": str(example["sc_score_1"]),
47 | "sc_score_2.txt": str(example["sc_score_2"]),
48 | "pq_score.txt": str(example["pq_score"]),
49 | "o_score.txt": str(example["o_score"]),
50 | }
51 | shard_writer.write(wds_example)
52 |
53 | return shard_path
54 |
55 | futures = [convert_to_wds.remote(path) for path in all_parquets]
56 | ray.get(futures)
57 |
--------------------------------------------------------------------------------
/misc/launch_edit.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=control-flux-omniedit
3 | #SBATCH --nodes=1
4 | # Set the QoS
5 | #SBATCH --qos=normal
6 | # set 48h for job wall time limit
7 | #SBATCH --time=2-00:00:00
8 | # activate the requeue option
9 | #SBATCH --requeue
10 | #SBATCH --ntasks-per-node=1
11 | #SBATCH --cpus-per-task=96
12 | #SBATCH --gres=gpu:8
13 | #SBATCH --mem=1999G
14 | #SBATCH --partition=YOUR_PARTION_NAME
15 | #SBATCH -o /PATH/%x-%j-train.out
16 |
17 | set -xe
18 |
19 | export NCCL_ASYNC_ERROR_HANDLING=1
20 | export LR=1e-4
21 | export WEIGHT_DECAY=1e-4
22 | export GUIDANCE_SCALE=30.0
23 | export CAPTION_DROPOUT=0.0
24 | export LR_SCHEDULER="constant"
25 |
26 | srun --wait=60 --kill-on-bad-exit=1 accelerate launch --config_file=./misc/accelerate_ds2.yaml train.py \
27 | --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
28 | --output_dir="omniflux-lr_${LR}-wd_${WEIGHT_DECAY}-gs_${GUIDANCE_SCALE}-cd_${CAPTION_DROPOUT}-scheduler_${LR_SCHEDULER}-sim_flow-no8bitadam" \
29 | --mixed_precision="bf16" \
30 | --per_gpu_batch_size=4 \
31 | --dataloader_num_workers=4 \
32 | --gradient_accumulation_steps=4 \
33 | --quality_threshold=10.0 \
34 | --simplified_flow \
35 | --gradient_checkpointing \
36 | --proportion_empty_prompts=$CAPTION_DROPOUT \
37 | --learning_rate=$LR \
38 | --adam_weight_decay=$WEIGHT_DECAY \
39 | --guidance_scale=$GUIDANCE_SCALE \
40 | --validation_guidance_scale="10.,20.,30.,40." \
41 | --report_to="wandb" --log_dataset_samples \
42 | --lr_scheduler=$LR_SCHEDULER \
43 | --lr_warmup_steps=0 \
44 | --checkpointing_steps=4000 \
45 | --resume_from_checkpoint="latest" --checkpoints_total_limit=2 \
46 | --max_train_steps=20000 \
47 | --validation_steps=400 \
48 | --validation_image "car.jpg" "green_creature.jpg" "norte_dam.jpg" "mushroom.jpg" \
49 | --validation_prompt "Give this the look of a traditional Japanese woodblock print." "transform the setting to a winter scene" "Change it to look like it's in the style of an impasto painting." "turn the color of mushroom to gray" \
50 | --seed="0" \
51 | --push_to_hub
52 |
53 | echo "END TIME: $(date)"
54 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.5.1
2 | torchvision
3 | transformers
4 | deepspeed
5 | accelerate
6 | webdataset
7 | git+https://github.com/huggingface/diffusers.git
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import copy
18 | import logging
19 | import math
20 | import os
21 | import random
22 | import shutil
23 | from pathlib import Path
24 |
25 | import accelerate
26 | import torch
27 | import transformers
28 | from accelerate import Accelerator
29 | from accelerate.state import AcceleratorState
30 | from accelerate.logging import get_logger
31 | from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
32 | from huggingface_hub import create_repo, upload_folder
33 | from packaging import version
34 | from tqdm.auto import tqdm
35 |
36 | import diffusers
37 | from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
38 | from diffusers.optimization import get_scheduler
39 | from diffusers.training_utils import (
40 | compute_density_for_timestep_sampling,
41 | compute_loss_weighting_for_sd3,
42 | free_memory,
43 | )
44 | from diffusers.utils import check_min_version, is_wandb_available
45 | from diffusers.utils.torch_utils import is_compiled_module
46 | from dataset import ControlFluxDataset
47 | from args import parse_args
48 | from utils import encode_images, log_validation, save_model_card, expand_tensor_to_dim, log_data_samples_to_wandb
49 |
50 |
51 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
52 | check_min_version("0.33.0.dev0")
53 |
54 | logger = get_logger(__name__)
55 |
56 |
57 | def main(args):
58 | if args.report_to == "wandb" and args.hub_token is not None:
59 | raise ValueError(
60 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
61 | " Please use `huggingface-cli login` to authenticate with the Hub."
62 | )
63 |
64 | logging_out_dir = Path(args.output_dir, args.logging_dir)
65 |
66 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
67 | # due to pytorch#99272, MPS does not yet support bfloat16.
68 | raise ValueError(
69 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
70 | )
71 |
72 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
73 |
74 | accelerator = Accelerator(
75 | gradient_accumulation_steps=args.gradient_accumulation_steps,
76 | mixed_precision=args.mixed_precision,
77 | log_with=args.report_to,
78 | project_config=accelerator_project_config,
79 | )
80 | if accelerator.distributed_type == DistributedType.DEEPSPEED:
81 | AcceleratorState().deepspeed_plugin.deepspeed_config[
82 | "train_micro_batch_size_per_gpu"
83 | ] = args.per_gpu_batch_size
84 |
85 | # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
86 | if torch.backends.mps.is_available():
87 | logger.info("MPS is enabled. Disabling AMP.")
88 | accelerator.native_amp = False
89 |
90 | # Make one log on every process with the configuration for debugging.
91 | logging.basicConfig(
92 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
93 | datefmt="%m/%d/%Y %H:%M:%S",
94 | # DEBUG, INFO, WARNING, ERROR, CRITICAL
95 | level=logging.INFO,
96 | )
97 | logger.info(accelerator.state, main_process_only=False)
98 |
99 | if accelerator.is_local_main_process:
100 | transformers.utils.logging.set_verbosity_warning()
101 | diffusers.utils.logging.set_verbosity_info()
102 | else:
103 | transformers.utils.logging.set_verbosity_error()
104 | diffusers.utils.logging.set_verbosity_error()
105 |
106 | # If passed along, set the training seed now.
107 | if args.seed is not None:
108 | set_seed(args.seed)
109 |
110 | # Handle the repository creation
111 | if accelerator.is_main_process:
112 | if args.output_dir is not None:
113 | os.makedirs(args.output_dir, exist_ok=True)
114 |
115 | if args.push_to_hub:
116 | repo_id = create_repo(
117 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
118 | ).repo_id
119 |
120 | # Load models. We will load the text encoders later in a pipeline to compute
121 | # embeddings.
122 | vae = AutoencoderKL.from_pretrained(
123 | args.pretrained_model_name_or_path,
124 | subfolder="vae",
125 | revision=args.revision,
126 | variant=args.variant,
127 | )
128 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
129 | flux_transformer = FluxTransformer2DModel.from_pretrained(
130 | args.pretrained_model_name_or_path,
131 | subfolder="transformer",
132 | revision=args.revision,
133 | variant=args.variant,
134 | )
135 | logger.info("All models loaded successfully")
136 |
137 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
138 | args.pretrained_model_name_or_path,
139 | subfolder="scheduler",
140 | )
141 | noise_scheduler_copy = copy.deepcopy(noise_scheduler)
142 | if not args.only_target_transformer_blocks:
143 | flux_transformer.requires_grad_(True)
144 | vae.requires_grad_(False)
145 |
146 | # cast down and move to the CPU
147 | weight_dtype = torch.float32
148 | if accelerator.mixed_precision == "fp16":
149 | weight_dtype = torch.float16
150 | elif accelerator.mixed_precision == "bf16":
151 | weight_dtype = torch.bfloat16
152 |
153 | # let's not move the VAE to the GPU yet.
154 | vae.to(dtype=torch.float32) # keep the VAE in float32.
155 | if not args.offload:
156 | vae.to(accelerator.device)
157 |
158 | # enable image inputs
159 | with torch.no_grad():
160 | initial_input_channels = flux_transformer.config.in_channels
161 | new_linear = torch.nn.Linear(
162 | flux_transformer.x_embedder.in_features * 2,
163 | flux_transformer.x_embedder.out_features,
164 | bias=flux_transformer.x_embedder.bias is not None,
165 | dtype=flux_transformer.dtype,
166 | device=flux_transformer.device,
167 | )
168 | new_linear.weight.zero_()
169 | new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
170 | if flux_transformer.x_embedder.bias is not None:
171 | new_linear.bias.copy_(flux_transformer.x_embedder.bias)
172 | flux_transformer.x_embedder = new_linear
173 |
174 | assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
175 | flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
176 |
177 | if args.only_target_transformer_blocks:
178 | flux_transformer.x_embedder.requires_grad_(True)
179 | for name, module in flux_transformer.named_modules():
180 | if "transformer_blocks" in name:
181 | module.requires_grad_(True)
182 |
183 | def unwrap_model(model):
184 | model = accelerator.unwrap_model(model)
185 | model = model._orig_mod if is_compiled_module(model) else model
186 | return model
187 |
188 | # `accelerate` 0.16.0 will have better support for customized saving
189 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
190 |
191 | def save_model_hook(models, weights, output_dir):
192 | if accelerator.is_main_process:
193 | for model in models:
194 | if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
195 | model = unwrap_model(model)
196 | model.save_pretrained(os.path.join(output_dir, "transformer"))
197 | else:
198 | raise ValueError(f"unexpected save model: {model.__class__}")
199 |
200 | # make sure to pop weight so that corresponding model is not saved again
201 | if weights:
202 | weights.pop()
203 |
204 | def load_model_hook(models, input_dir):
205 | transformer_ = None
206 |
207 | if not accelerator.distributed_type == DistributedType.DEEPSPEED:
208 | while len(models) > 0:
209 | model = models.pop()
210 |
211 | if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
212 | transformer_ = model # noqa: F841
213 | else:
214 | raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}")
215 |
216 | else:
217 | transformer_ = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841
218 |
219 | accelerator.register_save_state_pre_hook(save_model_hook)
220 | accelerator.register_load_state_pre_hook(load_model_hook)
221 |
222 | if args.gradient_checkpointing:
223 | flux_transformer.enable_gradient_checkpointing()
224 |
225 | # Enable TF32 for faster training on Ampere GPUs,
226 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
227 | if args.allow_tf32:
228 | torch.backends.cuda.matmul.allow_tf32 = True
229 |
230 | if args.scale_lr:
231 | args.learning_rate = (
232 | args.learning_rate
233 | * args.gradient_accumulation_steps
234 | * args.per_gpu_batch_size
235 | * accelerator.num_processes
236 | )
237 |
238 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
239 | if args.use_8bit_adam:
240 | try:
241 | import bitsandbytes as bnb
242 | except ImportError:
243 | raise ImportError(
244 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
245 | )
246 |
247 | optimizer_class = bnb.optim.AdamW8bit
248 | else:
249 | optimizer_class = torch.optim.AdamW
250 |
251 | # Optimization parameters
252 | optimizer = optimizer_class(
253 | flux_transformer.parameters(),
254 | lr=args.learning_rate,
255 | betas=(args.adam_beta1, args.adam_beta2),
256 | weight_decay=args.adam_weight_decay,
257 | eps=args.adam_epsilon,
258 | )
259 |
260 | # Prepare dataset and dataloader.
261 | args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
262 | num_batches = math.ceil(args.num_train_examples / args.global_batch_size)
263 | dataset_obj = ControlFluxDataset(args)
264 | train_dataset = dataset_obj.get_dataset()
265 | train_dataloader = dataset_obj.prepare_dataloader(train_dataset)
266 |
267 | # Scheduler and math around the number of training steps.
268 | # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
269 | if args.max_train_steps is None:
270 | len_train_dataloader_after_sharding = math.ceil(args.num_train_examples / accelerator.num_processes)
271 | num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
272 | num_training_steps_for_scheduler = (
273 | args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
274 | )
275 | else:
276 | num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
277 |
278 | lr_scheduler = get_scheduler(
279 | args.lr_scheduler,
280 | optimizer=optimizer,
281 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
282 | num_training_steps=args.max_train_steps * accelerator.num_processes,
283 | num_cycles=args.lr_num_cycles,
284 | power=args.lr_power,
285 | )
286 | # Prepare everything with our `accelerator`.
287 | flux_transformer, optimizer, lr_scheduler = accelerator.prepare(
288 | flux_transformer, optimizer, lr_scheduler
289 | )
290 |
291 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
292 | num_update_steps_per_epoch = math.ceil(num_batches / args.gradient_accumulation_steps)
293 | if args.max_train_steps is None:
294 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
295 | if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
296 | logger.warning(
297 | f"The length of the 'train_dataloader' ({args.num_train_examples}) does not match "
298 | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
299 | f"This inconsistency may result in the learning rate scheduler not functioning properly."
300 | )
301 | # Afterwards we recalculate our number of training epochs
302 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
303 |
304 | # We need to initialize the trackers we use, and also store our configuration.
305 | # The trackers initializes automatically on the main process.
306 | if accelerator.is_main_process:
307 | tracker_config = dict(vars(args))
308 |
309 | # tensorboard cannot handle list types for config
310 | tracker_config.pop("validation_prompt")
311 | tracker_config.pop("validation_image")
312 |
313 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
314 |
315 | # Train!
316 | total_batch_size = args.per_gpu_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
317 |
318 | logger.info("***** Running training *****")
319 | logger.info(f" Num examples = {args.num_train_examples}")
320 | logger.info(f" Num Epochs = {args.num_train_epochs}")
321 | logger.info(f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
322 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
323 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
324 | logger.info(f" Total optimization steps = {args.max_train_steps}")
325 | global_step = 0
326 | first_epoch = 0
327 |
328 | # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.
329 | text_encoding_pipeline = FluxControlPipeline.from_pretrained(
330 | args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
331 | )
332 | if not args.offload:
333 | text_encoding_pipeline = text_encoding_pipeline.to("cuda")
334 |
335 | # Potentially load in the weights and states from a previous save
336 | if args.resume_from_checkpoint:
337 | if args.resume_from_checkpoint != "latest":
338 | path = os.path.basename(args.resume_from_checkpoint)
339 | else:
340 | # Get the most recent checkpoint
341 | dirs = os.listdir(args.output_dir)
342 | dirs = [d for d in dirs if d.startswith("checkpoint")]
343 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
344 | path = dirs[-1] if len(dirs) > 0 else None
345 |
346 | if path is None:
347 | logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
348 | args.resume_from_checkpoint = None
349 | initial_global_step = 0
350 | else:
351 | logger.info(f"Resuming from checkpoint {path}")
352 | accelerator.load_state(os.path.join(args.output_dir, path))
353 | global_step = int(path.split("-")[1])
354 |
355 | initial_global_step = global_step
356 | first_epoch = global_step // num_update_steps_per_epoch
357 | else:
358 | initial_global_step = 0
359 |
360 | if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
361 | log_data_samples_to_wandb(dataloader=train_dataloader, accelerator=accelerator)
362 |
363 | progress_bar = tqdm(
364 | range(0, args.max_train_steps),
365 | initial=initial_global_step,
366 | desc="Steps",
367 | # Only show the progress bar once on each machine.
368 | disable=not accelerator.is_local_main_process,
369 | )
370 |
371 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
372 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
373 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
374 | timesteps = timesteps.to(accelerator.device)
375 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
376 |
377 | sigma = sigmas[step_indices]
378 | return expand_tensor_to_dim(sigma, n_dim)
379 |
380 | image_logs = None
381 | for epoch in range(first_epoch, args.num_train_epochs):
382 | flux_transformer.train()
383 | for step, batch in enumerate(train_dataloader):
384 | with accelerator.accumulate(flux_transformer):
385 | # Convert images to latent space
386 | # vae encode
387 | if args.offload:
388 | vae.to(accelerator.device)
389 | edited_pixel_values = batch["edited_pixel_values"].to(dtype=vae.dtype, device=accelerator.device, non_blocking=True)
390 | pixel_latents = encode_images(edited_pixel_values, vae, weight_dtype)
391 | conditioning_pixel_values = batch["source_pixel_values"].to(
392 | dtype=vae.dtype, device=accelerator.device, non_blocking=True
393 | )
394 | control_latents = encode_images(conditioning_pixel_values, vae, weight_dtype)
395 | if args.offload:
396 | # offload vae to CPU.
397 | vae.cpu()
398 |
399 | # Sample a random timestep for each image
400 | # for weighting schemes where we sample timesteps non-uniformly
401 | bsz = pixel_latents.shape[0]
402 | noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)
403 | if args.simplified_flow:
404 | sigmas = torch.rand(bsz, device=accelerator.device, dtype=pixel_latents.dtype)
405 | timesteps = (sigmas * noise_scheduler_copy.config.num_train_timesteps).long().to(accelerator.device)
406 | sigmas = expand_tensor_to_dim(sigmas, pixel_latents.ndim)
407 | else:
408 | u = compute_density_for_timestep_sampling(
409 | weighting_scheme=args.weighting_scheme,
410 | batch_size=bsz,
411 | logit_mean=args.logit_mean,
412 | logit_std=args.logit_std,
413 | mode_scale=args.mode_scale,
414 | )
415 | indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
416 | timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
417 | sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
418 |
419 | # Add noise according to flow matching.
420 | noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
421 | concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
422 |
423 | # pack the latents.
424 | packed_noisy_model_input = FluxControlPipeline._pack_latents(
425 | concatenated_noisy_model_input,
426 | batch_size=bsz,
427 | num_channels_latents=concatenated_noisy_model_input.shape[1],
428 | height=concatenated_noisy_model_input.shape[2],
429 | width=concatenated_noisy_model_input.shape[3],
430 | )
431 |
432 | # latent image ids for RoPE.
433 | latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(
434 | bsz,
435 | concatenated_noisy_model_input.shape[2] // 2,
436 | concatenated_noisy_model_input.shape[3] // 2,
437 | accelerator.device,
438 | weight_dtype,
439 | )
440 |
441 | # handle guidance
442 | if unwrap_model(flux_transformer).config.guidance_embeds:
443 | guidance_vec = torch.full(
444 | (bsz,),
445 | args.guidance_scale,
446 | device=noisy_model_input.device,
447 | dtype=weight_dtype,
448 | )
449 | else:
450 | guidance_vec = None
451 |
452 | # text encoding.
453 | captions = batch["captions"]
454 | if args.offload:
455 | text_encoding_pipeline = text_encoding_pipeline.to("cuda")
456 | with torch.no_grad():
457 | prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
458 | captions, prompt_2=None
459 | )
460 | # this could be optimized by not having to do any text encoding and just
461 | # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
462 | if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
463 | prompt_embeds.zero_()
464 | pooled_prompt_embeds.zero_()
465 | if args.offload:
466 | text_encoding_pipeline = text_encoding_pipeline.to("cpu")
467 |
468 | # Predict.
469 | model_pred = flux_transformer(
470 | hidden_states=packed_noisy_model_input,
471 | timestep=timesteps / 1000,
472 | guidance=guidance_vec,
473 | pooled_projections=pooled_prompt_embeds,
474 | encoder_hidden_states=prompt_embeds,
475 | txt_ids=text_ids,
476 | img_ids=latent_image_ids,
477 | return_dict=False,
478 | )[0]
479 | model_pred = FluxControlPipeline._unpack_latents(
480 | model_pred,
481 | height=noisy_model_input.shape[2] * vae_scale_factor,
482 | width=noisy_model_input.shape[3] * vae_scale_factor,
483 | vae_scale_factor=vae_scale_factor,
484 | )
485 | # these weighting schemes use a uniform timestep sampling
486 | # and instead post-weight the loss
487 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
488 |
489 | # flow-matching loss
490 | target = noise - pixel_latents
491 | loss = torch.mean(
492 | (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
493 | 1,
494 | )
495 | loss = loss.mean()
496 | accelerator.backward(loss)
497 |
498 | if accelerator.sync_gradients:
499 | params_to_clip = flux_transformer.parameters()
500 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
501 |
502 | optimizer.step()
503 | lr_scheduler.step()
504 | optimizer.zero_grad()
505 |
506 | # Checks if the accelerator has performed an optimization step behind the scenes
507 | if accelerator.sync_gradients:
508 | progress_bar.update(1)
509 | global_step += 1
510 |
511 | # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
512 | if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
513 | if global_step % args.checkpointing_steps == 0:
514 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
515 | if args.checkpoints_total_limit is not None:
516 | checkpoints = os.listdir(args.output_dir)
517 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
518 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
519 |
520 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
521 | if len(checkpoints) >= args.checkpoints_total_limit:
522 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
523 | removing_checkpoints = checkpoints[0:num_to_remove]
524 |
525 | logger.info(
526 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
527 | )
528 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
529 |
530 | for removing_checkpoint in removing_checkpoints:
531 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
532 | shutil.rmtree(removing_checkpoint)
533 |
534 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
535 | accelerator.save_state(save_path)
536 | logger.info(f"Saved state to {save_path}")
537 |
538 | if args.validation_prompt is not None and global_step % args.validation_steps == 0:
539 | image_logs = log_validation(
540 | flux_transformer=flux_transformer,
541 | args=args,
542 | accelerator=accelerator,
543 | weight_dtype=weight_dtype,
544 | step=global_step,
545 | )
546 |
547 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
548 | progress_bar.set_postfix(**logs)
549 | accelerator.log(logs, step=global_step)
550 |
551 | if global_step >= args.max_train_steps:
552 | break
553 |
554 | # Create the pipeline using using the trained modules and save it.
555 | accelerator.wait_for_everyone()
556 | if accelerator.is_main_process:
557 | flux_transformer = unwrap_model(flux_transformer)
558 | if args.upcast_before_saving:
559 | flux_transformer.to(torch.float32)
560 | flux_transformer.save_pretrained(args.output_dir)
561 |
562 | del flux_transformer
563 | del text_encoding_pipeline
564 | del vae
565 | free_memory()
566 |
567 | # Run a final round of validation.
568 | image_logs = None
569 | if args.validation_prompt is not None:
570 | image_logs = log_validation(
571 | flux_transformer=None,
572 | args=args,
573 | accelerator=accelerator,
574 | weight_dtype=weight_dtype,
575 | step=global_step,
576 | is_final_validation=True,
577 | )
578 |
579 | if args.push_to_hub:
580 | save_model_card(
581 | repo_id,
582 | image_logs=image_logs,
583 | base_model=args.pretrained_model_name_or_path,
584 | repo_folder=args.output_dir,
585 | )
586 | upload_folder(
587 | repo_id=repo_id,
588 | folder_path=args.output_dir,
589 | commit_message="End of training",
590 | ignore_patterns=["step_*", "epoch_*", "checkpoint-*"],
591 | )
592 |
593 | accelerator.end_training()
594 |
595 |
596 | if __name__ == "__main__":
597 | args = parse_args()
598 | main(args)
599 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import os
19 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
20 | from contextlib import nullcontext
21 | from diffusers import FluxControlPipeline, FluxTransformer2DModel
22 | from diffusers.utils import load_image, make_image_grid, is_wandb_available
23 | from diffusers.training_utils import free_memory
24 | from accelerate.logging import get_logger
25 |
26 |
27 | logger = get_logger(__name__)
28 |
29 |
30 | if is_wandb_available():
31 | import wandb
32 |
33 | def expand_tensor_to_dim(tensor, ndim):
34 | tensor = tensor.flatten()
35 | while len(tensor.shape) < ndim:
36 | tensor = tensor.unsqueeze(-1)
37 | return tensor
38 |
39 | def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
40 | pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
41 | pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
42 | return pixel_latents.to(weight_dtype)
43 |
44 |
45 | def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
46 | logger.info("Running validation... ")
47 |
48 | if not is_final_validation:
49 | flux_transformer = accelerator.unwrap_model(flux_transformer)
50 | pipeline = FluxControlPipeline.from_pretrained(
51 | args.pretrained_model_name_or_path,
52 | transformer=flux_transformer,
53 | torch_dtype=weight_dtype,
54 | )
55 | else:
56 | transformer = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
57 | pipeline = FluxControlPipeline.from_pretrained(
58 | args.pretrained_model_name_or_path,
59 | transformer=transformer,
60 | torch_dtype=weight_dtype,
61 | )
62 |
63 | pipeline.to(accelerator.device)
64 | pipeline.set_progress_bar_config(disable=True)
65 |
66 | if args.seed is None:
67 | generator = None
68 | else:
69 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
70 |
71 | if len(args.validation_image) == len(args.validation_prompt):
72 | validation_images = args.validation_image
73 | validation_prompts = args.validation_prompt
74 | elif len(args.validation_image) == 1:
75 | validation_images = args.validation_image * len(args.validation_prompt)
76 | validation_prompts = args.validation_prompt
77 | elif len(args.validation_prompt) == 1:
78 | validation_images = args.validation_image
79 | validation_prompts = args.validation_prompt * len(args.validation_image)
80 | else:
81 | raise ValueError(
82 | "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
83 | )
84 |
85 | image_logs = []
86 | if is_final_validation or torch.backends.mps.is_available():
87 | autocast_ctx = nullcontext()
88 | else:
89 | autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
90 |
91 | vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
92 | for validation_prompt, validation_image in zip(validation_prompts, validation_images):
93 | validation_image = load_image(validation_image)
94 | width, height = validation_image.size
95 | if height % (vae_scale_factor * 2) != 0 or width % (vae_scale_factor * 2) != 0:
96 | height, width = (1024, 1024)
97 |
98 | for gs in args.validation_guidance_scale:
99 | images = []
100 | for _ in range(args.num_validation_images):
101 | with autocast_ctx:
102 | image = pipeline(
103 | prompt=validation_prompt,
104 | control_image=validation_image,
105 | num_inference_steps=50,
106 | guidance_scale=gs,
107 | generator=generator,
108 | max_sequence_length=512,
109 | height=height,
110 | width=width,
111 | ).images[0]
112 | images.append(image)
113 | image_logs.append(
114 | {
115 | "validation_image": validation_image,
116 | "images": images,
117 | "validation_prompt": validation_prompt,
118 | "guidance_scale": gs
119 | }
120 | )
121 |
122 | tracker_key = "test" if is_final_validation else "validation"
123 | for tracker in accelerator.trackers:
124 | if tracker.name == "wandb":
125 | formatted_images = []
126 | for log in image_logs:
127 | images = log["images"]
128 | guidance_scale = log["guidance_scale"]
129 | validation_prompt = log["validation_prompt"]
130 | validation_image = log["validation_image"]
131 | formatted_images.append(wandb.Image(validation_image, caption=f"Conditioning (gs: {guidance_scale})"))
132 | for image in images:
133 | image = wandb.Image(image, caption=validation_prompt)
134 | formatted_images.append(image)
135 |
136 | tracker.log({tracker_key: formatted_images})
137 | else:
138 | logger.warning(f"image logging not implemented for {tracker.name}")
139 |
140 | del pipeline
141 | free_memory()
142 | return image_logs
143 |
144 |
145 | def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
146 | img_str = ""
147 | if image_logs is not None:
148 | img_str = "You can find some example images below.\n\n"
149 | for i, log in enumerate(image_logs):
150 | images = log["images"]
151 | validation_prompt = log["validation_prompt"]
152 | validation_image = log["validation_image"]
153 | validation_image.save(os.path.join(repo_folder, "image_control.png"))
154 | img_str += f"prompt: {validation_prompt}\n"
155 | images = [validation_image] + images
156 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
157 | img_str += f"\n"
158 |
159 | model_description = f"""
160 | # flux-control-{repo_id}
161 |
162 | These are Control weights trained on {base_model} and [TIGER-Lab/OmniEdit-Filtered-1.2M](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M).
163 | {img_str}
164 |
165 | ## License
166 |
167 | Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
168 | """
169 |
170 | model_card = load_or_create_model_card(
171 | repo_id_or_path=repo_id,
172 | from_training=True,
173 | license="other",
174 | base_model=base_model,
175 | model_description=model_description,
176 | inference=True,
177 | )
178 |
179 | tags = [
180 | "flux",
181 | "flux-diffusers",
182 | "text-to-image",
183 | "diffusers",
184 | "control",
185 | "diffusers-training",
186 | ]
187 | model_card = populate_model_card(model_card, tags=tags)
188 |
189 | model_card.save(os.path.join(repo_folder, "README.md"))
190 |
191 |
192 | def log_data_samples_to_wandb(dataloader, accelerator):
193 | logger.info("Logging some dataset samples.")
194 | formatted_src_images = []
195 | formatted_edited_images = []
196 | all_prompts = []
197 | for i, batch in enumerate(dataloader):
198 | source_images = (batch["source_pixel_values"] + 1) / 2
199 | edited_images = (batch["edited_pixel_values"] + 1) / 2
200 | prompts = batch["captions"]
201 |
202 | if len(formatted_src_images) > 10:
203 | break
204 |
205 | for img, edited_img, prompt in zip(source_images, edited_images, prompts):
206 | formatted_src_images.append(img)
207 | formatted_edited_images.append(edited_img)
208 | all_prompts.append(prompt)
209 |
210 | logged_artifacts = []
211 | for img, edited_img, prompt in zip(formatted_src_images, formatted_edited_images, all_prompts):
212 | logged_artifacts.append(wandb.Image(img, caption="Conditioning"))
213 | logged_artifacts.append(wandb.Image(edited_img, caption=prompt))
214 |
215 | wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
216 | assert wandb_tracker, "wandb couldn't be found in the trackers."
217 | wandb_tracker[0].log({"dataset_samples": logged_artifacts})
--------------------------------------------------------------------------------