├── .gitignore ├── LICENSE ├── README.md ├── assets ├── boy.png ├── girl.png ├── hybrid.png └── kaori.jpg ├── docs └── images │ ├── comfyuiexample.png │ ├── gpt4o_comparison.png │ ├── gradio.png │ ├── lora_scale.png │ ├── official_workflow.png │ ├── teaser.png │ ├── windows_install.png │ ├── workflow.png │ ├── workflow_t8.png │ └── workflow_tutorial.png ├── requirements.txt ├── scripts ├── config.json ├── gradio_demo.py └── inference.py └── train ├── README.md ├── assets ├── book.jpg ├── clock.jpg ├── coffee.png ├── monalisa.jpg ├── oranges.jpg ├── penguin.jpg ├── room_corner.jpg └── vase.jpg ├── parquet └── prepare.sh ├── requirements.txt ├── runs ├── 20250513-085800 │ └── config.yaml ├── 20250513-090048 │ └── config.yaml └── 20250513-090235 │ └── config.yaml ├── src ├── flux │ ├── block.py │ ├── condition.py │ ├── generate.py │ ├── lora_controller.py │ ├── pipeline_tools.py │ └── transformer.py └── train │ ├── callbacks.py │ ├── data.py │ ├── model.py │ └── train.py └── train ├── config └── normal_lora.yaml └── script └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | gradio_results/ 3 | normal/ 4 | scripts/inference_50seed.py 5 | success.txt 6 | icedit/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | FLUX.1 [dev] Non-Commercial License 2 | Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models, including FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA and FLUX.1 Depth [dev] LoRA, and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). 3 | By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity. 4 | 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings: 5 | a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License. 6 | b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be. 7 | c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose. 8 | d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters. 9 | e. “you” or “your” means the individual or entity entering into this License with Company. 10 | 2. License Grant. 11 | a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf. 12 | b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: info@blackforestlabs.ai. 13 | c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License. 14 | d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model. 15 | 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions: 16 | a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License; 17 | b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”): 18 | “The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc. 19 | IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.” 20 | c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and 21 | d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions. 22 | e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing. 23 | 4. Restrictions. You will not, and will not permit, assist or cause any third party to 24 | a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing; 25 | b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model; 26 | c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or 27 | d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License. 28 | e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model; 29 | f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods. 30 | 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. 31 | 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. 32 | 7. INDEMNIFICATION 33 | 34 | You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties. 35 | 8. Termination; Survival. 36 | a. This License will automatically terminate upon any breach by you of the terms of this License. 37 | b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. 38 | c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated. 39 | d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11. 40 | 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. 41 | 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators. 42 | 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company. 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

In-Context Edit: Enabling Instructional Image Editing with In-Context Generation in Large Scale Diffusion Transformer

4 | 5 |
6 | Zechuan Zhang  7 | Ji Xie  8 | Yu Lu  9 | Zongxin Yang  10 | Yi Yang✉  11 |
12 |
13 | ReLER, CCAI, Zhejiang University; Harvard University 14 |
15 |
16 | Corresponding Author 17 |
18 |
19 | Arxiv  20 | Huggingface Demo 🤗  21 | Model 🤗  22 | Project Page 23 |
24 | 25 | 26 |
27 | 28 |

Image Editing is worth a single LoRA! We present In-Context Edit, a novel approach that achieves state-of-the-art instruction-based editing using just 0.5% of the training data and 1% of the parameters required by prior SOTA methods. The first row illustrates a series of multi-turn edits, executed with high precision, while the second and third rows highlight diverse, visually impressive single-turn editing results from our method.

