├── .gitignore ├── LICENSE ├── README.md ├── flux ├── __init__.py ├── __main__.py ├── _version.py ├── api.py ├── math.py ├── model.py ├── modules │ ├── autoencoder.py │ ├── conditioner.py │ └── layers.py ├── sampling.py └── util.py ├── gradio_kv_edit.py ├── gradio_kv_edit_gpu.py ├── gradio_kv_edit_inf.py ├── models └── kv_edit.py ├── requirements.txt └── resources ├── example.jpeg ├── pipeline.jpg └── teaser.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | script/ 2 | *.symlink 3 | output/ 4 | regress_result/ 5 | google 6 | checkpoints 7 | huggingface_models 8 | openai 9 | __pycache__/ 10 | .vscode/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # KV-Edit: Training-Free Image Editing for Precise Background Preservation 4 | 5 | [Tianrui Zhu](https://github.com/Xilluill)1*, [Shiyi Zhang](https://shiyi-zh0408.github.io/)1*, [Jiawei Shao](https://shaojiawei07.github.io/)2, [Yansong Tang](https://andytang15.github.io/)1† 6 | 7 | 1 Tsinghua University, 2 Institute of Artificial Intelligence (TeleAI) 8 | 9 | 10 | 11 | [![arXiv](https://img.shields.io/badge/arXiv-2502.17363-b31b1b.svg)](https://arxiv.org/abs/2502.17363) 12 | [![Huggingface space](https://img.shields.io/badge/🤗-Huggingface%20Space-orange.svg)](https://huggingface.co/spaces/xilluill/KV-Edit) 13 | [![GitHub Stars](https://img.shields.io/github/stars/Xilluill/KV-Edit)](https://github.com/Xilluill/KV-Edit) 14 | 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/kv-edit-training-free-image-editing-for/text-based-image-editing-on-pie-bench)](https://paperswithcode.com/sota/text-based-image-editing-on-pie-bench?p=kv-edit-training-free-image-editing-for) 16 | [![Static Badge](https://img.shields.io/badge/comfyUI-KV_Edit-blue)](https://github.com/smthemex/ComfyUI_KV_Edit) 17 | 18 |
19 | 20 |

21 | We propose KV-Edit, a training-free image editing approach that strictly preserves background consistency between the original and edited images. Our method achieves impressive performance on various editing tasks, including object addition, removal, and replacement. 22 |

23 | 24 | 25 | 26 |

27 | 28 |

29 | 30 | # 🔥 News 31 | - [2025.3.12] Thanks to @[smthemex](https://github.com/smthemex) for integrating KV-Edit into [ComfyUI](https://github.com/smthemex/ComfyUI_KV_Edit)! 32 | - [2025.3.4] We update "attention scale" feature to reduce the discontinuity with the background. 33 | - [2025.2.26] Our paper is featured in [huggingface Papers](https://huggingface.co/papers/2502.17363)! 34 | - [2025.2.25] Code for image editing is released! 35 | - [2025.2.25] Paper released! 36 | - [2025.2.25] More results can be found in our [project page](https://xilluill.github.io/projectpages/KV-Edit/)! 37 | 38 | # 👨‍💻 ToDo 39 | - ☑️ Release the gradio demo 40 | - ☑️ Release the huggingface space for image editing 41 | - ☑️ Release the paper 42 | 43 | 44 | # 📖 Pipeline 45 |

46 | 47 | We implemented KV Cache in our DiT-based generative model, which stores the key-value pairs of background tokens during the inversion process and concatenates them with foreground content during denoising. Since background tokens are preserved rather than regenerated, KV-Edit can strictly maintain background consistency while generating seamlessly integrated new content. 48 | 49 | # 🚀 Getting Started 50 | ## Environment Requirement 🌍 51 | The environment of our code is the same as FLUX, you can refer to the [official repo](https://github.com/black-forest-labs/flux/tree/main) of FLUX, or running the following command to construct a simplified environment. 52 | 53 | Clone the repo: 54 | ``` 55 | git clone https://github.com/Xilluill/KV-Edit 56 | ``` 57 | We recommend you first use conda to create virtual environment, then run: 58 | ``` 59 | conda create --name KV-Edit python=3.10 60 | conda activate KV-Edit 61 | pip install -r requirements.txt 62 | ``` 63 | ## Running Gradio demo 🛫 64 | We provide three demo scripts for different hardware configurations. For users with server access and sufficient CPU/GPU memory ( >40/24 GB), we recommend you use: 65 | ``` 66 | python gradio_kv_edit.py 67 | ``` 68 | For users with 2 GPUs(like 3090/4090) which can avoid offload models to accelerate, you can use: 69 | ``` 70 | python gradio_kv_edit_gpu.py --gpus 71 | ``` 72 | For users with limited GPU, we recommend you use: 73 | ``` 74 | python gradio_kv_edit.py --offload 75 | ``` 76 | For users with limited CPU memory such as PC, we recommend you use: 77 | ``` 78 | python gradio_kv_edit_inf.py --offload 79 | ``` 80 | Here's a sample workflow for our demo: 81 | 82 | 1️⃣ Upload your image that needs to be edited.
83 | 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion.
84 | 3️⃣ Use the brush tool to draw your mask area.
85 | 4️⃣ Fill in your target prompt, then adjust the hyperparameters.
86 | 5️⃣ Click the "Edit" button to generate your edited image!
87 | 88 |

