├── .gitignore ├── LICENSE ├── README.md ├── args.py ├── assets ├── car.jpg ├── green_creature.jpg ├── mushroom.jpg └── norte_dam.jpg ├── dataset.py ├── misc ├── accelerate_ds2.yaml ├── convert_to_wds.py └── launch_edit.slurm ├── requirements.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FluxEdit 2 | 3 | This project tries to teach [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) the task of image editing from language with the [Flux Control framework](https://blackforestlabs.ai/flux-1-tools/). We use the high-quality [`TIGER-Lab/OmniEdit-Filtered-1.2M`](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M/) dataset for training. 4 | 5 | * Fine-tuned edit model here: [`sayakpaul/FLUX.1-dev-edit-v0`](https://huggingface.co/sayakpaul/FLUX.1-dev-edit-v0) 6 | * Interactive to try out on Hugging Face Spaces: [`ariG23498/flux-edit`](https://huggingface.co/spaces/ariG23498/flux-edit) (thanks to `ariG23498` for this!) 7 | 8 |
9 | GIF 10 |
11 | 12 | >[!IMPORTANT] 13 | > Since we don't have the official Flux Control training details available, this project should be considered experimental and we welcome contributions from the community to make it better 🤗 14 | 15 | ## Setup 16 | 17 | Install the dependencies from [`requirements.txt`](./requirements.txt) and perform any other configuration that might be needed. 18 | 19 | The scripts were tested using PyTorch 2.5.1 and NVIDIA GPUs (H100). 20 | 21 | ## Training 22 | 23 | We first converted the original OmniEdit dataset into Webdataset shards using [this script](./misc/convert_to_wds.py) for efficiency. This script prepares the Webdataset shards and push them to an S3 bucket. But you can configure this as per your needs. 24 | 25 |
26 | Training Command 27 | 28 | ```bash 29 | export LR=1e-4 30 | export WEIGHT_DECAY=1e-4 31 | export GUIDANCE_SCALE=30.0 32 | export CAPTION_DROPOUT=0.0 33 | export LR_SCHEDULER="constant" 34 | 35 | srun --wait=60 --kill-on-bad-exit=1 accelerate launch --config_file=./misc/accelerate_ds2.yaml train.py \ 36 | --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ 37 | --output_dir="omniflux-lr_${LR}-wd_${WEIGHT_DECAY}-gs_${GUIDANCE_SCALE}-cd_${CAPTION_DROPOUT}-scheduler_${LR_SCHEDULER}-sim_flow-no8bitadam" \ 38 | --mixed_precision="bf16" \ 39 | --per_gpu_batch_size=4 \ 40 | --dataloader_num_workers=4 \ 41 | --gradient_accumulation_steps=4 \ 42 | --quality_threshold=10.0 \ 43 | --simplified_flow \ 44 | --gradient_checkpointing \ 45 | --proportion_empty_prompts=$CAPTION_DROPOUT \ 46 | --learning_rate=$LR \ 47 | --adam_weight_decay=$WEIGHT_DECAY \ 48 | --guidance_scale=$GUIDANCE_SCALE \ 49 | --validation_guidance_scale="10.,20.,30.,40." \ 50 | --report_to="wandb" --log_dataset_samples \ 51 | --lr_scheduler=$LR_SCHEDULER \ 52 | --lr_warmup_steps=0 \ 53 | --checkpointing_steps=4000 \ 54 | --resume_from_checkpoint="latest" --checkpoints_total_limit=2 \ 55 | --max_train_steps=20000 \ 56 | --validation_steps=400 \ 57 | --validation_image "assets/car.jpg" "assets/green_creature.jpg" "assets/norte_dam.jpg" "assets/mushroom.jpg" \ 58 | --validation_prompt "Give this the look of a traditional Japanese woodblock print." "transform the setting to a winter scene" "Change it to look like it's in the style of an impasto painting." "turn the color of mushroom to gray" \ 59 | --seed="0" \ 60 | --push_to_hub 61 | 62 | echo "END TIME: $(date)" 63 | ``` 64 | 65 |
66 | 67 | Training starts on 8 GPUs using DeepSpeed. You can configure the [`accelerate` config file](./misc/accelerate_ds2.yaml) to change that. 68 | 69 | Refer to the [`args.py`](./args.py) to know the different kinds of configurations supported. Training was conducted on a node of 8 H100s. If you prefer using Slurm, refer to [this Slurm](./misc/launch_edit.slurm) script for scheduling training. 70 | 71 | You can also use this [minimal version of the `train.py` script](https://github.com/huggingface/diffusers/blob/main/examples/flux-control/train_control_flux.py) with a [minimal version of the OmniEdit dataset](https://huggingface.co/datasets/sayakpaul/OmniEdit-mini) for quicker prototyping. 72 | 73 | ## Inference 74 | 75 | ```py 76 | from diffusers import FluxControlPipeline, FluxTransformer2DModel 77 | from diffusers.utils import load_image 78 | import torch 79 | 80 | path = "sayakpaul/FLUX.1-dev-edit-v0" 81 | edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16) 82 | pipeline = FluxControlPipeline.from_pretrained( 83 | "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16 84 | ).to("cuda") 85 | 86 | image = load_image("./assets/mushroom.jpg") # resize as needed. 87 | print(image.size) 88 | 89 | prompt = "turn the color of mushroom to gray" 90 | image = pipeline( 91 | control_image=image, 92 | prompt=prompt, 93 | guidance_scale=30., # change this as needed. 94 | num_inference_steps=50, # change this as needed. 95 | max_sequence_length=512, 96 | height=image.height, 97 | width=image.width, 98 | generator=torch.manual_seed(0) 99 | ).images[0] 100 | image.save("edited_image.png") 101 | ``` 102 | 103 | ### Speeding inference with a turbo LoRA 104 | 105 | We can speed up the inference by reducing the `num_inference_steps` to produce a nice image by using turbo LoRA like [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD). 106 | 107 | Make sure to install `peft` before running the code below: `pip install -U peft`. 108 | 109 |
110 | Code 111 | 112 | ```py 113 | from diffusers import FluxControlPipeline, FluxTransformer2DModel 114 | from diffusers.utils import load_image 115 | from huggingface_hub import hf_hub_download 116 | import torch 117 | 118 | path = "sayakpaul/FLUX.1-dev-edit-v0" 119 | edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16) 120 | control_pipe = FluxControlPipeline.from_pretrained( 121 | "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16 122 | ).to("cuda") 123 | 124 | # load the turbo LoRA 125 | control_pipe.load_lora_weights( 126 | hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" 127 | ) 128 | control_pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125]) 129 | 130 | image = load_image("./assets/mushroom.jpg") # resize as needed. 131 | print(image.size) 132 | 133 | prompt = "turn the color of mushroom to gray" 134 | image = pipeline( 135 | control_image=image, 136 | prompt=prompt, 137 | guidance_scale=30., # change this as needed. 138 | num_inference_steps=8, # change this as needed. 139 | max_sequence_length=512, 140 | height=image.height, 141 | width=image.width, 142 | generator=torch.manual_seed(0) 143 | ).images[0] 144 | image.save("edited_image.png") 145 | ``` 146 | 147 |
148 |
149 |
150 | Comparison 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 |
50 steps8 steps
50 steps 18 steps 1
50 steps 28 steps 2
50 steps 38 steps 3
50 steps 48 steps 4
174 | 175 | 176 |
177 | 178 | You can also choose to perform quantization if the memory requirements cannot be satisfied further w.r.t your hardware. Refer to the [Diffusers documentation](https://huggingface.co/docs/diffusers/main/en/quantization/overview) to learn more. 179 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | 19 | def parse_args(input_args=None): 20 | parser = argparse.ArgumentParser(description="Simple example of a Flux Control training script.") 21 | parser.add_argument( 22 | "--pretrained_model_name_or_path", 23 | type=str, 24 | default=None, 25 | required=True, 26 | help="Path to pretrained model or model identifier from huggingface.co/models.", 27 | ) 28 | parser.add_argument( 29 | "--variant", 30 | type=str, 31 | default=None, 32 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 33 | ) 34 | parser.add_argument( 35 | "--revision", 36 | type=str, 37 | default=None, 38 | required=False, 39 | help="Revision of pretrained model identifier from huggingface.co/models.", 40 | ) 41 | parser.add_argument( 42 | "--output_dir", 43 | type=str, 44 | default="flux-control", 45 | help="The output directory where the model predictions and checkpoints will be written.", 46 | ) 47 | parser.add_argument( 48 | "--cache_dir", 49 | type=str, 50 | default=None, 51 | help="The directory where the downloaded models and datasets will be stored.", 52 | ) 53 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 54 | parser.add_argument( 55 | "--resolution", 56 | type=int, 57 | default=1024, 58 | help=( 59 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 60 | " resolution" 61 | ), 62 | ) 63 | parser.add_argument( 64 | "--quality_threshold", 65 | type=float, 66 | default=None, 67 | help="Threshold to use to do some quality filtering." 68 | ) 69 | parser.add_argument( 70 | "--per_gpu_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 71 | ) 72 | # Original count is 1203497, extra 700 is for "dev". 73 | parser.add_argument("--num_train_examples", type=int, default=1202797, help="Number of samples in the training dataset.") 74 | parser.add_argument("--num_train_epochs", type=int, default=1) 75 | parser.add_argument( 76 | "--max_train_steps", 77 | type=int, 78 | default=None, 79 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 80 | ) 81 | parser.add_argument( 82 | "--checkpointing_steps", 83 | type=int, 84 | default=500, 85 | help=( 86 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 87 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 88 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 89 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 90 | "instructions." 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--checkpoints_total_limit", 95 | type=int, 96 | default=None, 97 | help=("Max number of checkpoints to store."), 98 | ) 99 | parser.add_argument( 100 | "--resume_from_checkpoint", 101 | type=str, 102 | default=None, 103 | help=( 104 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 105 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 106 | ), 107 | ) 108 | parser.add_argument( 109 | "--proportion_empty_prompts", 110 | type=float, 111 | default=0, 112 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", 113 | ) 114 | parser.add_argument( 115 | "--gradient_accumulation_steps", 116 | type=int, 117 | default=1, 118 | help="Number of updates steps to accumulate before performing a backward/update pass.", 119 | ) 120 | parser.add_argument( 121 | "--gradient_checkpointing", 122 | action="store_true", 123 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 124 | ) 125 | parser.add_argument( 126 | "--learning_rate", 127 | type=float, 128 | default=5e-6, 129 | help="Initial learning rate (after the potential warmup period) to use.", 130 | ) 131 | parser.add_argument( 132 | "--scale_lr", 133 | action="store_true", 134 | default=False, 135 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 136 | ) 137 | parser.add_argument( 138 | "--lr_scheduler", 139 | type=str, 140 | default="constant", 141 | help=( 142 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 143 | ' "constant", "constant_with_warmup"]' 144 | ), 145 | ) 146 | parser.add_argument( 147 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 148 | ) 149 | parser.add_argument( 150 | "--lr_num_cycles", 151 | type=int, 152 | default=1, 153 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 154 | ) 155 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 156 | parser.add_argument( 157 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 158 | ) 159 | 160 | parser.add_argument( 161 | "--dataloader_num_workers", 162 | type=int, 163 | default=0, 164 | help=( 165 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 166 | ), 167 | ) 168 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 169 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 170 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 171 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 172 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 173 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 174 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 175 | parser.add_argument( 176 | "--hub_model_id", 177 | type=str, 178 | default=None, 179 | help="The name of the repository to keep in sync with the local `output_dir`.", 180 | ) 181 | parser.add_argument( 182 | "--logging_dir", 183 | type=str, 184 | default="logs", 185 | help=( 186 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 187 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 188 | ), 189 | ) 190 | parser.add_argument( 191 | "--allow_tf32", 192 | action="store_true", 193 | help=( 194 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 195 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 196 | ), 197 | ) 198 | parser.add_argument( 199 | "--report_to", 200 | type=str, 201 | default="tensorboard", 202 | help=( 203 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 204 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--mixed_precision", 209 | type=str, 210 | default=None, 211 | choices=["no", "fp16", "bf16"], 212 | help=( 213 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 214 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 215 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 216 | ), 217 | ) 218 | parser.add_argument( 219 | "--dataset_path", 220 | type=str, 221 | default="pipe:aws s3 cp s3://omniedit-wds/train-{00000..00570}-of-00571.tar -", 222 | ) 223 | parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.") 224 | parser.add_argument( 225 | "--validation_prompt", 226 | type=str, 227 | default=None, 228 | nargs="+", 229 | help=( 230 | "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." 231 | " Provide either a matching number of `--validation_image`s, a single `--validation_image`" 232 | " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." 233 | ), 234 | ) 235 | parser.add_argument( 236 | "--validation_image", 237 | type=str, 238 | default=None, 239 | nargs="+", 240 | help=( 241 | "A set of paths to the control conditioning image be evaluated every `--validation_steps`" 242 | " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" 243 | " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" 244 | " `--validation_image` that will be used with all `--validation_prompt`s." 245 | ), 246 | ) 247 | parser.add_argument( 248 | "--num_validation_images", 249 | type=int, 250 | default=1, 251 | help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", 252 | ) 253 | parser.add_argument( 254 | "--validation_steps", 255 | type=int, 256 | default=100, 257 | help=( 258 | "Run validation every X steps. Validation consists of running the prompt" 259 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 260 | " and logging the images." 261 | ), 262 | ) 263 | parser.add_argument( 264 | "--tracker_project_name", 265 | type=str, 266 | default="flux_train_control_omniedit", 267 | help=( 268 | "The `project_name` argument passed to Accelerator.init_trackers for" 269 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 270 | ), 271 | ) 272 | parser.add_argument( 273 | "--jsonl_for_train", 274 | type=str, 275 | default=None, 276 | help="Path to the jsonl file containing the training data.", 277 | ) 278 | parser.add_argument( 279 | "--only_target_transformer_blocks", 280 | action="store_true", 281 | help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).", 282 | ) 283 | parser.add_argument( 284 | "--guidance_scale", 285 | type=float, 286 | default=30.0, 287 | help="the guidance scale used for transformer.", 288 | ) 289 | parser.add_argument( 290 | "--validation_guidance_scale", 291 | type=str, 292 | default=None, 293 | help="List of guidance scales to use in validation. Provide as - 10,20,30, ..." 294 | ) 295 | parser.add_argument( 296 | "--upcast_before_saving", 297 | action="store_true", 298 | help=( 299 | "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " 300 | "Defaults to precision dtype used for training to save memory" 301 | ), 302 | ) 303 | parser.add_argument("--simplified_flow", action="store_true", help="Flow simplified.") 304 | parser.add_argument( 305 | "--weighting_scheme", 306 | type=str, 307 | default="none", 308 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], 309 | help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), 310 | ) 311 | parser.add_argument( 312 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." 313 | ) 314 | parser.add_argument( 315 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." 316 | ) 317 | parser.add_argument( 318 | "--mode_scale", 319 | type=float, 320 | default=1.29, 321 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", 322 | ) 323 | parser.add_argument( 324 | "--offload", 325 | action="store_true", 326 | help="Whether to offload the VAE and the text encoders to CPU when they are not used.", 327 | ) 328 | 329 | if input_args is not None: 330 | args = parser.parse_args(input_args) 331 | else: 332 | args = parser.parse_args() 333 | 334 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: 335 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") 336 | 337 | if args.validation_prompt is not None and args.validation_image is None: 338 | raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") 339 | 340 | if args.validation_prompt is None and args.validation_image is not None: 341 | raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") 342 | 343 | if ( 344 | args.validation_image is not None 345 | and args.validation_prompt is not None 346 | and len(args.validation_image) != 1 347 | and len(args.validation_prompt) != 1 348 | and len(args.validation_image) != len(args.validation_prompt) 349 | ): 350 | raise ValueError( 351 | "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," 352 | " or the same number of `--validation_prompt`s and `--validation_image`s" 353 | ) 354 | 355 | if args.resolution % 8 != 0: 356 | raise ValueError( 357 | "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer." 358 | ) 359 | 360 | if args.validation_guidance_scale is None: 361 | args.validation_guidance_scale = args.guidance_scale 362 | else: 363 | args.validation_guidance_scale = [float(value.strip()) for value in args.validation_guidance_scale.split(",")] 364 | 365 | return args -------------------------------------------------------------------------------- /assets/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/car.jpg -------------------------------------------------------------------------------- /assets/green_creature.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/green_creature.jpg -------------------------------------------------------------------------------- /assets/mushroom.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/mushroom.jpg -------------------------------------------------------------------------------- /assets/norte_dam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/flux-image-editing/7869a545285c42eb2700787ed8fcb8f972c928a0/assets/norte_dam.jpg -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import torch 3 | import numpy as np 4 | from torchvision import transforms 5 | import math 6 | 7 | # Taken from https://github.com/tmbdev-archive/webdataset-imagenet-2/blob/01a4ab54307b9156c527d45b6b171f88623d2dec/imagenet.py#L65. 8 | def nodesplitter(src, group=None): 9 | if torch.distributed.is_initialized(): 10 | if group is None: 11 | group = torch.distributed.group.WORLD 12 | rank = torch.distributed.get_rank(group=group) 13 | size = torch.distributed.get_world_size(group=group) 14 | count = 0 15 | for i, item in enumerate(src): 16 | if i % size == rank: 17 | yield item 18 | count += 1 19 | else: 20 | yield from src 21 | 22 | def collate_fn(samples): 23 | source_pixel_values = torch.stack([example["source_pixel_values"] for example in samples]) 24 | source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float() 25 | edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in samples]) 26 | edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() 27 | captions = [example["prompt"] for example in samples] 28 | return {"source_pixel_values": source_pixel_values, "edited_pixel_values": edited_pixel_values, "captions": captions} 29 | 30 | 31 | class QualityFilter: 32 | def __init__(self, threshold): 33 | self.threshold = threshold 34 | def __call__(self, x): 35 | avg_sc_score = (x["sc_score_1"] + x["sc_score_2"]) / 2 36 | select = avg_sc_score >= self.threshold and x["pq_score"] >= self.threshold and x["o_score"] >= self.threshold 37 | if select: 38 | return True 39 | else: 40 | return False 41 | 42 | class ControlFluxDataset: 43 | def __init__(self, args): 44 | self.args = args 45 | self.image_transforms = transforms.Compose( 46 | [ 47 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 50 | ] 51 | ) 52 | 53 | def get_dataset(self): 54 | args = self.args 55 | dataset = ( 56 | wds.WebDataset( 57 | args.dataset_path, 58 | handler=wds.warn_and_continue, 59 | nodesplitter=nodesplitter, 60 | shardshuffle=500, 61 | empty_check=False 62 | ) 63 | .shuffle(2000, handler=wds.warn_and_continue) 64 | .decode("pil", handler=wds.warn_and_continue) 65 | .rename( 66 | src_img="src_img.jpg", 67 | edited_img="edited_img.jpg", 68 | prompt_list="edited_prompt_list.json", 69 | sc_score_1="sc_score_1.txt", 70 | sc_score_2="sc_score_2.txt", 71 | pq_score="pq_score.txt", 72 | o_score="o_score.txt", 73 | handler=wds.warn_and_continue, 74 | ) 75 | ) 76 | dataset = dataset.map(self.preprocess_fn, handler=wds.warn_and_continue) 77 | dataset = dataset.select(QualityFilter(args.quality_threshold)) if args.quality_threshold else dataset 78 | return dataset 79 | 80 | def preprocess_fn(self, sample): 81 | source_pixel_values = self.image_transforms(sample["src_img"]) 82 | edited_pixel_values = self.image_transforms(sample["edited_img"]) 83 | prompt = np.random.choice(sample["prompt_list"]) if isinstance(sample["prompt_list"], list) else sample["prompt_list"] 84 | 85 | return { 86 | "source_pixel_values": source_pixel_values, 87 | "edited_pixel_values": edited_pixel_values, 88 | "prompt": prompt, 89 | "sc_score_1": float(sample["sc_score_1"]), 90 | "sc_score_2": float(sample["sc_score_2"]), 91 | "pq_score": float(sample["pq_score"]), 92 | "o_score": float(sample["o_score"]) 93 | } 94 | 95 | def prepare_dataloader(self, dataset): 96 | args = self.args 97 | # per dataloader worker 98 | num_worker_batches = math.ceil(args.num_train_examples / (args.global_batch_size * args.dataloader_num_workers)) 99 | dataset = dataset.batched( 100 | args.per_gpu_batch_size, partial=False, collation_fn=collate_fn 101 | ).with_epoch(num_worker_batches) 102 | dataloader = wds.WebLoader( 103 | dataset, 104 | batch_size=None, 105 | shuffle=False, 106 | num_workers=args.dataloader_num_workers, 107 | pin_memory=True, 108 | persistent_workers=True, 109 | ) 110 | return dataloader 111 | 112 | 113 | if __name__ == "__main__": 114 | from argparse import Namespace 115 | 116 | args = Namespace( 117 | dataset_path="pipe:aws s3 cp s3://omniedit-wds/train-{00000..00570}-of-00571.tar -", 118 | num_train_examples=1203497, 119 | per_gpu_batch_size=8, 120 | global_batch_size=64, 121 | num_workers=4, 122 | resolution=256, 123 | ) 124 | dataset_obj = ControlFluxDataset(args) 125 | dataset = dataset_obj.get_dataset() 126 | 127 | sample_count = 0 128 | for sample in dataset: 129 | print(sample.keys()) 130 | print(sample["prompt"]) 131 | sample_count += 1 132 | 133 | with open("dataset_actual_count.txt", "w") as f: 134 | f.write(str(sample_count)) 135 | 136 | dataloader = dataset_obj.prepare_dataloader(dataset) 137 | for batch in dataloader: 138 | print(batch.keys()) 139 | print(batch["pixel_values"].shape) 140 | break 141 | -------------------------------------------------------------------------------- /misc/accelerate_ds2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | gpu_ids: all 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /misc/convert_to_wds.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make sure to change `path` as needed. 3 | Install `smart_open`, `ray` before running the script. 4 | if you're serializing to an S3 bucket, make sure you're authenticated. 5 | """ 6 | 7 | from datasets import Dataset 8 | import json 9 | import webdataset as wds 10 | from smart_open import smart_open 11 | import os 12 | import ray 13 | import glob 14 | 15 | ray.init(num_cpus=16) 16 | 17 | 18 | if __name__ == "__main__": 19 | path = "/fsx/sayak/.cache/datasets--TIGER-Lab--OmniEdit-Filtered-1.2M/snapshots/82455c6cd66db7f0e5bfce8d7a236441af59d6df/data/" 20 | all_parquets = sorted(glob.glob(f"{path}/train-*.parquet")) 21 | 22 | @ray.remote 23 | def convert_to_wds(parquet_path): 24 | dataset = Dataset.from_parquet(parquet_path, split="train", cache_dir=path) 25 | shard_path = os.path.basename(parquet_path).replace(".parquet", ".tar") 26 | shard_path = os.path.join("s3://omniedit-wds", shard_path) 27 | 28 | with smart_open(shard_path, "wb") as s3_file: 29 | with wds.TarWriter(s3_file) as shard_writer: 30 | for i, example in enumerate(dataset): 31 | json_data = json.dumps(example["edited_prompt_list"]).encode("utf-8") 32 | src_img = example["src_img"].convert("RGB") 33 | edited_img = example["edited_img"].convert("RGB") 34 | 35 | wds_example = { 36 | "__key__": str(i), 37 | "omni_edit_id.txt": example["omni_edit_id"], 38 | "task.txt": example["task"], 39 | "src_img.jpg": src_img, 40 | "edited_img.jpg": edited_img, 41 | "edited_prompt_list.json": json_data, 42 | "sc_reasoning.txt": example["sc_reasoning"], 43 | "pq_reasoning.txt": example["pq_reasoning"], 44 | "height.txt": str(example["height"]), 45 | "width.txt": str(example["width"]), 46 | "sc_score_1.txt": str(example["sc_score_1"]), 47 | "sc_score_2.txt": str(example["sc_score_2"]), 48 | "pq_score.txt": str(example["pq_score"]), 49 | "o_score.txt": str(example["o_score"]), 50 | } 51 | shard_writer.write(wds_example) 52 | 53 | return shard_path 54 | 55 | futures = [convert_to_wds.remote(path) for path in all_parquets] 56 | ray.get(futures) 57 | -------------------------------------------------------------------------------- /misc/launch_edit.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=control-flux-omniedit 3 | #SBATCH --nodes=1 4 | # Set the QoS 5 | #SBATCH --qos=normal 6 | # set 48h for job wall time limit 7 | #SBATCH --time=2-00:00:00 8 | # activate the requeue option 9 | #SBATCH --requeue 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --cpus-per-task=96 12 | #SBATCH --gres=gpu:8 13 | #SBATCH --mem=1999G 14 | #SBATCH --partition=YOUR_PARTION_NAME 15 | #SBATCH -o /PATH/%x-%j-train.out 16 | 17 | set -xe 18 | 19 | export NCCL_ASYNC_ERROR_HANDLING=1 20 | export LR=1e-4 21 | export WEIGHT_DECAY=1e-4 22 | export GUIDANCE_SCALE=30.0 23 | export CAPTION_DROPOUT=0.0 24 | export LR_SCHEDULER="constant" 25 | 26 | srun --wait=60 --kill-on-bad-exit=1 accelerate launch --config_file=./misc/accelerate_ds2.yaml train.py \ 27 | --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ 28 | --output_dir="omniflux-lr_${LR}-wd_${WEIGHT_DECAY}-gs_${GUIDANCE_SCALE}-cd_${CAPTION_DROPOUT}-scheduler_${LR_SCHEDULER}-sim_flow-no8bitadam" \ 29 | --mixed_precision="bf16" \ 30 | --per_gpu_batch_size=4 \ 31 | --dataloader_num_workers=4 \ 32 | --gradient_accumulation_steps=4 \ 33 | --quality_threshold=10.0 \ 34 | --simplified_flow \ 35 | --gradient_checkpointing \ 36 | --proportion_empty_prompts=$CAPTION_DROPOUT \ 37 | --learning_rate=$LR \ 38 | --adam_weight_decay=$WEIGHT_DECAY \ 39 | --guidance_scale=$GUIDANCE_SCALE \ 40 | --validation_guidance_scale="10.,20.,30.,40." \ 41 | --report_to="wandb" --log_dataset_samples \ 42 | --lr_scheduler=$LR_SCHEDULER \ 43 | --lr_warmup_steps=0 \ 44 | --checkpointing_steps=4000 \ 45 | --resume_from_checkpoint="latest" --checkpoints_total_limit=2 \ 46 | --max_train_steps=20000 \ 47 | --validation_steps=400 \ 48 | --validation_image "car.jpg" "green_creature.jpg" "norte_dam.jpg" "mushroom.jpg" \ 49 | --validation_prompt "Give this the look of a traditional Japanese woodblock print." "transform the setting to a winter scene" "Change it to look like it's in the style of an impasto painting." "turn the color of mushroom to gray" \ 50 | --seed="0" \ 51 | --push_to_hub 52 | 53 | echo "END TIME: $(date)" 54 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.5.1 2 | torchvision 3 | transformers 4 | deepspeed 5 | accelerate 6 | webdataset 7 | git+https://github.com/huggingface/diffusers.git -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import copy 18 | import logging 19 | import math 20 | import os 21 | import random 22 | import shutil 23 | from pathlib import Path 24 | 25 | import accelerate 26 | import torch 27 | import transformers 28 | from accelerate import Accelerator 29 | from accelerate.state import AcceleratorState 30 | from accelerate.logging import get_logger 31 | from accelerate.utils import DistributedType, ProjectConfiguration, set_seed 32 | from huggingface_hub import create_repo, upload_folder 33 | from packaging import version 34 | from tqdm.auto import tqdm 35 | 36 | import diffusers 37 | from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel 38 | from diffusers.optimization import get_scheduler 39 | from diffusers.training_utils import ( 40 | compute_density_for_timestep_sampling, 41 | compute_loss_weighting_for_sd3, 42 | free_memory, 43 | ) 44 | from diffusers.utils import check_min_version, is_wandb_available 45 | from diffusers.utils.torch_utils import is_compiled_module 46 | from dataset import ControlFluxDataset 47 | from args import parse_args 48 | from utils import encode_images, log_validation, save_model_card, expand_tensor_to_dim, log_data_samples_to_wandb 49 | 50 | 51 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 52 | check_min_version("0.33.0.dev0") 53 | 54 | logger = get_logger(__name__) 55 | 56 | 57 | def main(args): 58 | if args.report_to == "wandb" and args.hub_token is not None: 59 | raise ValueError( 60 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 61 | " Please use `huggingface-cli login` to authenticate with the Hub." 62 | ) 63 | 64 | logging_out_dir = Path(args.output_dir, args.logging_dir) 65 | 66 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16": 67 | # due to pytorch#99272, MPS does not yet support bfloat16. 68 | raise ValueError( 69 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 70 | ) 71 | 72 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir)) 73 | 74 | accelerator = Accelerator( 75 | gradient_accumulation_steps=args.gradient_accumulation_steps, 76 | mixed_precision=args.mixed_precision, 77 | log_with=args.report_to, 78 | project_config=accelerator_project_config, 79 | ) 80 | if accelerator.distributed_type == DistributedType.DEEPSPEED: 81 | AcceleratorState().deepspeed_plugin.deepspeed_config[ 82 | "train_micro_batch_size_per_gpu" 83 | ] = args.per_gpu_batch_size 84 | 85 | # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices. 86 | if torch.backends.mps.is_available(): 87 | logger.info("MPS is enabled. Disabling AMP.") 88 | accelerator.native_amp = False 89 | 90 | # Make one log on every process with the configuration for debugging. 91 | logging.basicConfig( 92 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 93 | datefmt="%m/%d/%Y %H:%M:%S", 94 | # DEBUG, INFO, WARNING, ERROR, CRITICAL 95 | level=logging.INFO, 96 | ) 97 | logger.info(accelerator.state, main_process_only=False) 98 | 99 | if accelerator.is_local_main_process: 100 | transformers.utils.logging.set_verbosity_warning() 101 | diffusers.utils.logging.set_verbosity_info() 102 | else: 103 | transformers.utils.logging.set_verbosity_error() 104 | diffusers.utils.logging.set_verbosity_error() 105 | 106 | # If passed along, set the training seed now. 107 | if args.seed is not None: 108 | set_seed(args.seed) 109 | 110 | # Handle the repository creation 111 | if accelerator.is_main_process: 112 | if args.output_dir is not None: 113 | os.makedirs(args.output_dir, exist_ok=True) 114 | 115 | if args.push_to_hub: 116 | repo_id = create_repo( 117 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 118 | ).repo_id 119 | 120 | # Load models. We will load the text encoders later in a pipeline to compute 121 | # embeddings. 122 | vae = AutoencoderKL.from_pretrained( 123 | args.pretrained_model_name_or_path, 124 | subfolder="vae", 125 | revision=args.revision, 126 | variant=args.variant, 127 | ) 128 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 129 | flux_transformer = FluxTransformer2DModel.from_pretrained( 130 | args.pretrained_model_name_or_path, 131 | subfolder="transformer", 132 | revision=args.revision, 133 | variant=args.variant, 134 | ) 135 | logger.info("All models loaded successfully") 136 | 137 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 138 | args.pretrained_model_name_or_path, 139 | subfolder="scheduler", 140 | ) 141 | noise_scheduler_copy = copy.deepcopy(noise_scheduler) 142 | if not args.only_target_transformer_blocks: 143 | flux_transformer.requires_grad_(True) 144 | vae.requires_grad_(False) 145 | 146 | # cast down and move to the CPU 147 | weight_dtype = torch.float32 148 | if accelerator.mixed_precision == "fp16": 149 | weight_dtype = torch.float16 150 | elif accelerator.mixed_precision == "bf16": 151 | weight_dtype = torch.bfloat16 152 | 153 | # let's not move the VAE to the GPU yet. 154 | vae.to(dtype=torch.float32) # keep the VAE in float32. 155 | if not args.offload: 156 | vae.to(accelerator.device) 157 | 158 | # enable image inputs 159 | with torch.no_grad(): 160 | initial_input_channels = flux_transformer.config.in_channels 161 | new_linear = torch.nn.Linear( 162 | flux_transformer.x_embedder.in_features * 2, 163 | flux_transformer.x_embedder.out_features, 164 | bias=flux_transformer.x_embedder.bias is not None, 165 | dtype=flux_transformer.dtype, 166 | device=flux_transformer.device, 167 | ) 168 | new_linear.weight.zero_() 169 | new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight) 170 | if flux_transformer.x_embedder.bias is not None: 171 | new_linear.bias.copy_(flux_transformer.x_embedder.bias) 172 | flux_transformer.x_embedder = new_linear 173 | 174 | assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) 175 | flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) 176 | 177 | if args.only_target_transformer_blocks: 178 | flux_transformer.x_embedder.requires_grad_(True) 179 | for name, module in flux_transformer.named_modules(): 180 | if "transformer_blocks" in name: 181 | module.requires_grad_(True) 182 | 183 | def unwrap_model(model): 184 | model = accelerator.unwrap_model(model) 185 | model = model._orig_mod if is_compiled_module(model) else model 186 | return model 187 | 188 | # `accelerate` 0.16.0 will have better support for customized saving 189 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 190 | 191 | def save_model_hook(models, weights, output_dir): 192 | if accelerator.is_main_process: 193 | for model in models: 194 | if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))): 195 | model = unwrap_model(model) 196 | model.save_pretrained(os.path.join(output_dir, "transformer")) 197 | else: 198 | raise ValueError(f"unexpected save model: {model.__class__}") 199 | 200 | # make sure to pop weight so that corresponding model is not saved again 201 | if weights: 202 | weights.pop() 203 | 204 | def load_model_hook(models, input_dir): 205 | transformer_ = None 206 | 207 | if not accelerator.distributed_type == DistributedType.DEEPSPEED: 208 | while len(models) > 0: 209 | model = models.pop() 210 | 211 | if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))): 212 | transformer_ = model # noqa: F841 213 | else: 214 | raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}") 215 | 216 | else: 217 | transformer_ = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841 218 | 219 | accelerator.register_save_state_pre_hook(save_model_hook) 220 | accelerator.register_load_state_pre_hook(load_model_hook) 221 | 222 | if args.gradient_checkpointing: 223 | flux_transformer.enable_gradient_checkpointing() 224 | 225 | # Enable TF32 for faster training on Ampere GPUs, 226 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 227 | if args.allow_tf32: 228 | torch.backends.cuda.matmul.allow_tf32 = True 229 | 230 | if args.scale_lr: 231 | args.learning_rate = ( 232 | args.learning_rate 233 | * args.gradient_accumulation_steps 234 | * args.per_gpu_batch_size 235 | * accelerator.num_processes 236 | ) 237 | 238 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 239 | if args.use_8bit_adam: 240 | try: 241 | import bitsandbytes as bnb 242 | except ImportError: 243 | raise ImportError( 244 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 245 | ) 246 | 247 | optimizer_class = bnb.optim.AdamW8bit 248 | else: 249 | optimizer_class = torch.optim.AdamW 250 | 251 | # Optimization parameters 252 | optimizer = optimizer_class( 253 | flux_transformer.parameters(), 254 | lr=args.learning_rate, 255 | betas=(args.adam_beta1, args.adam_beta2), 256 | weight_decay=args.adam_weight_decay, 257 | eps=args.adam_epsilon, 258 | ) 259 | 260 | # Prepare dataset and dataloader. 261 | args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes 262 | num_batches = math.ceil(args.num_train_examples / args.global_batch_size) 263 | dataset_obj = ControlFluxDataset(args) 264 | train_dataset = dataset_obj.get_dataset() 265 | train_dataloader = dataset_obj.prepare_dataloader(train_dataset) 266 | 267 | # Scheduler and math around the number of training steps. 268 | # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. 269 | if args.max_train_steps is None: 270 | len_train_dataloader_after_sharding = math.ceil(args.num_train_examples / accelerator.num_processes) 271 | num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) 272 | num_training_steps_for_scheduler = ( 273 | args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes 274 | ) 275 | else: 276 | num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes 277 | 278 | lr_scheduler = get_scheduler( 279 | args.lr_scheduler, 280 | optimizer=optimizer, 281 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 282 | num_training_steps=args.max_train_steps * accelerator.num_processes, 283 | num_cycles=args.lr_num_cycles, 284 | power=args.lr_power, 285 | ) 286 | # Prepare everything with our `accelerator`. 287 | flux_transformer, optimizer, lr_scheduler = accelerator.prepare( 288 | flux_transformer, optimizer, lr_scheduler 289 | ) 290 | 291 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 292 | num_update_steps_per_epoch = math.ceil(num_batches / args.gradient_accumulation_steps) 293 | if args.max_train_steps is None: 294 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 295 | if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: 296 | logger.warning( 297 | f"The length of the 'train_dataloader' ({args.num_train_examples}) does not match " 298 | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " 299 | f"This inconsistency may result in the learning rate scheduler not functioning properly." 300 | ) 301 | # Afterwards we recalculate our number of training epochs 302 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 303 | 304 | # We need to initialize the trackers we use, and also store our configuration. 305 | # The trackers initializes automatically on the main process. 306 | if accelerator.is_main_process: 307 | tracker_config = dict(vars(args)) 308 | 309 | # tensorboard cannot handle list types for config 310 | tracker_config.pop("validation_prompt") 311 | tracker_config.pop("validation_image") 312 | 313 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 314 | 315 | # Train! 316 | total_batch_size = args.per_gpu_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 317 | 318 | logger.info("***** Running training *****") 319 | logger.info(f" Num examples = {args.num_train_examples}") 320 | logger.info(f" Num Epochs = {args.num_train_epochs}") 321 | logger.info(f" Instantaneous batch size per device = {args.per_gpu_batch_size}") 322 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 323 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 324 | logger.info(f" Total optimization steps = {args.max_train_steps}") 325 | global_step = 0 326 | first_epoch = 0 327 | 328 | # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed. 329 | text_encoding_pipeline = FluxControlPipeline.from_pretrained( 330 | args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype 331 | ) 332 | if not args.offload: 333 | text_encoding_pipeline = text_encoding_pipeline.to("cuda") 334 | 335 | # Potentially load in the weights and states from a previous save 336 | if args.resume_from_checkpoint: 337 | if args.resume_from_checkpoint != "latest": 338 | path = os.path.basename(args.resume_from_checkpoint) 339 | else: 340 | # Get the most recent checkpoint 341 | dirs = os.listdir(args.output_dir) 342 | dirs = [d for d in dirs if d.startswith("checkpoint")] 343 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 344 | path = dirs[-1] if len(dirs) > 0 else None 345 | 346 | if path is None: 347 | logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") 348 | args.resume_from_checkpoint = None 349 | initial_global_step = 0 350 | else: 351 | logger.info(f"Resuming from checkpoint {path}") 352 | accelerator.load_state(os.path.join(args.output_dir, path)) 353 | global_step = int(path.split("-")[1]) 354 | 355 | initial_global_step = global_step 356 | first_epoch = global_step // num_update_steps_per_epoch 357 | else: 358 | initial_global_step = 0 359 | 360 | if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples: 361 | log_data_samples_to_wandb(dataloader=train_dataloader, accelerator=accelerator) 362 | 363 | progress_bar = tqdm( 364 | range(0, args.max_train_steps), 365 | initial=initial_global_step, 366 | desc="Steps", 367 | # Only show the progress bar once on each machine. 368 | disable=not accelerator.is_local_main_process, 369 | ) 370 | 371 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): 372 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) 373 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) 374 | timesteps = timesteps.to(accelerator.device) 375 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 376 | 377 | sigma = sigmas[step_indices] 378 | return expand_tensor_to_dim(sigma, n_dim) 379 | 380 | image_logs = None 381 | for epoch in range(first_epoch, args.num_train_epochs): 382 | flux_transformer.train() 383 | for step, batch in enumerate(train_dataloader): 384 | with accelerator.accumulate(flux_transformer): 385 | # Convert images to latent space 386 | # vae encode 387 | if args.offload: 388 | vae.to(accelerator.device) 389 | edited_pixel_values = batch["edited_pixel_values"].to(dtype=vae.dtype, device=accelerator.device, non_blocking=True) 390 | pixel_latents = encode_images(edited_pixel_values, vae, weight_dtype) 391 | conditioning_pixel_values = batch["source_pixel_values"].to( 392 | dtype=vae.dtype, device=accelerator.device, non_blocking=True 393 | ) 394 | control_latents = encode_images(conditioning_pixel_values, vae, weight_dtype) 395 | if args.offload: 396 | # offload vae to CPU. 397 | vae.cpu() 398 | 399 | # Sample a random timestep for each image 400 | # for weighting schemes where we sample timesteps non-uniformly 401 | bsz = pixel_latents.shape[0] 402 | noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype) 403 | if args.simplified_flow: 404 | sigmas = torch.rand(bsz, device=accelerator.device, dtype=pixel_latents.dtype) 405 | timesteps = (sigmas * noise_scheduler_copy.config.num_train_timesteps).long().to(accelerator.device) 406 | sigmas = expand_tensor_to_dim(sigmas, pixel_latents.ndim) 407 | else: 408 | u = compute_density_for_timestep_sampling( 409 | weighting_scheme=args.weighting_scheme, 410 | batch_size=bsz, 411 | logit_mean=args.logit_mean, 412 | logit_std=args.logit_std, 413 | mode_scale=args.mode_scale, 414 | ) 415 | indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() 416 | timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) 417 | sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype) 418 | 419 | # Add noise according to flow matching. 420 | noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise 421 | concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1) 422 | 423 | # pack the latents. 424 | packed_noisy_model_input = FluxControlPipeline._pack_latents( 425 | concatenated_noisy_model_input, 426 | batch_size=bsz, 427 | num_channels_latents=concatenated_noisy_model_input.shape[1], 428 | height=concatenated_noisy_model_input.shape[2], 429 | width=concatenated_noisy_model_input.shape[3], 430 | ) 431 | 432 | # latent image ids for RoPE. 433 | latent_image_ids = FluxControlPipeline._prepare_latent_image_ids( 434 | bsz, 435 | concatenated_noisy_model_input.shape[2] // 2, 436 | concatenated_noisy_model_input.shape[3] // 2, 437 | accelerator.device, 438 | weight_dtype, 439 | ) 440 | 441 | # handle guidance 442 | if unwrap_model(flux_transformer).config.guidance_embeds: 443 | guidance_vec = torch.full( 444 | (bsz,), 445 | args.guidance_scale, 446 | device=noisy_model_input.device, 447 | dtype=weight_dtype, 448 | ) 449 | else: 450 | guidance_vec = None 451 | 452 | # text encoding. 453 | captions = batch["captions"] 454 | if args.offload: 455 | text_encoding_pipeline = text_encoding_pipeline.to("cuda") 456 | with torch.no_grad(): 457 | prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( 458 | captions, prompt_2=None 459 | ) 460 | # this could be optimized by not having to do any text encoding and just 461 | # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` 462 | if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: 463 | prompt_embeds.zero_() 464 | pooled_prompt_embeds.zero_() 465 | if args.offload: 466 | text_encoding_pipeline = text_encoding_pipeline.to("cpu") 467 | 468 | # Predict. 469 | model_pred = flux_transformer( 470 | hidden_states=packed_noisy_model_input, 471 | timestep=timesteps / 1000, 472 | guidance=guidance_vec, 473 | pooled_projections=pooled_prompt_embeds, 474 | encoder_hidden_states=prompt_embeds, 475 | txt_ids=text_ids, 476 | img_ids=latent_image_ids, 477 | return_dict=False, 478 | )[0] 479 | model_pred = FluxControlPipeline._unpack_latents( 480 | model_pred, 481 | height=noisy_model_input.shape[2] * vae_scale_factor, 482 | width=noisy_model_input.shape[3] * vae_scale_factor, 483 | vae_scale_factor=vae_scale_factor, 484 | ) 485 | # these weighting schemes use a uniform timestep sampling 486 | # and instead post-weight the loss 487 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) 488 | 489 | # flow-matching loss 490 | target = noise - pixel_latents 491 | loss = torch.mean( 492 | (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 493 | 1, 494 | ) 495 | loss = loss.mean() 496 | accelerator.backward(loss) 497 | 498 | if accelerator.sync_gradients: 499 | params_to_clip = flux_transformer.parameters() 500 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 501 | 502 | optimizer.step() 503 | lr_scheduler.step() 504 | optimizer.zero_grad() 505 | 506 | # Checks if the accelerator has performed an optimization step behind the scenes 507 | if accelerator.sync_gradients: 508 | progress_bar.update(1) 509 | global_step += 1 510 | 511 | # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. 512 | if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: 513 | if global_step % args.checkpointing_steps == 0: 514 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 515 | if args.checkpoints_total_limit is not None: 516 | checkpoints = os.listdir(args.output_dir) 517 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 518 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 519 | 520 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 521 | if len(checkpoints) >= args.checkpoints_total_limit: 522 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 523 | removing_checkpoints = checkpoints[0:num_to_remove] 524 | 525 | logger.info( 526 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 527 | ) 528 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 529 | 530 | for removing_checkpoint in removing_checkpoints: 531 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 532 | shutil.rmtree(removing_checkpoint) 533 | 534 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 535 | accelerator.save_state(save_path) 536 | logger.info(f"Saved state to {save_path}") 537 | 538 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 539 | image_logs = log_validation( 540 | flux_transformer=flux_transformer, 541 | args=args, 542 | accelerator=accelerator, 543 | weight_dtype=weight_dtype, 544 | step=global_step, 545 | ) 546 | 547 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 548 | progress_bar.set_postfix(**logs) 549 | accelerator.log(logs, step=global_step) 550 | 551 | if global_step >= args.max_train_steps: 552 | break 553 | 554 | # Create the pipeline using using the trained modules and save it. 555 | accelerator.wait_for_everyone() 556 | if accelerator.is_main_process: 557 | flux_transformer = unwrap_model(flux_transformer) 558 | if args.upcast_before_saving: 559 | flux_transformer.to(torch.float32) 560 | flux_transformer.save_pretrained(args.output_dir) 561 | 562 | del flux_transformer 563 | del text_encoding_pipeline 564 | del vae 565 | free_memory() 566 | 567 | # Run a final round of validation. 568 | image_logs = None 569 | if args.validation_prompt is not None: 570 | image_logs = log_validation( 571 | flux_transformer=None, 572 | args=args, 573 | accelerator=accelerator, 574 | weight_dtype=weight_dtype, 575 | step=global_step, 576 | is_final_validation=True, 577 | ) 578 | 579 | if args.push_to_hub: 580 | save_model_card( 581 | repo_id, 582 | image_logs=image_logs, 583 | base_model=args.pretrained_model_name_or_path, 584 | repo_folder=args.output_dir, 585 | ) 586 | upload_folder( 587 | repo_id=repo_id, 588 | folder_path=args.output_dir, 589 | commit_message="End of training", 590 | ignore_patterns=["step_*", "epoch_*", "checkpoint-*"], 591 | ) 592 | 593 | accelerator.end_training() 594 | 595 | 596 | if __name__ == "__main__": 597 | args = parse_args() 598 | main(args) 599 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import os 19 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 20 | from contextlib import nullcontext 21 | from diffusers import FluxControlPipeline, FluxTransformer2DModel 22 | from diffusers.utils import load_image, make_image_grid, is_wandb_available 23 | from diffusers.training_utils import free_memory 24 | from accelerate.logging import get_logger 25 | 26 | 27 | logger = get_logger(__name__) 28 | 29 | 30 | if is_wandb_available(): 31 | import wandb 32 | 33 | def expand_tensor_to_dim(tensor, ndim): 34 | tensor = tensor.flatten() 35 | while len(tensor.shape) < ndim: 36 | tensor = tensor.unsqueeze(-1) 37 | return tensor 38 | 39 | def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype): 40 | pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample() 41 | pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor 42 | return pixel_latents.to(weight_dtype) 43 | 44 | 45 | def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False): 46 | logger.info("Running validation... ") 47 | 48 | if not is_final_validation: 49 | flux_transformer = accelerator.unwrap_model(flux_transformer) 50 | pipeline = FluxControlPipeline.from_pretrained( 51 | args.pretrained_model_name_or_path, 52 | transformer=flux_transformer, 53 | torch_dtype=weight_dtype, 54 | ) 55 | else: 56 | transformer = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) 57 | pipeline = FluxControlPipeline.from_pretrained( 58 | args.pretrained_model_name_or_path, 59 | transformer=transformer, 60 | torch_dtype=weight_dtype, 61 | ) 62 | 63 | pipeline.to(accelerator.device) 64 | pipeline.set_progress_bar_config(disable=True) 65 | 66 | if args.seed is None: 67 | generator = None 68 | else: 69 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 70 | 71 | if len(args.validation_image) == len(args.validation_prompt): 72 | validation_images = args.validation_image 73 | validation_prompts = args.validation_prompt 74 | elif len(args.validation_image) == 1: 75 | validation_images = args.validation_image * len(args.validation_prompt) 76 | validation_prompts = args.validation_prompt 77 | elif len(args.validation_prompt) == 1: 78 | validation_images = args.validation_image 79 | validation_prompts = args.validation_prompt * len(args.validation_image) 80 | else: 81 | raise ValueError( 82 | "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" 83 | ) 84 | 85 | image_logs = [] 86 | if is_final_validation or torch.backends.mps.is_available(): 87 | autocast_ctx = nullcontext() 88 | else: 89 | autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype) 90 | 91 | vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) 92 | for validation_prompt, validation_image in zip(validation_prompts, validation_images): 93 | validation_image = load_image(validation_image) 94 | width, height = validation_image.size 95 | if height % (vae_scale_factor * 2) != 0 or width % (vae_scale_factor * 2) != 0: 96 | height, width = (1024, 1024) 97 | 98 | for gs in args.validation_guidance_scale: 99 | images = [] 100 | for _ in range(args.num_validation_images): 101 | with autocast_ctx: 102 | image = pipeline( 103 | prompt=validation_prompt, 104 | control_image=validation_image, 105 | num_inference_steps=50, 106 | guidance_scale=gs, 107 | generator=generator, 108 | max_sequence_length=512, 109 | height=height, 110 | width=width, 111 | ).images[0] 112 | images.append(image) 113 | image_logs.append( 114 | { 115 | "validation_image": validation_image, 116 | "images": images, 117 | "validation_prompt": validation_prompt, 118 | "guidance_scale": gs 119 | } 120 | ) 121 | 122 | tracker_key = "test" if is_final_validation else "validation" 123 | for tracker in accelerator.trackers: 124 | if tracker.name == "wandb": 125 | formatted_images = [] 126 | for log in image_logs: 127 | images = log["images"] 128 | guidance_scale = log["guidance_scale"] 129 | validation_prompt = log["validation_prompt"] 130 | validation_image = log["validation_image"] 131 | formatted_images.append(wandb.Image(validation_image, caption=f"Conditioning (gs: {guidance_scale})")) 132 | for image in images: 133 | image = wandb.Image(image, caption=validation_prompt) 134 | formatted_images.append(image) 135 | 136 | tracker.log({tracker_key: formatted_images}) 137 | else: 138 | logger.warning(f"image logging not implemented for {tracker.name}") 139 | 140 | del pipeline 141 | free_memory() 142 | return image_logs 143 | 144 | 145 | def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): 146 | img_str = "" 147 | if image_logs is not None: 148 | img_str = "You can find some example images below.\n\n" 149 | for i, log in enumerate(image_logs): 150 | images = log["images"] 151 | validation_prompt = log["validation_prompt"] 152 | validation_image = log["validation_image"] 153 | validation_image.save(os.path.join(repo_folder, "image_control.png")) 154 | img_str += f"prompt: {validation_prompt}\n" 155 | images = [validation_image] + images 156 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) 157 | img_str += f"![images_{i})](./images_{i}.png)\n" 158 | 159 | model_description = f""" 160 | # flux-control-{repo_id} 161 | 162 | These are Control weights trained on {base_model} and [TIGER-Lab/OmniEdit-Filtered-1.2M](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M). 163 | {img_str} 164 | 165 | ## License 166 | 167 | Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) 168 | """ 169 | 170 | model_card = load_or_create_model_card( 171 | repo_id_or_path=repo_id, 172 | from_training=True, 173 | license="other", 174 | base_model=base_model, 175 | model_description=model_description, 176 | inference=True, 177 | ) 178 | 179 | tags = [ 180 | "flux", 181 | "flux-diffusers", 182 | "text-to-image", 183 | "diffusers", 184 | "control", 185 | "diffusers-training", 186 | ] 187 | model_card = populate_model_card(model_card, tags=tags) 188 | 189 | model_card.save(os.path.join(repo_folder, "README.md")) 190 | 191 | 192 | def log_data_samples_to_wandb(dataloader, accelerator): 193 | logger.info("Logging some dataset samples.") 194 | formatted_src_images = [] 195 | formatted_edited_images = [] 196 | all_prompts = [] 197 | for i, batch in enumerate(dataloader): 198 | source_images = (batch["source_pixel_values"] + 1) / 2 199 | edited_images = (batch["edited_pixel_values"] + 1) / 2 200 | prompts = batch["captions"] 201 | 202 | if len(formatted_src_images) > 10: 203 | break 204 | 205 | for img, edited_img, prompt in zip(source_images, edited_images, prompts): 206 | formatted_src_images.append(img) 207 | formatted_edited_images.append(edited_img) 208 | all_prompts.append(prompt) 209 | 210 | logged_artifacts = [] 211 | for img, edited_img, prompt in zip(formatted_src_images, formatted_edited_images, all_prompts): 212 | logged_artifacts.append(wandb.Image(img, caption="Conditioning")) 213 | logged_artifacts.append(wandb.Image(edited_img, caption=prompt)) 214 | 215 | wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"] 216 | assert wandb_tracker, "wandb couldn't be found in the trackers." 217 | wandb_tracker[0].log({"dataset_samples": logged_artifacts}) --------------------------------------------------------------------------------