29 |
30 | 31 | :open_book: For more visual results, go checkout our project page 32 | 33 | 34 |
35 | 36 | 37 | ### 📢 Attention All: Incorrect ComfyUI Workflow Usage Alert — Read Now! 38 | - ### We have released our **[official ComfyUI workflow](#official-comfyui-workflow)** for proper usage! Check our repository and have a try! 39 | - You need to **add the fixed pre-prompt "A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but {instruction}"** before inputing the edit instructions, otherwise you may get bad results! (This is mentioned in the paper!, The code for the Hugging Face gradio demo already embeds this prompt. So, you can simply input the editing instructions without additional setup.) 40 | - The width of the input image must resize to **512** (no restriction to height). 41 | - Please **[use the Normal LoRA](https://huggingface.co/RiverZ/normal-lora/tree/main)** not the MoE-LoRA, because the MoE-LoRA cannot be correctly loaded with ComfyUI lora loader. 42 | - 🔥💐🎆 Welcome to share your **creative workflows** (such as combining Redux, ACE, etc.) in the Issues section and showcase the results! We will include references so that more people can see your creativity. 43 | 44 | 45 | # 🎆 News 46 | 47 | ### 👑 Feel free to share your results in this [Gallery](https://github.com/River-Zhang/ICEdit/discussions/21)! 48 | - **[2025/5/16]** 🌟 Many thanks to [gluttony-10 (十字鱼)](https://github.com/River-Zhang/ICEdit/pull/47#issue-3067039788) for adapting Gradio demo with [GGUF quantization](#inference-in-gradio-demo), further reducing memory usage to **10GB**. 49 | - **[2025/5/14]** 🔥 With the help of the [official comfy-org](https://www.comfy.org/zh-cn/), we have integrated our ComfyUI nodes into [Comfy Registry](https://registry.comfy.org/nodes/ICEdit)! 50 | - **[2025/5/13]** 🔥 We have released the [training code](./train/)! Train your own editing LoRAs now! 51 | - **[2025/5/11]** 🌟 Great thanks to [gluttony-10 (十字鱼)](https://github.com/River-Zhang/ICEdit/issues/23#issue-3050804566) for making a [windows gradio demo](#inference-in-gradio-demo-on-windows) to use our project on Windows! 52 | - **[2025/5/8]** 🔥 We have released our **[official ComfyUI workflow](#official-comfyui-workflow)**! 🚀 Check the repository and have a try! 53 | - **[2025/5/8]** 🔥 We have added LoRA scale slider in the gradio demo. You can try to discover more interesting demo with different scale! 54 |
55 | 56 |
57 | 58 | - **[2025/5/7]** 🌟 We update some notes when using the ComfyUI workflow to avoid unsatisfactory results! 59 | - **[2025/5/6]** 🔥 ICEdit currently ranks **2nd** on the overall/weekly trending list of [Hugging Face space](https://huggingface.co/spaces). Thank you all for your support and love!🤗 60 | - **[2025/5/5]** 🌟 Heartfelt thanks to [Datou](https://x.com/Datou) for creating a fantastic [ComfyUI workflow](https://openart.ai/workflows/datou/icedit-moe-lora-flux-fill/QFmaWNKsQo3P5liYz4RB) on OpenArt! 🚀 Have a try! 61 | - **[2025/5/2]** 🌟 Heartfelt thanks to [judian17](https://github.com/River-Zhang/ICEdit/issues/1#issuecomment-2846568411) for crafting an amazing [ComfyUI-nunchaku demo](https://github.com/River-Zhang/ICEdit/issues/1#issuecomment-2846568411)! Only **4GB VRAM GPU** is enough to run with ComfyUI-nunchaku!🚀 Dive in and give it a spin! 62 | - **[2025/4/30]** 🔥 We release the [Huggingface Demo](https://huggingface.co/spaces/RiverZ/ICEdit) 🤗! Have a try! 63 | - **[2025/4/30]** 🔥 We release the [paper](https://arxiv.org/abs/2504.20690) on arXiv! 64 | - **[2025/4/29]** We release the [project page](https://river-zhang.github.io/ICEdit-gh-pages/) and demo video! Codes will be made available in next week~ Happy Labor Day! 65 | 66 | # 🎈 Tutorial on Bilibili or Youtube 67 | 68 | - **[2025/5/15]** 🌟 We find that [啦啦啦的小黄瓜](https://space.bilibili.com/219572544) has made a detailed [bilibili tutorial](https://www.bilibili.com/video/BV1tSEqzJE7q/?share_source=copy_web&vd_source=8fcb933ee576af56337afc41509fa095) introducing our model! What a great video! 69 | - **[2025/5/14]** 🌟 We find that [Nenly同学](https://space.bilibili.com/1814756990) has made a fantastic [bilibili tutorial](https://www.bilibili.com/video/BV1bNEvzrEn1/?share_source=copy_web&vd_source=8fcb933ee576af56337afc41509fa095) on how to use our repository! Great thanks to him! 70 | - **[2025/5/10]** 🌟 Great thanks to [月下Hugo](https://www.bilibili.com/video/BV1JZVRzuE12/?share_source=copy_web&vd_source=8fcb933ee576af56337afc41509fa095) for making a [Chinese tutorial](https://www.bilibili.com/video/BV1JZVRzuE12/?share_source=copy_web&vd_source=8fcb933ee576af56337afc41509fa095) on how to use our official workflow! 71 | - **[2025/5/7]** 🌟 Heartfelt thanks to [T8star](https://x.com/T8star_Aix) for making a [tutorial](https://www.youtube.com/watch?v=s6GMKL-Jjos) and [ComfyUI workflow](https://www.runninghub.cn/post/1920075398585974786/?utm_source=kol01-RH099) on how to **increase the editing success to 100%**!🚀 Have a try! 72 | - **[2025/5/3]** 🌟 Heartfelt thanks to [softicelee2](https://github.com/softicelee2) for making a [Youtube video](https://youtu.be/rRMc5DE4qMo) on how to use our model! 73 | # 📖 Table of Contents 74 | 75 | - [🎆 News](#-news) 76 | - [👑 Feel free to share your results in this Gallery!](#-feel-free-to-share-your-results-in-this-gallery) 77 | - [🎈 Tutorial on Bilibili or Youtube](#-tutorial-on-bilibili-or-youtube) 78 | - [📖 Table of Contents](#-table-of-contents) 79 | - [🎨ComfyUI Workflow](#comfyui-workflow) 80 | - [Official ComfyUI-workflow](#official-comfyui-workflow) 81 | - [ComfyUI-workflow for increased editing success rate](#comfyui-workflow-for-increased-editing-success-rate) 82 | - [ComfyUI-nunchaku](#comfyui-nunchaku) 83 | - [ComfyUI-workflow](#comfyui-workflow-1) 84 | - [⚠️ Tips](#️-tips) 85 | - [If you encounter such a failure case, please **try again with a different seed**!](#if-you-encounter-such-a-failure-case-please-try-again-with-a-different-seed) 86 | - [⚠️ Clarification](#️-clarification) 87 | - [💼 Installation](#-installation) 88 | - [Conda environment setup](#conda-environment-setup) 89 | - [Download pretrained weights](#download-pretrained-weights) 90 | - [Inference in bash (w/o VLM Inference-time Scaling)](#inference-in-bash-wo-vlm-inference-time-scaling) 91 | - [Inference in Gradio Demo](#inference-in-gradio-demo) 92 | - [💼 Windows one-click package](#-windows-one-click-package) 93 | - [🔧 Training](#-training) 94 | - [💪 To Do List](#-to-do-list) 95 | - [💪 Comparison with Commercial Models](#-comparison-with-commercial-models) 96 | - [🌟 Star History](#-star-history) 97 | - [Bibtex](#bibtex) 98 | 99 | 100 | 101 | # 🎨ComfyUI Workflow 102 | 103 | 104 | ### Official ComfyUI-workflow 105 | We have released our **official ComfyUI workflow** in this repository for correct usage of our model! **We have embedded the prompt "A diptych with two side-by-side images of the same scene ... but" into our nodes** and you just need to input the edit instructions such as "make the girl wear pink sunglasses". We also add a high resolution refinement module for better image quality! The total VRAM consumption is about 14GB. Use this [workflow](https://github.com/hayd-zju/ICEdit-ComfyUI-official) and the [ICEdit-normal-lora](https://huggingface.co/RiverZ/normal-lora/tree/main) to fulfill your creative ideas! 106 | 107 | We have specially created [a repository for the workflow](https://github.com/hayd-zju/ICEdit-ComfyUI-official) and you can **install it directly in ComfyUI**. Just open the manager tab and click **'Install via Git URL'**, copy the following URL and you are able to use it. For more details please refer to this [issue](https://github.com/River-Zhang/ICEdit/issues/22#issuecomment-2864977880) 108 | 109 | **URL:** [https://github.com/hayd-zju/ICEdit-ComfyUI-official](https://github.com/hayd-zju/ICEdit-ComfyUI-official) 110 | 111 | 112 | 113 | 114 | Great thanks to [月下Hugo](https://www.bilibili.com/video/BV1JZVRzuE12/?share_source=copy_web&vd_source=8fcb933ee576af56337afc41509fa095) for making a [Chinese tutorial](https://www.bilibili.com/video/BV1JZVRzuE12/?share_source=copy_web&vd_source=8fcb933ee576af56337afc41509fa095) on how to use our official workflow! 115 | 116 | ### ComfyUI-workflow for increased editing success rate 117 | Thanks to [T8star](https://x.com/T8star_Aix)! He made a tutorial ([Youtube](https://www.youtube.com/watch?v=s6GMKL-Jjos) and [bilibili](https://www.bilibili.com/video/BV11HVhz1Eky/?spm_id_from=333.40164.top_right_bar_window_dynamic.content.click&vd_source=2a911c0bc75f6d9b9d056bf0e7410d45)) and a creative workflow ([OpenArt](https://openart.ai/workflows/t8star/icedit100v1/HN4EZ2Cej98ZX8CC1RK5) and [RunningHub](https://www.runninghub.cn/post/1920075398585974786/?utm_source=kol01-RH099)) that could increase the editing success rate greatly (about 100%)! Have a try with it! 118 | 119 | 120 | 121 | 122 | ### ComfyUI-nunchaku 123 | 124 | We extend our heartfelt thanks to @[judian17](https://github.com/judian17) for crafting a ComfyUI [workflow](https://github.com/River-Zhang/ICEdit/issues/1#issuecomment-2846568411) that facilitates seamless usage of our model. Explore this excellent [workflow](https://github.com/River-Zhang/ICEdit/issues/1#issuecomment-2846568411) to effortlessly run our model within ComfyUI. Only **4GB VRAM GPU** is enough to run with ComfyUI-nunchaku! 125 | 126 | This workflow incorporates high-definition refinement, yielding remarkably good results. Moreover, integrating this LoRA with Redux enables outfit changes to a certain degree. Once again, a huge thank you to @[judian17](https://github.com/judian17) for his innovative contributions! 127 | 128 | ![comfyui image](docs/images/comfyuiexample.png) 129 | 130 | 131 | ### ComfyUI-workflow 132 | 133 | Thanks to [Datou](https://x.com/Datou), a workflow of ICEdit in ComfyUI can also be downloaded [here](https://openart.ai/workflows/datou/icedit-moe-lora-flux-fill/QFmaWNKsQo3P5liYz4RB). Try it with the [normal lora ckpt](https://huggingface.co/RiverZ/normal-lora/tree/main). 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | # ⚠️ Tips 143 | 144 | ### If you encounter such a failure case, please **try again with a different seed**! 145 | 146 | - Our base model, FLUX, does not inherently support a wide range of styles, so a large portion of our dataset involves style transfer. As a result, the model **may sometimes inexplicably change your artistic style**. 147 | 148 | - Our training dataset is **mostly targeted at realistic images**. For non-realistic images, such as **anime** or **blurry pictures**, the success rate of the editing **drop and could potentially affect the final image quality**. 149 | 150 | - While the success rates for adding objects, modifying color attributes, applying style transfer, and changing backgrounds are high, the success rate for object removal is relatively lower due to the low quality of the removal dataset we use. 151 | 152 | The current model is the one used in the experiments in the paper, trained with only 4 A800 GPUs (total `batch_size` = 2 x 2 x 4 = 16). In the future, we will enhance the dataset, and do scale-up, finally release a more powerful model. 153 | 154 | ### ⚠️ Clarification 155 | 156 | We've noticed numerous web pages related to ICEdit, including [https://icedit.net/](https://icedit.net/), [https://icedit.org/](https://icedit.org/). Kudos to those who built these pages! 157 | 158 | However, we'd like to emphasize two important points: 159 | - **No Commercial Use**: Our project **cannot** be used for commercial purposes. Please check the [LICENSE](https://github.com/River-Zhang/ICEdit/blob/main/LICENSE) for details. 160 | - **Official Page**: The official project page is [https://river-zhang.github.io/ICEdit-gh-pages/](https://river-zhang.github.io/ICEdit-gh-pages/). 161 | 162 | 163 | 164 | # 💼 Installation 165 | 166 | ## Conda environment setup 167 | 168 | ```bash 169 | conda create -n icedit python=3.10 170 | conda activate icedit 171 | pip install -r requirements.txt 172 | pip install -U huggingface_hub 173 | ``` 174 | 175 | ## Download pretrained weights 176 | 177 | If you can connect to Huggingface, you don't need to download the weights. Otherwise, you need to download the weights to local. 178 | 179 | - [Flux.1-fill-dev](https://huggingface.co/black-forest-labs/flux.1-fill-dev). 180 | - [ICEdit-normal-LoRA](https://huggingface.co/RiverZ/normal-lora/tree/main). 181 | 182 | Note: Due to some cooperation permission issues, we have to withdraw the weights and codes of moe-lora temporarily. What is released currently is just the ordinary lora, but it still has powerful performance. If you urgently need the moe lora weights of the original text, please email the author. 183 | 184 | ## Inference in bash (w/o VLM Inference-time Scaling) 185 | 186 | Now you can have a try! 187 | 188 | > Our model can **only edit images with a width of 512 pixels** (there is no restriction on the height). If you pass in an image with a width other than 512 pixels, the model will automatically resize it to 512 pixels. 189 | 190 | > If you found the model failed to generate the expected results, please try to change the `--seed` parameter. Inference-time Scaling with VLM can help much to improve the results. 191 | 192 | ```bash 193 | python scripts/inference.py --image assets/girl.png \ 194 | --instruction "Make her hair dark green and her clothes checked." \ 195 | --seed 304897401 \ 196 | ``` 197 | 198 | Editing a 512×768 image requires 35 GB of GPU memory. If you need to run on a system with 24 GB of GPU memory (for example, an NVIDIA RTX3090), you can add the `--enable-model-cpu-offload` parameter. 199 | 200 | ```bash 201 | python scripts/inference.py --image assets/girl.png \ 202 | --instruction "Make her hair dark green and her clothes checked." \ 203 | --enable-model-cpu-offload 204 | ``` 205 | 206 | If you have downloaded the pretrained weights locally, please pass the parameters during inference, as in: 207 | 208 | ```bash 209 | python scripts/inference.py --image assets/girl.png \ 210 | --instruction "Make her hair dark green and her clothes checked." \ 211 | --flux-path /path/to/flux.1-fill-dev \ 212 | --lora-path /path/to/ICEdit-normal-LoRA 213 | ``` 214 | 215 | ## Inference in Gradio Demo 216 | 217 | We provide a gradio demo for you to edit images in a more user-friendly way. You can run the following command to start the demo. 218 | 219 | ```bash 220 | python scripts/gradio_demo.py --port 7860 221 | ``` 222 | 223 | Like the inference script, if you want to run the demo on a system with 24 GB of GPU memory, you can add the `--enable-model-cpu-offload` parameter. And if you have downloaded the pretrained weights locally, please pass the parameters during inference, as in: 224 | 225 | ```bash 226 | python scripts/gradio_demo.py --port 7860 \ 227 | --flux-path /path/to/flux.1-fill-dev (optional) \ 228 | --lora-path /path/to/ICEdit-normal-LoRA (optional) \ 229 | --enable-model-cpu-offload (optional) \ 230 | ``` 231 | 232 | Or if you want to run the demo on a system with 10 GB of GPU memory, you can download the gguf models from [FLUX.1-Fill-dev-gguf](https://huggingface.co/YarvixPA/FLUX.1-Fill-dev-gguf), [t5-v1_1-xxl-encoder-gguf](https://huggingface.co/city96/t5-v1_1-xxl-encoder-gguf) and pass the parameters during inference, as in: 233 | 234 | ```bash 235 | python scripts/gradio_demo.py --port 7861 \ 236 | --flux-path models/flux.1-fill-dev \ 237 | --lora-path models/ICEdit-normal-LoRA \ 238 | --transformer models/flux1-fill-dev-Q4_0.gguf \ 239 | --text_encoder_2 models/t5-v1_1-xxl-encoder-Q8_0.gguf \ 240 | --enable-model-cpu-offload \ 241 | ``` 242 | 243 | Then you can open the link in your browser to edit images. 244 | 245 |
246 |
247 | 248 |

Gradio Demo: just input the instruction and wait for the result!.

249 |
250 | 251 |
252 | 253 | Here is also a Chinese tutorial [Youtube video](https://www.youtube.com/watch?v=rRMc5DE4qMo) on how to install and use ICEdit, created by [softicelee2](https://github.com/softicelee2). It's definitely worth a watch! 254 | 255 | ## 💼 Windows one-click package 256 | 257 | Great thanks to [gluttony-10](https://github.com/River-Zhang/ICEdit/issues/23#issue-3050804566), a famous [Bilibili Up](https://space.bilibili.com/893892)! He made a tutorial ([Youtube](https://youtu.be/C-OpWlJi424) and [Bilibili](https://www.bilibili.com/video/BV1oT5uzzEbs)) on how to install our project on windows and a one-click package for Windows! **Just unzip it and it's ready to use**. It has undergone quantization processing. It only takes up 14GB of space and supports graphics cards of the 50 series. 258 | 259 | Download link: [Google Drive](https://drive.google.com/drive/folders/16j3wQvWjuzCRKnVolszLmhCtc_yOCqcx?usp=sharing) or [Baidu Wangpan](https://www.bilibili.com/video/BV1oT5uzzEbs/?vd_source=2a911c0bc75f6d9b9d056bf0e7410d45)(refer to the comment section of the video) 260 | 261 | 262 | 263 | # 🔧 Training 264 | 265 | Found more details in here: [Training Code](./train/) 266 | 267 | # 💪 To Do List 268 | 269 | - [x] Inference Code 270 | - [ ] Inference-time Scaling with VLM 271 | - [x] Pretrained Weights 272 | - [x] More Inference Demos 273 | - [x] Gradio demo 274 | - [x] Comfy UI demo (by @[judian17](https://github.com/River-Zhang/ICEdit/issues/1#issuecomment-2846568411), compatible with [nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku), support high-res refinement and FLUX Redux. Only 4GB VRAM GPU is enough to run!) 275 | - [x] Comfy UI demo with normal lora (by @[Datou](https://openart.ai/workflows/datou/icedit-moe-lora-flux-fill/QFmaWNKsQo3P5liYz4RB) in openart) 276 | - [x] Official ComfyUI workflow 277 | - [x] Training Code 278 | - [ ] LoRA for higher image resolution (768, 1024) 279 | 280 | 281 | 282 | # 💪 Comparison with Commercial Models 283 | 284 |
285 |
286 | 287 |

Compared with commercial models such as Gemini and GPT-4o, our methods are comparable to and even superior to these commercial models in terms of character ID preservation and instruction following. We are more open-source than them, with lower costs, faster speed (it takes about 9 seconds to process one image), and powerful performance.

288 |
289 | 290 | 291 |
292 | 293 | 294 | # 🌟 Star History 295 | 296 | [![Star History Chart](https://api.star-history.com/svg?repos=River-Zhang/ICEdit&type=Date)](https://www.star-history.com/#River-Zhang/ICEdit&Date) 297 | 298 | # Bibtex 299 | If this work is helpful for your research, please consider citing the following BibTeX entry. 300 | 301 | ``` 302 | @misc{zhang2025ICEdit, 303 | title={In-Context Edit: Enabling Instructional Image Editing with In-Context Generation in Large Scale Diffusion Transformer}, 304 | author={Zechuan Zhang and Ji Xie and Yu Lu and Zongxin Yang and Yi Yang}, 305 | year={2025}, 306 | eprint={2504.20690}, 307 | archivePrefix={arXiv}, 308 | primaryClass={cs.CV}, 309 | url={https://arxiv.org/abs/2504.20690}, 310 | } 311 | ``` 312 | -------------------------------------------------------------------------------- /assets/boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/boy.png -------------------------------------------------------------------------------- /assets/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/girl.png -------------------------------------------------------------------------------- /assets/hybrid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/hybrid.png -------------------------------------------------------------------------------- /assets/kaori.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/kaori.jpg -------------------------------------------------------------------------------- /docs/images/comfyuiexample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/comfyuiexample.png -------------------------------------------------------------------------------- /docs/images/gpt4o_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/gpt4o_comparison.png -------------------------------------------------------------------------------- /docs/images/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/gradio.png -------------------------------------------------------------------------------- /docs/images/lora_scale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/lora_scale.png -------------------------------------------------------------------------------- /docs/images/official_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/official_workflow.png -------------------------------------------------------------------------------- /docs/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/teaser.png -------------------------------------------------------------------------------- /docs/images/windows_install.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/windows_install.png -------------------------------------------------------------------------------- /docs/images/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/workflow.png -------------------------------------------------------------------------------- /docs/images/workflow_t8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/workflow_t8.png -------------------------------------------------------------------------------- /docs/images/workflow_tutorial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/workflow_tutorial.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | diffusers==0.33.0 3 | gradio 4 | numpy 5 | peft 6 | protobuf 7 | sentencepiece 8 | spaces 9 | torch==2.7.0 10 | torchvision 11 | transformers==4.51.3 12 | gguf 13 | -------------------------------------------------------------------------------- /scripts/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "FluxTransformer2DModel", 3 | "_diffusers_version": "0.32.0.dev0", 4 | "attention_head_dim": 128, 5 | "axes_dims_rope": [ 6 | 16, 7 | 56, 8 | 56 9 | ], 10 | "guidance_embeds": true, 11 | "in_channels": 384, 12 | "joint_attention_dim": 4096, 13 | "num_attention_heads": 24, 14 | "num_layers": 19, 15 | "num_single_layers": 38, 16 | "out_channels": 64, 17 | "patch_size": 1, 18 | "pooled_projection_dim": 768 19 | } 20 | -------------------------------------------------------------------------------- /scripts/gradio_demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python scripts/gradio_demo.py 3 | ''' 4 | 5 | import sys 6 | import os 7 | # workspace_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../icedit")) 8 | 9 | # if workspace_dir not in sys.path: 10 | # sys.path.insert(0, workspace_dir) 11 | 12 | from diffusers import FluxFillPipeline, FluxTransformer2DModel, GGUFQuantizationConfig 13 | import gradio as gr 14 | import numpy as np 15 | import torch 16 | import spaces 17 | import argparse 18 | import random 19 | from diffusers import FluxFillPipeline 20 | from PIL import Image 21 | 22 | from transformers import T5EncoderModel 23 | 24 | MAX_SEED = np.iinfo(np.int32).max 25 | MAX_IMAGE_SIZE = 1024 26 | 27 | current_lora_scale = 1.0 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--server_name", type=str, default="127.0.0.1") 31 | parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio app") 32 | parser.add_argument("--share", action="store_true") 33 | parser.add_argument("--output-dir", type=str, default="gradio_results", help="Directory to save the output image") 34 | parser.add_argument("--flux-path", type=str, default='black-forest-labs/flux.1-fill-dev', help="Path to the model") 35 | parser.add_argument("--lora-path", type=str, default='RiverZ/normal-lora', help="Path to the LoRA weights") 36 | parser.add_argument("--transformer", type=str, default=None, help="The gguf model of FluxTransformer2DModel") 37 | parser.add_argument("--text_encoder_2", type=str, default=None, help="The gguf model of T5EncoderModel") 38 | parser.add_argument("--enable-model-cpu-offload", action="store_true", help="Enable CPU offloading for the model") 39 | args = parser.parse_args() 40 | 41 | if args.transformer: 42 | args.transformer = os.path.abspath(args.transformer) 43 | transformer = FluxTransformer2DModel.from_single_file( 44 | args.transformer, 45 | config="scripts/config.json", 46 | quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), 47 | torch_dtype=torch.bfloat16, 48 | ) 49 | else: 50 | transformer = FluxTransformer2DModel.from_pretrained( 51 | args.flux_path, 52 | subfolder="transformer", 53 | torch_dtype=torch.bfloat16, 54 | ) 55 | 56 | if args.text_encoder_2: 57 | args.text_encoder_2 = os.path.abspath(args.text_encoder_2) 58 | text_encoder_2 = T5EncoderModel.from_pretrained( 59 | args.flux_path, 60 | subfolder="text_encoder_2", 61 | gguf_file=f"{args.text_encoder_2}", 62 | torch_dtype=torch.bfloat16, 63 | ) 64 | else: 65 | text_encoder_2 = T5EncoderModel.from_pretrained( 66 | args.flux_path, 67 | subfolder="text_encoder_2", 68 | torch_dtype=torch.bfloat16, 69 | ) 70 | 71 | pipe = FluxFillPipeline.from_pretrained( 72 | args.flux_path, 73 | transformer=transformer, 74 | text_encoder_2=text_encoder_2, 75 | torch_dtype=torch.bfloat16 76 | ) 77 | pipe.load_lora_weights(args.lora_path, adapter_name="icedit") 78 | pipe.set_adapters("icedit", 1.0) 79 | 80 | if args.enable_model_cpu_offload: 81 | pipe.enable_model_cpu_offload() 82 | else: 83 | pipe = pipe.to("cuda") 84 | 85 | 86 | @spaces.GPU 87 | def infer(edit_images, 88 | prompt, 89 | seed=666, 90 | randomize_seed=False, 91 | width=1024, 92 | height=1024, 93 | guidance_scale=50, 94 | num_inference_steps=28, 95 | lora_scale=1.0, 96 | progress=gr.Progress(track_tqdm=True)): 97 | 98 | global current_lora_scale 99 | 100 | if lora_scale != current_lora_scale: 101 | print(f"\033[93m[INFO] LoRA scale changed from {current_lora_scale} to {lora_scale}, reloading LoRA weights\033[0m") 102 | pipe.set_adapters("icedit", lora_scale) 103 | current_lora_scale = lora_scale 104 | 105 | image = edit_images 106 | 107 | if image.size[0] != 512: 108 | print("\033[93m[WARNING] We can only deal with the case where the image's width is 512.\033[0m") 109 | new_width = 512 110 | scale = new_width / image.size[0] 111 | new_height = int(image.size[1] * scale) 112 | new_height = (new_height // 8) * 8 113 | image = image.resize((new_width, new_height)) 114 | print(f"\033[93m[WARNING] Resizing the image to {new_width} x {new_height}\033[0m") 115 | 116 | image = image.convert("RGB") 117 | width, height = image.size 118 | image = image.resize((512, int(512 * height / width))) 119 | combined_image = Image.new("RGB", (width * 2, height)) 120 | combined_image.paste(image, (0, 0)) 121 | mask_array = np.zeros((height, width * 2), dtype=np.uint8) 122 | mask_array[:, width:] = 255 123 | mask = Image.fromarray(mask_array) 124 | instruction = f'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but {prompt}' 125 | 126 | if randomize_seed: 127 | seed = random.randint(0, MAX_SEED) 128 | 129 | image = pipe( 130 | prompt=instruction, 131 | image=combined_image, 132 | mask_image=mask, 133 | height=height, 134 | width=width * 2, 135 | guidance_scale=guidance_scale, 136 | num_inference_steps=num_inference_steps, 137 | generator=torch.Generator().manual_seed(seed), 138 | ).images[0] 139 | 140 | w, h = image.size 141 | image = image.crop((w // 2, 0, w, h)) 142 | 143 | os.makedirs(args.output_dir, exist_ok=True) 144 | 145 | index = len(os.listdir(args.output_dir)) 146 | image.save(f"{args.output_dir}/result_{index}.png") 147 | 148 | return image, seed 149 | 150 | original_examples = [ 151 | "a tiny astronaut hatching from an egg on the moon", 152 | "a cat holding a sign that says hello world", 153 | "an anime illustration of a wiener schnitzel", 154 | ] 155 | 156 | new_examples = [ 157 | ['assets/girl.png', 'Make her hair dark green and her clothes checked.', 304897401], 158 | ['assets/boy.png', 'Change the sunglasses to a Christmas hat.', 748891420], 159 | ['assets/kaori.jpg', 'Make it a sketch.', 484817364] 160 | ] 161 | 162 | css = """ 163 | #col-container { 164 | margin: 0 auto; 165 | max-width: 1000px; 166 | } 167 | """ 168 | 169 | with gr.Blocks(css=css) as demo: 170 | 171 | with gr.Column(elem_id="col-container"): 172 | gr.Markdown(f"""# IC-Edit 173 | A demo for [IC-Edit](https://arxiv.org/pdf/2504.20690). 174 | More **open-source**, with **lower costs**, **faster speed** (it takes about 9 seconds to process one image), and **powerful performance**. 175 | For more details, check out our [Github Repository](https://github.com/River-Zhang/ICEdit) and [website](https://river-zhang.github.io/ICEdit-gh-pages/). If our project resonates with you or proves useful, we'd be truly grateful if you could spare a moment to give it a star. 176 | """) 177 | with gr.Row(): 178 | with gr.Column(): 179 | edit_image = gr.Image( 180 | label='Upload image for editing', 181 | type='pil', 182 | sources=["upload", "webcam"], 183 | image_mode='RGB', 184 | height=600 185 | ) 186 | prompt = gr.Text( 187 | label="Prompt", 188 | show_label=False, 189 | max_lines=1, 190 | placeholder="Enter your prompt", 191 | container=False, 192 | ) 193 | run_button = gr.Button("Run") 194 | 195 | result = gr.Image(label="Result", show_label=False) 196 | 197 | with gr.Accordion("Advanced Settings", open=True): 198 | 199 | seed = gr.Slider( 200 | label="Seed", 201 | minimum=0, 202 | maximum=MAX_SEED, 203 | step=1, 204 | value=0, 205 | ) 206 | 207 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 208 | 209 | with gr.Row(): 210 | 211 | width = gr.Slider( 212 | label="Width", 213 | minimum=512, 214 | maximum=MAX_IMAGE_SIZE, 215 | step=32, 216 | value=1024, 217 | visible=False 218 | ) 219 | 220 | height = gr.Slider( 221 | label="Height", 222 | minimum=512, 223 | maximum=MAX_IMAGE_SIZE, 224 | step=32, 225 | value=1024, 226 | visible=False 227 | ) 228 | 229 | with gr.Row(): 230 | 231 | guidance_scale = gr.Slider( 232 | label="Guidance Scale", 233 | minimum=1, 234 | maximum=100, 235 | step=0.5, 236 | value=50, 237 | ) 238 | 239 | num_inference_steps = gr.Slider( 240 | label="Number of inference steps", 241 | minimum=1, 242 | maximum=50, 243 | step=1, 244 | value=28, 245 | ) 246 | 247 | lora_scale = gr.Slider( 248 | label="LoRA Scale", 249 | minimum=0, 250 | maximum=1.0, 251 | step=0.01, 252 | value=1.0, 253 | ) 254 | 255 | def process_example(edit_image, prompt, seed, randomize_seed): 256 | result, seed_out = infer(edit_image, prompt, seed, False, 1024, 1024, 50, 28, 1.0) 257 | return result, seed_out, False 258 | 259 | gr.Examples( 260 | examples=new_examples, 261 | inputs=[edit_image, prompt, seed, randomize_seed], 262 | outputs=[result, seed, randomize_seed], 263 | fn=process_example, 264 | cache_examples=False 265 | ) 266 | 267 | gr.on( 268 | triggers=[run_button.click, prompt.submit], 269 | fn=infer, 270 | inputs=[edit_image, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_scale], 271 | outputs=[result, seed] 272 | ) 273 | 274 | if __name__ == "__main__": 275 | demo.launch( 276 | server_name=args.server_name, 277 | server_port=args.port, 278 | share=args.share, 279 | inbrowser=True, 280 | ) 281 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | # Use the modified diffusers & peft library 2 | import sys 3 | import os 4 | # workspace_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../icedit")) 5 | 6 | # if workspace_dir not in sys.path: 7 | # sys.path.insert(0, workspace_dir) 8 | 9 | from diffusers import FluxFillPipeline 10 | 11 | # Below is the original library 12 | import torch 13 | from PIL import Image 14 | import numpy as np 15 | import argparse 16 | import random 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--image", type=str, help="Name of the image to be edited", required=True) 20 | parser.add_argument("--instruction", type=str, help="Instruction for editing the image", required=True) 21 | parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") 22 | parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output image") 23 | parser.add_argument("--flux-path", type=str, default='black-forest-labs/flux.1-fill-dev', help="Path to the model") 24 | parser.add_argument("--lora-path", type=str, default='RiverZ/normal-lora', help="Path to the LoRA weights") 25 | parser.add_argument("--enable-model-cpu-offload", action="store_true", help="Enable CPU offloading for the model") 26 | 27 | 28 | args = parser.parse_args() 29 | pipe = FluxFillPipeline.from_pretrained(args.flux_path, torch_dtype=torch.bfloat16) 30 | pipe.load_lora_weights(args.lora_path) 31 | 32 | if args.enable_model_cpu_offload: 33 | pipe.enable_model_cpu_offload() 34 | else: 35 | pipe = pipe.to("cuda") 36 | 37 | image = Image.open(args.image) 38 | image = image.convert("RGB") 39 | 40 | if image.size[0] != 512: 41 | print("\033[93m[WARNING] We can only deal with the case where the image's width is 512.\033[0m") 42 | new_width = 512 43 | scale = new_width / image.size[0] 44 | new_height = int(image.size[1] * scale) 45 | new_height = (new_height // 8) * 8 46 | image = image.resize((new_width, new_height)) 47 | print(f"\033[93m[WARNING] Resizing the image to {new_width} x {new_height}\033[0m") 48 | 49 | instruction = args.instruction 50 | 51 | print(f"Instruction: {instruction}") 52 | instruction = f'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but {instruction}' 53 | 54 | width, height = image.size 55 | combined_image = Image.new("RGB", (width * 2, height)) 56 | combined_image.paste(image, (0, 0)) 57 | combined_image.paste(image, (width, 0)) 58 | mask_array = np.zeros((height, width * 2), dtype=np.uint8) 59 | mask_array[:, width:] = 255 60 | mask = Image.fromarray(mask_array) 61 | 62 | result_image = pipe( 63 | prompt=instruction, 64 | image=combined_image, 65 | mask_image=mask, 66 | height=height, 67 | width=width * 2, 68 | guidance_scale=50, 69 | num_inference_steps=28, 70 | generator=torch.Generator("cpu").manual_seed(args.seed) if args.seed is not None else None, 71 | ).images[0] 72 | 73 | result_image = result_image.crop((width,0,width*2,height)) 74 | 75 | os.makedirs(args.output_dir, exist_ok=True) 76 | 77 | image_name = args.image.split("/")[-1] 78 | result_image.save(os.path.join(args.output_dir, f"{image_name}")) 79 | print(f"\033[92mResult saved as {os.path.abspath(os.path.join(args.output_dir, image_name))}\033[0m") 80 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # ICEdit Training Repository 2 | 3 | This repository contains the training code for ICEdit, a model for image editing based on text instructions. It utilizes conditional generation to perform instructional image edits. 4 | 5 | This codebase is based heavily on the [OminiControl](https://github.com/Yuanshi9815/OminiControl) repository. We thank the authors for their work and contributions to the field! 6 | 7 | ## Setup and Installation 8 | 9 | ```bash 10 | # Create a new conda environment 11 | conda create -n train python=3.10 12 | conda activate train 13 | 14 | # Install requirements 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Project Structure 19 | 20 | - `src/`: Source code directory 21 | - `train/`: Training modules 22 | - `train.py`: Main training script 23 | - `data.py`: Dataset classes for handling different data formats 24 | - `model.py`: Model definition using Flux pipeline 25 | - `callbacks.py`: Training callbacks for logging and checkpointing 26 | - `flux/`: Flux model implementation 27 | - `assets/`: Asset files 28 | - `parquet/`: Parquet data files 29 | - `requirements.txt`: Dependency list 30 | 31 | ## Datasets 32 | 33 | Download training datasets (part of OmniEdit) to the `parquet/` directory. You can use the provided scripts `parquet/prepare.sh`. 34 | 35 | ```bash 36 | cd parquet 37 | bash prepare.sh 38 | ``` 39 | 40 | ## Training 41 | 42 | ```bash 43 | bash train/script/train.sh 44 | ``` 45 | 46 | You can modify the training configuration in `train/config/normal_lora.yaml`. 47 | -------------------------------------------------------------------------------- /train/assets/book.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/book.jpg -------------------------------------------------------------------------------- /train/assets/clock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/clock.jpg -------------------------------------------------------------------------------- /train/assets/coffee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/coffee.png -------------------------------------------------------------------------------- /train/assets/monalisa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/monalisa.jpg -------------------------------------------------------------------------------- /train/assets/oranges.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/oranges.jpg -------------------------------------------------------------------------------- /train/assets/penguin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/penguin.jpg -------------------------------------------------------------------------------- /train/assets/room_corner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/room_corner.jpg -------------------------------------------------------------------------------- /train/assets/vase.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/vase.jpg -------------------------------------------------------------------------------- /train/parquet/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p $(dirname "$0") 4 | 5 | BASE_URL_105="https://huggingface.co/datasets/sayakpaul/OmniEdit-mini/resolve/main/data" 6 | BASE_URL_571="https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M/resolve/main/data" 7 | 8 | FILES=( 9 | "train-00053-of-00105.parquet" 10 | "train-00008-of-00105.parquet" 11 | "train-00093-of-00105.parquet" 12 | "train-00097-of-00105.parquet" 13 | "train-00009-of-00105.parquet" 14 | "train-00069-of-00105.parquet" 15 | "train-00029-of-00105.parquet" 16 | "train-00083-of-00105.parquet" 17 | "train-00037-of-00105.parquet" 18 | "train-00079-of-00105.parquet" 19 | "train-00085-of-00105.parquet" 20 | "train-00087-of-00105.parquet" 21 | "train-00038-of-00105.parquet" 22 | "train-00041-of-00105.parquet" 23 | "train-00047-of-00105.parquet" 24 | "train-00145-of-00571.parquet" 25 | "train-00091-of-00105.parquet" 26 | "train-00004-of-00105.parquet" 27 | "train-00014-of-00105.parquet" 28 | "train-00016-of-00105.parquet" 29 | "train-00035-of-00105.parquet" 30 | "train-00017-of-00105.parquet" 31 | "train-00066-of-00105.parquet" 32 | "train-00071-of-00105.parquet" 33 | "train-00043-of-00105.parquet" 34 | "train-00067-of-00105.parquet" 35 | "train-00074-of-00105.parquet" 36 | "train-00001-of-00105.parquet" 37 | "train-00115-of-00571.parquet" 38 | "train-00048-of-00105.parquet" 39 | "train-00064-of-00105.parquet" 40 | "train-00010-of-00105.parquet" 41 | "train-00011-of-00105.parquet" 42 | "train-00062-of-00105.parquet" 43 | "train-00567-of-00571.parquet" 44 | "train-00032-of-00105.parquet" 45 | "train-00070-of-00105.parquet" 46 | "train-00160-of-00571.parquet" 47 | "train-00046-of-00105.parquet" 48 | "train-00073-of-00105.parquet" 49 | "train-00006-of-00105.parquet" 50 | "train-00061-of-00105.parquet" 51 | "train-00050-of-00105.parquet" 52 | "train-00056-of-00105.parquet" 53 | "train-00003-of-00105.parquet" 54 | "train-00012-of-00105.parquet" 55 | "train-00089-of-00105.parquet" 56 | "train-00028-of-00105.parquet" 57 | "train-00015-of-00105.parquet" 58 | "train-00103-of-00105.parquet" 59 | "train-00099-of-00105.parquet" 60 | "train-00020-of-00105.parquet" 61 | "train-00033-of-00105.parquet" 62 | "train-00078-of-00105.parquet" 63 | "train-00000-of-00105.parquet" 64 | "train-00566-of-00571.parquet" 65 | "train-00054-of-00105.parquet" 66 | "train-00044-of-00105.parquet" 67 | "train-00100-of-00571.parquet" 68 | "train-00049-of-00105.parquet" 69 | "train-00019-of-00105.parquet" 70 | "train-00076-of-00105.parquet" 71 | "train-00025-of-00105.parquet" 72 | "train-00081-of-00105.parquet" 73 | "train-00045-of-00105.parquet" 74 | "train-00036-of-00105.parquet" 75 | "train-00080-of-00105.parquet" 76 | "train-00034-of-00105.parquet" 77 | "train-00057-of-00105.parquet" 78 | "train-00082-of-00105.parquet" 79 | "train-00059-of-00105.parquet" 80 | "train-00058-of-00105.parquet" 81 | "train-00013-of-00105.parquet" 82 | "train-00084-of-00105.parquet" 83 | "train-00100-of-00105.parquet" 84 | "train-00090-of-00105.parquet" 85 | "train-00094-of-00105.parquet" 86 | "train-00060-of-00105.parquet" 87 | "train-00175-of-00571.parquet" 88 | "train-00065-of-00105.parquet" 89 | "train-00040-of-00105.parquet" 90 | "train-00023-of-00105.parquet" 91 | "train-00088-of-00105.parquet" 92 | "train-00068-of-00105.parquet" 93 | "train-00027-of-00105.parquet" 94 | "train-00568-of-00571.parquet" 95 | "train-00098-of-00105.parquet" 96 | "train-00031-of-00105.parquet" 97 | "train-00063-of-00105.parquet" 98 | "train-00002-of-00105.parquet" 99 | "train-00007-of-00105.parquet" 100 | "train-00569-of-00571.parquet" 101 | "train-00052-of-00105.parquet" 102 | "train-00102-of-00105.parquet" 103 | "train-00104-of-00105.parquet" 104 | "train-00072-of-00105.parquet" 105 | "train-00051-of-00105.parquet" 106 | "train-00101-of-00105.parquet" 107 | "train-00570-of-00571.parquet" 108 | "train-00095-of-00105.parquet" 109 | "train-00092-of-00105.parquet" 110 | "train-00030-of-00105.parquet" 111 | "train-00055-of-00105.parquet" 112 | "train-00042-of-00105.parquet" 113 | "train-00018-of-00105.parquet" 114 | "train-00096-of-00105.parquet" 115 | "train-00005-of-00105.parquet" 116 | "train-00022-of-00105.parquet" 117 | "train-00086-of-00105.parquet" 118 | "train-00024-of-00105.parquet" 119 | "train-00077-of-00105.parquet" 120 | "train-00075-of-00105.parquet" 121 | "train-00039-of-00105.parquet" 122 | "train-00021-of-00105.parquet" 123 | "train-00130-of-00571.parquet" 124 | "train-00026-of-00105.parquet" 125 | "train-00000-of-00571.parquet" 126 | ) 127 | 128 | TOTAL=${#FILES[@]} 129 | CURRENT=0 130 | 131 | for file in "${FILES[@]}"; do 132 | CURRENT=$((CURRENT+1)) 133 | echo "[$CURRENT/$TOTAL] $file" 134 | 135 | if [[ $file == *"-of-00105.parquet" ]]; then 136 | BASE_URL=$BASE_URL_105 137 | else 138 | BASE_URL=$BASE_URL_571 139 | fi 140 | 141 | wget -c "$BASE_URL/$file" -O "$file" || { 142 | echo "Download $file failed, trying to continue..." 143 | } 144 | 145 | if [ -f "$file" ]; then 146 | filesize=$(du -h "$file" | cut -f1) 147 | echo "Downloaded : $file ($filesize)" 148 | else 149 | echo "Warning: $file failed!" 150 | fi 151 | 152 | echo "------------------------------------" 153 | done 154 | 155 | echo "All files downloaded!" -------------------------------------------------------------------------------- /train/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.32.0 2 | datasets==3.6.0 3 | transformers 4 | peft 5 | opencv-python 6 | protobuf 7 | sentencepiece 8 | gradio 9 | jupyter 10 | torchao 11 | lightning 12 | torchvision 13 | prodigyopt 14 | wandb -------------------------------------------------------------------------------- /train/runs/20250513-085800/config.yaml: -------------------------------------------------------------------------------- 1 | dtype: bfloat16 2 | flux_path: black-forest-labs/flux.1-fill-dev 3 | model: 4 | add_cond_attn: false 5 | latent_lora: false 6 | union_cond_attn: true 7 | use_sep: false 8 | train: 9 | accumulate_grad_batches: 1 10 | batch_size: 2 11 | condition_type: edit 12 | dataloader_workers: 5 13 | dataset: 14 | condition_size: 512 15 | drop_image_prob: 0.1 16 | drop_text_prob: 0.1 17 | image_size: 512 18 | padding: 8 19 | path: parquet/*.parquet 20 | target_size: 512 21 | type: edit_with_omini 22 | gradient_checkpointing: false 23 | lora_config: 24 | init_lora_weights: gaussian 25 | lora_alpha: 32 26 | r: 32 27 | target_modules: (.*x_embedder|.*(? torch.FloatTensor: 17 | batch_size, _, _ = ( 18 | hidden_states.shape 19 | if encoder_hidden_states is None 20 | else encoder_hidden_states.shape 21 | ) 22 | 23 | with enable_lora( 24 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False) 25 | ): 26 | # `sample` projections. 27 | query = attn.to_q(hidden_states) 28 | key = attn.to_k(hidden_states) 29 | value = attn.to_v(hidden_states) 30 | 31 | inner_dim = key.shape[-1] 32 | head_dim = inner_dim // attn.heads 33 | 34 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 35 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 36 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 37 | 38 | if attn.norm_q is not None: 39 | query = attn.norm_q(query) 40 | if attn.norm_k is not None: 41 | key = attn.norm_k(key) 42 | 43 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 44 | if encoder_hidden_states is not None: 45 | # `context` projections. 46 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 47 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 48 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 49 | 50 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 51 | batch_size, -1, attn.heads, head_dim 52 | ).transpose(1, 2) 53 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 54 | batch_size, -1, attn.heads, head_dim 55 | ).transpose(1, 2) 56 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 57 | batch_size, -1, attn.heads, head_dim 58 | ).transpose(1, 2) 59 | 60 | if attn.norm_added_q is not None: 61 | encoder_hidden_states_query_proj = attn.norm_added_q( 62 | encoder_hidden_states_query_proj 63 | ) 64 | if attn.norm_added_k is not None: 65 | encoder_hidden_states_key_proj = attn.norm_added_k( 66 | encoder_hidden_states_key_proj 67 | ) 68 | 69 | # attention 70 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 71 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 72 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 73 | 74 | if image_rotary_emb is not None: 75 | from diffusers.models.embeddings import apply_rotary_emb 76 | 77 | query = apply_rotary_emb(query, image_rotary_emb) 78 | key = apply_rotary_emb(key, image_rotary_emb) 79 | 80 | if condition_latents is not None: 81 | cond_query = attn.to_q(condition_latents) 82 | cond_key = attn.to_k(condition_latents) 83 | cond_value = attn.to_v(condition_latents) 84 | 85 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( 86 | 1, 2 87 | ) 88 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 89 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( 90 | 1, 2 91 | ) 92 | if attn.norm_q is not None: 93 | cond_query = attn.norm_q(cond_query) 94 | if attn.norm_k is not None: 95 | cond_key = attn.norm_k(cond_key) 96 | 97 | if cond_rotary_emb is not None: 98 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb) 99 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb) 100 | 101 | if condition_latents is not None: 102 | query = torch.cat([query, cond_query], dim=2) 103 | key = torch.cat([key, cond_key], dim=2) 104 | value = torch.cat([value, cond_value], dim=2) 105 | 106 | if not model_config.get("union_cond_attn", True): 107 | # If we don't want to use the union condition attention, we need to mask the attention 108 | # between the hidden states and the condition latents 109 | attention_mask = torch.ones( 110 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool 111 | ) 112 | condition_n = cond_query.shape[2] 113 | attention_mask[-condition_n:, :-condition_n] = False 114 | attention_mask[:-condition_n, -condition_n:] = False 115 | if hasattr(attn, "c_factor"): 116 | attention_mask = torch.zeros( 117 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype 118 | ) 119 | condition_n = cond_query.shape[2] 120 | bias = torch.log(attn.c_factor[0]) 121 | attention_mask[-condition_n:, :-condition_n] = bias 122 | attention_mask[:-condition_n, -condition_n:] = bias 123 | hidden_states = F.scaled_dot_product_attention( 124 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask 125 | ) 126 | hidden_states = hidden_states.transpose(1, 2).reshape( 127 | batch_size, -1, attn.heads * head_dim 128 | ) 129 | hidden_states = hidden_states.to(query.dtype) 130 | 131 | if encoder_hidden_states is not None: 132 | if condition_latents is not None: 133 | encoder_hidden_states, hidden_states, condition_latents = ( 134 | hidden_states[:, : encoder_hidden_states.shape[1]], 135 | hidden_states[ 136 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1] 137 | ], 138 | hidden_states[:, -condition_latents.shape[1] :], 139 | ) 140 | else: 141 | encoder_hidden_states, hidden_states = ( 142 | hidden_states[:, : encoder_hidden_states.shape[1]], 143 | hidden_states[:, encoder_hidden_states.shape[1] :], 144 | ) 145 | 146 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)): 147 | # linear proj 148 | hidden_states = attn.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = attn.to_out[1](hidden_states) 151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 152 | 153 | if condition_latents is not None: 154 | condition_latents = attn.to_out[0](condition_latents) 155 | condition_latents = attn.to_out[1](condition_latents) 156 | 157 | return ( 158 | (hidden_states, encoder_hidden_states, condition_latents) 159 | if condition_latents is not None 160 | else (hidden_states, encoder_hidden_states) 161 | ) 162 | elif condition_latents is not None: 163 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents 164 | hidden_states, condition_latents = ( 165 | hidden_states[:, : -condition_latents.shape[1]], 166 | hidden_states[:, -condition_latents.shape[1] :], 167 | ) 168 | return hidden_states, condition_latents 169 | else: 170 | return hidden_states 171 | 172 | 173 | def block_forward( 174 | self, 175 | hidden_states: torch.FloatTensor, 176 | encoder_hidden_states: torch.FloatTensor, 177 | condition_latents: torch.FloatTensor, 178 | temb: torch.FloatTensor, 179 | cond_temb: torch.FloatTensor, 180 | cond_rotary_emb=None, 181 | image_rotary_emb=None, 182 | model_config: Optional[Dict[str, Any]] = {}, 183 | ): 184 | use_cond = condition_latents is not None 185 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)): 186 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 187 | hidden_states, emb=temb 188 | ) 189 | 190 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( 191 | self.norm1_context(encoder_hidden_states, emb=temb) 192 | ) 193 | 194 | if use_cond: 195 | ( 196 | norm_condition_latents, 197 | cond_gate_msa, 198 | cond_shift_mlp, 199 | cond_scale_mlp, 200 | cond_gate_mlp, 201 | ) = self.norm1(condition_latents, emb=cond_temb) 202 | 203 | # Attention. 204 | result = attn_forward( 205 | self.attn, 206 | model_config=model_config, 207 | hidden_states=norm_hidden_states, 208 | encoder_hidden_states=norm_encoder_hidden_states, 209 | condition_latents=norm_condition_latents if use_cond else None, 210 | image_rotary_emb=image_rotary_emb, 211 | cond_rotary_emb=cond_rotary_emb if use_cond else None, 212 | ) 213 | attn_output, context_attn_output = result[:2] 214 | cond_attn_output = result[2] if use_cond else None 215 | 216 | # Process attention outputs for the `hidden_states`. 217 | # 1. hidden_states 218 | attn_output = gate_msa.unsqueeze(1) * attn_output 219 | hidden_states = hidden_states + attn_output 220 | # 2. encoder_hidden_states 221 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 222 | encoder_hidden_states = encoder_hidden_states + context_attn_output 223 | # 3. condition_latents 224 | if use_cond: 225 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 226 | condition_latents = condition_latents + cond_attn_output 227 | if model_config.get("add_cond_attn", False): 228 | hidden_states += cond_attn_output 229 | 230 | # LayerNorm + MLP. 231 | # 1. hidden_states 232 | norm_hidden_states = self.norm2(hidden_states) 233 | norm_hidden_states = ( 234 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 235 | ) 236 | # 2. encoder_hidden_states 237 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 238 | norm_encoder_hidden_states = ( 239 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 240 | ) 241 | # 3. condition_latents 242 | if use_cond: 243 | norm_condition_latents = self.norm2(condition_latents) 244 | norm_condition_latents = ( 245 | norm_condition_latents * (1 + cond_scale_mlp[:, None]) 246 | + cond_shift_mlp[:, None] 247 | ) 248 | 249 | # Feed-forward. 250 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)): 251 | # 1. hidden_states 252 | ff_output = self.ff(norm_hidden_states) 253 | ff_output = gate_mlp.unsqueeze(1) * ff_output 254 | # 2. encoder_hidden_states 255 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 256 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output 257 | # 3. condition_latents 258 | if use_cond: 259 | cond_ff_output = self.ff(norm_condition_latents) 260 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 261 | 262 | # Process feed-forward outputs. 263 | hidden_states = hidden_states + ff_output 264 | encoder_hidden_states = encoder_hidden_states + context_ff_output 265 | if use_cond: 266 | condition_latents = condition_latents + cond_ff_output 267 | 268 | # Clip to avoid overflow. 269 | if encoder_hidden_states.dtype == torch.float16: 270 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 271 | 272 | return encoder_hidden_states, hidden_states, condition_latents if use_cond else None 273 | 274 | 275 | def single_block_forward( 276 | self, 277 | hidden_states: torch.FloatTensor, 278 | temb: torch.FloatTensor, 279 | image_rotary_emb=None, 280 | condition_latents: torch.FloatTensor = None, 281 | cond_temb: torch.FloatTensor = None, 282 | cond_rotary_emb=None, 283 | model_config: Optional[Dict[str, Any]] = {}, 284 | ): 285 | 286 | using_cond = condition_latents is not None 287 | residual = hidden_states 288 | with enable_lora( 289 | ( 290 | self.norm.linear, 291 | self.proj_mlp, 292 | ), 293 | model_config.get("latent_lora", False), 294 | ): 295 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 296 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 297 | if using_cond: 298 | residual_cond = condition_latents 299 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb) 300 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents)) 301 | 302 | attn_output = attn_forward( 303 | self.attn, 304 | model_config=model_config, 305 | hidden_states=norm_hidden_states, 306 | image_rotary_emb=image_rotary_emb, 307 | **( 308 | { 309 | "condition_latents": norm_condition_latents, 310 | "cond_rotary_emb": cond_rotary_emb if using_cond else None, 311 | } 312 | if using_cond 313 | else {} 314 | ), 315 | ) 316 | if using_cond: 317 | attn_output, cond_attn_output = attn_output 318 | 319 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)): 320 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 321 | gate = gate.unsqueeze(1) 322 | hidden_states = gate * self.proj_out(hidden_states) 323 | hidden_states = residual + hidden_states 324 | if using_cond: 325 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 326 | cond_gate = cond_gate.unsqueeze(1) 327 | condition_latents = cond_gate * self.proj_out(condition_latents) 328 | condition_latents = residual_cond + condition_latents 329 | 330 | if hidden_states.dtype == torch.float16: 331 | hidden_states = hidden_states.clip(-65504, 65504) 332 | 333 | return hidden_states if not using_cond else (hidden_states, condition_latents) 334 | -------------------------------------------------------------------------------- /train/src/flux/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union, List, Tuple 3 | from diffusers.pipelines import FluxPipeline 4 | from PIL import Image, ImageFilter 5 | import numpy as np 6 | import cv2 7 | 8 | from .pipeline_tools import encode_images 9 | 10 | condition_dict = { 11 | "depth": 0, 12 | "canny": 1, 13 | "subject": 4, 14 | "coloring": 6, 15 | "deblurring": 7, 16 | "depth_pred": 8, 17 | "fill": 9, 18 | "sr": 10, 19 | } 20 | 21 | 22 | class Condition(object): 23 | def __init__( 24 | self, 25 | condition_type: str, 26 | raw_img: Union[Image.Image, torch.Tensor] = None, 27 | condition: Union[Image.Image, torch.Tensor] = None, 28 | mask=None, 29 | position_delta=None, 30 | ) -> None: 31 | self.condition_type = condition_type 32 | assert raw_img is not None or condition is not None 33 | if raw_img is not None: 34 | self.condition = self.get_condition(condition_type, raw_img) 35 | else: 36 | self.condition = condition 37 | self.position_delta = position_delta 38 | # TODO: Add mask support 39 | assert mask is None, "Mask not supported yet" 40 | 41 | def get_condition( 42 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] 43 | ) -> Union[Image.Image, torch.Tensor]: 44 | """ 45 | Returns the condition image. 46 | """ 47 | if condition_type == "depth": 48 | from transformers import pipeline 49 | 50 | depth_pipe = pipeline( 51 | task="depth-estimation", 52 | model="LiheYoung/depth-anything-small-hf", 53 | device="cuda", 54 | ) 55 | source_image = raw_img.convert("RGB") 56 | condition_img = depth_pipe(source_image)["depth"].convert("RGB") 57 | return condition_img 58 | elif condition_type == "canny": 59 | img = np.array(raw_img) 60 | edges = cv2.Canny(img, 100, 200) 61 | edges = Image.fromarray(edges).convert("RGB") 62 | return edges 63 | elif condition_type == "subject": 64 | return raw_img 65 | elif condition_type == "coloring": 66 | return raw_img.convert("L").convert("RGB") 67 | elif condition_type == "deblurring": 68 | condition_image = ( 69 | raw_img.convert("RGB") 70 | .filter(ImageFilter.GaussianBlur(10)) 71 | .convert("RGB") 72 | ) 73 | return condition_image 74 | elif condition_type == "fill": 75 | return raw_img.convert("RGB") 76 | return self.condition 77 | 78 | @property 79 | def type_id(self) -> int: 80 | """ 81 | Returns the type id of the condition. 82 | """ 83 | return condition_dict[self.condition_type] 84 | 85 | @classmethod 86 | def get_type_id(cls, condition_type: str) -> int: 87 | """ 88 | Returns the type id of the condition. 89 | """ 90 | return condition_dict[condition_type] 91 | 92 | def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]: 93 | """ 94 | Encodes the condition into tokens, ids and type_id. 95 | """ 96 | if self.condition_type in [ 97 | "depth", 98 | "canny", 99 | "subject", 100 | "coloring", 101 | "deblurring", 102 | "depth_pred", 103 | "fill", 104 | "sr", 105 | ]: 106 | tokens, ids = encode_images(pipe, self.condition) 107 | else: 108 | raise NotImplementedError( 109 | f"Condition type {self.condition_type} not implemented" 110 | ) 111 | if self.position_delta is None and self.condition_type == "subject": 112 | self.position_delta = [0, -self.condition.size[0] // 16] 113 | if self.position_delta is not None: 114 | ids[:, 1] += self.position_delta[0] 115 | ids[:, 2] += self.position_delta[1] 116 | type_id = torch.ones_like(ids[:, :1]) * self.type_id 117 | return tokens, ids, type_id 118 | -------------------------------------------------------------------------------- /train/src/flux/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml, os 3 | from diffusers.pipelines import FluxPipeline 4 | from typing import List, Union, Optional, Dict, Any, Callable 5 | from .transformer import tranformer_forward 6 | from .condition import Condition 7 | 8 | from diffusers.pipelines.flux.pipeline_flux import ( 9 | FluxPipelineOutput, 10 | calculate_shift, 11 | retrieve_timesteps, 12 | np, 13 | ) 14 | 15 | 16 | def get_config(config_path: str = None): 17 | config_path = config_path or os.environ.get("XFL_CONFIG") 18 | if not config_path: 19 | return {} 20 | with open(config_path, "r") as f: 21 | config = yaml.safe_load(f) 22 | return config 23 | 24 | 25 | def prepare_params( 26 | prompt: Union[str, List[str]] = None, 27 | prompt_2: Optional[Union[str, List[str]]] = None, 28 | height: Optional[int] = 512, 29 | width: Optional[int] = 512, 30 | num_inference_steps: int = 28, 31 | timesteps: List[int] = None, 32 | guidance_scale: float = 3.5, 33 | num_images_per_prompt: Optional[int] = 1, 34 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 35 | latents: Optional[torch.FloatTensor] = None, 36 | prompt_embeds: Optional[torch.FloatTensor] = None, 37 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 38 | output_type: Optional[str] = "pil", 39 | return_dict: bool = True, 40 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 41 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 42 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 43 | max_sequence_length: int = 512, 44 | **kwargs: dict, 45 | ): 46 | return ( 47 | prompt, 48 | prompt_2, 49 | height, 50 | width, 51 | num_inference_steps, 52 | timesteps, 53 | guidance_scale, 54 | num_images_per_prompt, 55 | generator, 56 | latents, 57 | prompt_embeds, 58 | pooled_prompt_embeds, 59 | output_type, 60 | return_dict, 61 | joint_attention_kwargs, 62 | callback_on_step_end, 63 | callback_on_step_end_tensor_inputs, 64 | max_sequence_length, 65 | ) 66 | 67 | 68 | def seed_everything(seed: int = 42): 69 | torch.backends.cudnn.deterministic = True 70 | torch.manual_seed(seed) 71 | np.random.seed(seed) 72 | 73 | 74 | @torch.no_grad() 75 | def generate( 76 | pipeline: FluxPipeline, 77 | conditions: List[Condition] = None, 78 | config_path: str = None, 79 | model_config: Optional[Dict[str, Any]] = {}, 80 | condition_scale: float = 1.0, 81 | default_lora: bool = False, 82 | **params: dict, 83 | ): 84 | model_config = model_config or get_config(config_path).get("model", {}) 85 | if condition_scale != 1: 86 | for name, module in pipeline.transformer.named_modules(): 87 | if not name.endswith(".attn"): 88 | continue 89 | module.c_factor = torch.ones(1, 1) * condition_scale 90 | 91 | self = pipeline 92 | ( 93 | prompt, 94 | prompt_2, 95 | height, 96 | width, 97 | num_inference_steps, 98 | timesteps, 99 | guidance_scale, 100 | num_images_per_prompt, 101 | generator, 102 | latents, 103 | prompt_embeds, 104 | pooled_prompt_embeds, 105 | output_type, 106 | return_dict, 107 | joint_attention_kwargs, 108 | callback_on_step_end, 109 | callback_on_step_end_tensor_inputs, 110 | max_sequence_length, 111 | ) = prepare_params(**params) 112 | 113 | height = height or self.default_sample_size * self.vae_scale_factor 114 | width = width or self.default_sample_size * self.vae_scale_factor 115 | 116 | # 1. Check inputs. Raise error if not correct 117 | self.check_inputs( 118 | prompt, 119 | prompt_2, 120 | height, 121 | width, 122 | prompt_embeds=prompt_embeds, 123 | pooled_prompt_embeds=pooled_prompt_embeds, 124 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 125 | max_sequence_length=max_sequence_length, 126 | ) 127 | 128 | self._guidance_scale = guidance_scale 129 | self._joint_attention_kwargs = joint_attention_kwargs 130 | self._interrupt = False 131 | 132 | # 2. Define call parameters 133 | if prompt is not None and isinstance(prompt, str): 134 | batch_size = 1 135 | elif prompt is not None and isinstance(prompt, list): 136 | batch_size = len(prompt) 137 | else: 138 | batch_size = prompt_embeds.shape[0] 139 | 140 | device = self._execution_device 141 | 142 | lora_scale = ( 143 | self.joint_attention_kwargs.get("scale", None) 144 | if self.joint_attention_kwargs is not None 145 | else None 146 | ) 147 | ( 148 | prompt_embeds, 149 | pooled_prompt_embeds, 150 | text_ids, 151 | ) = self.encode_prompt( 152 | prompt=prompt, 153 | prompt_2=prompt_2, 154 | prompt_embeds=prompt_embeds, 155 | pooled_prompt_embeds=pooled_prompt_embeds, 156 | device=device, 157 | num_images_per_prompt=num_images_per_prompt, 158 | max_sequence_length=max_sequence_length, 159 | lora_scale=lora_scale, 160 | ) 161 | 162 | # 4. Prepare latent variables 163 | num_channels_latents = self.transformer.config.in_channels // 4 164 | latents, latent_image_ids = self.prepare_latents( 165 | batch_size * num_images_per_prompt, 166 | num_channels_latents, 167 | height, 168 | width, 169 | prompt_embeds.dtype, 170 | device, 171 | generator, 172 | latents, 173 | ) 174 | 175 | # 4.1. Prepare conditions 176 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3)) 177 | use_condition = conditions is not None or [] 178 | if use_condition: 179 | assert len(conditions) <= 1, "Only one condition is supported for now." 180 | if not default_lora: 181 | pipeline.set_adapters(conditions[0].condition_type) 182 | for condition in conditions: 183 | tokens, ids, type_id = condition.encode(self) 184 | condition_latents.append(tokens) # [batch_size, token_n, token_dim] 185 | condition_ids.append(ids) # [token_n, id_dim(3)] 186 | condition_type_ids.append(type_id) # [token_n, 1] 187 | condition_latents = torch.cat(condition_latents, dim=1) 188 | condition_ids = torch.cat(condition_ids, dim=0) 189 | condition_type_ids = torch.cat(condition_type_ids, dim=0) 190 | 191 | # 5. Prepare timesteps 192 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 193 | image_seq_len = latents.shape[1] 194 | mu = calculate_shift( 195 | image_seq_len, 196 | self.scheduler.config.base_image_seq_len, 197 | self.scheduler.config.max_image_seq_len, 198 | self.scheduler.config.base_shift, 199 | self.scheduler.config.max_shift, 200 | ) 201 | timesteps, num_inference_steps = retrieve_timesteps( 202 | self.scheduler, 203 | num_inference_steps, 204 | device, 205 | timesteps, 206 | sigmas, 207 | mu=mu, 208 | ) 209 | num_warmup_steps = max( 210 | len(timesteps) - num_inference_steps * self.scheduler.order, 0 211 | ) 212 | self._num_timesteps = len(timesteps) 213 | 214 | # 6. Denoising loop 215 | with self.progress_bar(total=num_inference_steps) as progress_bar: 216 | for i, t in enumerate(timesteps): 217 | if self.interrupt: 218 | continue 219 | 220 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 221 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 222 | 223 | # handle guidance 224 | if self.transformer.config.guidance_embeds: 225 | guidance = torch.tensor([guidance_scale], device=device) 226 | guidance = guidance.expand(latents.shape[0]) 227 | else: 228 | guidance = None 229 | noise_pred = tranformer_forward( 230 | self.transformer, 231 | model_config=model_config, 232 | # Inputs of the condition (new feature) 233 | condition_latents=condition_latents if use_condition else None, 234 | condition_ids=condition_ids if use_condition else None, 235 | condition_type_ids=condition_type_ids if use_condition else None, 236 | # Inputs to the original transformer 237 | hidden_states=latents, 238 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 239 | timestep=timestep / 1000, 240 | guidance=guidance, 241 | pooled_projections=pooled_prompt_embeds, 242 | encoder_hidden_states=prompt_embeds, 243 | txt_ids=text_ids, 244 | img_ids=latent_image_ids, 245 | joint_attention_kwargs=self.joint_attention_kwargs, 246 | return_dict=False, 247 | )[0] 248 | 249 | # compute the previous noisy sample x_t -> x_t-1 250 | latents_dtype = latents.dtype 251 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 252 | 253 | if latents.dtype != latents_dtype: 254 | if torch.backends.mps.is_available(): 255 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 256 | latents = latents.to(latents_dtype) 257 | 258 | if callback_on_step_end is not None: 259 | callback_kwargs = {} 260 | for k in callback_on_step_end_tensor_inputs: 261 | callback_kwargs[k] = locals()[k] 262 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 263 | 264 | latents = callback_outputs.pop("latents", latents) 265 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 266 | 267 | # call the callback, if provided 268 | if i == len(timesteps) - 1 or ( 269 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 270 | ): 271 | progress_bar.update() 272 | 273 | if output_type == "latent": 274 | image = latents 275 | 276 | else: 277 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 278 | latents = ( 279 | latents / self.vae.config.scaling_factor 280 | ) + self.vae.config.shift_factor 281 | image = self.vae.decode(latents, return_dict=False)[0] 282 | image = self.image_processor.postprocess(image, output_type=output_type) 283 | 284 | # Offload all models 285 | self.maybe_free_model_hooks() 286 | 287 | if condition_scale != 1: 288 | for name, module in pipeline.transformer.named_modules(): 289 | if not name.endswith(".attn"): 290 | continue 291 | del module.c_factor 292 | 293 | if not return_dict: 294 | return (image,) 295 | 296 | return FluxPipelineOutput(images=image) 297 | -------------------------------------------------------------------------------- /train/src/flux/lora_controller.py: -------------------------------------------------------------------------------- 1 | from peft.tuners.tuners_utils import BaseTunerLayer 2 | from typing import List, Any, Optional, Type 3 | 4 | 5 | class enable_lora: 6 | def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: 7 | self.activated: bool = activated 8 | if activated: 9 | return 10 | self.lora_modules: List[BaseTunerLayer] = [ 11 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 12 | ] 13 | self.scales = [ 14 | { 15 | active_adapter: lora_module.scaling[active_adapter] 16 | for active_adapter in lora_module.active_adapters 17 | } 18 | for lora_module in self.lora_modules 19 | ] 20 | 21 | def __enter__(self) -> None: 22 | if self.activated: 23 | return 24 | 25 | for lora_module in self.lora_modules: 26 | if not isinstance(lora_module, BaseTunerLayer): 27 | continue 28 | lora_module.scale_layer(0) 29 | 30 | def __exit__( 31 | self, 32 | exc_type: Optional[Type[BaseException]], 33 | exc_val: Optional[BaseException], 34 | exc_tb: Optional[Any], 35 | ) -> None: 36 | if self.activated: 37 | return 38 | for i, lora_module in enumerate(self.lora_modules): 39 | if not isinstance(lora_module, BaseTunerLayer): 40 | continue 41 | for active_adapter in lora_module.active_adapters: 42 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 43 | 44 | 45 | class set_lora_scale: 46 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: 47 | self.lora_modules: List[BaseTunerLayer] = [ 48 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 49 | ] 50 | self.scales = [ 51 | { 52 | active_adapter: lora_module.scaling[active_adapter] 53 | for active_adapter in lora_module.active_adapters 54 | } 55 | for lora_module in self.lora_modules 56 | ] 57 | self.scale = scale 58 | 59 | def __enter__(self) -> None: 60 | for lora_module in self.lora_modules: 61 | if not isinstance(lora_module, BaseTunerLayer): 62 | continue 63 | lora_module.scale_layer(self.scale) 64 | 65 | def __exit__( 66 | self, 67 | exc_type: Optional[Type[BaseException]], 68 | exc_val: Optional[BaseException], 69 | exc_tb: Optional[Any], 70 | ) -> None: 71 | for i, lora_module in enumerate(self.lora_modules): 72 | if not isinstance(lora_module, BaseTunerLayer): 73 | continue 74 | for active_adapter in lora_module.active_adapters: 75 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 76 | -------------------------------------------------------------------------------- /train/src/flux/pipeline_tools.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines import FluxPipeline, FluxFillPipeline 2 | from diffusers.utils import logging 3 | from diffusers.pipelines.flux.pipeline_flux import logger 4 | from torch import Tensor 5 | import torch 6 | 7 | 8 | def encode_images(pipeline: FluxPipeline, images: Tensor): 9 | images = pipeline.image_processor.preprocess(images) 10 | images = images.to(pipeline.device).to(pipeline.dtype) 11 | images = pipeline.vae.encode(images).latent_dist.sample() 12 | images = ( 13 | images - pipeline.vae.config.shift_factor 14 | ) * pipeline.vae.config.scaling_factor 15 | images_tokens = pipeline._pack_latents(images, *images.shape) 16 | images_ids = pipeline._prepare_latent_image_ids( 17 | images.shape[0], 18 | images.shape[2], 19 | images.shape[3], 20 | pipeline.device, 21 | pipeline.dtype, 22 | ) 23 | if images_tokens.shape[1] != images_ids.shape[0]: 24 | images_ids = pipeline._prepare_latent_image_ids( 25 | images.shape[0], 26 | images.shape[2] // 2, 27 | images.shape[3] // 2, 28 | pipeline.device, 29 | pipeline.dtype, 30 | ) 31 | return images_tokens, images_ids 32 | 33 | 34 | 35 | def encode_images_fill(pipeline: FluxFillPipeline, image: Tensor, mask_image: Tensor, dtype: torch.dtype, device: str): 36 | images_tokens, images_ids = encode_images(pipeline, image.clone().detach()) 37 | height, width = image.shape[-2:] 38 | # print(f"height: {height}, width: {width}") 39 | image = pipeline.image_processor.preprocess(image, height=height, width=width) 40 | mask_image = pipeline.mask_processor.preprocess(mask_image, height=height, width=width) 41 | 42 | masked_image = image * (1 - mask_image) 43 | masked_image = masked_image.to(device=device, dtype=dtype) 44 | 45 | num_channels_latents = pipeline.vae.config.latent_channels 46 | height, width = image.shape[-2:] 47 | device = pipeline._execution_device 48 | mask, masked_image_latents = pipeline.prepare_mask_latents( 49 | mask_image, 50 | masked_image, 51 | image.shape[0], 52 | num_channels_latents, 53 | 1, 54 | height, 55 | width, 56 | dtype, 57 | device, 58 | None, 59 | ) 60 | masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) 61 | return images_tokens, masked_image_latents, images_ids 62 | 63 | 64 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): 65 | # Turn off warnings (CLIP overflow) 66 | logger.setLevel(logging.ERROR) 67 | ( 68 | prompt_embeds, 69 | pooled_prompt_embeds, 70 | text_ids, 71 | ) = pipeline.encode_prompt( 72 | prompt=prompts, 73 | prompt_2=None, 74 | prompt_embeds=None, 75 | pooled_prompt_embeds=None, 76 | device=pipeline.device, 77 | num_images_per_prompt=1, 78 | max_sequence_length=max_sequence_length, 79 | lora_scale=None, 80 | ) 81 | # Turn on warnings 82 | logger.setLevel(logging.WARNING) 83 | return prompt_embeds, pooled_prompt_embeds, text_ids 84 | -------------------------------------------------------------------------------- /train/src/flux/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines import FluxPipeline 3 | from typing import List, Union, Optional, Dict, Any, Callable 4 | from .block import block_forward, single_block_forward 5 | from .lora_controller import enable_lora 6 | from diffusers.models.transformers.transformer_flux import ( 7 | FluxTransformer2DModel, 8 | Transformer2DModelOutput, 9 | USE_PEFT_BACKEND, 10 | is_torch_version, 11 | scale_lora_layers, 12 | unscale_lora_layers, 13 | logger, 14 | ) 15 | import numpy as np 16 | 17 | 18 | def prepare_params( 19 | hidden_states: torch.Tensor, 20 | encoder_hidden_states: torch.Tensor = None, 21 | pooled_projections: torch.Tensor = None, 22 | timestep: torch.LongTensor = None, 23 | img_ids: torch.Tensor = None, 24 | txt_ids: torch.Tensor = None, 25 | guidance: torch.Tensor = None, 26 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 27 | controlnet_block_samples=None, 28 | controlnet_single_block_samples=None, 29 | return_dict: bool = True, 30 | **kwargs: dict, 31 | ): 32 | return ( 33 | hidden_states, 34 | encoder_hidden_states, 35 | pooled_projections, 36 | timestep, 37 | img_ids, 38 | txt_ids, 39 | guidance, 40 | joint_attention_kwargs, 41 | controlnet_block_samples, 42 | controlnet_single_block_samples, 43 | return_dict, 44 | ) 45 | 46 | 47 | def tranformer_forward( 48 | transformer: FluxTransformer2DModel, 49 | condition_latents: torch.Tensor, 50 | condition_ids: torch.Tensor, 51 | condition_type_ids: torch.Tensor, 52 | model_config: Optional[Dict[str, Any]] = {}, 53 | c_t=0, 54 | **params: dict, 55 | ): 56 | self = transformer 57 | use_condition = condition_latents is not None 58 | 59 | ( 60 | hidden_states, 61 | encoder_hidden_states, 62 | pooled_projections, 63 | timestep, 64 | img_ids, 65 | txt_ids, 66 | guidance, 67 | joint_attention_kwargs, 68 | controlnet_block_samples, 69 | controlnet_single_block_samples, 70 | return_dict, 71 | ) = prepare_params(**params) 72 | 73 | if joint_attention_kwargs is not None: 74 | joint_attention_kwargs = joint_attention_kwargs.copy() 75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 76 | else: 77 | lora_scale = 1.0 78 | 79 | if USE_PEFT_BACKEND: 80 | # weight the lora layers by setting `lora_scale` for each PEFT layer 81 | scale_lora_layers(self, lora_scale) 82 | else: 83 | if ( 84 | joint_attention_kwargs is not None 85 | and joint_attention_kwargs.get("scale", None) is not None 86 | ): 87 | logger.warning( 88 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 89 | ) 90 | 91 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): 92 | hidden_states = self.x_embedder(hidden_states) 93 | condition_latents = self.x_embedder(condition_latents) if use_condition else None 94 | 95 | timestep = timestep.to(hidden_states.dtype) * 1000 96 | 97 | if guidance is not None: 98 | guidance = guidance.to(hidden_states.dtype) * 1000 99 | else: 100 | guidance = None 101 | 102 | temb = ( 103 | self.time_text_embed(timestep, pooled_projections) 104 | if guidance is None 105 | else self.time_text_embed(timestep, guidance, pooled_projections) 106 | ) 107 | 108 | cond_temb = ( 109 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) 110 | if guidance is None 111 | else self.time_text_embed( 112 | torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections 113 | ) 114 | ) 115 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 116 | 117 | if txt_ids.ndim == 3: 118 | logger.warning( 119 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 120 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 121 | ) 122 | txt_ids = txt_ids[0] 123 | if img_ids.ndim == 3: 124 | logger.warning( 125 | "Passing `img_ids` 3d torch.Tensor is deprecated." 126 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 127 | ) 128 | img_ids = img_ids[0] 129 | 130 | ids = torch.cat((txt_ids, img_ids), dim=0) 131 | image_rotary_emb = self.pos_embed(ids) 132 | if use_condition: 133 | # condition_ids[:, :1] = condition_type_ids 134 | cond_rotary_emb = self.pos_embed(condition_ids) 135 | 136 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) 137 | 138 | for index_block, block in enumerate(self.transformer_blocks): 139 | if self.training and self.gradient_checkpointing: 140 | ckpt_kwargs: Dict[str, Any] = ( 141 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 142 | ) 143 | encoder_hidden_states, hidden_states, condition_latents = ( 144 | torch.utils.checkpoint.checkpoint( 145 | block_forward, 146 | self=block, 147 | model_config=model_config, 148 | hidden_states=hidden_states, 149 | encoder_hidden_states=encoder_hidden_states, 150 | condition_latents=condition_latents if use_condition else None, 151 | temb=temb, 152 | cond_temb=cond_temb if use_condition else None, 153 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 154 | image_rotary_emb=image_rotary_emb, 155 | **ckpt_kwargs, 156 | ) 157 | ) 158 | 159 | else: 160 | encoder_hidden_states, hidden_states, condition_latents = block_forward( 161 | block, 162 | model_config=model_config, 163 | hidden_states=hidden_states, 164 | encoder_hidden_states=encoder_hidden_states, 165 | condition_latents=condition_latents if use_condition else None, 166 | temb=temb, 167 | cond_temb=cond_temb if use_condition else None, 168 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 169 | image_rotary_emb=image_rotary_emb, 170 | ) 171 | 172 | # controlnet residual 173 | if controlnet_block_samples is not None: 174 | interval_control = len(self.transformer_blocks) / len( 175 | controlnet_block_samples 176 | ) 177 | interval_control = int(np.ceil(interval_control)) 178 | hidden_states = ( 179 | hidden_states 180 | + controlnet_block_samples[index_block // interval_control] 181 | ) 182 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 183 | 184 | for index_block, block in enumerate(self.single_transformer_blocks): 185 | if self.training and self.gradient_checkpointing: 186 | ckpt_kwargs: Dict[str, Any] = ( 187 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 188 | ) 189 | result = torch.utils.checkpoint.checkpoint( 190 | single_block_forward, 191 | self=block, 192 | model_config=model_config, 193 | hidden_states=hidden_states, 194 | temb=temb, 195 | image_rotary_emb=image_rotary_emb, 196 | **( 197 | { 198 | "condition_latents": condition_latents, 199 | "cond_temb": cond_temb, 200 | "cond_rotary_emb": cond_rotary_emb, 201 | } 202 | if use_condition 203 | else {} 204 | ), 205 | **ckpt_kwargs, 206 | ) 207 | 208 | else: 209 | result = single_block_forward( 210 | block, 211 | model_config=model_config, 212 | hidden_states=hidden_states, 213 | temb=temb, 214 | image_rotary_emb=image_rotary_emb, 215 | **( 216 | { 217 | "condition_latents": condition_latents, 218 | "cond_temb": cond_temb, 219 | "cond_rotary_emb": cond_rotary_emb, 220 | } 221 | if use_condition 222 | else {} 223 | ), 224 | ) 225 | if use_condition: 226 | hidden_states, condition_latents = result 227 | else: 228 | hidden_states = result 229 | 230 | # controlnet residual 231 | if controlnet_single_block_samples is not None: 232 | interval_control = len(self.single_transformer_blocks) / len( 233 | controlnet_single_block_samples 234 | ) 235 | interval_control = int(np.ceil(interval_control)) 236 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 237 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 238 | + controlnet_single_block_samples[index_block // interval_control] 239 | ) 240 | 241 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 242 | 243 | hidden_states = self.norm_out(hidden_states, temb) 244 | output = self.proj_out(hidden_states) 245 | 246 | if USE_PEFT_BACKEND: 247 | # remove `lora_scale` from each PEFT layer 248 | unscale_lora_layers(self, lora_scale) 249 | 250 | if not return_dict: 251 | return (output,) 252 | return Transformer2DModelOutput(sample=output) 253 | -------------------------------------------------------------------------------- /train/src/train/callbacks.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from PIL import Image, ImageFilter, ImageDraw 3 | import numpy as np 4 | from transformers import pipeline 5 | # import cv2 6 | import torch 7 | import os 8 | from datetime import datetime 9 | 10 | try: 11 | import wandb 12 | except ImportError: 13 | wandb = None 14 | 15 | from ..flux.condition import Condition 16 | from ..flux.generate import generate 17 | 18 | 19 | class TrainingCallback(L.Callback): 20 | def __init__(self, run_name, training_config: dict = {}): 21 | self.run_name, self.training_config = run_name, training_config 22 | 23 | self.print_every_n_steps = training_config.get("print_every_n_steps", 10) 24 | self.save_interval = training_config.get("save_interval", 1000) 25 | self.sample_interval = training_config.get("sample_interval", 1000) 26 | self.save_path = training_config.get("save_path", "./output") 27 | 28 | self.wandb_config = training_config.get("wandb", None) 29 | self.use_wandb = ( 30 | wandb is not None and os.environ.get("WANDB_API_KEY") is not None 31 | ) 32 | 33 | self.total_steps = 0 34 | 35 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 36 | gradient_size = 0 37 | max_gradient_size = 0 38 | count = 0 39 | for _, param in pl_module.named_parameters(): 40 | if param.grad is not None: 41 | gradient_size += param.grad.norm(2).item() 42 | max_gradient_size = max(max_gradient_size, param.grad.norm(2).item()) 43 | count += 1 44 | if count > 0: 45 | gradient_size /= count 46 | 47 | self.total_steps += 1 48 | 49 | # Print training progress every n steps 50 | if self.use_wandb: 51 | report_dict = { 52 | "steps": batch_idx, 53 | "steps": self.total_steps, 54 | "epoch": trainer.current_epoch, 55 | "gradient_size": gradient_size, 56 | } 57 | loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches 58 | report_dict["loss"] = loss_value 59 | report_dict["t"] = pl_module.last_t 60 | wandb.log(report_dict) 61 | 62 | if self.total_steps % self.print_every_n_steps == 0: 63 | print( 64 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}" 65 | ) 66 | 67 | # Save LoRA weights at specified intervals 68 | if self.total_steps % self.save_interval == 0: 69 | print( 70 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights" 71 | ) 72 | pl_module.save_lora( 73 | f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}" 74 | ) 75 | 76 | # Generate and save a sample image at specified intervals 77 | if self.total_steps % self.sample_interval == 0: 78 | print( 79 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample" 80 | ) 81 | self.generate_a_sample( 82 | trainer, 83 | pl_module, 84 | f"{self.save_path}/{self.run_name}", 85 | f"lora_{self.total_steps}", 86 | batch["condition_type"][ 87 | 0 88 | ], # Use the condition type from the current batch 89 | ) 90 | 91 | @torch.no_grad() 92 | def generate_a_sample( 93 | self, 94 | trainer, 95 | pl_module, 96 | save_path, 97 | file_name, 98 | condition_type, 99 | ): 100 | 101 | file_name = [ 102 | "assets/coffee.png", 103 | "assets/coffee.png", 104 | "assets/coffee.png", 105 | "assets/coffee.png", 106 | "assets/clock.jpg", 107 | "assets/book.jpg", 108 | "assets/monalisa.jpg", 109 | "assets/oranges.jpg", 110 | "assets/penguin.jpg", 111 | "assets/vase.jpg", 112 | "assets/room_corner.jpg", 113 | ] 114 | 115 | test_instruction = [ 116 | "Make the image look like it's from an ancient Egyptian mural.", 117 | 'get rid of the coffee bean.', 118 | 'remove the cup.', 119 | "Change it to look like it's in the style of an impasto painting.", 120 | "Make this photo look like a comic book", 121 | "Give this the look of a traditional Japanese woodblock print.", 122 | 'delete the woman', 123 | "Change the image into a watercolor painting.", 124 | "Make it black and white.", 125 | "Make it pop art.", 126 | 'the sofa is leather, and the wall is black', 127 | ] 128 | 129 | pl_module.flux_fill_pipe.transformer.eval() 130 | for i, name in enumerate(file_name): 131 | test_image = Image.open(name) 132 | combined_image = Image.new('RGB', (test_image.size[0] * 2, test_image.size[1])) 133 | combined_image.paste(test_image, (0, 0)) 134 | combined_image.paste(test_image, (test_image.size[0], 0)) 135 | 136 | mask = Image.new('L', combined_image.size, 0) 137 | draw = ImageDraw.Draw(mask) 138 | draw.rectangle([test_image.size[0], 0, test_image.size[0] * 2, test_image.size[1]], fill=255) 139 | if condition_type == 'edit_n': 140 | prompt_ = "A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left. \n " + test_instruction[i] 141 | else: 142 | prompt_ = "A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but " + test_instruction[i] 143 | 144 | image = pl_module.flux_fill_pipe( 145 | prompt=prompt_, 146 | image=combined_image, 147 | height=512, 148 | width=1024, 149 | mask_image=mask, 150 | guidance_scale=50, 151 | num_inference_steps=50, 152 | max_sequence_length=512, 153 | generator=torch.Generator("cpu").manual_seed(666) 154 | ).images[0] 155 | image.save(os.path.join(save_path, f'flux-fill-test-{self.total_steps}-{i}-{condition_type}.jpg')) 156 | 157 | pl_module.flux_fill_pipe.transformer.train() 158 | -------------------------------------------------------------------------------- /train/src/train/data.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageDraw 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as T 5 | import random 6 | from io import BytesIO 7 | import glob 8 | from tqdm import tqdm 9 | 10 | class EditDataset_with_Omini(Dataset): 11 | def __init__( 12 | self, 13 | magic_dataset, 14 | omni_dataset, 15 | condition_size: int = 512, 16 | target_size: int = 512, 17 | drop_text_prob: float = 0.1, 18 | return_pil_image: bool = False, 19 | crop_the_noise: bool = True, 20 | ): 21 | self.dataset = [magic_dataset['train'], magic_dataset['dev'], omni_dataset] 22 | 23 | from collections import Counter 24 | tasks = omni_dataset['task'] 25 | task_counts = Counter(tasks) 26 | print("\n task type statistic:") 27 | for task, count in task_counts.items(): 28 | print(f"{task}: {count} data ({count/len(tasks)*100:.2f}%)") 29 | 30 | self.condition_size = condition_size 31 | self.target_size = target_size 32 | self.drop_text_prob = drop_text_prob 33 | self.return_pil_image = return_pil_image 34 | self.crop_the_noise = crop_the_noise 35 | self.to_tensor = T.ToTensor() 36 | 37 | def __len__(self): 38 | return len(self.dataset[0]) + len(self.dataset[1]) + len(self.dataset[2]) 39 | 40 | 41 | def __getitem__(self, idx): 42 | split = 0 if idx < len(self.dataset[0]) else (1 if idx < len(self.dataset[0]) + len(self.dataset[1]) else 2) 43 | 44 | if idx >= len(self.dataset[0]) + len(self.dataset[1]): 45 | idx -= len(self.dataset[0]) + len(self.dataset[1]) 46 | elif idx >= len(self.dataset[0]): 47 | idx -= len(self.dataset[0]) 48 | 49 | image = self.dataset[split][idx]["source_img" if split != 2 else "src_img"] 50 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but ' + (self.dataset[split][idx]["instruction"] if split != 2 else random.choice(self.dataset[split][idx]["edited_prompt_list"])) 51 | edited_image = self.dataset[split][idx]["target_img" if split != 2 else "edited_img"] 52 | 53 | if self.crop_the_noise and split <= 1: 54 | image = image.crop((0, 0, image.width, image.height - image.height // 32)) 55 | edited_image = edited_image.crop((0, 0, edited_image.width, edited_image.height - edited_image.height // 32)) 56 | 57 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB") 58 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB") 59 | 60 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size)) 61 | combined_image.paste(image, (0, 0)) 62 | combined_image.paste(edited_image, (self.condition_size, 0)) 63 | 64 | 65 | 66 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0) 67 | draw = ImageDraw.Draw(mask) 68 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 69 | 70 | mask_combined_image = combined_image.copy() 71 | draw = ImageDraw.Draw(mask_combined_image) 72 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 73 | 74 | if random.random() < self.drop_text_prob: 75 | instruction = " " 76 | 77 | return { 78 | "image": self.to_tensor(combined_image), 79 | "condition": self.to_tensor(mask), 80 | "condition_type": "edit", 81 | "description": instruction, 82 | "position_delta": np.array([0, 0]), 83 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}), 84 | } 85 | 86 | class OminiDataset(Dataset): 87 | def __init__( 88 | self, 89 | base_dataset, 90 | condition_size: int = 512, 91 | target_size: int = 512, 92 | drop_text_prob: float = 0.1, 93 | return_pil_image: bool = False, 94 | specific_task: list = None, 95 | ): 96 | self.base_dataset = base_dataset['train'] 97 | if specific_task is not None: 98 | self.specific_task = specific_task 99 | task_indices = [i for i, task in enumerate(self.base_dataset['task']) if task in self.specific_task] 100 | task_set = set([task for task in self.base_dataset['task']]) 101 | ori_len = len(self.base_dataset) 102 | self.base_dataset = self.base_dataset.select(task_indices) 103 | print(specific_task, len(self.base_dataset), ori_len) 104 | print(task_set) 105 | 106 | self.condition_size = condition_size 107 | self.target_size = target_size 108 | self.drop_text_prob = drop_text_prob 109 | self.return_pil_image = return_pil_image 110 | self.to_tensor = T.ToTensor() 111 | 112 | def __len__(self): 113 | return len(self.base_dataset) 114 | 115 | def __getitem__(self, idx): 116 | image = self.base_dataset[idx]["src_img"] 117 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but ' + random.choice(self.base_dataset[idx]["edited_prompt_list"]) 118 | 119 | edited_image = self.base_dataset[idx]["edited_img"] 120 | 121 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB") 122 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB") 123 | 124 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size)) 125 | combined_image.paste(image, (0, 0)) 126 | combined_image.paste(edited_image, (self.condition_size, 0)) 127 | 128 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0) 129 | draw = ImageDraw.Draw(mask) 130 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 131 | 132 | mask_combined_image = combined_image.copy() 133 | draw = ImageDraw.Draw(mask_combined_image) 134 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 135 | 136 | if random.random() < self.drop_text_prob: 137 | instruction = "" 138 | 139 | return { 140 | "image": self.to_tensor(combined_image), 141 | "condition": self.to_tensor(mask), 142 | "condition_type": "edit", 143 | "description": instruction, 144 | "position_delta": np.array([0, 0]), 145 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}), 146 | } 147 | 148 | 149 | class EditDataset_mask(Dataset): 150 | def __init__( 151 | self, 152 | base_dataset, 153 | condition_size: int = 512, 154 | target_size: int = 512, 155 | drop_text_prob: float = 0.1, 156 | return_pil_image: bool = False, 157 | crop_the_noise: bool = True, 158 | ): 159 | print('THIS IS MAGICBRUSH!') 160 | self.base_dataset = base_dataset 161 | self.condition_size = condition_size 162 | self.target_size = target_size 163 | self.drop_text_prob = drop_text_prob 164 | self.return_pil_image = return_pil_image 165 | self.crop_the_noise = crop_the_noise 166 | self.to_tensor = T.ToTensor() 167 | 168 | def __len__(self): 169 | return len(self.base_dataset['train']) + len(self.base_dataset['dev']) 170 | 171 | def rgba_to_01_mask(image_rgba: Image.Image, reverse: bool = False, return_type: str = "numpy"): 172 | """ 173 | Convert an RGBA image to a binary mask with values in the range [0, 1], where 0 represents transparent areas 174 | and 1 represents non-transparent areas. The resulting mask has a shape of (1, H, W). 175 | 176 | :param image_rgba: An RGBA image in PIL format. 177 | :param reverse: If True, reverse the mask, making transparent areas 1 and non-transparent areas 0. 178 | :param return_type: Specifies the return type. "numpy" returns a NumPy array, "PIL" returns a PIL Image. 179 | 180 | :return: The binary mask as a NumPy array or a PIL Image in RGB format. 181 | """ 182 | alpha_channel = np.array(image_rgba)[:, :, 3] 183 | image_bw = (alpha_channel != 255).astype(np.uint8) 184 | if reverse: 185 | image_bw = 1 - image_bw 186 | mask = image_bw 187 | if return_type == "numpy": 188 | return mask 189 | else: # return PIL image 190 | mask = Image.fromarray(np.uint8(mask * 255) , 'L').convert('RGB') 191 | return mask 192 | 193 | def __getitem__(self, idx): 194 | split = 'train' if idx < len(self.base_dataset['train']) else 'dev' 195 | idx = idx - len(self.base_dataset['train']) if split == 'dev' else idx 196 | image = self.base_dataset[split][idx]["source_img"] 197 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left. \n ' + self.base_dataset[split][idx]["instruction"] 198 | edited_image = self.base_dataset[split][idx]["target_img"] 199 | 200 | if self.crop_the_noise: 201 | image = image.crop((0, 0, image.width, image.height - image.height // 32)) 202 | edited_image = edited_image.crop((0, 0, edited_image.width, edited_image.height - edited_image.height // 32)) 203 | 204 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB") 205 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB") 206 | 207 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size)) 208 | combined_image.paste(image, (0, 0)) 209 | combined_image.paste(edited_image, (self.condition_size, 0)) 210 | 211 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0) 212 | draw = ImageDraw.Draw(mask) 213 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 214 | 215 | mask_combined_image = combined_image.copy() 216 | draw = ImageDraw.Draw(mask_combined_image) 217 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 218 | 219 | if random.random() < self.drop_text_prob: 220 | instruction = " \n " 221 | return { 222 | "image": self.to_tensor(combined_image), 223 | "condition": self.to_tensor(mask), 224 | "condition_type": "edit_n", 225 | "description": instruction, 226 | "position_delta": np.array([0, 0]), 227 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}), 228 | } 229 | 230 | class EditDataset(Dataset): 231 | def __init__( 232 | self, 233 | base_dataset, 234 | condition_size: int = 512, 235 | target_size: int = 512, 236 | drop_text_prob: float = 0.1, 237 | return_pil_image: bool = False, 238 | crop_the_noise: bool = True, 239 | ): 240 | print('THIS IS MAGICBRUSH!') 241 | self.base_dataset = base_dataset 242 | self.condition_size = condition_size 243 | self.target_size = target_size 244 | self.drop_text_prob = drop_text_prob 245 | self.return_pil_image = return_pil_image 246 | self.crop_the_noise = crop_the_noise 247 | self.to_tensor = T.ToTensor() 248 | 249 | def __len__(self): 250 | return len(self.base_dataset['train']) + len(self.base_dataset['dev']) 251 | 252 | def rgba_to_01_mask(image_rgba: Image.Image, reverse: bool = False, return_type: str = "numpy"): 253 | """ 254 | Convert an RGBA image to a binary mask with values in the range [0, 1], where 0 represents transparent areas 255 | and 1 represents non-transparent areas. The resulting mask has a shape of (1, H, W). 256 | 257 | :param image_rgba: An RGBA image in PIL format. 258 | :param reverse: If True, reverse the mask, making transparent areas 1 and non-transparent areas 0. 259 | :param return_type: Specifies the return type. "numpy" returns a NumPy array, "PIL" returns a PIL Image. 260 | 261 | :return: The binary mask as a NumPy array or a PIL Image in RGB format. 262 | """ 263 | alpha_channel = np.array(image_rgba)[:, :, 3] 264 | image_bw = (alpha_channel != 255).astype(np.uint8) 265 | if reverse: 266 | image_bw = 1 - image_bw 267 | mask = image_bw 268 | if return_type == "numpy": 269 | return mask 270 | else: # return PIL image 271 | mask = Image.fromarray(np.uint8(mask * 255) , 'L').convert('RGB') 272 | return mask 273 | 274 | def __getitem__(self, idx): 275 | split = 'train' if idx < len(self.base_dataset['train']) else 'dev' 276 | idx = idx - len(self.base_dataset['train']) if split == 'dev' else idx 277 | image = self.base_dataset[split][idx]["source_img"] 278 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but ' + self.base_dataset[split][idx]["instruction"] 279 | edited_image = self.base_dataset[split][idx]["target_img"] 280 | 281 | 282 | if self.crop_the_noise: 283 | image = image.crop((0, 0, image.width, image.height - image.height // 32)) 284 | edited_image = edited_image.crop((0, 0, edited_image.width, edited_image.height - edited_image.height // 32)) 285 | 286 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB") 287 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB") 288 | 289 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size)) 290 | combined_image.paste(image, (0, 0)) 291 | combined_image.paste(edited_image, (self.condition_size, 0)) 292 | 293 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0) 294 | draw = ImageDraw.Draw(mask) 295 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 296 | 297 | mask_combined_image = combined_image.copy() 298 | draw = ImageDraw.Draw(mask_combined_image) 299 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255) 300 | 301 | if random.random() < self.drop_text_prob: 302 | instruction = " " 303 | 304 | return { 305 | "image": self.to_tensor(combined_image), 306 | "condition": self.to_tensor(mask), 307 | "condition_type": "edit", 308 | "description": instruction, 309 | "position_delta": np.array([0, 0]), 310 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}), 311 | } 312 | -------------------------------------------------------------------------------- /train/src/train/model.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from diffusers.pipelines import FluxPipeline, FluxFillPipeline 3 | import torch 4 | from peft import LoraConfig, get_peft_model_state_dict 5 | import os 6 | import prodigyopt 7 | 8 | from ..flux.transformer import tranformer_forward 9 | from ..flux.condition import Condition 10 | from ..flux.pipeline_tools import encode_images, encode_images_fill, prepare_text_input 11 | 12 | 13 | class OminiModel(L.LightningModule): 14 | def __init__( 15 | self, 16 | flux_fill_id: str, 17 | lora_path: str = None, 18 | lora_config: dict = None, 19 | device: str = "cuda", 20 | dtype: torch.dtype = torch.bfloat16, 21 | model_config: dict = {}, 22 | optimizer_config: dict = None, 23 | gradient_checkpointing: bool = False, 24 | use_offset_noise: bool = False, 25 | ): 26 | # Initialize the LightningModule 27 | super().__init__() 28 | self.model_config = model_config 29 | 30 | self.optimizer_config = optimizer_config 31 | 32 | # Load the Flux pipeline 33 | self.flux_fill_pipe = FluxFillPipeline.from_pretrained(flux_fill_id).to(dtype=dtype).to(device) 34 | 35 | self.transformer = self.flux_fill_pipe.transformer 36 | self.text_encoder = self.flux_fill_pipe.text_encoder 37 | self.text_encoder_2 = self.flux_fill_pipe.text_encoder_2 38 | self.transformer.gradient_checkpointing = gradient_checkpointing 39 | self.transformer.train() 40 | # Freeze the Flux pipeline 41 | self.text_encoder.requires_grad_(False) 42 | self.text_encoder_2.requires_grad_(False) 43 | self.flux_fill_pipe.vae.requires_grad_(False).eval() 44 | self.use_offset_noise = use_offset_noise 45 | 46 | if use_offset_noise: 47 | print('[debug] use OFFSET NOISE.') 48 | 49 | self.lora_layers = self.init_lora(lora_path, lora_config) 50 | 51 | self.to(device).to(dtype) 52 | 53 | def init_lora(self, lora_path: str, lora_config: dict): 54 | assert lora_path or lora_config 55 | if lora_path: 56 | # TODO: Implement this 57 | raise NotImplementedError 58 | else: 59 | self.transformer.add_adapter(LoraConfig(**lora_config)) 60 | # TODO: Check if this is correct (p.requires_grad) 61 | lora_layers = filter( 62 | lambda p: p.requires_grad, self.transformer.parameters() 63 | ) 64 | return list(lora_layers) 65 | 66 | def save_lora(self, path: str): 67 | FluxFillPipeline.save_lora_weights( 68 | save_directory=path, 69 | transformer_lora_layers=get_peft_model_state_dict(self.transformer), 70 | safe_serialization=True, 71 | ) 72 | if self.model_config['use_sep']: 73 | torch.save(self.text_encoder_2.shared, os.path.join(path, "t5_embedding.pth")) 74 | torch.save(self.text_encoder.text_model.embeddings.token_embedding, os.path.join(path, "clip_embedding.pth")) 75 | 76 | def configure_optimizers(self): 77 | # Freeze the transformer 78 | self.transformer.requires_grad_(False) 79 | opt_config = self.optimizer_config 80 | 81 | # Set the trainable parameters 82 | self.trainable_params = self.lora_layers 83 | 84 | # Unfreeze trainable parameters 85 | for p in self.trainable_params: 86 | p.requires_grad_(True) 87 | 88 | # Initialize the optimizer 89 | if opt_config["type"] == "AdamW": 90 | optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) 91 | elif opt_config["type"] == "Prodigy": 92 | optimizer = prodigyopt.Prodigy( 93 | self.trainable_params, 94 | **opt_config["params"], 95 | ) 96 | elif opt_config["type"] == "SGD": 97 | optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) 98 | else: 99 | raise NotImplementedError 100 | 101 | return optimizer 102 | 103 | def training_step(self, batch, batch_idx): 104 | step_loss = self.step(batch) 105 | self.log_loss = ( 106 | step_loss.item() 107 | if not hasattr(self, "log_loss") 108 | else self.log_loss * 0.95 + step_loss.item() * 0.05 109 | ) 110 | return step_loss 111 | 112 | def step(self, batch): 113 | imgs = batch["image"] 114 | mask_imgs = batch["condition"] 115 | condition_types = batch["condition_type"] 116 | prompts = batch["description"] 117 | position_delta = batch["position_delta"][0] 118 | 119 | with torch.no_grad(): 120 | prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( 121 | self.flux_fill_pipe, prompts 122 | ) 123 | 124 | x_0, x_cond, img_ids = encode_images_fill(self.flux_fill_pipe, imgs, mask_imgs, prompt_embeds.dtype, prompt_embeds.device) 125 | 126 | # Prepare t and x_t 127 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) 128 | x_1 = torch.randn_like(x_0).to(self.device) 129 | 130 | if self.use_offset_noise: 131 | x_1 = x_1 + 0.1 * torch.randn(x_1.shape[0], 1, x_1.shape[2]).to(self.device).to(self.dtype) 132 | 133 | t_ = t.unsqueeze(1).unsqueeze(1) 134 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) 135 | 136 | # Prepare guidance 137 | guidance = ( 138 | torch.ones_like(t).to(self.device) 139 | if self.transformer.config.guidance_embeds 140 | else None 141 | ) 142 | 143 | # Forward pass 144 | transformer_out = self.transformer( 145 | hidden_states=torch.cat((x_t, x_cond), dim=2), 146 | timestep=t, 147 | guidance=guidance, 148 | pooled_projections=pooled_prompt_embeds, 149 | encoder_hidden_states=prompt_embeds, 150 | txt_ids=text_ids, 151 | img_ids=img_ids, 152 | joint_attention_kwargs=None, 153 | return_dict=False, 154 | ) 155 | pred = transformer_out[0] 156 | 157 | # Compute loss 158 | loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") 159 | self.last_t = t.mean().item() 160 | return loss 161 | -------------------------------------------------------------------------------- /train/src/train/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | import lightning as L 4 | import yaml 5 | import os 6 | import random 7 | import time 8 | import numpy as np 9 | from datasets import load_dataset 10 | 11 | from .data import ( 12 | EditDataset, 13 | OminiDataset, 14 | EditDataset_with_Omini 15 | ) 16 | from .model import OminiModel 17 | from .callbacks import TrainingCallback 18 | 19 | 20 | def get_rank(): 21 | try: 22 | rank = int(os.environ.get("LOCAL_RANK")) 23 | except: 24 | rank = 0 25 | return rank 26 | 27 | 28 | def get_config(): 29 | config_path = os.environ.get("XFL_CONFIG") 30 | assert config_path is not None, "Please set the XFL_CONFIG environment variable" 31 | with open(config_path, "r") as f: 32 | config = yaml.safe_load(f) 33 | return config 34 | 35 | 36 | def init_wandb(wandb_config, run_name): 37 | import wandb 38 | 39 | try: 40 | assert os.environ.get("WANDB_API_KEY") is not None 41 | wandb.init( 42 | project=wandb_config["project"], 43 | name=run_name, 44 | config={}, 45 | ) 46 | except Exception as e: 47 | print("Failed to initialize WanDB:", e) 48 | 49 | 50 | def main(): 51 | # Initialize 52 | is_main_process, rank = get_rank() == 0, get_rank() 53 | torch.cuda.set_device(rank) 54 | config = get_config() 55 | training_config = config["train"] 56 | run_name = time.strftime("%Y%m%d-%H%M%S") 57 | 58 | seed = 666 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | torch.cuda.manual_seed_all(seed) 62 | random.seed(seed) 63 | 64 | # Initialize WanDB 65 | wandb_config = training_config.get("wandb", None) 66 | if wandb_config is not None and is_main_process: 67 | init_wandb(wandb_config, run_name) 68 | 69 | print("Rank:", rank) 70 | if is_main_process: 71 | print("Config:", config) 72 | 73 | if 'use_offset_noise' not in config.keys(): 74 | config['use_offset_noise'] = False 75 | 76 | # Initialize dataset and dataloader 77 | 78 | if training_config["dataset"]["type"] == "edit": 79 | dataset = load_dataset('osunlp/MagicBrush') 80 | dataset = EditDataset( 81 | dataset, 82 | condition_size=training_config["dataset"]["condition_size"], 83 | target_size=training_config["dataset"]["target_size"], 84 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 85 | ) 86 | elif training_config["dataset"]["type"] == "omini": 87 | dataset = load_dataset(training_config["dataset"]["path"]) 88 | dataset = OminiDataset( 89 | dataset, 90 | condition_size=training_config["dataset"]["condition_size"], 91 | target_size=training_config["dataset"]["target_size"], 92 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 93 | ) 94 | 95 | elif training_config["dataset"]["type"] == "edit_with_omini": 96 | omni = load_dataset("parquet", data_files=os.path.abspath(training_config["dataset"]["path"]), split="train") 97 | magic = load_dataset('osunlp/MagicBrush') 98 | dataset = EditDataset_with_Omini( 99 | magic, 100 | omni, 101 | condition_size=training_config["dataset"]["condition_size"], 102 | target_size=training_config["dataset"]["target_size"], 103 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 104 | ) 105 | 106 | 107 | print("Dataset length:", len(dataset)) 108 | train_loader = DataLoader( 109 | dataset, 110 | batch_size=training_config["batch_size"], 111 | shuffle=True, 112 | num_workers=training_config["dataloader_workers"], 113 | ) 114 | 115 | # Initialize model 116 | trainable_model = OminiModel( 117 | flux_fill_id=config["flux_path"], 118 | lora_config=training_config["lora_config"], 119 | device=f"cuda", 120 | dtype=getattr(torch, config["dtype"]), 121 | optimizer_config=training_config["optimizer"], 122 | model_config=config.get("model", {}), 123 | gradient_checkpointing=training_config.get("gradient_checkpointing", False), 124 | use_offset_noise=config["use_offset_noise"], 125 | ) 126 | 127 | # Callbacks for logging and saving checkpoints 128 | training_callbacks = ( 129 | [TrainingCallback(run_name, training_config=training_config)] 130 | if is_main_process 131 | else [] 132 | ) 133 | 134 | # Initialize trainer 135 | trainer = L.Trainer( 136 | accumulate_grad_batches=training_config["accumulate_grad_batches"], 137 | callbacks=training_callbacks, 138 | enable_checkpointing=False, 139 | enable_progress_bar=False, 140 | logger=False, 141 | max_steps=training_config.get("max_steps", -1), 142 | max_epochs=training_config.get("max_epochs", -1), 143 | gradient_clip_val=training_config.get("gradient_clip_val", 0.5), 144 | ) 145 | 146 | setattr(trainer, "training_config", training_config) 147 | 148 | # Save config 149 | save_path = training_config.get("save_path", "./output") 150 | if is_main_process: 151 | os.makedirs(f"{save_path}/{run_name}") 152 | with open(f"{save_path}/{run_name}/config.yaml", "w") as f: 153 | yaml.dump(config, f) 154 | 155 | # Start training 156 | trainer.fit(trainable_model, train_loader) 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /train/train/config/normal_lora.yaml: -------------------------------------------------------------------------------- 1 | flux_path: "black-forest-labs/flux.1-fill-dev" 2 | dtype: "bfloat16" 3 | 4 | model: 5 | union_cond_attn: true 6 | add_cond_attn: false 7 | latent_lora: false 8 | use_sep: false 9 | 10 | train: 11 | batch_size: 2 12 | accumulate_grad_batches: 1 13 | dataloader_workers: 5 14 | save_interval: 1000 15 | sample_interval: 1000 16 | max_steps: -1 17 | gradient_checkpointing: true 18 | save_path: "runs" 19 | 20 | condition_type: "edit" 21 | dataset: 22 | type: "edit_with_omini" 23 | path: "parquet/*.parquet" 24 | condition_size: 512 25 | target_size: 512 26 | image_size: 512 27 | padding: 8 28 | drop_text_prob: 0.1 29 | drop_image_prob: 0.1 30 | # specific_task: ["removal", "style", "attribute_modification", "env", "swap"] 31 | 32 | wandb: 33 | project: "ICEdit" 34 | 35 | lora_config: 36 | r: 32 37 | lora_alpha: 32 38 | init_lora_weights: "gaussian" 39 | target_modules: "(.*x_embedder|.*(?