89 | 90 |
91 | 92 | ### 💡Important Notes: 93 | - 🎨 When using the inversion-based version, you only need to perform the inversion once for each image. You can then repeat steps 3-5 for multiple editing attempts! 94 | - 🎨 "re_init" means using image blending with noise instead of result from inversion to generate new contents. 95 | - 🎨 When the "attn_mask" option is checked, you need to input the mask before performing the inversion. 96 | - 🎨 When the mask is large, and using less skip steps or "re_init", the content of the mask area may not be continuous with background, you can try to increase "attn_scale". 97 | 98 | 99 | # 🖋️ Citation 100 | 101 | If you find our work helpful, please **star 🌟** this repo and **cite 📑** our paper. Thanks for your support! 102 | ``` 103 | @article{zhu2025kv, 104 | title={KV-Edit: Training-Free Image Editing for Precise Background Preservation}, 105 | author={Zhu, Tianrui and Zhang, Shiyi and Shao, Jiawei and Tang, Yansong}, 106 | journal={arXiv preprint arXiv:2502.17363}, 107 | year={2025} 108 | } 109 | ``` 110 | 111 | # 👍🏻 Acknowledgements 112 | Our code is modified based on [FLUX](https://github.com/black-forest-labs/flux) and [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit). Special thanks to [Wenke Huang](https://wenkehuang.github.io/) for his early inspiration and helpful guidance to this project! 113 | 114 | # 📧 Contact 115 | This repository is currently under active development and restructuring. The codebase is being optimized for better stability and reproducibility. While we strive to maintain code quality, you may encounter temporary issues during this transition period. For any questions or technical discussions, feel free to open an issue or contact us via email at xilluill070513@gmail.com. -------------------------------------------------------------------------------- /flux/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ # type: ignore 3 | from ._version import version_tuple 4 | except ImportError: 5 | __version__ = "unknown (no version information available)" 6 | version_tuple = (0, 0, "unknown", "noinfo") 7 | 8 | from pathlib import Path 9 | 10 | PACKAGE = __package__.replace("_", "-") 11 | PACKAGE_ROOT = Path(__file__).parent 12 | -------------------------------------------------------------------------------- /flux/__main__.py: -------------------------------------------------------------------------------- 1 | from .cli import app 2 | 3 | if __name__ == "__main__": 4 | app() 5 | -------------------------------------------------------------------------------- /flux/_version.py: -------------------------------------------------------------------------------- 1 | # file generated by setuptools_scm 2 | # don't change, don't track in version control 3 | TYPE_CHECKING = False 4 | if TYPE_CHECKING: 5 | from typing import Tuple, Union 6 | VERSION_TUPLE = Tuple[Union[int, str], ...] 7 | else: 8 | VERSION_TUPLE = object 9 | 10 | version: str 11 | __version__: str 12 | __version_tuple__: VERSION_TUPLE 13 | version_tuple: VERSION_TUPLE 14 | 15 | __version__ = version = '0.0.post6+ge52e00f.d20250111' 16 | __version_tuple__ = version_tuple = (0, 0, 'ge52e00f.d20250111') 17 | -------------------------------------------------------------------------------- /flux/api.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | from pathlib import Path 5 | 6 | import requests 7 | from PIL import Image 8 | 9 | API_ENDPOINT = "https://api.bfl.ml" 10 | 11 | 12 | class ApiException(Exception): 13 | def __init__(self, status_code: int, detail: str | list[dict] | None = None): 14 | super().__init__() 15 | self.detail = detail 16 | self.status_code = status_code 17 | 18 | def __str__(self) -> str: 19 | return self.__repr__() 20 | 21 | def __repr__(self) -> str: 22 | if self.detail is None: 23 | message = None 24 | elif isinstance(self.detail, str): 25 | message = self.detail 26 | else: 27 | message = "[" + ",".join(d["msg"] for d in self.detail) + "]" 28 | return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" 29 | 30 | 31 | class ImageRequest: 32 | def __init__( 33 | self, 34 | prompt: str, 35 | width: int = 1024, 36 | height: int = 1024, 37 | name: str = "flux.1-pro", 38 | num_steps: int = 50, 39 | prompt_upsampling: bool = False, 40 | seed: int | None = None, 41 | validate: bool = True, 42 | launch: bool = True, 43 | api_key: str | None = None, 44 | ): 45 | """ 46 | Manages an image generation request to the API. 47 | 48 | Args: 49 | prompt: Prompt to sample 50 | width: Width of the image in pixel 51 | height: Height of the image in pixel 52 | name: Name of the model 53 | num_steps: Number of network evaluations 54 | prompt_upsampling: Use prompt upsampling 55 | seed: Fix the generation seed 56 | validate: Run input validation 57 | launch: Directly launches request 58 | api_key: Your API key if not provided by the environment 59 | 60 | Raises: 61 | ValueError: For invalid input 62 | ApiException: For errors raised from the API 63 | """ 64 | if validate: 65 | if name not in ["flux.1-pro"]: 66 | raise ValueError(f"Invalid model {name}") 67 | elif width % 32 != 0: 68 | raise ValueError(f"width must be divisible by 32, got {width}") 69 | elif not (256 <= width <= 1440): 70 | raise ValueError(f"width must be between 256 and 1440, got {width}") 71 | elif height % 32 != 0: 72 | raise ValueError(f"height must be divisible by 32, got {height}") 73 | elif not (256 <= height <= 1440): 74 | raise ValueError(f"height must be between 256 and 1440, got {height}") 75 | elif not (1 <= num_steps <= 50): 76 | raise ValueError(f"steps must be between 1 and 50, got {num_steps}") 77 | 78 | self.request_json = { 79 | "prompt": prompt, 80 | "width": width, 81 | "height": height, 82 | "variant": name, 83 | "steps": num_steps, 84 | "prompt_upsampling": prompt_upsampling, 85 | } 86 | if seed is not None: 87 | self.request_json["seed"] = seed 88 | 89 | self.request_id: str | None = None 90 | self.result: dict | None = None 91 | self._image_bytes: bytes | None = None 92 | self._url: str | None = None 93 | if api_key is None: 94 | self.api_key = os.environ.get("BFL_API_KEY") 95 | else: 96 | self.api_key = api_key 97 | 98 | if launch: 99 | self.request() 100 | 101 | def request(self): 102 | """ 103 | Request to generate the image. 104 | """ 105 | if self.request_id is not None: 106 | return 107 | response = requests.post( 108 | f"{API_ENDPOINT}/v1/image", 109 | headers={ 110 | "accept": "application/json", 111 | "x-key": self.api_key, 112 | "Content-Type": "application/json", 113 | }, 114 | json=self.request_json, 115 | ) 116 | result = response.json() 117 | if response.status_code != 200: 118 | raise ApiException(status_code=response.status_code, detail=result.get("detail")) 119 | self.request_id = response.json()["id"] 120 | 121 | def retrieve(self) -> dict: 122 | """ 123 | Wait for the generation to finish and retrieve response. 124 | """ 125 | if self.request_id is None: 126 | self.request() 127 | while self.result is None: 128 | response = requests.get( 129 | f"{API_ENDPOINT}/v1/get_result", 130 | headers={ 131 | "accept": "application/json", 132 | "x-key": self.api_key, 133 | }, 134 | params={ 135 | "id": self.request_id, 136 | }, 137 | ) 138 | result = response.json() 139 | if "status" not in result: 140 | raise ApiException(status_code=response.status_code, detail=result.get("detail")) 141 | elif result["status"] == "Ready": 142 | self.result = result["result"] 143 | elif result["status"] == "Pending": 144 | time.sleep(0.5) 145 | else: 146 | raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") 147 | return self.result 148 | 149 | @property 150 | def bytes(self) -> bytes: 151 | """ 152 | Generated image as bytes. 153 | """ 154 | if self._image_bytes is None: 155 | response = requests.get(self.url) 156 | if response.status_code == 200: 157 | self._image_bytes = response.content 158 | else: 159 | raise ApiException(status_code=response.status_code) 160 | return self._image_bytes 161 | 162 | @property 163 | def url(self) -> str: 164 | """ 165 | Public url to retrieve the image from 166 | """ 167 | if self._url is None: 168 | result = self.retrieve() 169 | self._url = result["sample"] 170 | return self._url 171 | 172 | @property 173 | def image(self) -> Image.Image: 174 | """ 175 | Load the image as a PIL Image 176 | """ 177 | return Image.open(io.BytesIO(self.bytes)) 178 | 179 | def save(self, path: str): 180 | """ 181 | Save the generated image to a local path 182 | """ 183 | suffix = Path(self.url).suffix 184 | if not path.endswith(suffix): 185 | path = path + suffix 186 | Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) 187 | with open(path, "wb") as file: 188 | file.write(self.bytes) 189 | 190 | 191 | if __name__ == "__main__": 192 | from fire import Fire 193 | 194 | Fire(ImageRequest) 195 | -------------------------------------------------------------------------------- /flux/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor 4 | 5 | 6 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor,pe_q = None, attention_mask = None) -> Tensor: 7 | if pe_q is None: 8 | q, k = apply_rope(q, k, pe) 9 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v,attn_mask=attention_mask) 10 | x = rearrange(x, "B H L D -> B L (H D)") 11 | return x 12 | else: 13 | q, k = apply_rope_qk(q, k, pe_q, pe) 14 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v,attn_mask=attention_mask) 15 | x = rearrange(x, "B H L D -> B L (H D)") 16 | return x 17 | 18 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 19 | assert dim % 2 == 0 20 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim # dim =16 + 56 + 56 21 | omega = 1.0 / (theta**scale) # 64 omega 22 | out = torch.einsum("...n,d->...nd", pos, omega) 23 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 24 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 25 | return out.float() 26 | 27 | 28 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 29 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 30 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 31 | xq_out = freqs_cis[:, :, :xq_.shape[2], :, :, 0] * xq_[..., 0] + freqs_cis[:, :, :xq_.shape[2], :, :, 1] * xq_[..., 1] 32 | xk_out = freqs_cis[:, :, :xk_.shape[2], :, :, 0] * xk_[..., 0] + freqs_cis[:, :, :xk_.shape[2], :, :, 1] * xk_[..., 1] 33 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 34 | 35 | def apply_rope_qk(xq: Tensor, xk: Tensor, freqs_cis_q: Tensor,freqs_cis_k: Tensor) -> tuple[Tensor, Tensor]: 36 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 37 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 38 | xq_out = freqs_cis_q[:, :, :xq_.shape[2], :, :, 0] * xq_[..., 0] + freqs_cis_q[:, :, :xq_.shape[2], :, :, 1] * xq_[..., 1] 39 | xk_out = freqs_cis_k[:, :, :xk_.shape[2], :, :, 0] * xk_[..., 0] + freqs_cis_k[:, :, :xk_.shape[2], :, :, 1] * xk_[..., 1] 40 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 41 | -------------------------------------------------------------------------------- /flux/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, 7 | MLPEmbedder, SingleStreamBlock, 8 | SingleStreamBlock_kv,DoubleStreamBlock_kv, 9 | timestep_embedding) 10 | 11 | 12 | @dataclass 13 | class FluxParams: 14 | in_channels: int 15 | vec_in_dim: int 16 | context_in_dim: int 17 | hidden_size: int 18 | mlp_ratio: float 19 | num_heads: int 20 | depth: int 21 | depth_single_blocks: int 22 | axes_dim: list[int] 23 | theta: int 24 | qkv_bias: bool 25 | guidance_embed: bool 26 | 27 | 28 | class Flux(nn.Module): 29 | """ 30 | Transformer model for flow matching on sequences. 31 | """ 32 | 33 | def __init__(self, params: FluxParams,double_block_cls=DoubleStreamBlock,single_block_cls=SingleStreamBlock): 34 | super().__init__() 35 | 36 | self.params = params 37 | self.in_channels = params.in_channels 38 | self.out_channels = self.in_channels 39 | if params.hidden_size % params.num_heads != 0: 40 | raise ValueError( 41 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 42 | ) 43 | pe_dim = params.hidden_size // params.num_heads 44 | if sum(params.axes_dim) != pe_dim: 45 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 46 | self.hidden_size = params.hidden_size 47 | self.num_heads = params.num_heads 48 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 49 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 50 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 51 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 52 | self.guidance_in = ( 53 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 54 | ) 55 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 56 | 57 | self.double_blocks = nn.ModuleList( 58 | [ 59 | double_block_cls( 60 | self.hidden_size, 61 | self.num_heads, 62 | mlp_ratio=params.mlp_ratio, 63 | qkv_bias=params.qkv_bias, 64 | ) 65 | for _ in range(params.depth) 66 | ] 67 | ) 68 | 69 | self.single_blocks = nn.ModuleList( 70 | [ 71 | single_block_cls(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 72 | for _ in range(params.depth_single_blocks) 73 | ] 74 | ) 75 | 76 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 77 | 78 | def forward( 79 | self, 80 | img: Tensor, 81 | img_ids: Tensor, 82 | txt: Tensor, 83 | txt_ids: Tensor, 84 | timesteps: Tensor, 85 | y: Tensor, 86 | guidance: Tensor | None = None, 87 | ) -> Tensor: 88 | if img.ndim != 3 or txt.ndim != 3: 89 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 90 | 91 | # running on sequences img 92 | img = self.img_in(img) 93 | vec = self.time_in(timestep_embedding(timesteps, 256)) 94 | if self.params.guidance_embed: 95 | if guidance is None: 96 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 97 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 98 | vec = vec + self.vector_in(y) 99 | txt = self.txt_in(txt) 100 | 101 | ids = torch.cat((txt_ids, img_ids), dim=1) 102 | pe = self.pe_embedder(ids) 103 | 104 | for block in self.double_blocks: 105 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 106 | 107 | img = torch.cat((txt, img), 1) 108 | for block in self.single_blocks: 109 | img = block(img, vec=vec, pe=pe) 110 | img = img[:, txt.shape[1] :, ...] 111 | 112 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 113 | return img 114 | 115 | class Flux_kv(Flux): 116 | """ 117 | 继承Flux类,重写forward方法 118 | """ 119 | 120 | def __init__(self, params: FluxParams,double_block_cls=DoubleStreamBlock_kv,single_block_cls=SingleStreamBlock_kv): 121 | super().__init__(params,double_block_cls,single_block_cls) 122 | 123 | def forward( 124 | self, 125 | img: Tensor, 126 | img_ids: Tensor, 127 | txt: Tensor, 128 | txt_ids: Tensor, 129 | timesteps: Tensor, 130 | y: Tensor, 131 | guidance: Tensor | None = None, 132 | info: dict = {}, 133 | ) -> Tensor: 134 | if img.ndim != 3 or txt.ndim != 3: 135 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 136 | 137 | # running on sequences img 138 | img = self.img_in(img) 139 | vec = self.time_in(timestep_embedding(timesteps, 256)) 140 | if self.params.guidance_embed: 141 | if guidance is None: 142 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 143 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 144 | vec = vec + self.vector_in(y) 145 | txt = self.txt_in(txt) 146 | 147 | ids = torch.cat((txt_ids, img_ids), dim=1) 148 | pe = self.pe_embedder(ids) 149 | if not info['inverse']: 150 | mask_indices = info['mask_indices'] 151 | info['pe_mask'] = torch.cat((pe[:, :, :512, ...],pe[:, :, mask_indices+512, ...]),dim=2) 152 | 153 | cnt = 0 154 | for block in self.double_blocks: 155 | info['id'] = cnt 156 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe, info=info) 157 | cnt += 1 158 | 159 | cnt = 0 160 | x = torch.cat((txt, img), 1) 161 | for block in self.single_blocks: 162 | info['id'] = cnt 163 | x = block(x, vec=vec, pe=pe, info=info) 164 | cnt += 1 165 | 166 | img = x[:, txt.shape[1] :, ...] 167 | 168 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 169 | 170 | return img 171 | -------------------------------------------------------------------------------- /flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import Tensor, nn 6 | 7 | 8 | @dataclass 9 | class AutoEncoderParams: 10 | resolution: int 11 | in_channels: int 12 | ch: int 13 | out_ch: int 14 | ch_mult: list[int] 15 | num_res_blocks: int 16 | z_channels: int 17 | scale_factor: float 18 | shift_factor: float 19 | 20 | 21 | def swish(x: Tensor) -> Tensor: 22 | return x * torch.sigmoid(x) 23 | 24 | 25 | class AttnBlock(nn.Module): 26 | def __init__(self, in_channels: int): 27 | super().__init__() 28 | self.in_channels = in_channels 29 | 30 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 31 | 32 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 33 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 34 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | 37 | def attention(self, h_: Tensor) -> Tensor: 38 | h_ = self.norm(h_) 39 | q = self.q(h_) 40 | k = self.k(h_) 41 | v = self.v(h_) 42 | 43 | b, c, h, w = q.shape 44 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 45 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 46 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 47 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 48 | 49 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | return x + self.proj_out(self.attention(x)) 53 | 54 | 55 | class ResnetBlock(nn.Module): 56 | def __init__(self, in_channels: int, out_channels: int): 57 | super().__init__() 58 | self.in_channels = in_channels 59 | out_channels = in_channels if out_channels is None else out_channels 60 | self.out_channels = out_channels 61 | 62 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 63 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 64 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 65 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 66 | if self.in_channels != self.out_channels: 67 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 68 | 69 | def forward(self, x): 70 | h = x 71 | h = self.norm1(h) 72 | h = swish(h) 73 | h = self.conv1(h) 74 | 75 | h = self.norm2(h) 76 | h = swish(h) 77 | h = self.conv2(h) 78 | 79 | if self.in_channels != self.out_channels: 80 | x = self.nin_shortcut(x) 81 | 82 | return x + h 83 | 84 | 85 | class Downsample(nn.Module): 86 | def __init__(self, in_channels: int): 87 | super().__init__() 88 | # no asymmetric padding in torch conv, must do it ourselves 89 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 90 | 91 | def forward(self, x: Tensor): 92 | pad = (0, 1, 0, 1) 93 | x = nn.functional.pad(x, pad, mode="constant", value=0) 94 | x = self.conv(x) 95 | return x 96 | 97 | 98 | class Upsample(nn.Module): 99 | def __init__(self, in_channels: int): 100 | super().__init__() 101 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 102 | 103 | def forward(self, x: Tensor): 104 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 105 | x = self.conv(x) 106 | return x 107 | 108 | 109 | class Encoder(nn.Module): 110 | def __init__( 111 | self, 112 | resolution: int, 113 | in_channels: int, 114 | ch: int, 115 | ch_mult: list[int], 116 | num_res_blocks: int, 117 | z_channels: int, 118 | ): 119 | super().__init__() 120 | self.ch = ch 121 | self.num_resolutions = len(ch_mult) 122 | self.num_res_blocks = num_res_blocks 123 | self.resolution = resolution 124 | self.in_channels = in_channels 125 | # downsampling 126 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 127 | 128 | curr_res = resolution 129 | in_ch_mult = (1,) + tuple(ch_mult) 130 | self.in_ch_mult = in_ch_mult 131 | self.down = nn.ModuleList() 132 | block_in = self.ch 133 | for i_level in range(self.num_resolutions): 134 | block = nn.ModuleList() 135 | attn = nn.ModuleList() 136 | block_in = ch * in_ch_mult[i_level] 137 | block_out = ch * ch_mult[i_level] 138 | for _ in range(self.num_res_blocks): 139 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 140 | block_in = block_out 141 | down = nn.Module() 142 | down.block = block 143 | down.attn = attn 144 | if i_level != self.num_resolutions - 1: 145 | down.downsample = Downsample(block_in) 146 | curr_res = curr_res // 2 147 | self.down.append(down) 148 | 149 | # middle 150 | self.mid = nn.Module() 151 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 152 | self.mid.attn_1 = AttnBlock(block_in) 153 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 154 | 155 | # end 156 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 157 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 158 | 159 | def forward(self, x: Tensor) -> Tensor: 160 | # downsampling 161 | hs = [self.conv_in(x)] 162 | for i_level in range(self.num_resolutions): 163 | for i_block in range(self.num_res_blocks): 164 | h = self.down[i_level].block[i_block](hs[-1]) 165 | if len(self.down[i_level].attn) > 0: 166 | h = self.down[i_level].attn[i_block](h) 167 | hs.append(h) 168 | if i_level != self.num_resolutions - 1: 169 | hs.append(self.down[i_level].downsample(hs[-1])) 170 | 171 | # middle 172 | h = hs[-1] 173 | h = self.mid.block_1(h) 174 | h = self.mid.attn_1(h) 175 | h = self.mid.block_2(h) 176 | # end 177 | h = self.norm_out(h) 178 | h = swish(h) 179 | h = self.conv_out(h) 180 | return h 181 | 182 | 183 | class Decoder(nn.Module): 184 | def __init__( 185 | self, 186 | ch: int, 187 | out_ch: int, 188 | ch_mult: list[int], 189 | num_res_blocks: int, 190 | in_channels: int, 191 | resolution: int, 192 | z_channels: int, 193 | ): 194 | super().__init__() 195 | self.ch = ch 196 | self.num_resolutions = len(ch_mult) 197 | self.num_res_blocks = num_res_blocks 198 | self.resolution = resolution 199 | self.in_channels = in_channels 200 | self.ffactor = 2 ** (self.num_resolutions - 1) 201 | 202 | # compute in_ch_mult, block_in and curr_res at lowest res 203 | block_in = ch * ch_mult[self.num_resolutions - 1] 204 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 205 | self.z_shape = (1, z_channels, curr_res, curr_res) 206 | 207 | # z to block_in 208 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 209 | 210 | # middle 211 | self.mid = nn.Module() 212 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 213 | self.mid.attn_1 = AttnBlock(block_in) 214 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 215 | 216 | # upsampling 217 | self.up = nn.ModuleList() 218 | for i_level in reversed(range(self.num_resolutions)): 219 | block = nn.ModuleList() 220 | attn = nn.ModuleList() 221 | block_out = ch * ch_mult[i_level] 222 | for _ in range(self.num_res_blocks + 1): 223 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 224 | block_in = block_out 225 | up = nn.Module() 226 | up.block = block 227 | up.attn = attn 228 | if i_level != 0: 229 | up.upsample = Upsample(block_in) 230 | curr_res = curr_res * 2 231 | self.up.insert(0, up) # prepend to get consistent order 232 | 233 | # end 234 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 235 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 236 | 237 | def forward(self, z: Tensor) -> Tensor: 238 | # z to block_in 239 | h = self.conv_in(z) 240 | 241 | # middle 242 | h = self.mid.block_1(h) 243 | h = self.mid.attn_1(h) 244 | h = self.mid.block_2(h) 245 | 246 | # upsampling 247 | for i_level in reversed(range(self.num_resolutions)): 248 | for i_block in range(self.num_res_blocks + 1): 249 | h = self.up[i_level].block[i_block](h) 250 | if len(self.up[i_level].attn) > 0: 251 | h = self.up[i_level].attn[i_block](h) 252 | if i_level != 0: 253 | h = self.up[i_level].upsample(h) 254 | 255 | # end 256 | h = self.norm_out(h) 257 | h = swish(h) 258 | h = self.conv_out(h) 259 | return h 260 | 261 | 262 | class DiagonalGaussian(nn.Module): 263 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 264 | super().__init__() 265 | self.sample = sample 266 | self.chunk_dim = chunk_dim 267 | 268 | def forward(self, z: Tensor) -> Tensor: 269 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 270 | # import pdb;pdb.set_trace() 271 | if self.sample: 272 | std = torch.exp(0.5 * logvar) 273 | return mean #+ std * torch.randn_like(mean) 274 | else: 275 | return mean 276 | 277 | 278 | class AutoEncoder(nn.Module): 279 | def __init__(self, params: AutoEncoderParams): 280 | super().__init__() 281 | self.encoder = Encoder( 282 | resolution=params.resolution, 283 | in_channels=params.in_channels, 284 | ch=params.ch, 285 | ch_mult=params.ch_mult, 286 | num_res_blocks=params.num_res_blocks, 287 | z_channels=params.z_channels, 288 | ) 289 | self.decoder = Decoder( 290 | resolution=params.resolution, 291 | in_channels=params.in_channels, 292 | ch=params.ch, 293 | out_ch=params.out_ch, 294 | ch_mult=params.ch_mult, 295 | num_res_blocks=params.num_res_blocks, 296 | z_channels=params.z_channels, 297 | ) 298 | self.reg = DiagonalGaussian() 299 | 300 | self.scale_factor = params.scale_factor 301 | self.shift_factor = params.shift_factor 302 | 303 | def encode(self, x: Tensor) -> Tensor: 304 | z = self.reg(self.encoder(x)) 305 | z = self.scale_factor * (z - self.shift_factor) 306 | return z 307 | 308 | def decode(self, z: Tensor) -> Tensor: 309 | z = z / self.scale_factor + self.shift_factor 310 | return self.decoder(z) 311 | 312 | def forward(self, x: Tensor) -> Tensor: 313 | return self.decode(self.encode(x)) 314 | -------------------------------------------------------------------------------- /flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, 3 | T5Tokenizer) 4 | 5 | 6 | class HFEmbedder(nn.Module): 7 | def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs): 8 | super().__init__() 9 | self.is_clip = is_clip 10 | self.max_length = max_length 11 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 12 | 13 | if version == 'black-forest-labs/FLUX.1-dev': 14 | if self.is_clip: 15 | self.tokenizer: T5Tokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer") 16 | self.hf_module: T5EncoderModel = CLIPTextModel.from_pretrained(version,subfolder='text_encoder' , **hf_kwargs) 17 | else: 18 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer_2") 19 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version,subfolder='text_encoder_2' , **hf_kwargs) 20 | else: 21 | if self.is_clip: 22 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) 23 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) 24 | else: 25 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) 26 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) 27 | 28 | self.hf_module = self.hf_module.eval().requires_grad_(False) 29 | 30 | def forward(self, text: list[str]) -> Tensor: 31 | batch_encoding = self.tokenizer( 32 | text, 33 | truncation=True, 34 | max_length=self.max_length, 35 | return_length=False, 36 | return_overflowing_tokens=False, 37 | padding="max_length", 38 | return_tensors="pt", 39 | ) 40 | 41 | outputs = self.hf_module( 42 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 43 | attention_mask=None, 44 | output_hidden_states=False, 45 | ) 46 | return outputs[self.output_key] 47 | -------------------------------------------------------------------------------- /flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | 8 | from flux.math import attention, rope,apply_rope 9 | 10 | import os 11 | 12 | class EmbedND(nn.Module): 13 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 14 | super().__init__() 15 | self.dim = dim 16 | self.theta = theta 17 | self.axes_dim = axes_dim 18 | 19 | def forward(self, ids: Tensor) -> Tensor: 20 | n_axes = ids.shape[-1] 21 | emb = torch.cat( 22 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 23 | dim=-3, 24 | ) 25 | 26 | return emb.unsqueeze(1) 27 | 28 | 29 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 30 | """ 31 | Create sinusoidal timestep embeddings. 32 | :param t: a 1-D Tensor of N indices, one per batch element. 33 | These may be fractional. 34 | :param dim: the dimension of the output. 35 | :param max_period: controls the minimum frequency of the embeddings. 36 | :return: an (N, D) Tensor of positional embeddings. 37 | """ 38 | t = time_factor * t 39 | half = dim // 2 40 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 41 | t.device 42 | ) 43 | 44 | args = t[:, None].float() * freqs[None] 45 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 46 | if dim % 2: 47 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 48 | if torch.is_floating_point(t): 49 | embedding = embedding.to(t) 50 | return embedding 51 | 52 | 53 | class MLPEmbedder(nn.Module): 54 | def __init__(self, in_dim: int, hidden_dim: int): 55 | super().__init__() 56 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 57 | self.silu = nn.SiLU() 58 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 59 | 60 | def forward(self, x: Tensor) -> Tensor: 61 | return self.out_layer(self.silu(self.in_layer(x))) 62 | 63 | 64 | class RMSNorm(torch.nn.Module): 65 | def __init__(self, dim: int): 66 | super().__init__() 67 | self.scale = nn.Parameter(torch.ones(dim)) 68 | 69 | def forward(self, x: Tensor): 70 | x_dtype = x.dtype 71 | x = x.float() 72 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 73 | return (x * rrms).to(dtype=x_dtype) * self.scale 74 | 75 | 76 | class QKNorm(torch.nn.Module): 77 | def __init__(self, dim: int): 78 | super().__init__() 79 | self.query_norm = RMSNorm(dim) 80 | self.key_norm = RMSNorm(dim) 81 | 82 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 83 | q = self.query_norm(q) 84 | k = self.key_norm(k) 85 | return q.to(v), k.to(v) 86 | 87 | 88 | class SelfAttention(nn.Module): 89 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 90 | super().__init__() 91 | self.num_heads = num_heads 92 | head_dim = dim // num_heads 93 | 94 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 95 | self.norm = QKNorm(head_dim) 96 | self.proj = nn.Linear(dim, dim) 97 | 98 | def forward(self, x: Tensor, pe: Tensor) -> Tensor: 99 | qkv = self.qkv(x) 100 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 101 | q, k = self.norm(q, k, v) 102 | x = attention(q, k, v, pe=pe) 103 | x = self.proj(x) 104 | return x 105 | 106 | 107 | @dataclass 108 | class ModulationOut: 109 | shift: Tensor 110 | scale: Tensor 111 | gate: Tensor 112 | 113 | 114 | class Modulation(nn.Module): 115 | def __init__(self, dim: int, double: bool): 116 | super().__init__() 117 | self.is_double = double 118 | self.multiplier = 6 if double else 3 119 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 120 | 121 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 122 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 123 | 124 | return ( 125 | ModulationOut(*out[:3]), 126 | ModulationOut(*out[3:]) if self.is_double else None, 127 | ) 128 | 129 | 130 | class DoubleStreamBlock(nn.Module): 131 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 132 | super().__init__() 133 | 134 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 135 | self.num_heads = num_heads 136 | self.hidden_size = hidden_size 137 | self.img_mod = Modulation(hidden_size, double=True) 138 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 139 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 140 | 141 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 142 | self.img_mlp = nn.Sequential( 143 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 144 | nn.GELU(approximate="tanh"), 145 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 146 | ) 147 | 148 | self.txt_mod = Modulation(hidden_size, double=True) 149 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 150 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 151 | 152 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 153 | self.txt_mlp = nn.Sequential( 154 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 155 | nn.GELU(approximate="tanh"), 156 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 157 | ) 158 | 159 | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: 160 | img_mod1, img_mod2 = self.img_mod(vec) 161 | txt_mod1, txt_mod2 = self.txt_mod(vec) 162 | 163 | # prepare image for attention 164 | img_modulated = self.img_norm1(img) 165 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 166 | img_qkv = self.img_attn.qkv(img_modulated) 167 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 168 | 169 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 170 | # prepare txt for attention 171 | txt_modulated = self.txt_norm1(txt) 172 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 173 | txt_qkv = self.txt_attn.qkv(txt_modulated) 174 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 175 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 176 | 177 | # run actual attention 178 | q = torch.cat((txt_q, img_q), dim=2) 179 | k = torch.cat((txt_k, img_k), dim=2) 180 | v = torch.cat((txt_v, img_v), dim=2) 181 | # import pdb;pdb.set_trace() 182 | attn = attention(q, k, v, pe=pe) 183 | 184 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 185 | 186 | # calculate the img bloks 187 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 188 | img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 189 | 190 | # calculate the txt bloks 191 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 192 | txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 193 | return img, txt 194 | class SingleStreamBlock(nn.Module): 195 | """ 196 | A DiT block with parallel linear layers as described in 197 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 198 | """ 199 | 200 | def __init__( 201 | self, 202 | hidden_size: int, 203 | num_heads: int, 204 | mlp_ratio: float = 4.0, 205 | qk_scale: float | None = None, 206 | ): 207 | super().__init__() 208 | self.hidden_dim = hidden_size 209 | self.num_heads = num_heads 210 | head_dim = hidden_size // num_heads 211 | self.scale = qk_scale or head_dim**-0.5 212 | 213 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 214 | # qkv and mlp_in 215 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 216 | # proj and mlp_out 217 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 218 | 219 | self.norm = QKNorm(head_dim) 220 | 221 | self.hidden_size = hidden_size 222 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 223 | 224 | self.mlp_act = nn.GELU(approximate="tanh") 225 | self.modulation = Modulation(hidden_size, double=False) 226 | 227 | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 228 | mod, _ = self.modulation(vec) 229 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 230 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 231 | 232 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 233 | q, k = self.norm(q, k, v) 234 | 235 | # compute attention 236 | attn = attention(q, k, v, pe=pe) 237 | # compute activation in mlp stream, cat again and run second linear layer 238 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 239 | return x + mod.gate * output 240 | 241 | 242 | class LastLayer(nn.Module): 243 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 244 | super().__init__() 245 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 246 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 247 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 248 | 249 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 250 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 251 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 252 | x = self.linear(x) 253 | return x 254 | 255 | class LastLayer(nn.Module): 256 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 257 | super().__init__() 258 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 259 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 260 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 261 | 262 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 263 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 264 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 265 | x = self.linear(x) 266 | return x 267 | 268 | 269 | class DoubleStreamBlock_kv(DoubleStreamBlock): 270 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 271 | super().__init__(hidden_size, num_heads, mlp_ratio, qkv_bias) 272 | 273 | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]: 274 | img_mod1, img_mod2 = self.img_mod(vec) 275 | txt_mod1, txt_mod2 = self.txt_mod(vec) 276 | 277 | # prepare image for attention 278 | img_modulated = self.img_norm1(img) 279 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 280 | img_qkv = self.img_attn.qkv(img_modulated) 281 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 282 | 283 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 284 | # prepare txt for attention 285 | txt_modulated = self.txt_norm1(txt) 286 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 287 | txt_qkv = self.txt_attn.qkv(txt_modulated) 288 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 289 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 290 | 291 | feature_k_name = str(info['t']) + '_' + str(info['id']) + '_' + 'MB' + '_' + 'K' 292 | feature_v_name = str(info['t']) + '_' + str(info['id']) + '_' + 'MB' + '_' + 'V' 293 | if info['inverse']: 294 | info['feature'][feature_k_name] = img_k.cpu() 295 | info['feature'][feature_v_name] = img_v.cpu() 296 | q = torch.cat((txt_q, img_q), dim=2) 297 | k = torch.cat((txt_k, img_k), dim=2) 298 | v = torch.cat((txt_v, img_v), dim=2) 299 | if 'attention_mask' in info: 300 | attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask']) 301 | else: 302 | attn = attention(q, k, v, pe=pe) 303 | 304 | else: 305 | source_img_k = info['feature'][feature_k_name].to(img.device) 306 | source_img_v = info['feature'][feature_v_name].to(img.device) 307 | 308 | mask_indices = info['mask_indices'] 309 | source_img_k[:, :, mask_indices, ...] = img_k 310 | source_img_v[:, :, mask_indices, ...] = img_v 311 | 312 | q = torch.cat((txt_q, img_q), dim=2) 313 | k = torch.cat((txt_k, source_img_k), dim=2) 314 | v = torch.cat((txt_v, source_img_v), dim=2) 315 | attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale']) 316 | 317 | 318 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 319 | 320 | # calculate the img bloks 321 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 322 | img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 323 | 324 | # calculate the txt bloks 325 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 326 | txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 327 | return img, txt 328 | 329 | class SingleStreamBlock_kv(SingleStreamBlock): 330 | """ 331 | A DiT block with parallel linear layers as described in 332 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 333 | """ 334 | 335 | def __init__( 336 | self, 337 | hidden_size: int, 338 | num_heads: int, 339 | mlp_ratio: float = 4.0, 340 | qk_scale: float | None = None, 341 | ): 342 | super().__init__(hidden_size, num_heads, mlp_ratio, qk_scale) 343 | 344 | def forward(self,x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor: 345 | mod, _ = self.modulation(vec) 346 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 347 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 348 | 349 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 350 | q, k = self.norm(q, k, v) 351 | img_k = k[:, :, 512:, ...] 352 | img_v = v[:, :, 512:, ...] 353 | 354 | txt_k = k[:, :, :512, ...] 355 | txt_v = v[:, :, :512, ...] 356 | 357 | 358 | feature_k_name = str(info['t']) + '_' + str(info['id']) + '_' + 'SB' + '_' + 'K' 359 | feature_v_name = str(info['t']) + '_' + str(info['id']) + '_' + 'SB' + '_' + 'V' 360 | if info['inverse']: 361 | info['feature'][feature_k_name] = img_k.cpu() 362 | info['feature'][feature_v_name] = img_v.cpu() 363 | if 'attention_mask' in info: 364 | attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask']) 365 | else: 366 | attn = attention(q, k, v, pe=pe) 367 | 368 | else: 369 | source_img_k = info['feature'][feature_k_name].to(x.device) 370 | source_img_v = info['feature'][feature_v_name].to(x.device) 371 | 372 | mask_indices = info['mask_indices'] 373 | source_img_k[:, :, mask_indices, ...] = img_k 374 | source_img_v[:, :, mask_indices, ...] = img_v 375 | 376 | k = torch.cat((txt_k, source_img_k), dim=2) 377 | v = torch.cat((txt_v, source_img_v), dim=2) 378 | attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale']) 379 | 380 | # compute activation in mlp stream, cat again and run second linear layer 381 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 382 | return x + mod.gate * output -------------------------------------------------------------------------------- /flux/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from torch import Tensor 7 | 8 | from .model import Flux,Flux_kv 9 | from .modules.conditioner import HFEmbedder 10 | from tqdm import tqdm 11 | from tqdm.contrib import tzip 12 | 13 | def get_noise( 14 | num_samples: int, 15 | height: int, 16 | width: int, 17 | device: torch.device, 18 | dtype: torch.dtype, 19 | seed: int, 20 | ): 21 | return torch.randn( 22 | num_samples, 23 | 16, 24 | # allow for packing 25 | 2 * math.ceil(height / 16), 26 | 2 * math.ceil(width / 16), 27 | device=device, 28 | dtype=dtype, 29 | generator=torch.Generator(device=device).manual_seed(seed), 30 | ) 31 | 32 | 33 | def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: 34 | bs, c, h, w = img.shape 35 | if bs == 1 and not isinstance(prompt, str): 36 | bs = len(prompt) 37 | 38 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 39 | if img.shape[0] == 1 and bs > 1: 40 | img = repeat(img, "1 ... -> bs ...", bs=bs) 41 | 42 | img_ids = torch.zeros(h // 2, w // 2, 3) 43 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 44 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 45 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 46 | 47 | if isinstance(prompt, str): 48 | prompt = [prompt] 49 | txt = t5(prompt) 50 | if txt.shape[0] == 1 and bs > 1: 51 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 52 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 53 | 54 | vec = clip(prompt) 55 | if vec.shape[0] == 1 and bs > 1: 56 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 57 | 58 | return { 59 | "img": img, 60 | "img_ids": img_ids.to(img.device), 61 | "txt": txt.to(img.device), 62 | "txt_ids": txt_ids.to(img.device), 63 | "vec": vec.to(img.device), 64 | } 65 | 66 | def time_shift(mu: float, sigma: float, t: Tensor): 67 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 68 | 69 | 70 | def get_lin_function( 71 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 72 | ) -> Callable[[float], float]: 73 | m = (y2 - y1) / (x2 - x1) 74 | b = y1 - m * x1 75 | return lambda x: m * x + b 76 | 77 | 78 | def get_schedule( 79 | num_steps: int, 80 | image_seq_len: int, 81 | base_shift: float = 0.5, 82 | max_shift: float = 1.15, 83 | shift: bool = True, 84 | ) -> list[float]: 85 | # extra step for zero 86 | timesteps = torch.linspace(1, 0, num_steps + 1) 87 | 88 | # shifting the schedule to favor high timesteps for higher signal images 89 | if shift: 90 | # estimate mu based on linear estimation between two points 91 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 92 | timesteps = time_shift(mu, 1.0, timesteps) 93 | 94 | return timesteps.tolist() 95 | 96 | 97 | def denoise( 98 | model: Flux, 99 | # model input 100 | img: Tensor, 101 | img_ids: Tensor, 102 | txt: Tensor, 103 | txt_ids: Tensor, 104 | vec: Tensor, 105 | # sampling parameters 106 | timesteps: list[float], 107 | guidance: float = 4.0, 108 | ): 109 | # this is ignored for schnell 110 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 111 | for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): 112 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 113 | pred = model( 114 | img=img, 115 | img_ids=img_ids, 116 | txt=txt, 117 | txt_ids=txt_ids, 118 | y=vec, 119 | timesteps=t_vec, 120 | guidance=guidance_vec, 121 | ) 122 | 123 | img = img + (t_prev - t_curr) * pred 124 | 125 | return img 126 | 127 | def unpack(x: Tensor, height: int, width: int) -> Tensor: 128 | return rearrange( 129 | x, 130 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 131 | h=math.ceil(height / 16), 132 | w=math.ceil(width / 16), 133 | ph=2, 134 | pw=2, 135 | ) 136 | 137 | def denoise_kv( 138 | model: Flux_kv, 139 | # model input 140 | img: Tensor, 141 | img_ids: Tensor, 142 | txt: Tensor, 143 | txt_ids: Tensor, 144 | vec: Tensor, 145 | # sampling parameters 146 | timesteps: list[float], 147 | inverse, 148 | info, 149 | guidance: float = 4.0 150 | ): 151 | 152 | if inverse: 153 | timesteps = timesteps[::-1] 154 | 155 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 156 | 157 | for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): 158 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 159 | info['t'] = t_prev if inverse else t_curr 160 | 161 | if inverse: 162 | img_name = str(info['t']) + '_' + 'img' 163 | info['feature'][img_name] = img.cpu() 164 | else: 165 | img_name = str(info['t']) + '_' + 'img' 166 | source_img = info['feature'][img_name].to(img.device) 167 | img = source_img[:, info['mask_indices'],...] * (1 - info['mask'][:, info['mask_indices'],...]) + img * info['mask'][:, info['mask_indices'],...] 168 | pred = model( 169 | img=img, 170 | img_ids=img_ids, 171 | txt=txt, 172 | txt_ids=txt_ids, 173 | y=vec, 174 | timesteps=t_vec, 175 | guidance=guidance_vec, 176 | info=info 177 | ) 178 | img = img + (t_prev - t_curr) * pred 179 | return img, info 180 | 181 | def denoise_kv_inf( 182 | model: Flux_kv, 183 | # model input 184 | img: Tensor, 185 | img_ids: Tensor, 186 | source_txt: Tensor, 187 | source_txt_ids: Tensor, 188 | source_vec: Tensor, 189 | target_txt: Tensor, 190 | target_txt_ids: Tensor, 191 | target_vec: Tensor, 192 | # sampling parameters 193 | timesteps: list[float], 194 | target_guidance: float = 4.0, 195 | source_guidance: float = 4.0, 196 | info: dict = {}, 197 | ): 198 | 199 | target_guidance_vec = torch.full((img.shape[0],), target_guidance, device=img.device, dtype=img.dtype) 200 | source_guidance_vec = torch.full((img.shape[0],), source_guidance, device=img.device, dtype=img.dtype) 201 | 202 | mask_indices = info['mask_indices'] 203 | init_img = img.clone() 204 | z_fe = img[:, mask_indices,...] 205 | 206 | noise_list = [] 207 | for i in range(len(timesteps)): 208 | noise = torch.randn(init_img.size(), dtype=init_img.dtype, 209 | layout=init_img.layout, device=init_img.device, 210 | generator=torch.Generator(device=init_img.device).manual_seed(0)) 211 | noise_list.append(noise) 212 | 213 | for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): 214 | 215 | info['t'] = t_curr 216 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 217 | 218 | z_src = (1 - t_curr) * init_img + t_curr * noise_list[i] 219 | z_tar = z_src[:, mask_indices,...] - init_img[:, mask_indices,...] + z_fe 220 | 221 | info['inverse'] = True 222 | info['feature'] = {} 223 | v_src = model( 224 | img=z_src, 225 | img_ids=img_ids, 226 | txt=source_txt, 227 | txt_ids=source_txt_ids, 228 | y=source_vec, 229 | timesteps=t_vec, 230 | guidance=source_guidance_vec, 231 | info=info 232 | ) 233 | 234 | info['inverse'] = False 235 | v_tar = model( 236 | img=z_tar, 237 | img_ids=img_ids, 238 | txt=target_txt, 239 | txt_ids=target_txt_ids, 240 | y=target_vec, 241 | timesteps=t_vec, 242 | guidance=target_guidance_vec, 243 | info=info 244 | ) 245 | 246 | v_fe = v_tar - v_src[:, mask_indices,...] 247 | z_fe = z_fe + (t_prev - t_curr) * v_fe * info['mask'][:, mask_indices,...] 248 | return z_fe, info 249 | -------------------------------------------------------------------------------- /flux/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from einops import rearrange 6 | from huggingface_hub import hf_hub_download 7 | from imwatermark import WatermarkEncoder 8 | from safetensors.torch import load_file as load_sft 9 | 10 | from flux.model import Flux, FluxParams 11 | from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams 12 | from flux.modules.conditioner import HFEmbedder 13 | 14 | 15 | @dataclass 16 | class ModelSpec: 17 | params: FluxParams 18 | ae_params: AutoEncoderParams 19 | ckpt_path: str | None 20 | ae_path: str | None 21 | repo_id: str | None 22 | repo_flow: str | None 23 | repo_ae: str | None 24 | 25 | configs = { 26 | "flux-dev": ModelSpec( 27 | repo_id="black-forest-labs/FLUX.1-dev", 28 | repo_flow="flux1-dev.safetensors", 29 | repo_ae="ae.safetensors", 30 | ckpt_path=os.getenv("FLUX_DEV"), 31 | params=FluxParams( 32 | in_channels=64, 33 | vec_in_dim=768, 34 | context_in_dim=4096, 35 | hidden_size=3072, 36 | mlp_ratio=4.0, 37 | num_heads=24, 38 | depth=19, 39 | depth_single_blocks=38, 40 | axes_dim=[16, 56, 56], 41 | theta=10_000, 42 | qkv_bias=True, 43 | guidance_embed=True, 44 | ), 45 | ae_path=os.getenv("AE"), 46 | ae_params=AutoEncoderParams( 47 | resolution=256, 48 | in_channels=3, 49 | ch=128, 50 | out_ch=3, 51 | ch_mult=[1, 2, 4, 4], 52 | num_res_blocks=2, 53 | z_channels=16, 54 | scale_factor=0.3611, 55 | shift_factor=0.1159, 56 | ), 57 | ), 58 | "flux-schnell": ModelSpec( 59 | repo_id="black-forest-labs/FLUX.1-schnell", 60 | repo_flow="flux1-schnell.safetensors", 61 | repo_ae="ae.safetensors", 62 | ckpt_path=os.getenv("FLUX_SCHNELL"), 63 | params=FluxParams( 64 | in_channels=64, 65 | vec_in_dim=768, 66 | context_in_dim=4096, 67 | hidden_size=3072, 68 | mlp_ratio=4.0, 69 | num_heads=24, 70 | depth=19, 71 | depth_single_blocks=38, 72 | axes_dim=[16, 56, 56], 73 | theta=10_000, 74 | qkv_bias=True, 75 | guidance_embed=False, 76 | ), 77 | ae_path=os.getenv("AE"), 78 | ae_params=AutoEncoderParams( 79 | resolution=256, 80 | in_channels=3, 81 | ch=128, 82 | out_ch=3, 83 | ch_mult=[1, 2, 4, 4], 84 | num_res_blocks=2, 85 | z_channels=16, 86 | scale_factor=0.3611, 87 | shift_factor=0.1159, 88 | ), 89 | ), 90 | } 91 | 92 | 93 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 94 | if len(missing) > 0 and len(unexpected) > 0: 95 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 96 | print("\n" + "-" * 79 + "\n") 97 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 98 | elif len(missing) > 0: 99 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 100 | elif len(unexpected) > 0: 101 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 102 | 103 | 104 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True, flux_cls=Flux) -> Flux: 105 | # Loading Flux 106 | print("Init model") 107 | 108 | ckpt_path = configs[name].ckpt_path 109 | if ( 110 | ckpt_path is None 111 | and configs[name].repo_id is not None 112 | and configs[name].repo_flow is not None 113 | and hf_download 114 | ): 115 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 116 | 117 | with torch.device("meta" if ckpt_path is not None else device): 118 | model = flux_cls(configs[name].params).to(torch.bfloat16) 119 | 120 | if ckpt_path is not None: 121 | print("Loading checkpoint") 122 | # load_sft doesn't support torch.device 123 | sd = load_sft(ckpt_path, device=str(device)) 124 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 125 | print_load_warning(missing, unexpected) 126 | return model 127 | 128 | 129 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 130 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 131 | # return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device) 132 | return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device) 133 | 134 | 135 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 136 | # return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device) 137 | return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device) 138 | 139 | 140 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: 141 | ckpt_path = configs[name].ae_path 142 | if ( 143 | ckpt_path is None 144 | and configs[name].repo_id is not None 145 | and configs[name].repo_ae is not None 146 | and hf_download 147 | ): 148 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) 149 | 150 | # Loading the autoencoder 151 | print("Init AE") 152 | with torch.device("meta" if ckpt_path is not None else device): 153 | ae = AutoEncoder(configs[name].ae_params) 154 | 155 | if ckpt_path is not None: 156 | sd = load_sft(ckpt_path, device=str(device)) 157 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 158 | print_load_warning(missing, unexpected) 159 | return ae 160 | 161 | 162 | class WatermarkEmbedder: 163 | def __init__(self, watermark): 164 | self.watermark = watermark 165 | self.num_bits = len(WATERMARK_BITS) 166 | self.encoder = WatermarkEncoder() 167 | self.encoder.set_watermark("bits", self.watermark) 168 | 169 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 170 | """ 171 | Adds a predefined watermark to the input image 172 | 173 | Args: 174 | image: ([N,] B, RGB, H, W) in range [-1, 1] 175 | 176 | Returns: 177 | same as input but watermarked 178 | """ 179 | image = 0.5 * image + 0.5 180 | squeeze = len(image.shape) == 4 181 | if squeeze: 182 | image = image[None, ...] 183 | n = image.shape[0] 184 | image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] 185 | # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] 186 | # watermarking libary expects input as cv2 BGR format 187 | for k in range(image_np.shape[0]): 188 | image_np[k] = self.encoder.encode(image_np[k], "dwtDct") 189 | image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( 190 | image.device 191 | ) 192 | image = torch.clamp(image / 255, min=0.0, max=1.0) 193 | if squeeze: 194 | image = image[0] 195 | image = 2 * image - 1 196 | return image 197 | 198 | 199 | # A fixed 48-bit message that was chosen at random 200 | WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 201 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 202 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] 203 | embed_watermark = WatermarkEmbedder(WATERMARK_BITS) 204 | -------------------------------------------------------------------------------- /gradio_kv_edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from dataclasses import dataclass 5 | from glob import iglob 6 | import argparse 7 | from einops import rearrange 8 | from PIL import ExifTags, Image 9 | import torch 10 | import gradio as gr 11 | import numpy as np 12 | from flux.sampling import prepare 13 | from flux.util import (configs, load_ae, load_clip, load_t5) 14 | from models.kv_edit import Flux_kv_edit 15 | 16 | @dataclass 17 | class SamplingOptions: 18 | source_prompt: str = '' 19 | target_prompt: str = '' 20 | width: int = 1366 21 | height: int = 768 22 | inversion_num_steps: int = 0 23 | denoise_num_steps: int = 0 24 | skip_step: int = 0 25 | inversion_guidance: float = 1.0 26 | denoise_guidance: float = 1.0 27 | seed: int = 42 28 | re_init: bool = False 29 | attn_mask: bool = False 30 | attn_scale: float = 1.0 31 | 32 | class FluxEditor_kv_demo: 33 | def __init__(self, args): 34 | self.args = args 35 | self.device = torch.device(args.device) 36 | self.offload = args.offload 37 | 38 | self.name = args.name 39 | self.is_schnell = args.name == "flux-schnell" 40 | 41 | self.output_dir = 'regress_result' 42 | 43 | self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512) 44 | self.clip = load_clip(self.device) 45 | self.model = Flux_kv_edit(device="cpu" if self.offload else self.device, name=self.name) 46 | self.ae = load_ae(self.name, device="cpu" if self.offload else self.device) 47 | 48 | self.t5.eval() 49 | self.clip.eval() 50 | self.ae.eval() 51 | self.model.eval() 52 | self.info = {} 53 | if self.offload: 54 | self.model.cpu() 55 | torch.cuda.empty_cache() 56 | self.ae.encoder.to(self.device) 57 | 58 | @torch.inference_mode() 59 | def inverse(self, brush_canvas, 60 | source_prompt, target_prompt, 61 | inversion_num_steps, denoise_num_steps, 62 | skip_step, 63 | inversion_guidance, denoise_guidance,seed, 64 | re_init, attn_mask 65 | ): 66 | self.z0 = None 67 | self.zt = None 68 | # self.info = {} 69 | # gc.collect() 70 | if 'feature' in self.info: 71 | key_list = list(self.info['feature'].keys()) 72 | for key in key_list: 73 | del self.info['feature'][key] 74 | self.info = {} 75 | 76 | rgba_init_image = brush_canvas["background"] 77 | init_image = rgba_init_image[:,:,:3] 78 | shape = init_image.shape 79 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 80 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 81 | init_image = init_image[:height, :width, :] 82 | rgba_init_image = rgba_init_image[:height, :width, :] 83 | 84 | opts = SamplingOptions( 85 | source_prompt=source_prompt, 86 | target_prompt=target_prompt, 87 | width=width, 88 | height=height, 89 | inversion_num_steps=inversion_num_steps, 90 | denoise_num_steps=denoise_num_steps, 91 | skip_step=0,# no skip step in inverse leads chance to adjest skip_step in edit 92 | inversion_guidance=inversion_guidance, 93 | denoise_guidance=denoise_guidance, 94 | seed=seed, 95 | re_init=re_init, 96 | attn_mask=attn_mask 97 | ) 98 | torch.manual_seed(opts.seed) 99 | if torch.cuda.is_available(): 100 | torch.cuda.manual_seed_all(opts.seed) 101 | torch.cuda.empty_cache() 102 | 103 | if opts.attn_mask: 104 | rgba_mask = brush_canvas["layers"][0][:height, :width, :] 105 | mask = rgba_mask[:,:,3]/255 106 | mask = mask.astype(int) 107 | 108 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device) 109 | else: 110 | mask = None 111 | 112 | self.init_image = self.encode(init_image, self.device).to(self.device) 113 | 114 | t0 = time.perf_counter() 115 | 116 | if self.offload: 117 | self.ae = self.ae.cpu() 118 | torch.cuda.empty_cache() 119 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 120 | 121 | with torch.no_grad(): 122 | inp = prepare(self.t5, self.clip,self.init_image, prompt=opts.source_prompt) 123 | 124 | if self.offload: 125 | self.t5, self.clip = self.t5.cpu(), self.clip.cpu() 126 | torch.cuda.empty_cache() 127 | self.model = self.model.to(self.device) 128 | self.z0,self.zt,self.info = self.model.inverse(inp,mask,opts) 129 | 130 | if self.offload: 131 | self.model.cpu() 132 | torch.cuda.empty_cache() 133 | 134 | t1 = time.perf_counter() 135 | print(f"inversion Done in {t1 - t0:.1f}s.") 136 | return None 137 | 138 | 139 | 140 | @torch.inference_mode() 141 | def edit(self, brush_canvas, 142 | source_prompt, target_prompt, 143 | inversion_num_steps, denoise_num_steps, 144 | skip_step, 145 | inversion_guidance, denoise_guidance,seed, 146 | re_init, attn_mask,attn_scale 147 | ): 148 | 149 | torch.cuda.empty_cache() 150 | 151 | rgba_init_image = brush_canvas["background"] 152 | init_image = rgba_init_image[:,:,:3] 153 | shape = init_image.shape 154 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 155 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 156 | init_image = init_image[:height, :width, :] 157 | rgba_init_image = rgba_init_image[:height, :width, :] 158 | 159 | rgba_mask = brush_canvas["layers"][0][:height, :width, :] 160 | mask = rgba_mask[:,:,3]/255 161 | mask = mask.astype(int) 162 | 163 | rgba_mask[:,:,3] = rgba_mask[:,:,3]//2 164 | masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA')) 165 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device) 166 | 167 | seed = int(seed) 168 | if seed == -1: 169 | seed = torch.randint(0, 2**32, (1,)).item() 170 | opts = SamplingOptions( 171 | source_prompt=source_prompt, 172 | target_prompt=target_prompt, 173 | width=width, 174 | height=height, 175 | inversion_num_steps=inversion_num_steps, 176 | denoise_num_steps=denoise_num_steps, 177 | skip_step=skip_step, 178 | inversion_guidance=inversion_guidance, 179 | denoise_guidance=denoise_guidance, 180 | seed=seed, 181 | re_init=re_init, 182 | attn_mask=attn_mask, 183 | attn_scale=attn_scale 184 | ) 185 | if self.offload: 186 | 187 | torch.cuda.empty_cache() 188 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 189 | 190 | torch.manual_seed(opts.seed) 191 | if torch.cuda.is_available(): 192 | torch.cuda.manual_seed_all(opts.seed) 193 | 194 | t0 = time.perf_counter() 195 | 196 | with torch.no_grad(): 197 | inp_target = prepare(self.t5, self.clip, self.init_image, prompt=opts.target_prompt) 198 | 199 | if self.offload: 200 | self.t5, self.clip = self.t5.cpu(), self.clip.cpu() 201 | torch.cuda.empty_cache() 202 | self.model = self.model.to(self.device) 203 | 204 | x = self.model.denoise(self.z0,self.zt,inp_target,mask,opts,self.info) 205 | 206 | if self.offload: 207 | self.model.cpu() 208 | torch.cuda.empty_cache() 209 | self.ae.decoder.to(x.device) 210 | 211 | with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): 212 | x = self.ae.decode(x.to(self.device)) 213 | 214 | x = x.clamp(-1, 1) 215 | x = x.float().cpu() 216 | x = rearrange(x[0], "c h w -> h w c") 217 | 218 | if torch.cuda.is_available(): 219 | torch.cuda.synchronize() 220 | 221 | output_name = os.path.join(self.output_dir, "img_{idx}.jpg") 222 | if not os.path.exists(self.output_dir): 223 | os.makedirs(self.output_dir) 224 | idx = 0 225 | else: 226 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] 227 | if len(fns) > 0: 228 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 229 | else: 230 | idx = 0 231 | 232 | fn = output_name.format(idx=idx) 233 | 234 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) 235 | exif_data = Image.Exif() 236 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" 237 | exif_data[ExifTags.Base.Make] = "Black Forest Labs" 238 | exif_data[ExifTags.Base.Model] = self.name 239 | 240 | exif_data[ExifTags.Base.ImageDescription] = target_prompt 241 | img.save(fn, exif=exif_data, quality=95, subsampling=0) 242 | masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG') 243 | t1 = time.perf_counter() 244 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}") 245 | 246 | print("End Edit") 247 | return img 248 | 249 | 250 | @torch.inference_mode() 251 | def encode(self,init_image, torch_device): 252 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1 253 | init_image = init_image.unsqueeze(0) 254 | init_image = init_image.to(torch_device) 255 | self.ae.encoder.to(torch_device) 256 | 257 | init_image = self.ae.encode(init_image).to(torch.bfloat16) 258 | return init_image 259 | 260 | def create_demo(model_name: str): 261 | editor = FluxEditor_kv_demo(args) 262 | is_schnell = model_name == "flux-schnell" 263 | 264 | title = r""" 265 |

🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation

266 | """ 267 | 268 | description = r""" 269 | Official 🤗 Gradio demo for KV-Edit: Training-Free Image Editing for Precise Background Preservation.
270 | 271 | 💫💫 Here is editing steps:
272 | 1️⃣ Upload your image that needs to be edited.
273 | 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion.
274 | 3️⃣ Use the brush tool to draw your mask area.
275 | 4️⃣ Fill in your target prompt, then adjust the hyperparameters.
276 | 5️⃣ Click the "Edit" button to generate your edited image!
277 | 278 | 🔔🔔 [Important] Less skip steps, "re_init" and "attn_mask" will enhance the editing performance, making the results more aligned with your text but may lead to discontinuous images.
279 | If you fail because of these three, we recommend trying to increase "attn_scale" to increase attention between mask and background.
280 | """ 281 | article = r""" 282 | If our work is helpful, please help to ⭐ the Github Repo. Thanks! 283 | """ 284 | 285 | badge = r""" 286 | [![GitHub Stars](https://img.shields.io/github/stars/Xilluill/KV-Edit)](https://github.com/Xilluill/KV-Edit) 287 | """ 288 | 289 | with gr.Blocks() as demo: 290 | gr.HTML(title) 291 | gr.Markdown(description) 292 | 293 | with gr.Row(): 294 | with gr.Column(): 295 | source_prompt = gr.Textbox(label="Source Prompt", value='' ) 296 | inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps") 297 | target_prompt = gr.Textbox(label="Target Prompt", value='' ) 298 | denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps") 299 | brush_canvas = gr.ImageEditor(label="Brush Canvas", 300 | sources=('upload'), 301 | brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'), 302 | interactive=True, 303 | transforms=[], 304 | container=True, 305 | format='png',scale=1) 306 | 307 | inv_btn = gr.Button("inverse") 308 | edit_btn = gr.Button("edit") 309 | 310 | 311 | with gr.Column(): 312 | with gr.Accordion("Advanced Options", open=True): 313 | 314 | skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps") 315 | inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell) 316 | denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell) 317 | attn_scale = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale") 318 | seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True) 319 | with gr.Row(): 320 | re_init = gr.Checkbox(label="re_init", value=False) 321 | attn_mask = gr.Checkbox(label="attn_mask", value=False) 322 | 323 | 324 | output_image = gr.Image(label="Generated Image") 325 | gr.Markdown(article) 326 | inv_btn.click( 327 | fn=editor.inverse, 328 | inputs=[brush_canvas, 329 | source_prompt, target_prompt, 330 | inversion_num_steps, denoise_num_steps, 331 | skip_step, 332 | inversion_guidance, 333 | denoise_guidance,seed, 334 | re_init, attn_mask 335 | ], 336 | outputs=[output_image] 337 | ) 338 | edit_btn.click( 339 | fn=editor.edit, 340 | inputs=[brush_canvas, 341 | source_prompt, target_prompt, 342 | inversion_num_steps, denoise_num_steps, 343 | skip_step, 344 | inversion_guidance, 345 | denoise_guidance,seed, 346 | re_init, attn_mask,attn_scale 347 | ], 348 | outputs=[output_image] 349 | ) 350 | return demo 351 | 352 | if __name__ == "__main__": 353 | import argparse 354 | parser = argparse.ArgumentParser(description="Flux") 355 | parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name") 356 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") 357 | parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") 358 | parser.add_argument("--share", action="store_true", help="Create a public link to your demo") 359 | parser.add_argument("--port", type=int, default=41032) 360 | args = parser.parse_args() 361 | 362 | demo = create_demo(args.name) 363 | 364 | demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port) -------------------------------------------------------------------------------- /gradio_kv_edit_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from dataclasses import dataclass 5 | from glob import iglob 6 | import argparse 7 | from einops import rearrange 8 | from PIL import ExifTags, Image 9 | import torch 10 | import gradio as gr 11 | import numpy as np 12 | from flux.sampling import prepare 13 | from flux.util import (configs, load_ae, load_clip, load_t5) 14 | from models.kv_edit import Flux_kv_edit 15 | 16 | @dataclass 17 | class SamplingOptions: 18 | source_prompt: str = '' 19 | target_prompt: str = '' 20 | width: int = 1366 21 | height: int = 768 22 | inversion_num_steps: int = 0 23 | denoise_num_steps: int = 0 24 | skip_step: int = 0 25 | inversion_guidance: float = 1.0 26 | denoise_guidance: float = 1.0 27 | seed: int = 42 28 | re_init: bool = False 29 | attn_mask: bool = False 30 | attn_scale: float = 1.0 31 | 32 | def resize_image(image_array, max_width=1360, max_height=768): 33 | # 将numpy数组转换为PIL图像 34 | if image_array.shape[-1] == 4: 35 | mode = 'RGBA' 36 | else: 37 | mode = 'RGB' 38 | 39 | pil_image = Image.fromarray(image_array, mode=mode) 40 | 41 | # 获取原始图像的宽度和高度 42 | original_width, original_height = pil_image.size 43 | 44 | # 计算缩放比例 45 | width_ratio = max_width / original_width 46 | height_ratio = max_height / original_height 47 | 48 | # 选择较小的缩放比例以确保图像不超过最大宽度和高度 49 | scale_ratio = min(width_ratio, height_ratio) 50 | 51 | # 如果图像已经小于或等于最大分辨率,则不进行缩放 52 | if scale_ratio >= 1: 53 | return image_array 54 | 55 | # 计算新的宽度和高度 56 | new_width = int(original_width * scale_ratio) 57 | new_height = int(original_height * scale_ratio) 58 | 59 | # 缩放图像 60 | resized_image = pil_image.resize((new_width, new_height)) 61 | 62 | # 将PIL图像转换回numpy数组 63 | resized_array = np.array(resized_image) 64 | 65 | return resized_array 66 | 67 | class FluxEditor_kv_demo: 68 | def __init__(self, args): 69 | self.args = args 70 | self.gpus = args.gpus 71 | if self.gpus: 72 | self.device = [torch.device("cuda:0"), torch.device("cuda:1")] 73 | else: 74 | self.device = [torch.device(args.device), torch.device(args.device)] 75 | 76 | self.name = args.name 77 | self.is_schnell = args.name == "flux-schnell" 78 | 79 | self.output_dir = 'regress_result' 80 | 81 | self.t5 = load_t5(self.device[1], max_length=256 if self.name == "flux-schnell" else 512) 82 | self.clip = load_clip(self.device[1]) 83 | self.model = Flux_kv_edit(self.device[0], name=self.name) 84 | self.ae = load_ae(self.name, device=self.device[1]) 85 | 86 | self.t5.eval() 87 | self.clip.eval() 88 | self.ae.eval() 89 | self.model.eval() 90 | self.info = {} 91 | @torch.inference_mode() 92 | def inverse(self, brush_canvas, 93 | source_prompt, target_prompt, 94 | inversion_num_steps, denoise_num_steps, 95 | skip_step, 96 | inversion_guidance, denoise_guidance,seed, 97 | re_init, attn_mask 98 | ): 99 | if hasattr(self, 'z0'): 100 | del self.z0 101 | del self.zt 102 | # self.info = {} 103 | # gc.collect() 104 | 105 | if 'feature' in self.info: 106 | key_list = list(self.info['feature'].keys()) 107 | for key in key_list: 108 | del self.info['feature'][key] 109 | self.info = {} 110 | 111 | rgba_init_image = brush_canvas["background"] 112 | init_image = rgba_init_image[:,:,:3] 113 | # init_image = resize_image(init_image) 114 | shape = init_image.shape 115 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 116 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 117 | init_image = init_image[:height, :width, :] 118 | rgba_init_image = rgba_init_image[:height, :width, :] 119 | 120 | opts = SamplingOptions( 121 | source_prompt=source_prompt, 122 | target_prompt=target_prompt, 123 | width=width, 124 | height=height, 125 | inversion_num_steps=inversion_num_steps, 126 | denoise_num_steps=denoise_num_steps, 127 | skip_step=0, # no skip step in inverse leads chance to adjest skip_step in edit 128 | inversion_guidance=inversion_guidance, 129 | denoise_guidance=denoise_guidance, 130 | seed=seed, 131 | re_init=re_init, 132 | attn_mask=attn_mask 133 | ) 134 | torch.manual_seed(opts.seed) 135 | if torch.cuda.is_available(): 136 | torch.cuda.manual_seed_all(opts.seed) 137 | torch.cuda.empty_cache() 138 | 139 | if opts.attn_mask: 140 | # rgba_mask = resize_image(brush_canvas["layers"][0])[:height, :width, :] 141 | rgba_mask = brush_canvas["layers"][0][:height, :width, :] 142 | mask = rgba_mask[:,:,3]/255 143 | mask = mask.astype(int) 144 | 145 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device[0]) 146 | else: 147 | mask = None 148 | 149 | self.init_image = self.encode(init_image, self.device[1]).to(self.device[0]) 150 | 151 | t0 = time.perf_counter() 152 | 153 | with torch.no_grad(): 154 | inp = prepare(self.t5, self.clip,self.init_image, prompt=opts.source_prompt) 155 | self.z0,self.zt,self.info = self.model.inverse(inp,mask,opts) 156 | t1 = time.perf_counter() 157 | print(f"inversion Done in {t1 - t0:.1f}s.") 158 | return None 159 | 160 | 161 | 162 | @torch.inference_mode() 163 | def edit(self, brush_canvas, 164 | source_prompt, target_prompt, 165 | inversion_num_steps, denoise_num_steps, 166 | skip_step, 167 | inversion_guidance, denoise_guidance,seed, 168 | re_init, attn_mask,attn_scale 169 | ): 170 | 171 | torch.cuda.empty_cache() 172 | 173 | rgba_init_image = brush_canvas["background"] 174 | init_image = rgba_init_image[:,:,:3] 175 | shape = init_image.shape 176 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 177 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 178 | init_image = init_image[:height, :width, :] 179 | rgba_init_image = rgba_init_image[:height, :width, :] 180 | # brush_canvas = brush_canvas["composite"][:,:,:3][:height, :width, :] 181 | 182 | # if np.all(brush_canvas[:,:,0] == brush_canvas[:,:,1]) and np.all(brush_canvas[:,:,1] == brush_canvas[:,:,2]): 183 | rgba_mask = brush_canvas["layers"][0][:height, :width, :] 184 | mask = rgba_mask[:,:,3]/255 185 | mask = mask.astype(int) 186 | 187 | 188 | rgba_mask[:,:,3] = rgba_mask[:,:,3]//2 189 | masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA')) 190 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device[0]) 191 | 192 | seed = int(seed) 193 | if seed == -1: 194 | seed = torch.randint(0, 2**32, (1,)).item() 195 | opts = SamplingOptions( 196 | source_prompt=source_prompt, 197 | target_prompt=target_prompt, 198 | width=width, 199 | height=height, 200 | inversion_num_steps=inversion_num_steps, 201 | denoise_num_steps=denoise_num_steps, 202 | skip_step=skip_step, 203 | inversion_guidance=inversion_guidance, 204 | denoise_guidance=denoise_guidance, 205 | seed=seed, 206 | re_init=re_init, 207 | attn_mask=attn_mask, 208 | attn_scale=attn_scale 209 | ) 210 | torch.manual_seed(opts.seed) 211 | if torch.cuda.is_available(): 212 | torch.cuda.manual_seed_all(opts.seed) 213 | 214 | t0 = time.perf_counter() 215 | 216 | 217 | with torch.no_grad(): 218 | inp_target = prepare(self.t5, self.clip, self.init_image, prompt=opts.target_prompt) 219 | 220 | x = self.model.denoise(self.z0.clone(),self.zt,inp_target,mask,opts,self.info) 221 | 222 | with torch.autocast(device_type=self.device[1].type, dtype=torch.bfloat16): 223 | x = self.ae.decode(x.to(self.device[1])) 224 | 225 | x = x.clamp(-1, 1) 226 | x = x.float().cpu() 227 | x = rearrange(x[0], "c h w -> h w c") 228 | 229 | if torch.cuda.is_available(): 230 | torch.cuda.synchronize() 231 | 232 | output_name = os.path.join(self.output_dir, "img_{idx}.jpg") 233 | if not os.path.exists(self.output_dir): 234 | os.makedirs(self.output_dir) 235 | idx = 0 236 | else: 237 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] 238 | if len(fns) > 0: 239 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 240 | else: 241 | idx = 0 242 | 243 | 244 | fn = output_name.format(idx=idx) 245 | 246 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) 247 | exif_data = Image.Exif() 248 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" 249 | exif_data[ExifTags.Base.Make] = "Black Forest Labs" 250 | exif_data[ExifTags.Base.Model] = self.name 251 | exif_data[ExifTags.Base.ImageDescription] = target_prompt 252 | img.save(fn, exif=exif_data, quality=95, subsampling=0) 253 | masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG') 254 | t1 = time.perf_counter() 255 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}") 256 | 257 | print("End Edit") 258 | return img 259 | 260 | 261 | @torch.inference_mode() 262 | def encode(self,init_image, torch_device): 263 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1 264 | init_image = init_image.unsqueeze(0) 265 | init_image = init_image.to(torch_device) 266 | self.ae.encoder.to(torch_device) 267 | 268 | init_image = self.ae.encode(init_image).to(torch.bfloat16) 269 | return init_image 270 | 271 | def create_demo(model_name: str): 272 | editor = FluxEditor_kv_demo(args) 273 | is_schnell = model_name == "flux-schnell" 274 | 275 | title = r""" 276 |

🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation

277 | """ 278 | one = r""" 279 | We recommend that you try our code locally, you can try several different edits after inverting the image only once! 280 | """ 281 | 282 | description = r""" 283 | Official 🤗 Gradio demo for KV-Edit: Training-Free Image Editing for Precise Background Preservation.
284 | 285 | 💫💫 Here is editing steps:
286 | 1️⃣ Upload your image that needs to be edited.
287 | 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion.
288 | 3️⃣ Use the brush tool to draw your mask area.
289 | 4️⃣ Fill in your target prompt, then adjust the hyperparameters.
290 | 5️⃣ Click the "Edit" button to generate your edited image!
291 | 292 | 🔔🔔 [Important] Less skip steps, "re_init" and "attn_mask" will enhance the editing performance, making the results more aligned with your text but may lead to discontinuous images.
293 | If you fail because of these three, we recommend trying to increase "attn_scale" to increase attention between mask and background.
294 | """ 295 | article = r""" 296 | If our work is helpful, please help to ⭐ the Github Repo. Thanks! 297 | """ 298 | 299 | with gr.Blocks() as demo: 300 | gr.HTML(title) 301 | gr.Markdown(description) 302 | 303 | with gr.Row(): 304 | with gr.Column(): 305 | source_prompt = gr.Textbox(label="Source Prompt", value='' ) 306 | inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps") 307 | target_prompt = gr.Textbox(label="Target Prompt", value='' ) 308 | denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps") 309 | # init_image = gr.Image(label="Input Image", visible=True,scale=1) 310 | brush_canvas = gr.ImageEditor(label="Brush Canvas", 311 | sources=('upload'), 312 | brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'), 313 | interactive=True, 314 | transforms=[], 315 | container=True, 316 | format='png') 317 | 318 | inv_btn = gr.Button("inverse") 319 | edit_btn = gr.Button("edit") 320 | 321 | 322 | with gr.Column(): 323 | with gr.Accordion("Advanced Options", open=True): 324 | skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps") 325 | inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell) 326 | denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell) 327 | attn_scale = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale") 328 | seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True) 329 | with gr.Row(): 330 | re_init = gr.Checkbox(label="re_init", value=False) 331 | attn_mask = gr.Checkbox(label="attn_mask", value=False) 332 | 333 | output_image = gr.Image(label="Generated Image") 334 | gr.Markdown(article) 335 | inv_btn.click( 336 | fn=editor.inverse, 337 | inputs=[brush_canvas, 338 | source_prompt, target_prompt, 339 | inversion_num_steps, denoise_num_steps, 340 | skip_step, 341 | inversion_guidance, 342 | denoise_guidance,seed, 343 | re_init, attn_mask 344 | ], 345 | outputs=[output_image] 346 | ) 347 | edit_btn.click( 348 | fn=editor.edit, 349 | inputs=[ brush_canvas, 350 | source_prompt, target_prompt, 351 | inversion_num_steps, denoise_num_steps, 352 | skip_step, 353 | inversion_guidance, 354 | denoise_guidance,seed, 355 | re_init, attn_mask,attn_scale 356 | ], 357 | outputs=[output_image] 358 | ) 359 | return demo 360 | 361 | if __name__ == "__main__": 362 | import argparse 363 | parser = argparse.ArgumentParser(description="Flux") 364 | parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name") 365 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") 366 | parser.add_argument("--gpus", action="store_true", help="2 gpu to use") 367 | parser.add_argument("--share", action="store_true", help="Create a public link to your demo") 368 | parser.add_argument("--port", type=int, default=41032) 369 | args = parser.parse_args() 370 | 371 | demo = create_demo(args.name) 372 | 373 | demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port) -------------------------------------------------------------------------------- /gradio_kv_edit_inf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from dataclasses import dataclass 5 | from glob import iglob 6 | import argparse 7 | from einops import rearrange 8 | from PIL import ExifTags, Image 9 | import torch 10 | import gradio as gr 11 | import numpy as np 12 | from flux.sampling import prepare, unpack 13 | from flux.util import (configs, load_ae, load_clip, load_t5) 14 | from models.kv_edit import Flux_kv_edit_inf 15 | 16 | @dataclass 17 | class SamplingOptions: 18 | source_prompt: str = '' 19 | target_prompt: str = '' 20 | width: int = 1366 21 | height: int = 768 22 | inversion_num_steps: int = 0 23 | denoise_num_steps: int = 0 24 | skip_step: int = 0 25 | inversion_guidance: float = 1.0 26 | denoise_guidance: float = 1.0 27 | seed: int = 42 28 | attn_mask: bool = False 29 | attn_scale: float = 1.0 30 | 31 | class FluxEditor_kv_demo: 32 | def __init__(self, args): 33 | self.args = args 34 | self.device = torch.device(args.device) 35 | self.offload = args.offload 36 | self.name = args.name 37 | self.is_schnell = args.name == "flux-schnell" 38 | 39 | self.output_dir = 'regress_result' 40 | 41 | self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512) 42 | self.clip = load_clip(self.device) 43 | self.model = Flux_kv_edit_inf(device="cpu" if self.offload else self.device, name=self.name) 44 | self.ae = load_ae(self.name, device="cpu" if self.offload else self.device) 45 | 46 | self.t5.eval() 47 | self.clip.eval() 48 | self.ae.eval() 49 | self.model.eval() 50 | 51 | if self.offload: 52 | self.model.cpu() 53 | torch.cuda.empty_cache() 54 | self.ae.encoder.to(self.device) 55 | 56 | @torch.inference_mode() 57 | def edit(self, brush_canvas, 58 | source_prompt, target_prompt, 59 | inversion_num_steps, denoise_num_steps, 60 | skip_step, 61 | inversion_guidance, denoise_guidance,seed, 62 | attn_mask,attn_scale 63 | ): 64 | 65 | torch.cuda.empty_cache() 66 | 67 | rgba_init_image = brush_canvas["background"] 68 | init_image = rgba_init_image[:,:,:3] 69 | shape = init_image.shape 70 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 71 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 72 | init_image = init_image[:height, :width, :] 73 | rgba_init_image = rgba_init_image[:height, :width, :] 74 | 75 | rgba_mask = brush_canvas["layers"][0][:height, :width, :] 76 | mask = rgba_mask[:,:,3]/255 77 | mask = mask.astype(int) 78 | 79 | rgba_mask[:,:,3] = rgba_mask[:,:,3]//2 80 | masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA')) 81 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device) 82 | 83 | init_image = self.encode(init_image, self.device).to(self.device) 84 | 85 | seed = int(seed) 86 | if seed == -1: 87 | seed = torch.randint(0, 2**32, (1,)).item() 88 | opts = SamplingOptions( 89 | source_prompt=source_prompt, 90 | target_prompt=target_prompt, 91 | width=width, 92 | height=height, 93 | inversion_num_steps=inversion_num_steps, 94 | denoise_num_steps=denoise_num_steps, 95 | skip_step=skip_step, 96 | inversion_guidance=inversion_guidance, 97 | denoise_guidance=denoise_guidance, 98 | seed=seed, 99 | attn_mask=attn_mask, 100 | attn_scale=attn_scale 101 | ) 102 | 103 | if self.offload: 104 | self.ae = self.ae.cpu() 105 | torch.cuda.empty_cache() 106 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 107 | 108 | torch.manual_seed(opts.seed) 109 | if torch.cuda.is_available(): 110 | torch.cuda.manual_seed_all(opts.seed) 111 | 112 | t0 = time.perf_counter() 113 | 114 | with torch.no_grad(): 115 | inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt) 116 | inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt) 117 | 118 | if self.offload: 119 | self.t5, self.clip = self.t5.cpu(), self.clip.cpu() 120 | torch.cuda.empty_cache() 121 | self.model = self.model.to(self.device) 122 | 123 | x = self.model(inp, inp_target, mask, opts) 124 | 125 | if self.offload: 126 | self.model.cpu() 127 | torch.cuda.empty_cache() 128 | self.ae.decoder.to(x.device) 129 | 130 | with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): 131 | x = self.ae.decode(x) 132 | 133 | x = x.clamp(-1, 1) 134 | x = x.float().cpu() 135 | x = rearrange(x[0], "c h w -> h w c") 136 | 137 | if torch.cuda.is_available(): 138 | torch.cuda.synchronize() 139 | 140 | output_name = os.path.join(self.output_dir, "img_{idx}.jpg") 141 | if not os.path.exists(self.output_dir): 142 | os.makedirs(self.output_dir) 143 | idx = 0 144 | else: 145 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] 146 | if len(fns) > 0: 147 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 148 | else: 149 | idx = 0 150 | 151 | fn = output_name.format(idx=idx) 152 | 153 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) 154 | exif_data = Image.Exif() 155 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" 156 | exif_data[ExifTags.Base.Make] = "Black Forest Labs" 157 | exif_data[ExifTags.Base.Model] = self.name 158 | 159 | exif_data[ExifTags.Base.ImageDescription] = target_prompt 160 | img.save(fn, exif=exif_data, quality=95, subsampling=0) 161 | masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG') 162 | t1 = time.perf_counter() 163 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}") 164 | 165 | print("End Edit") 166 | return img 167 | 168 | @torch.inference_mode() 169 | def encode(self,init_image, torch_device): 170 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1 171 | init_image = init_image.unsqueeze(0) 172 | init_image = init_image.to(torch_device) 173 | self.ae.encoder.to(torch_device) 174 | 175 | init_image = self.ae.encode(init_image).to(torch.bfloat16) 176 | return init_image 177 | 178 | def create_demo(model_name: str): 179 | editor = FluxEditor_kv_demo(args) 180 | is_schnell = model_name == "flux-schnell" 181 | 182 | title = r""" 183 |

🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation

184 | """ 185 | 186 | description = r""" 187 | Inversion free version for KV-Edit: Training-Free Image Editing for Precise Background Preservation.
188 | 189 | 💫💫 Here is editing steps:
190 | # 💫💫 Here is editing steps: (We highly recommend you run our code locally!😘 Only one inversion before multiple editing, very productive!)
191 | 1️⃣ Upload your image that needs to be edited.
192 | 2️⃣ Use the brush tool to draw your mask area.
193 | 3️⃣ Fill in your source prompt and target prompt, then adjust the hyperparameters.
194 | 4️⃣ Click the "Edit" button to generate your edited image!
195 | 196 | 🔔🔔 [Important] Less skip steps and "attn_mask" will enhance the editing performance, making the results more aligned with your text but may lead to discontinuous images.
197 | If you fail because of these two, we recommend trying to increase "attn_scale" to increase attention between mask and background.
198 | """ 199 | article = r""" 200 | If our work is helpful, please help to ⭐ the Github Repo. Thanks! 201 | """ 202 | 203 | badge = r""" 204 | [![GitHub Stars](https://img.shields.io/github/stars/Xilluill/KV-Edit)](https://github.com/Xilluill/KV-Edit) 205 | """ 206 | 207 | with gr.Blocks() as demo: 208 | gr.HTML(title) 209 | gr.Markdown(description) 210 | # gr.Markdown(badge) 211 | 212 | with gr.Row(): 213 | with gr.Column(): 214 | source_prompt = gr.Textbox(label="Source Prompt", value='' ) 215 | inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps") 216 | target_prompt = gr.Textbox(label="Target Prompt", value='' ) 217 | denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps") 218 | brush_canvas = gr.ImageEditor(label="Brush Canvas", 219 | sources=('upload'), 220 | brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'), 221 | interactive=True, 222 | transforms=[], 223 | container=True, 224 | format='png',scale=1) 225 | edit_btn = gr.Button("edit") 226 | 227 | 228 | with gr.Column(): 229 | with gr.Accordion("Advanced Options", open=True): 230 | 231 | skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps") 232 | inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell) 233 | denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell) 234 | attn_scale = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale") 235 | seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True) 236 | with gr.Row(): 237 | attn_mask = gr.Checkbox(label="attn_mask", value=False) 238 | 239 | 240 | output_image = gr.Image(label="Generated Image") 241 | gr.Markdown(article) 242 | edit_btn.click( 243 | fn=editor.edit, 244 | inputs=[brush_canvas, 245 | source_prompt, target_prompt, 246 | inversion_num_steps, denoise_num_steps, 247 | skip_step, 248 | inversion_guidance, 249 | denoise_guidance,seed, 250 | attn_mask,attn_scale 251 | ], 252 | outputs=[output_image] 253 | ) 254 | return demo 255 | 256 | if __name__ == "__main__": 257 | import argparse 258 | parser = argparse.ArgumentParser(description="Flux") 259 | parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name") 260 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") 261 | parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") 262 | parser.add_argument("--share", action="store_true", help="Create a public link to your demo") 263 | parser.add_argument("--port", type=int, default=41032) 264 | args = parser.parse_args() 265 | 266 | demo = create_demo(args.name) 267 | 268 | demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port) -------------------------------------------------------------------------------- /models/kv_edit.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from einops import rearrange,repeat 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from flux.sampling import get_schedule, unpack,denoise_kv,denoise_kv_inf 9 | from flux.util import load_flow_model 10 | from flux.model import Flux_kv 11 | 12 | class only_Flux(torch.nn.Module): 13 | def __init__(self, device,name='flux-dev'): 14 | self.device = device 15 | self.name = name 16 | super().__init__() 17 | self.model = load_flow_model(self.name, device=self.device,flux_cls=Flux_kv) 18 | 19 | def create_attention_mask(self,seq_len, mask_indices, text_len=512, device='cuda'): 20 | 21 | attention_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device) 22 | 23 | text_indices = torch.arange(0, text_len, device=device) 24 | 25 | mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device) 26 | 27 | all_indices = torch.arange(text_len, seq_len, device=device) 28 | background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices]) 29 | 30 | # text setting 31 | attention_mask[text_indices.unsqueeze(1).expand(-1, seq_len)] = True 32 | attention_mask[text_indices.unsqueeze(1), text_indices] = True 33 | attention_mask[text_indices.unsqueeze(1), background_token_indices] = True 34 | 35 | 36 | # mask setting 37 | # attention_mask[mask_token_indices.unsqueeze(1), background_token_indices] = True 38 | attention_mask[mask_token_indices.unsqueeze(1), text_indices] = True 39 | attention_mask[mask_token_indices.unsqueeze(1), mask_token_indices] = True 40 | 41 | # background setting 42 | # attention_mask[background_token_indices.unsqueeze(1), mask_token_indices] = True 43 | attention_mask[background_token_indices.unsqueeze(1), text_indices] = True 44 | attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True 45 | 46 | return attention_mask.unsqueeze(0) 47 | 48 | def create_attention_scale(self,seq_len, mask_indices, text_len=512, device='cuda',scale = 0): 49 | 50 | attention_scale = torch.zeros(1, seq_len, dtype=torch.bfloat16, device=device) # 相加时广播 51 | 52 | 53 | text_indices = torch.arange(0, text_len, device=device) 54 | 55 | mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device) 56 | 57 | all_indices = torch.arange(text_len, seq_len, device=device) 58 | background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices]) 59 | 60 | attention_scale[0, background_token_indices] = scale # 61 | 62 | return attention_scale.unsqueeze(0) 63 | 64 | class Flux_kv_edit_inf(only_Flux): 65 | def __init__(self, device,name): 66 | super().__init__(device,name) 67 | 68 | @torch.inference_mode() 69 | def forward(self,inp,inp_target,mask:Tensor,opts): 70 | 71 | info = {} 72 | info['feature'] = {} 73 | bs, L, d = inp["img"].shape 74 | h = opts.height // 8 75 | w = opts.width // 8 76 | mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False) 77 | mask[mask > 0] = 1 78 | 79 | mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16) 80 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 81 | info['mask'] = mask 82 | bool_mask = (mask.sum(dim=2) > 0.5) 83 | info['mask_indices'] = torch.nonzero(bool_mask)[:,1] 84 | #单独分离inversion 85 | if opts.attn_mask and (~bool_mask).any(): 86 | attention_mask = self.create_attention_mask(L+512, info['mask_indices'], device=self.device) 87 | else: 88 | attention_mask = None 89 | info['attention_mask'] = attention_mask 90 | 91 | if opts.attn_scale != 0 and (~bool_mask).any(): 92 | attention_scale = self.create_attention_scale(L+512, info['mask_indices'], device=mask.device,scale = opts.attn_scale) 93 | else: 94 | attention_scale = None 95 | info['attention_scale'] = attention_scale 96 | 97 | denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) 98 | denoise_timesteps = denoise_timesteps[opts.skip_step:] 99 | 100 | z0 = inp["img"] 101 | 102 | with torch.no_grad(): 103 | info['inject'] = True 104 | z_fe, info = denoise_kv_inf(self.model, img=inp["img"], img_ids=inp['img_ids'], 105 | source_txt=inp['txt'], source_txt_ids=inp['txt_ids'], source_vec=inp['vec'], 106 | target_txt=inp_target['txt'], target_txt_ids=inp_target['txt_ids'], target_vec=inp_target['vec'], 107 | timesteps=denoise_timesteps, source_guidance=opts.inversion_guidance, target_guidance=opts.denoise_guidance, 108 | info=info) 109 | mask_indices = info['mask_indices'] 110 | 111 | z0[:, mask_indices,...] = z_fe 112 | 113 | z0 = unpack(z0.float(), opts.height, opts.width) 114 | del info 115 | return z0 116 | 117 | class Flux_kv_edit(only_Flux): 118 | def __init__(self, device,name): 119 | super().__init__(device,name) 120 | 121 | @torch.inference_mode() 122 | def forward(self,inp,inp_target,mask:Tensor,opts): 123 | z0,zt,info = self.inverse(inp,mask,opts) 124 | z0 = self.denoise(z0,zt,inp_target,mask,opts,info) 125 | return z0 126 | @torch.inference_mode() 127 | def inverse(self,inp,mask,opts): 128 | info = {} 129 | info['feature'] = {} 130 | bs, L, d = inp["img"].shape 131 | h = opts.height // 8 132 | w = opts.width // 8 133 | 134 | if opts.attn_mask: 135 | mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False) 136 | mask[mask > 0] = 1 137 | 138 | mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16) 139 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 140 | bool_mask = (mask.sum(dim=2) > 0.5) 141 | mask_indices = torch.nonzero(bool_mask)[:,1] 142 | 143 | #单独分离inversion 144 | assert not (~bool_mask).all(), "mask is all false" 145 | assert not (bool_mask).all(), "mask is all true" 146 | attention_mask = self.create_attention_mask(L+512, mask_indices, device=mask.device) 147 | info['attention_mask'] = attention_mask 148 | 149 | 150 | denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) 151 | denoise_timesteps = denoise_timesteps[opts.skip_step:] 152 | 153 | # 加噪过程 154 | z0 = inp["img"].clone() 155 | info['inverse'] = True 156 | zt, info = denoise_kv(self.model, **inp, timesteps=denoise_timesteps, guidance=opts.inversion_guidance, inverse=True, info=info) 157 | return z0,zt,info 158 | 159 | @torch.inference_mode() 160 | def denoise(self,z0,zt,inp_target,mask:Tensor,opts,info): 161 | 162 | h = opts.height // 8 163 | w = opts.width // 8 164 | L = h * w // 4 165 | mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False) 166 | mask[mask > 0] = 1 167 | 168 | mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16) 169 | 170 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 171 | info['mask'] = mask 172 | bool_mask = (mask.sum(dim=2) > 0.5) 173 | info['mask_indices'] = torch.nonzero(bool_mask)[:,1] 174 | 175 | denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell")) 176 | denoise_timesteps = denoise_timesteps[opts.skip_step:] 177 | 178 | mask_indices = info['mask_indices'] 179 | if opts.re_init: 180 | noise = torch.randn_like(zt) 181 | t = denoise_timesteps[0] 182 | zt_noise = z0 *(1 - t) + noise * t 183 | inp_target["img"] = zt_noise[:, mask_indices,...] 184 | else: 185 | img_name = str(info['t']) + '_' + 'img' 186 | zt = info['feature'][img_name].to(zt.device) 187 | inp_target["img"] = zt[:, mask_indices,...] 188 | 189 | if opts.attn_scale != 0 and (~bool_mask).any(): 190 | attention_scale = self.create_attention_scale(L+512, mask_indices, device=mask.device,scale = opts.attn_scale) 191 | else: 192 | attention_scale = None 193 | info['attention_scale'] = attention_scale 194 | 195 | info['inverse'] = False 196 | x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info) 197 | 198 | z0[:, mask_indices,...] = z0[:, mask_indices,...] * (1 - info['mask'][:, mask_indices,...]) + x * info['mask'][:, mask_indices,...] 199 | 200 | z0 = unpack(z0.float(), opts.height, opts.width) 201 | del info 202 | return z0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | einops 3 | fire 4 | gradio==5.17.1 5 | huggingface-hub 6 | invisible-watermark 7 | matplotlib 8 | numpy 9 | opencv-python 10 | Pillow 11 | Requests 12 | safetensors 13 | scikit-learn 14 | scipy 15 | scikit-image 16 | torchvision 17 | tqdm 18 | transformers==4.49.0 19 | sentencepiece -------------------------------------------------------------------------------- /resources/example.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xilluill/KV-Edit/77129a23eb8eff7408df7d8c91f1f8c523ce56ac/resources/example.jpeg -------------------------------------------------------------------------------- /resources/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xilluill/KV-Edit/77129a23eb8eff7408df7d8c91f1f8c523ce56ac/resources/pipeline.jpg -------------------------------------------------------------------------------- /resources/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xilluill/KV-Edit/77129a23eb8eff7408df7d8c91f1f8c523ce56ac/resources/teaser.jpg --------------------------------------------------------------------------------