├── .gitignore
├── LICENSE
├── README.md
├── flux
├── __init__.py
├── __main__.py
├── _version.py
├── api.py
├── math.py
├── model.py
├── modules
│ ├── autoencoder.py
│ ├── conditioner.py
│ └── layers.py
├── sampling.py
└── util.py
├── gradio_kv_edit.py
├── gradio_kv_edit_gpu.py
├── gradio_kv_edit_inf.py
├── models
└── kv_edit.py
├── requirements.txt
└── resources
├── example.jpeg
├── pipeline.jpg
└── teaser.jpg
/.gitignore:
--------------------------------------------------------------------------------
1 | script/
2 | *.symlink
3 | output/
4 | regress_result/
5 | google
6 | checkpoints
7 | huggingface_models
8 | openai
9 | __pycache__/
10 | .vscode/
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # KV-Edit: Training-Free Image Editing for Precise Background Preservation
4 |
5 | [Tianrui Zhu](https://github.com/Xilluill)
1*, [Shiyi Zhang](https://shiyi-zh0408.github.io/)
1*, [Jiawei Shao](https://shaojiawei07.github.io/)
2, [Yansong Tang](https://andytang15.github.io/)
1†
6 |
7 |
1 Tsinghua University,
2 Institute of Artificial Intelligence (TeleAI)
8 |
9 |
10 |

11 | [](https://arxiv.org/abs/2502.17363)
12 | [](https://huggingface.co/spaces/xilluill/KV-Edit)
13 | [](https://github.com/Xilluill/KV-Edit)
14 |
15 | [](https://paperswithcode.com/sota/text-based-image-editing-on-pie-bench?p=kv-edit-training-free-image-editing-for)
16 | [](https://github.com/smthemex/ComfyUI_KV_Edit)
17 |
18 |
19 |
20 |
21 | We propose KV-Edit, a training-free image editing approach that strictly preserves background consistency between the original and edited images. Our method achieves impressive performance on various editing tasks, including object addition, removal, and replacement.
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | # 🔥 News
31 | - [2025.3.12] Thanks to @[smthemex](https://github.com/smthemex) for integrating KV-Edit into [ComfyUI](https://github.com/smthemex/ComfyUI_KV_Edit)!
32 | - [2025.3.4] We update "attention scale" feature to reduce the discontinuity with the background.
33 | - [2025.2.26] Our paper is featured in [huggingface Papers](https://huggingface.co/papers/2502.17363)!
34 | - [2025.2.25] Code for image editing is released!
35 | - [2025.2.25] Paper released!
36 | - [2025.2.25] More results can be found in our [project page](https://xilluill.github.io/projectpages/KV-Edit/)!
37 |
38 | # 👨💻 ToDo
39 | - ☑️ Release the gradio demo
40 | - ☑️ Release the huggingface space for image editing
41 | - ☑️ Release the paper
42 |
43 |
44 | # 📖 Pipeline
45 |
46 |
47 | We implemented KV Cache in our DiT-based generative model, which stores the key-value pairs of background tokens during the inversion process and concatenates them with foreground content during denoising. Since background tokens are preserved rather than regenerated, KV-Edit can strictly maintain background consistency while generating seamlessly integrated new content.
48 |
49 | # 🚀 Getting Started
50 | ## Environment Requirement 🌍
51 | The environment of our code is the same as FLUX, you can refer to the [official repo](https://github.com/black-forest-labs/flux/tree/main) of FLUX, or running the following command to construct a simplified environment.
52 |
53 | Clone the repo:
54 | ```
55 | git clone https://github.com/Xilluill/KV-Edit
56 | ```
57 | We recommend you first use conda to create virtual environment, then run:
58 | ```
59 | conda create --name KV-Edit python=3.10
60 | conda activate KV-Edit
61 | pip install -r requirements.txt
62 | ```
63 | ## Running Gradio demo 🛫
64 | We provide three demo scripts for different hardware configurations. For users with server access and sufficient CPU/GPU memory ( >40/24 GB), we recommend you use:
65 | ```
66 | python gradio_kv_edit.py
67 | ```
68 | For users with 2 GPUs(like 3090/4090) which can avoid offload models to accelerate, you can use:
69 | ```
70 | python gradio_kv_edit_gpu.py --gpus
71 | ```
72 | For users with limited GPU, we recommend you use:
73 | ```
74 | python gradio_kv_edit.py --offload
75 | ```
76 | For users with limited CPU memory such as PC, we recommend you use:
77 | ```
78 | python gradio_kv_edit_inf.py --offload
79 | ```
80 | Here's a sample workflow for our demo:
81 |
82 | 1️⃣ Upload your image that needs to be edited.
83 | 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion.
84 | 3️⃣ Use the brush tool to draw your mask area.
85 | 4️⃣ Fill in your target prompt, then adjust the hyperparameters.
86 | 5️⃣ Click the "Edit" button to generate your edited image!
87 |
88 |
89 |

90 |
91 |
92 | ### 💡Important Notes:
93 | - 🎨 When using the inversion-based version, you only need to perform the inversion once for each image. You can then repeat steps 3-5 for multiple editing attempts!
94 | - 🎨 "re_init" means using image blending with noise instead of result from inversion to generate new contents.
95 | - 🎨 When the "attn_mask" option is checked, you need to input the mask before performing the inversion.
96 | - 🎨 When the mask is large, and using less skip steps or "re_init", the content of the mask area may not be continuous with background, you can try to increase "attn_scale".
97 |
98 |
99 | # 🖋️ Citation
100 |
101 | If you find our work helpful, please **star 🌟** this repo and **cite 📑** our paper. Thanks for your support!
102 | ```
103 | @article{zhu2025kv,
104 | title={KV-Edit: Training-Free Image Editing for Precise Background Preservation},
105 | author={Zhu, Tianrui and Zhang, Shiyi and Shao, Jiawei and Tang, Yansong},
106 | journal={arXiv preprint arXiv:2502.17363},
107 | year={2025}
108 | }
109 | ```
110 |
111 | # 👍🏻 Acknowledgements
112 | Our code is modified based on [FLUX](https://github.com/black-forest-labs/flux) and [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit). Special thanks to [Wenke Huang](https://wenkehuang.github.io/) for his early inspiration and helpful guidance to this project!
113 |
114 | # 📧 Contact
115 | This repository is currently under active development and restructuring. The codebase is being optimized for better stability and reproducibility. While we strive to maintain code quality, you may encounter temporary issues during this transition period. For any questions or technical discussions, feel free to open an issue or contact us via email at xilluill070513@gmail.com.
--------------------------------------------------------------------------------
/flux/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from ._version import version as __version__ # type: ignore
3 | from ._version import version_tuple
4 | except ImportError:
5 | __version__ = "unknown (no version information available)"
6 | version_tuple = (0, 0, "unknown", "noinfo")
7 |
8 | from pathlib import Path
9 |
10 | PACKAGE = __package__.replace("_", "-")
11 | PACKAGE_ROOT = Path(__file__).parent
12 |
--------------------------------------------------------------------------------
/flux/__main__.py:
--------------------------------------------------------------------------------
1 | from .cli import app
2 |
3 | if __name__ == "__main__":
4 | app()
5 |
--------------------------------------------------------------------------------
/flux/_version.py:
--------------------------------------------------------------------------------
1 | # file generated by setuptools_scm
2 | # don't change, don't track in version control
3 | TYPE_CHECKING = False
4 | if TYPE_CHECKING:
5 | from typing import Tuple, Union
6 | VERSION_TUPLE = Tuple[Union[int, str], ...]
7 | else:
8 | VERSION_TUPLE = object
9 |
10 | version: str
11 | __version__: str
12 | __version_tuple__: VERSION_TUPLE
13 | version_tuple: VERSION_TUPLE
14 |
15 | __version__ = version = '0.0.post6+ge52e00f.d20250111'
16 | __version_tuple__ = version_tuple = (0, 0, 'ge52e00f.d20250111')
17 |
--------------------------------------------------------------------------------
/flux/api.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import time
4 | from pathlib import Path
5 |
6 | import requests
7 | from PIL import Image
8 |
9 | API_ENDPOINT = "https://api.bfl.ml"
10 |
11 |
12 | class ApiException(Exception):
13 | def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14 | super().__init__()
15 | self.detail = detail
16 | self.status_code = status_code
17 |
18 | def __str__(self) -> str:
19 | return self.__repr__()
20 |
21 | def __repr__(self) -> str:
22 | if self.detail is None:
23 | message = None
24 | elif isinstance(self.detail, str):
25 | message = self.detail
26 | else:
27 | message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28 | return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29 |
30 |
31 | class ImageRequest:
32 | def __init__(
33 | self,
34 | prompt: str,
35 | width: int = 1024,
36 | height: int = 1024,
37 | name: str = "flux.1-pro",
38 | num_steps: int = 50,
39 | prompt_upsampling: bool = False,
40 | seed: int | None = None,
41 | validate: bool = True,
42 | launch: bool = True,
43 | api_key: str | None = None,
44 | ):
45 | """
46 | Manages an image generation request to the API.
47 |
48 | Args:
49 | prompt: Prompt to sample
50 | width: Width of the image in pixel
51 | height: Height of the image in pixel
52 | name: Name of the model
53 | num_steps: Number of network evaluations
54 | prompt_upsampling: Use prompt upsampling
55 | seed: Fix the generation seed
56 | validate: Run input validation
57 | launch: Directly launches request
58 | api_key: Your API key if not provided by the environment
59 |
60 | Raises:
61 | ValueError: For invalid input
62 | ApiException: For errors raised from the API
63 | """
64 | if validate:
65 | if name not in ["flux.1-pro"]:
66 | raise ValueError(f"Invalid model {name}")
67 | elif width % 32 != 0:
68 | raise ValueError(f"width must be divisible by 32, got {width}")
69 | elif not (256 <= width <= 1440):
70 | raise ValueError(f"width must be between 256 and 1440, got {width}")
71 | elif height % 32 != 0:
72 | raise ValueError(f"height must be divisible by 32, got {height}")
73 | elif not (256 <= height <= 1440):
74 | raise ValueError(f"height must be between 256 and 1440, got {height}")
75 | elif not (1 <= num_steps <= 50):
76 | raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77 |
78 | self.request_json = {
79 | "prompt": prompt,
80 | "width": width,
81 | "height": height,
82 | "variant": name,
83 | "steps": num_steps,
84 | "prompt_upsampling": prompt_upsampling,
85 | }
86 | if seed is not None:
87 | self.request_json["seed"] = seed
88 |
89 | self.request_id: str | None = None
90 | self.result: dict | None = None
91 | self._image_bytes: bytes | None = None
92 | self._url: str | None = None
93 | if api_key is None:
94 | self.api_key = os.environ.get("BFL_API_KEY")
95 | else:
96 | self.api_key = api_key
97 |
98 | if launch:
99 | self.request()
100 |
101 | def request(self):
102 | """
103 | Request to generate the image.
104 | """
105 | if self.request_id is not None:
106 | return
107 | response = requests.post(
108 | f"{API_ENDPOINT}/v1/image",
109 | headers={
110 | "accept": "application/json",
111 | "x-key": self.api_key,
112 | "Content-Type": "application/json",
113 | },
114 | json=self.request_json,
115 | )
116 | result = response.json()
117 | if response.status_code != 200:
118 | raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119 | self.request_id = response.json()["id"]
120 |
121 | def retrieve(self) -> dict:
122 | """
123 | Wait for the generation to finish and retrieve response.
124 | """
125 | if self.request_id is None:
126 | self.request()
127 | while self.result is None:
128 | response = requests.get(
129 | f"{API_ENDPOINT}/v1/get_result",
130 | headers={
131 | "accept": "application/json",
132 | "x-key": self.api_key,
133 | },
134 | params={
135 | "id": self.request_id,
136 | },
137 | )
138 | result = response.json()
139 | if "status" not in result:
140 | raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141 | elif result["status"] == "Ready":
142 | self.result = result["result"]
143 | elif result["status"] == "Pending":
144 | time.sleep(0.5)
145 | else:
146 | raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147 | return self.result
148 |
149 | @property
150 | def bytes(self) -> bytes:
151 | """
152 | Generated image as bytes.
153 | """
154 | if self._image_bytes is None:
155 | response = requests.get(self.url)
156 | if response.status_code == 200:
157 | self._image_bytes = response.content
158 | else:
159 | raise ApiException(status_code=response.status_code)
160 | return self._image_bytes
161 |
162 | @property
163 | def url(self) -> str:
164 | """
165 | Public url to retrieve the image from
166 | """
167 | if self._url is None:
168 | result = self.retrieve()
169 | self._url = result["sample"]
170 | return self._url
171 |
172 | @property
173 | def image(self) -> Image.Image:
174 | """
175 | Load the image as a PIL Image
176 | """
177 | return Image.open(io.BytesIO(self.bytes))
178 |
179 | def save(self, path: str):
180 | """
181 | Save the generated image to a local path
182 | """
183 | suffix = Path(self.url).suffix
184 | if not path.endswith(suffix):
185 | path = path + suffix
186 | Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187 | with open(path, "wb") as file:
188 | file.write(self.bytes)
189 |
190 |
191 | if __name__ == "__main__":
192 | from fire import Fire
193 |
194 | Fire(ImageRequest)
195 |
--------------------------------------------------------------------------------
/flux/math.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import Tensor
4 |
5 |
6 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor,pe_q = None, attention_mask = None) -> Tensor:
7 | if pe_q is None:
8 | q, k = apply_rope(q, k, pe)
9 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v,attn_mask=attention_mask)
10 | x = rearrange(x, "B H L D -> B L (H D)")
11 | return x
12 | else:
13 | q, k = apply_rope_qk(q, k, pe_q, pe)
14 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v,attn_mask=attention_mask)
15 | x = rearrange(x, "B H L D -> B L (H D)")
16 | return x
17 |
18 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
19 | assert dim % 2 == 0
20 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim # dim =16 + 56 + 56
21 | omega = 1.0 / (theta**scale) # 64 omega
22 | out = torch.einsum("...n,d->...nd", pos, omega)
23 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
24 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
25 | return out.float()
26 |
27 |
28 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
29 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
30 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
31 | xq_out = freqs_cis[:, :, :xq_.shape[2], :, :, 0] * xq_[..., 0] + freqs_cis[:, :, :xq_.shape[2], :, :, 1] * xq_[..., 1]
32 | xk_out = freqs_cis[:, :, :xk_.shape[2], :, :, 0] * xk_[..., 0] + freqs_cis[:, :, :xk_.shape[2], :, :, 1] * xk_[..., 1]
33 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
34 |
35 | def apply_rope_qk(xq: Tensor, xk: Tensor, freqs_cis_q: Tensor,freqs_cis_k: Tensor) -> tuple[Tensor, Tensor]:
36 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
37 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
38 | xq_out = freqs_cis_q[:, :, :xq_.shape[2], :, :, 0] * xq_[..., 0] + freqs_cis_q[:, :, :xq_.shape[2], :, :, 1] * xq_[..., 1]
39 | xk_out = freqs_cis_k[:, :, :xk_.shape[2], :, :, 0] * xk_[..., 0] + freqs_cis_k[:, :, :xk_.shape[2], :, :, 1] * xk_[..., 1]
40 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
41 |
--------------------------------------------------------------------------------
/flux/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from torch import Tensor, nn
5 |
6 | from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7 | MLPEmbedder, SingleStreamBlock,
8 | SingleStreamBlock_kv,DoubleStreamBlock_kv,
9 | timestep_embedding)
10 |
11 |
12 | @dataclass
13 | class FluxParams:
14 | in_channels: int
15 | vec_in_dim: int
16 | context_in_dim: int
17 | hidden_size: int
18 | mlp_ratio: float
19 | num_heads: int
20 | depth: int
21 | depth_single_blocks: int
22 | axes_dim: list[int]
23 | theta: int
24 | qkv_bias: bool
25 | guidance_embed: bool
26 |
27 |
28 | class Flux(nn.Module):
29 | """
30 | Transformer model for flow matching on sequences.
31 | """
32 |
33 | def __init__(self, params: FluxParams,double_block_cls=DoubleStreamBlock,single_block_cls=SingleStreamBlock):
34 | super().__init__()
35 |
36 | self.params = params
37 | self.in_channels = params.in_channels
38 | self.out_channels = self.in_channels
39 | if params.hidden_size % params.num_heads != 0:
40 | raise ValueError(
41 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
42 | )
43 | pe_dim = params.hidden_size // params.num_heads
44 | if sum(params.axes_dim) != pe_dim:
45 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
46 | self.hidden_size = params.hidden_size
47 | self.num_heads = params.num_heads
48 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
49 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
50 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
51 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
52 | self.guidance_in = (
53 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
54 | )
55 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
56 |
57 | self.double_blocks = nn.ModuleList(
58 | [
59 | double_block_cls(
60 | self.hidden_size,
61 | self.num_heads,
62 | mlp_ratio=params.mlp_ratio,
63 | qkv_bias=params.qkv_bias,
64 | )
65 | for _ in range(params.depth)
66 | ]
67 | )
68 |
69 | self.single_blocks = nn.ModuleList(
70 | [
71 | single_block_cls(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
72 | for _ in range(params.depth_single_blocks)
73 | ]
74 | )
75 |
76 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
77 |
78 | def forward(
79 | self,
80 | img: Tensor,
81 | img_ids: Tensor,
82 | txt: Tensor,
83 | txt_ids: Tensor,
84 | timesteps: Tensor,
85 | y: Tensor,
86 | guidance: Tensor | None = None,
87 | ) -> Tensor:
88 | if img.ndim != 3 or txt.ndim != 3:
89 | raise ValueError("Input img and txt tensors must have 3 dimensions.")
90 |
91 | # running on sequences img
92 | img = self.img_in(img)
93 | vec = self.time_in(timestep_embedding(timesteps, 256))
94 | if self.params.guidance_embed:
95 | if guidance is None:
96 | raise ValueError("Didn't get guidance strength for guidance distilled model.")
97 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
98 | vec = vec + self.vector_in(y)
99 | txt = self.txt_in(txt)
100 |
101 | ids = torch.cat((txt_ids, img_ids), dim=1)
102 | pe = self.pe_embedder(ids)
103 |
104 | for block in self.double_blocks:
105 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
106 |
107 | img = torch.cat((txt, img), 1)
108 | for block in self.single_blocks:
109 | img = block(img, vec=vec, pe=pe)
110 | img = img[:, txt.shape[1] :, ...]
111 |
112 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
113 | return img
114 |
115 | class Flux_kv(Flux):
116 | """
117 | 继承Flux类,重写forward方法
118 | """
119 |
120 | def __init__(self, params: FluxParams,double_block_cls=DoubleStreamBlock_kv,single_block_cls=SingleStreamBlock_kv):
121 | super().__init__(params,double_block_cls,single_block_cls)
122 |
123 | def forward(
124 | self,
125 | img: Tensor,
126 | img_ids: Tensor,
127 | txt: Tensor,
128 | txt_ids: Tensor,
129 | timesteps: Tensor,
130 | y: Tensor,
131 | guidance: Tensor | None = None,
132 | info: dict = {},
133 | ) -> Tensor:
134 | if img.ndim != 3 or txt.ndim != 3:
135 | raise ValueError("Input img and txt tensors must have 3 dimensions.")
136 |
137 | # running on sequences img
138 | img = self.img_in(img)
139 | vec = self.time_in(timestep_embedding(timesteps, 256))
140 | if self.params.guidance_embed:
141 | if guidance is None:
142 | raise ValueError("Didn't get guidance strength for guidance distilled model.")
143 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
144 | vec = vec + self.vector_in(y)
145 | txt = self.txt_in(txt)
146 |
147 | ids = torch.cat((txt_ids, img_ids), dim=1)
148 | pe = self.pe_embedder(ids)
149 | if not info['inverse']:
150 | mask_indices = info['mask_indices']
151 | info['pe_mask'] = torch.cat((pe[:, :, :512, ...],pe[:, :, mask_indices+512, ...]),dim=2)
152 |
153 | cnt = 0
154 | for block in self.double_blocks:
155 | info['id'] = cnt
156 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe, info=info)
157 | cnt += 1
158 |
159 | cnt = 0
160 | x = torch.cat((txt, img), 1)
161 | for block in self.single_blocks:
162 | info['id'] = cnt
163 | x = block(x, vec=vec, pe=pe, info=info)
164 | cnt += 1
165 |
166 | img = x[:, txt.shape[1] :, ...]
167 |
168 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
169 |
170 | return img
171 |
--------------------------------------------------------------------------------
/flux/modules/autoencoder.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from einops import rearrange
5 | from torch import Tensor, nn
6 |
7 |
8 | @dataclass
9 | class AutoEncoderParams:
10 | resolution: int
11 | in_channels: int
12 | ch: int
13 | out_ch: int
14 | ch_mult: list[int]
15 | num_res_blocks: int
16 | z_channels: int
17 | scale_factor: float
18 | shift_factor: float
19 |
20 |
21 | def swish(x: Tensor) -> Tensor:
22 | return x * torch.sigmoid(x)
23 |
24 |
25 | class AttnBlock(nn.Module):
26 | def __init__(self, in_channels: int):
27 | super().__init__()
28 | self.in_channels = in_channels
29 |
30 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31 |
32 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36 |
37 | def attention(self, h_: Tensor) -> Tensor:
38 | h_ = self.norm(h_)
39 | q = self.q(h_)
40 | k = self.k(h_)
41 | v = self.v(h_)
42 |
43 | b, c, h, w = q.shape
44 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47 | h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48 |
49 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50 |
51 | def forward(self, x: Tensor) -> Tensor:
52 | return x + self.proj_out(self.attention(x))
53 |
54 |
55 | class ResnetBlock(nn.Module):
56 | def __init__(self, in_channels: int, out_channels: int):
57 | super().__init__()
58 | self.in_channels = in_channels
59 | out_channels = in_channels if out_channels is None else out_channels
60 | self.out_channels = out_channels
61 |
62 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66 | if self.in_channels != self.out_channels:
67 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68 |
69 | def forward(self, x):
70 | h = x
71 | h = self.norm1(h)
72 | h = swish(h)
73 | h = self.conv1(h)
74 |
75 | h = self.norm2(h)
76 | h = swish(h)
77 | h = self.conv2(h)
78 |
79 | if self.in_channels != self.out_channels:
80 | x = self.nin_shortcut(x)
81 |
82 | return x + h
83 |
84 |
85 | class Downsample(nn.Module):
86 | def __init__(self, in_channels: int):
87 | super().__init__()
88 | # no asymmetric padding in torch conv, must do it ourselves
89 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90 |
91 | def forward(self, x: Tensor):
92 | pad = (0, 1, 0, 1)
93 | x = nn.functional.pad(x, pad, mode="constant", value=0)
94 | x = self.conv(x)
95 | return x
96 |
97 |
98 | class Upsample(nn.Module):
99 | def __init__(self, in_channels: int):
100 | super().__init__()
101 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102 |
103 | def forward(self, x: Tensor):
104 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105 | x = self.conv(x)
106 | return x
107 |
108 |
109 | class Encoder(nn.Module):
110 | def __init__(
111 | self,
112 | resolution: int,
113 | in_channels: int,
114 | ch: int,
115 | ch_mult: list[int],
116 | num_res_blocks: int,
117 | z_channels: int,
118 | ):
119 | super().__init__()
120 | self.ch = ch
121 | self.num_resolutions = len(ch_mult)
122 | self.num_res_blocks = num_res_blocks
123 | self.resolution = resolution
124 | self.in_channels = in_channels
125 | # downsampling
126 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127 |
128 | curr_res = resolution
129 | in_ch_mult = (1,) + tuple(ch_mult)
130 | self.in_ch_mult = in_ch_mult
131 | self.down = nn.ModuleList()
132 | block_in = self.ch
133 | for i_level in range(self.num_resolutions):
134 | block = nn.ModuleList()
135 | attn = nn.ModuleList()
136 | block_in = ch * in_ch_mult[i_level]
137 | block_out = ch * ch_mult[i_level]
138 | for _ in range(self.num_res_blocks):
139 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140 | block_in = block_out
141 | down = nn.Module()
142 | down.block = block
143 | down.attn = attn
144 | if i_level != self.num_resolutions - 1:
145 | down.downsample = Downsample(block_in)
146 | curr_res = curr_res // 2
147 | self.down.append(down)
148 |
149 | # middle
150 | self.mid = nn.Module()
151 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152 | self.mid.attn_1 = AttnBlock(block_in)
153 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154 |
155 | # end
156 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158 |
159 | def forward(self, x: Tensor) -> Tensor:
160 | # downsampling
161 | hs = [self.conv_in(x)]
162 | for i_level in range(self.num_resolutions):
163 | for i_block in range(self.num_res_blocks):
164 | h = self.down[i_level].block[i_block](hs[-1])
165 | if len(self.down[i_level].attn) > 0:
166 | h = self.down[i_level].attn[i_block](h)
167 | hs.append(h)
168 | if i_level != self.num_resolutions - 1:
169 | hs.append(self.down[i_level].downsample(hs[-1]))
170 |
171 | # middle
172 | h = hs[-1]
173 | h = self.mid.block_1(h)
174 | h = self.mid.attn_1(h)
175 | h = self.mid.block_2(h)
176 | # end
177 | h = self.norm_out(h)
178 | h = swish(h)
179 | h = self.conv_out(h)
180 | return h
181 |
182 |
183 | class Decoder(nn.Module):
184 | def __init__(
185 | self,
186 | ch: int,
187 | out_ch: int,
188 | ch_mult: list[int],
189 | num_res_blocks: int,
190 | in_channels: int,
191 | resolution: int,
192 | z_channels: int,
193 | ):
194 | super().__init__()
195 | self.ch = ch
196 | self.num_resolutions = len(ch_mult)
197 | self.num_res_blocks = num_res_blocks
198 | self.resolution = resolution
199 | self.in_channels = in_channels
200 | self.ffactor = 2 ** (self.num_resolutions - 1)
201 |
202 | # compute in_ch_mult, block_in and curr_res at lowest res
203 | block_in = ch * ch_mult[self.num_resolutions - 1]
204 | curr_res = resolution // 2 ** (self.num_resolutions - 1)
205 | self.z_shape = (1, z_channels, curr_res, curr_res)
206 |
207 | # z to block_in
208 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209 |
210 | # middle
211 | self.mid = nn.Module()
212 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213 | self.mid.attn_1 = AttnBlock(block_in)
214 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215 |
216 | # upsampling
217 | self.up = nn.ModuleList()
218 | for i_level in reversed(range(self.num_resolutions)):
219 | block = nn.ModuleList()
220 | attn = nn.ModuleList()
221 | block_out = ch * ch_mult[i_level]
222 | for _ in range(self.num_res_blocks + 1):
223 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224 | block_in = block_out
225 | up = nn.Module()
226 | up.block = block
227 | up.attn = attn
228 | if i_level != 0:
229 | up.upsample = Upsample(block_in)
230 | curr_res = curr_res * 2
231 | self.up.insert(0, up) # prepend to get consistent order
232 |
233 | # end
234 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236 |
237 | def forward(self, z: Tensor) -> Tensor:
238 | # z to block_in
239 | h = self.conv_in(z)
240 |
241 | # middle
242 | h = self.mid.block_1(h)
243 | h = self.mid.attn_1(h)
244 | h = self.mid.block_2(h)
245 |
246 | # upsampling
247 | for i_level in reversed(range(self.num_resolutions)):
248 | for i_block in range(self.num_res_blocks + 1):
249 | h = self.up[i_level].block[i_block](h)
250 | if len(self.up[i_level].attn) > 0:
251 | h = self.up[i_level].attn[i_block](h)
252 | if i_level != 0:
253 | h = self.up[i_level].upsample(h)
254 |
255 | # end
256 | h = self.norm_out(h)
257 | h = swish(h)
258 | h = self.conv_out(h)
259 | return h
260 |
261 |
262 | class DiagonalGaussian(nn.Module):
263 | def __init__(self, sample: bool = True, chunk_dim: int = 1):
264 | super().__init__()
265 | self.sample = sample
266 | self.chunk_dim = chunk_dim
267 |
268 | def forward(self, z: Tensor) -> Tensor:
269 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270 | # import pdb;pdb.set_trace()
271 | if self.sample:
272 | std = torch.exp(0.5 * logvar)
273 | return mean #+ std * torch.randn_like(mean)
274 | else:
275 | return mean
276 |
277 |
278 | class AutoEncoder(nn.Module):
279 | def __init__(self, params: AutoEncoderParams):
280 | super().__init__()
281 | self.encoder = Encoder(
282 | resolution=params.resolution,
283 | in_channels=params.in_channels,
284 | ch=params.ch,
285 | ch_mult=params.ch_mult,
286 | num_res_blocks=params.num_res_blocks,
287 | z_channels=params.z_channels,
288 | )
289 | self.decoder = Decoder(
290 | resolution=params.resolution,
291 | in_channels=params.in_channels,
292 | ch=params.ch,
293 | out_ch=params.out_ch,
294 | ch_mult=params.ch_mult,
295 | num_res_blocks=params.num_res_blocks,
296 | z_channels=params.z_channels,
297 | )
298 | self.reg = DiagonalGaussian()
299 |
300 | self.scale_factor = params.scale_factor
301 | self.shift_factor = params.shift_factor
302 |
303 | def encode(self, x: Tensor) -> Tensor:
304 | z = self.reg(self.encoder(x))
305 | z = self.scale_factor * (z - self.shift_factor)
306 | return z
307 |
308 | def decode(self, z: Tensor) -> Tensor:
309 | z = z / self.scale_factor + self.shift_factor
310 | return self.decoder(z)
311 |
312 | def forward(self, x: Tensor) -> Tensor:
313 | return self.decode(self.encode(x))
314 |
--------------------------------------------------------------------------------
/flux/modules/conditioner.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, nn
2 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3 | T5Tokenizer)
4 |
5 |
6 | class HFEmbedder(nn.Module):
7 | def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
8 | super().__init__()
9 | self.is_clip = is_clip
10 | self.max_length = max_length
11 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12 |
13 | if version == 'black-forest-labs/FLUX.1-dev':
14 | if self.is_clip:
15 | self.tokenizer: T5Tokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer")
16 | self.hf_module: T5EncoderModel = CLIPTextModel.from_pretrained(version,subfolder='text_encoder' , **hf_kwargs)
17 | else:
18 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, subfolder="tokenizer_2")
19 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version,subfolder='text_encoder_2' , **hf_kwargs)
20 | else:
21 | if self.is_clip:
22 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
23 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
24 | else:
25 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
26 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
27 |
28 | self.hf_module = self.hf_module.eval().requires_grad_(False)
29 |
30 | def forward(self, text: list[str]) -> Tensor:
31 | batch_encoding = self.tokenizer(
32 | text,
33 | truncation=True,
34 | max_length=self.max_length,
35 | return_length=False,
36 | return_overflowing_tokens=False,
37 | padding="max_length",
38 | return_tensors="pt",
39 | )
40 |
41 | outputs = self.hf_module(
42 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
43 | attention_mask=None,
44 | output_hidden_states=False,
45 | )
46 | return outputs[self.output_key]
47 |
--------------------------------------------------------------------------------
/flux/modules/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 |
4 | import torch
5 | from einops import rearrange
6 | from torch import Tensor, nn
7 |
8 | from flux.math import attention, rope,apply_rope
9 |
10 | import os
11 |
12 | class EmbedND(nn.Module):
13 | def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14 | super().__init__()
15 | self.dim = dim
16 | self.theta = theta
17 | self.axes_dim = axes_dim
18 |
19 | def forward(self, ids: Tensor) -> Tensor:
20 | n_axes = ids.shape[-1]
21 | emb = torch.cat(
22 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23 | dim=-3,
24 | )
25 |
26 | return emb.unsqueeze(1)
27 |
28 |
29 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30 | """
31 | Create sinusoidal timestep embeddings.
32 | :param t: a 1-D Tensor of N indices, one per batch element.
33 | These may be fractional.
34 | :param dim: the dimension of the output.
35 | :param max_period: controls the minimum frequency of the embeddings.
36 | :return: an (N, D) Tensor of positional embeddings.
37 | """
38 | t = time_factor * t
39 | half = dim // 2
40 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41 | t.device
42 | )
43 |
44 | args = t[:, None].float() * freqs[None]
45 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46 | if dim % 2:
47 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48 | if torch.is_floating_point(t):
49 | embedding = embedding.to(t)
50 | return embedding
51 |
52 |
53 | class MLPEmbedder(nn.Module):
54 | def __init__(self, in_dim: int, hidden_dim: int):
55 | super().__init__()
56 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57 | self.silu = nn.SiLU()
58 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59 |
60 | def forward(self, x: Tensor) -> Tensor:
61 | return self.out_layer(self.silu(self.in_layer(x)))
62 |
63 |
64 | class RMSNorm(torch.nn.Module):
65 | def __init__(self, dim: int):
66 | super().__init__()
67 | self.scale = nn.Parameter(torch.ones(dim))
68 |
69 | def forward(self, x: Tensor):
70 | x_dtype = x.dtype
71 | x = x.float()
72 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73 | return (x * rrms).to(dtype=x_dtype) * self.scale
74 |
75 |
76 | class QKNorm(torch.nn.Module):
77 | def __init__(self, dim: int):
78 | super().__init__()
79 | self.query_norm = RMSNorm(dim)
80 | self.key_norm = RMSNorm(dim)
81 |
82 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83 | q = self.query_norm(q)
84 | k = self.key_norm(k)
85 | return q.to(v), k.to(v)
86 |
87 |
88 | class SelfAttention(nn.Module):
89 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
90 | super().__init__()
91 | self.num_heads = num_heads
92 | head_dim = dim // num_heads
93 |
94 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95 | self.norm = QKNorm(head_dim)
96 | self.proj = nn.Linear(dim, dim)
97 |
98 | def forward(self, x: Tensor, pe: Tensor) -> Tensor:
99 | qkv = self.qkv(x)
100 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101 | q, k = self.norm(q, k, v)
102 | x = attention(q, k, v, pe=pe)
103 | x = self.proj(x)
104 | return x
105 |
106 |
107 | @dataclass
108 | class ModulationOut:
109 | shift: Tensor
110 | scale: Tensor
111 | gate: Tensor
112 |
113 |
114 | class Modulation(nn.Module):
115 | def __init__(self, dim: int, double: bool):
116 | super().__init__()
117 | self.is_double = double
118 | self.multiplier = 6 if double else 3
119 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120 |
121 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123 |
124 | return (
125 | ModulationOut(*out[:3]),
126 | ModulationOut(*out[3:]) if self.is_double else None,
127 | )
128 |
129 |
130 | class DoubleStreamBlock(nn.Module):
131 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
132 | super().__init__()
133 |
134 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
135 | self.num_heads = num_heads
136 | self.hidden_size = hidden_size
137 | self.img_mod = Modulation(hidden_size, double=True)
138 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140 |
141 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142 | self.img_mlp = nn.Sequential(
143 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144 | nn.GELU(approximate="tanh"),
145 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146 | )
147 |
148 | self.txt_mod = Modulation(hidden_size, double=True)
149 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151 |
152 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153 | self.txt_mlp = nn.Sequential(
154 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155 | nn.GELU(approximate="tanh"),
156 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157 | )
158 |
159 | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
160 | img_mod1, img_mod2 = self.img_mod(vec)
161 | txt_mod1, txt_mod2 = self.txt_mod(vec)
162 |
163 | # prepare image for attention
164 | img_modulated = self.img_norm1(img)
165 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
166 | img_qkv = self.img_attn.qkv(img_modulated)
167 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
168 |
169 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170 | # prepare txt for attention
171 | txt_modulated = self.txt_norm1(txt)
172 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
173 | txt_qkv = self.txt_attn.qkv(txt_modulated)
174 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
175 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
176 |
177 | # run actual attention
178 | q = torch.cat((txt_q, img_q), dim=2)
179 | k = torch.cat((txt_k, img_k), dim=2)
180 | v = torch.cat((txt_v, img_v), dim=2)
181 | # import pdb;pdb.set_trace()
182 | attn = attention(q, k, v, pe=pe)
183 |
184 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
185 |
186 | # calculate the img bloks
187 | img = img + img_mod1.gate * self.img_attn.proj(img_attn)
188 | img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
189 |
190 | # calculate the txt bloks
191 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
192 | txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
193 | return img, txt
194 | class SingleStreamBlock(nn.Module):
195 | """
196 | A DiT block with parallel linear layers as described in
197 | https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198 | """
199 |
200 | def __init__(
201 | self,
202 | hidden_size: int,
203 | num_heads: int,
204 | mlp_ratio: float = 4.0,
205 | qk_scale: float | None = None,
206 | ):
207 | super().__init__()
208 | self.hidden_dim = hidden_size
209 | self.num_heads = num_heads
210 | head_dim = hidden_size // num_heads
211 | self.scale = qk_scale or head_dim**-0.5
212 |
213 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214 | # qkv and mlp_in
215 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216 | # proj and mlp_out
217 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218 |
219 | self.norm = QKNorm(head_dim)
220 |
221 | self.hidden_size = hidden_size
222 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223 |
224 | self.mlp_act = nn.GELU(approximate="tanh")
225 | self.modulation = Modulation(hidden_size, double=False)
226 |
227 | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228 | mod, _ = self.modulation(vec)
229 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231 |
232 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233 | q, k = self.norm(q, k, v)
234 |
235 | # compute attention
236 | attn = attention(q, k, v, pe=pe)
237 | # compute activation in mlp stream, cat again and run second linear layer
238 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239 | return x + mod.gate * output
240 |
241 |
242 | class LastLayer(nn.Module):
243 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244 | super().__init__()
245 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248 |
249 | def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252 | x = self.linear(x)
253 | return x
254 |
255 | class LastLayer(nn.Module):
256 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
257 | super().__init__()
258 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
259 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
260 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
261 |
262 | def forward(self, x: Tensor, vec: Tensor) -> Tensor:
263 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
264 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
265 | x = self.linear(x)
266 | return x
267 |
268 |
269 | class DoubleStreamBlock_kv(DoubleStreamBlock):
270 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
271 | super().__init__(hidden_size, num_heads, mlp_ratio, qkv_bias)
272 |
273 | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]:
274 | img_mod1, img_mod2 = self.img_mod(vec)
275 | txt_mod1, txt_mod2 = self.txt_mod(vec)
276 |
277 | # prepare image for attention
278 | img_modulated = self.img_norm1(img)
279 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
280 | img_qkv = self.img_attn.qkv(img_modulated)
281 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
282 |
283 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
284 | # prepare txt for attention
285 | txt_modulated = self.txt_norm1(txt)
286 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
287 | txt_qkv = self.txt_attn.qkv(txt_modulated)
288 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
289 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
290 |
291 | feature_k_name = str(info['t']) + '_' + str(info['id']) + '_' + 'MB' + '_' + 'K'
292 | feature_v_name = str(info['t']) + '_' + str(info['id']) + '_' + 'MB' + '_' + 'V'
293 | if info['inverse']:
294 | info['feature'][feature_k_name] = img_k.cpu()
295 | info['feature'][feature_v_name] = img_v.cpu()
296 | q = torch.cat((txt_q, img_q), dim=2)
297 | k = torch.cat((txt_k, img_k), dim=2)
298 | v = torch.cat((txt_v, img_v), dim=2)
299 | if 'attention_mask' in info:
300 | attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
301 | else:
302 | attn = attention(q, k, v, pe=pe)
303 |
304 | else:
305 | source_img_k = info['feature'][feature_k_name].to(img.device)
306 | source_img_v = info['feature'][feature_v_name].to(img.device)
307 |
308 | mask_indices = info['mask_indices']
309 | source_img_k[:, :, mask_indices, ...] = img_k
310 | source_img_v[:, :, mask_indices, ...] = img_v
311 |
312 | q = torch.cat((txt_q, img_q), dim=2)
313 | k = torch.cat((txt_k, source_img_k), dim=2)
314 | v = torch.cat((txt_v, source_img_v), dim=2)
315 | attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale'])
316 |
317 |
318 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
319 |
320 | # calculate the img bloks
321 | img = img + img_mod1.gate * self.img_attn.proj(img_attn)
322 | img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
323 |
324 | # calculate the txt bloks
325 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
326 | txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
327 | return img, txt
328 |
329 | class SingleStreamBlock_kv(SingleStreamBlock):
330 | """
331 | A DiT block with parallel linear layers as described in
332 | https://arxiv.org/abs/2302.05442 and adapted modulation interface.
333 | """
334 |
335 | def __init__(
336 | self,
337 | hidden_size: int,
338 | num_heads: int,
339 | mlp_ratio: float = 4.0,
340 | qk_scale: float | None = None,
341 | ):
342 | super().__init__(hidden_size, num_heads, mlp_ratio, qk_scale)
343 |
344 | def forward(self,x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
345 | mod, _ = self.modulation(vec)
346 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
347 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
348 |
349 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
350 | q, k = self.norm(q, k, v)
351 | img_k = k[:, :, 512:, ...]
352 | img_v = v[:, :, 512:, ...]
353 |
354 | txt_k = k[:, :, :512, ...]
355 | txt_v = v[:, :, :512, ...]
356 |
357 |
358 | feature_k_name = str(info['t']) + '_' + str(info['id']) + '_' + 'SB' + '_' + 'K'
359 | feature_v_name = str(info['t']) + '_' + str(info['id']) + '_' + 'SB' + '_' + 'V'
360 | if info['inverse']:
361 | info['feature'][feature_k_name] = img_k.cpu()
362 | info['feature'][feature_v_name] = img_v.cpu()
363 | if 'attention_mask' in info:
364 | attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
365 | else:
366 | attn = attention(q, k, v, pe=pe)
367 |
368 | else:
369 | source_img_k = info['feature'][feature_k_name].to(x.device)
370 | source_img_v = info['feature'][feature_v_name].to(x.device)
371 |
372 | mask_indices = info['mask_indices']
373 | source_img_k[:, :, mask_indices, ...] = img_k
374 | source_img_v[:, :, mask_indices, ...] = img_v
375 |
376 | k = torch.cat((txt_k, source_img_k), dim=2)
377 | v = torch.cat((txt_v, source_img_v), dim=2)
378 | attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'],attention_mask=info['attention_scale'])
379 |
380 | # compute activation in mlp stream, cat again and run second linear layer
381 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
382 | return x + mod.gate * output
--------------------------------------------------------------------------------
/flux/sampling.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Callable
3 |
4 | import torch
5 | from einops import rearrange, repeat
6 | from torch import Tensor
7 |
8 | from .model import Flux,Flux_kv
9 | from .modules.conditioner import HFEmbedder
10 | from tqdm import tqdm
11 | from tqdm.contrib import tzip
12 |
13 | def get_noise(
14 | num_samples: int,
15 | height: int,
16 | width: int,
17 | device: torch.device,
18 | dtype: torch.dtype,
19 | seed: int,
20 | ):
21 | return torch.randn(
22 | num_samples,
23 | 16,
24 | # allow for packing
25 | 2 * math.ceil(height / 16),
26 | 2 * math.ceil(width / 16),
27 | device=device,
28 | dtype=dtype,
29 | generator=torch.Generator(device=device).manual_seed(seed),
30 | )
31 |
32 |
33 | def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
34 | bs, c, h, w = img.shape
35 | if bs == 1 and not isinstance(prompt, str):
36 | bs = len(prompt)
37 |
38 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
39 | if img.shape[0] == 1 and bs > 1:
40 | img = repeat(img, "1 ... -> bs ...", bs=bs)
41 |
42 | img_ids = torch.zeros(h // 2, w // 2, 3)
43 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
44 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
45 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
46 |
47 | if isinstance(prompt, str):
48 | prompt = [prompt]
49 | txt = t5(prompt)
50 | if txt.shape[0] == 1 and bs > 1:
51 | txt = repeat(txt, "1 ... -> bs ...", bs=bs)
52 | txt_ids = torch.zeros(bs, txt.shape[1], 3)
53 |
54 | vec = clip(prompt)
55 | if vec.shape[0] == 1 and bs > 1:
56 | vec = repeat(vec, "1 ... -> bs ...", bs=bs)
57 |
58 | return {
59 | "img": img,
60 | "img_ids": img_ids.to(img.device),
61 | "txt": txt.to(img.device),
62 | "txt_ids": txt_ids.to(img.device),
63 | "vec": vec.to(img.device),
64 | }
65 |
66 | def time_shift(mu: float, sigma: float, t: Tensor):
67 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68 |
69 |
70 | def get_lin_function(
71 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
72 | ) -> Callable[[float], float]:
73 | m = (y2 - y1) / (x2 - x1)
74 | b = y1 - m * x1
75 | return lambda x: m * x + b
76 |
77 |
78 | def get_schedule(
79 | num_steps: int,
80 | image_seq_len: int,
81 | base_shift: float = 0.5,
82 | max_shift: float = 1.15,
83 | shift: bool = True,
84 | ) -> list[float]:
85 | # extra step for zero
86 | timesteps = torch.linspace(1, 0, num_steps + 1)
87 |
88 | # shifting the schedule to favor high timesteps for higher signal images
89 | if shift:
90 | # estimate mu based on linear estimation between two points
91 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
92 | timesteps = time_shift(mu, 1.0, timesteps)
93 |
94 | return timesteps.tolist()
95 |
96 |
97 | def denoise(
98 | model: Flux,
99 | # model input
100 | img: Tensor,
101 | img_ids: Tensor,
102 | txt: Tensor,
103 | txt_ids: Tensor,
104 | vec: Tensor,
105 | # sampling parameters
106 | timesteps: list[float],
107 | guidance: float = 4.0,
108 | ):
109 | # this is ignored for schnell
110 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
111 | for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
112 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
113 | pred = model(
114 | img=img,
115 | img_ids=img_ids,
116 | txt=txt,
117 | txt_ids=txt_ids,
118 | y=vec,
119 | timesteps=t_vec,
120 | guidance=guidance_vec,
121 | )
122 |
123 | img = img + (t_prev - t_curr) * pred
124 |
125 | return img
126 |
127 | def unpack(x: Tensor, height: int, width: int) -> Tensor:
128 | return rearrange(
129 | x,
130 | "b (h w) (c ph pw) -> b c (h ph) (w pw)",
131 | h=math.ceil(height / 16),
132 | w=math.ceil(width / 16),
133 | ph=2,
134 | pw=2,
135 | )
136 |
137 | def denoise_kv(
138 | model: Flux_kv,
139 | # model input
140 | img: Tensor,
141 | img_ids: Tensor,
142 | txt: Tensor,
143 | txt_ids: Tensor,
144 | vec: Tensor,
145 | # sampling parameters
146 | timesteps: list[float],
147 | inverse,
148 | info,
149 | guidance: float = 4.0
150 | ):
151 |
152 | if inverse:
153 | timesteps = timesteps[::-1]
154 |
155 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
156 |
157 | for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])):
158 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
159 | info['t'] = t_prev if inverse else t_curr
160 |
161 | if inverse:
162 | img_name = str(info['t']) + '_' + 'img'
163 | info['feature'][img_name] = img.cpu()
164 | else:
165 | img_name = str(info['t']) + '_' + 'img'
166 | source_img = info['feature'][img_name].to(img.device)
167 | img = source_img[:, info['mask_indices'],...] * (1 - info['mask'][:, info['mask_indices'],...]) + img * info['mask'][:, info['mask_indices'],...]
168 | pred = model(
169 | img=img,
170 | img_ids=img_ids,
171 | txt=txt,
172 | txt_ids=txt_ids,
173 | y=vec,
174 | timesteps=t_vec,
175 | guidance=guidance_vec,
176 | info=info
177 | )
178 | img = img + (t_prev - t_curr) * pred
179 | return img, info
180 |
181 | def denoise_kv_inf(
182 | model: Flux_kv,
183 | # model input
184 | img: Tensor,
185 | img_ids: Tensor,
186 | source_txt: Tensor,
187 | source_txt_ids: Tensor,
188 | source_vec: Tensor,
189 | target_txt: Tensor,
190 | target_txt_ids: Tensor,
191 | target_vec: Tensor,
192 | # sampling parameters
193 | timesteps: list[float],
194 | target_guidance: float = 4.0,
195 | source_guidance: float = 4.0,
196 | info: dict = {},
197 | ):
198 |
199 | target_guidance_vec = torch.full((img.shape[0],), target_guidance, device=img.device, dtype=img.dtype)
200 | source_guidance_vec = torch.full((img.shape[0],), source_guidance, device=img.device, dtype=img.dtype)
201 |
202 | mask_indices = info['mask_indices']
203 | init_img = img.clone()
204 | z_fe = img[:, mask_indices,...]
205 |
206 | noise_list = []
207 | for i in range(len(timesteps)):
208 | noise = torch.randn(init_img.size(), dtype=init_img.dtype,
209 | layout=init_img.layout, device=init_img.device,
210 | generator=torch.Generator(device=init_img.device).manual_seed(0))
211 | noise_list.append(noise)
212 |
213 | for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])):
214 |
215 | info['t'] = t_curr
216 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
217 |
218 | z_src = (1 - t_curr) * init_img + t_curr * noise_list[i]
219 | z_tar = z_src[:, mask_indices,...] - init_img[:, mask_indices,...] + z_fe
220 |
221 | info['inverse'] = True
222 | info['feature'] = {}
223 | v_src = model(
224 | img=z_src,
225 | img_ids=img_ids,
226 | txt=source_txt,
227 | txt_ids=source_txt_ids,
228 | y=source_vec,
229 | timesteps=t_vec,
230 | guidance=source_guidance_vec,
231 | info=info
232 | )
233 |
234 | info['inverse'] = False
235 | v_tar = model(
236 | img=z_tar,
237 | img_ids=img_ids,
238 | txt=target_txt,
239 | txt_ids=target_txt_ids,
240 | y=target_vec,
241 | timesteps=t_vec,
242 | guidance=target_guidance_vec,
243 | info=info
244 | )
245 |
246 | v_fe = v_tar - v_src[:, mask_indices,...]
247 | z_fe = z_fe + (t_prev - t_curr) * v_fe * info['mask'][:, mask_indices,...]
248 | return z_fe, info
249 |
--------------------------------------------------------------------------------
/flux/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 | import torch
5 | from einops import rearrange
6 | from huggingface_hub import hf_hub_download
7 | from imwatermark import WatermarkEncoder
8 | from safetensors.torch import load_file as load_sft
9 |
10 | from flux.model import Flux, FluxParams
11 | from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
12 | from flux.modules.conditioner import HFEmbedder
13 |
14 |
15 | @dataclass
16 | class ModelSpec:
17 | params: FluxParams
18 | ae_params: AutoEncoderParams
19 | ckpt_path: str | None
20 | ae_path: str | None
21 | repo_id: str | None
22 | repo_flow: str | None
23 | repo_ae: str | None
24 |
25 | configs = {
26 | "flux-dev": ModelSpec(
27 | repo_id="black-forest-labs/FLUX.1-dev",
28 | repo_flow="flux1-dev.safetensors",
29 | repo_ae="ae.safetensors",
30 | ckpt_path=os.getenv("FLUX_DEV"),
31 | params=FluxParams(
32 | in_channels=64,
33 | vec_in_dim=768,
34 | context_in_dim=4096,
35 | hidden_size=3072,
36 | mlp_ratio=4.0,
37 | num_heads=24,
38 | depth=19,
39 | depth_single_blocks=38,
40 | axes_dim=[16, 56, 56],
41 | theta=10_000,
42 | qkv_bias=True,
43 | guidance_embed=True,
44 | ),
45 | ae_path=os.getenv("AE"),
46 | ae_params=AutoEncoderParams(
47 | resolution=256,
48 | in_channels=3,
49 | ch=128,
50 | out_ch=3,
51 | ch_mult=[1, 2, 4, 4],
52 | num_res_blocks=2,
53 | z_channels=16,
54 | scale_factor=0.3611,
55 | shift_factor=0.1159,
56 | ),
57 | ),
58 | "flux-schnell": ModelSpec(
59 | repo_id="black-forest-labs/FLUX.1-schnell",
60 | repo_flow="flux1-schnell.safetensors",
61 | repo_ae="ae.safetensors",
62 | ckpt_path=os.getenv("FLUX_SCHNELL"),
63 | params=FluxParams(
64 | in_channels=64,
65 | vec_in_dim=768,
66 | context_in_dim=4096,
67 | hidden_size=3072,
68 | mlp_ratio=4.0,
69 | num_heads=24,
70 | depth=19,
71 | depth_single_blocks=38,
72 | axes_dim=[16, 56, 56],
73 | theta=10_000,
74 | qkv_bias=True,
75 | guidance_embed=False,
76 | ),
77 | ae_path=os.getenv("AE"),
78 | ae_params=AutoEncoderParams(
79 | resolution=256,
80 | in_channels=3,
81 | ch=128,
82 | out_ch=3,
83 | ch_mult=[1, 2, 4, 4],
84 | num_res_blocks=2,
85 | z_channels=16,
86 | scale_factor=0.3611,
87 | shift_factor=0.1159,
88 | ),
89 | ),
90 | }
91 |
92 |
93 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
94 | if len(missing) > 0 and len(unexpected) > 0:
95 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
96 | print("\n" + "-" * 79 + "\n")
97 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
98 | elif len(missing) > 0:
99 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
100 | elif len(unexpected) > 0:
101 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
102 |
103 |
104 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True, flux_cls=Flux) -> Flux:
105 | # Loading Flux
106 | print("Init model")
107 |
108 | ckpt_path = configs[name].ckpt_path
109 | if (
110 | ckpt_path is None
111 | and configs[name].repo_id is not None
112 | and configs[name].repo_flow is not None
113 | and hf_download
114 | ):
115 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
116 |
117 | with torch.device("meta" if ckpt_path is not None else device):
118 | model = flux_cls(configs[name].params).to(torch.bfloat16)
119 |
120 | if ckpt_path is not None:
121 | print("Loading checkpoint")
122 | # load_sft doesn't support torch.device
123 | sd = load_sft(ckpt_path, device=str(device))
124 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
125 | print_load_warning(missing, unexpected)
126 | return model
127 |
128 |
129 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131 | # return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
132 | return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
133 |
134 |
135 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
136 | # return HFEmbedder("black-forest-labs/FLUX.1-dev", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
137 | return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
138 |
139 |
140 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
141 | ckpt_path = configs[name].ae_path
142 | if (
143 | ckpt_path is None
144 | and configs[name].repo_id is not None
145 | and configs[name].repo_ae is not None
146 | and hf_download
147 | ):
148 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
149 |
150 | # Loading the autoencoder
151 | print("Init AE")
152 | with torch.device("meta" if ckpt_path is not None else device):
153 | ae = AutoEncoder(configs[name].ae_params)
154 |
155 | if ckpt_path is not None:
156 | sd = load_sft(ckpt_path, device=str(device))
157 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
158 | print_load_warning(missing, unexpected)
159 | return ae
160 |
161 |
162 | class WatermarkEmbedder:
163 | def __init__(self, watermark):
164 | self.watermark = watermark
165 | self.num_bits = len(WATERMARK_BITS)
166 | self.encoder = WatermarkEncoder()
167 | self.encoder.set_watermark("bits", self.watermark)
168 |
169 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
170 | """
171 | Adds a predefined watermark to the input image
172 |
173 | Args:
174 | image: ([N,] B, RGB, H, W) in range [-1, 1]
175 |
176 | Returns:
177 | same as input but watermarked
178 | """
179 | image = 0.5 * image + 0.5
180 | squeeze = len(image.shape) == 4
181 | if squeeze:
182 | image = image[None, ...]
183 | n = image.shape[0]
184 | image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
185 | # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
186 | # watermarking libary expects input as cv2 BGR format
187 | for k in range(image_np.shape[0]):
188 | image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
189 | image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
190 | image.device
191 | )
192 | image = torch.clamp(image / 255, min=0.0, max=1.0)
193 | if squeeze:
194 | image = image[0]
195 | image = 2 * image - 1
196 | return image
197 |
198 |
199 | # A fixed 48-bit message that was chosen at random
200 | WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
201 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
202 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
203 | embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
204 |
--------------------------------------------------------------------------------
/gradio_kv_edit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 | from dataclasses import dataclass
5 | from glob import iglob
6 | import argparse
7 | from einops import rearrange
8 | from PIL import ExifTags, Image
9 | import torch
10 | import gradio as gr
11 | import numpy as np
12 | from flux.sampling import prepare
13 | from flux.util import (configs, load_ae, load_clip, load_t5)
14 | from models.kv_edit import Flux_kv_edit
15 |
16 | @dataclass
17 | class SamplingOptions:
18 | source_prompt: str = ''
19 | target_prompt: str = ''
20 | width: int = 1366
21 | height: int = 768
22 | inversion_num_steps: int = 0
23 | denoise_num_steps: int = 0
24 | skip_step: int = 0
25 | inversion_guidance: float = 1.0
26 | denoise_guidance: float = 1.0
27 | seed: int = 42
28 | re_init: bool = False
29 | attn_mask: bool = False
30 | attn_scale: float = 1.0
31 |
32 | class FluxEditor_kv_demo:
33 | def __init__(self, args):
34 | self.args = args
35 | self.device = torch.device(args.device)
36 | self.offload = args.offload
37 |
38 | self.name = args.name
39 | self.is_schnell = args.name == "flux-schnell"
40 |
41 | self.output_dir = 'regress_result'
42 |
43 | self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
44 | self.clip = load_clip(self.device)
45 | self.model = Flux_kv_edit(device="cpu" if self.offload else self.device, name=self.name)
46 | self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
47 |
48 | self.t5.eval()
49 | self.clip.eval()
50 | self.ae.eval()
51 | self.model.eval()
52 | self.info = {}
53 | if self.offload:
54 | self.model.cpu()
55 | torch.cuda.empty_cache()
56 | self.ae.encoder.to(self.device)
57 |
58 | @torch.inference_mode()
59 | def inverse(self, brush_canvas,
60 | source_prompt, target_prompt,
61 | inversion_num_steps, denoise_num_steps,
62 | skip_step,
63 | inversion_guidance, denoise_guidance,seed,
64 | re_init, attn_mask
65 | ):
66 | self.z0 = None
67 | self.zt = None
68 | # self.info = {}
69 | # gc.collect()
70 | if 'feature' in self.info:
71 | key_list = list(self.info['feature'].keys())
72 | for key in key_list:
73 | del self.info['feature'][key]
74 | self.info = {}
75 |
76 | rgba_init_image = brush_canvas["background"]
77 | init_image = rgba_init_image[:,:,:3]
78 | shape = init_image.shape
79 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
80 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
81 | init_image = init_image[:height, :width, :]
82 | rgba_init_image = rgba_init_image[:height, :width, :]
83 |
84 | opts = SamplingOptions(
85 | source_prompt=source_prompt,
86 | target_prompt=target_prompt,
87 | width=width,
88 | height=height,
89 | inversion_num_steps=inversion_num_steps,
90 | denoise_num_steps=denoise_num_steps,
91 | skip_step=0,# no skip step in inverse leads chance to adjest skip_step in edit
92 | inversion_guidance=inversion_guidance,
93 | denoise_guidance=denoise_guidance,
94 | seed=seed,
95 | re_init=re_init,
96 | attn_mask=attn_mask
97 | )
98 | torch.manual_seed(opts.seed)
99 | if torch.cuda.is_available():
100 | torch.cuda.manual_seed_all(opts.seed)
101 | torch.cuda.empty_cache()
102 |
103 | if opts.attn_mask:
104 | rgba_mask = brush_canvas["layers"][0][:height, :width, :]
105 | mask = rgba_mask[:,:,3]/255
106 | mask = mask.astype(int)
107 |
108 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device)
109 | else:
110 | mask = None
111 |
112 | self.init_image = self.encode(init_image, self.device).to(self.device)
113 |
114 | t0 = time.perf_counter()
115 |
116 | if self.offload:
117 | self.ae = self.ae.cpu()
118 | torch.cuda.empty_cache()
119 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
120 |
121 | with torch.no_grad():
122 | inp = prepare(self.t5, self.clip,self.init_image, prompt=opts.source_prompt)
123 |
124 | if self.offload:
125 | self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
126 | torch.cuda.empty_cache()
127 | self.model = self.model.to(self.device)
128 | self.z0,self.zt,self.info = self.model.inverse(inp,mask,opts)
129 |
130 | if self.offload:
131 | self.model.cpu()
132 | torch.cuda.empty_cache()
133 |
134 | t1 = time.perf_counter()
135 | print(f"inversion Done in {t1 - t0:.1f}s.")
136 | return None
137 |
138 |
139 |
140 | @torch.inference_mode()
141 | def edit(self, brush_canvas,
142 | source_prompt, target_prompt,
143 | inversion_num_steps, denoise_num_steps,
144 | skip_step,
145 | inversion_guidance, denoise_guidance,seed,
146 | re_init, attn_mask,attn_scale
147 | ):
148 |
149 | torch.cuda.empty_cache()
150 |
151 | rgba_init_image = brush_canvas["background"]
152 | init_image = rgba_init_image[:,:,:3]
153 | shape = init_image.shape
154 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
155 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
156 | init_image = init_image[:height, :width, :]
157 | rgba_init_image = rgba_init_image[:height, :width, :]
158 |
159 | rgba_mask = brush_canvas["layers"][0][:height, :width, :]
160 | mask = rgba_mask[:,:,3]/255
161 | mask = mask.astype(int)
162 |
163 | rgba_mask[:,:,3] = rgba_mask[:,:,3]//2
164 | masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA'))
165 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device)
166 |
167 | seed = int(seed)
168 | if seed == -1:
169 | seed = torch.randint(0, 2**32, (1,)).item()
170 | opts = SamplingOptions(
171 | source_prompt=source_prompt,
172 | target_prompt=target_prompt,
173 | width=width,
174 | height=height,
175 | inversion_num_steps=inversion_num_steps,
176 | denoise_num_steps=denoise_num_steps,
177 | skip_step=skip_step,
178 | inversion_guidance=inversion_guidance,
179 | denoise_guidance=denoise_guidance,
180 | seed=seed,
181 | re_init=re_init,
182 | attn_mask=attn_mask,
183 | attn_scale=attn_scale
184 | )
185 | if self.offload:
186 |
187 | torch.cuda.empty_cache()
188 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
189 |
190 | torch.manual_seed(opts.seed)
191 | if torch.cuda.is_available():
192 | torch.cuda.manual_seed_all(opts.seed)
193 |
194 | t0 = time.perf_counter()
195 |
196 | with torch.no_grad():
197 | inp_target = prepare(self.t5, self.clip, self.init_image, prompt=opts.target_prompt)
198 |
199 | if self.offload:
200 | self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
201 | torch.cuda.empty_cache()
202 | self.model = self.model.to(self.device)
203 |
204 | x = self.model.denoise(self.z0,self.zt,inp_target,mask,opts,self.info)
205 |
206 | if self.offload:
207 | self.model.cpu()
208 | torch.cuda.empty_cache()
209 | self.ae.decoder.to(x.device)
210 |
211 | with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
212 | x = self.ae.decode(x.to(self.device))
213 |
214 | x = x.clamp(-1, 1)
215 | x = x.float().cpu()
216 | x = rearrange(x[0], "c h w -> h w c")
217 |
218 | if torch.cuda.is_available():
219 | torch.cuda.synchronize()
220 |
221 | output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
222 | if not os.path.exists(self.output_dir):
223 | os.makedirs(self.output_dir)
224 | idx = 0
225 | else:
226 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
227 | if len(fns) > 0:
228 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
229 | else:
230 | idx = 0
231 |
232 | fn = output_name.format(idx=idx)
233 |
234 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
235 | exif_data = Image.Exif()
236 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
237 | exif_data[ExifTags.Base.Make] = "Black Forest Labs"
238 | exif_data[ExifTags.Base.Model] = self.name
239 |
240 | exif_data[ExifTags.Base.ImageDescription] = target_prompt
241 | img.save(fn, exif=exif_data, quality=95, subsampling=0)
242 | masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG')
243 | t1 = time.perf_counter()
244 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
245 |
246 | print("End Edit")
247 | return img
248 |
249 |
250 | @torch.inference_mode()
251 | def encode(self,init_image, torch_device):
252 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
253 | init_image = init_image.unsqueeze(0)
254 | init_image = init_image.to(torch_device)
255 | self.ae.encoder.to(torch_device)
256 |
257 | init_image = self.ae.encode(init_image).to(torch.bfloat16)
258 | return init_image
259 |
260 | def create_demo(model_name: str):
261 | editor = FluxEditor_kv_demo(args)
262 | is_schnell = model_name == "flux-schnell"
263 |
264 | title = r"""
265 | 🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation
266 | """
267 |
268 | description = r"""
269 | Official 🤗 Gradio demo for KV-Edit: Training-Free Image Editing for Precise Background Preservation.
270 |
271 | 💫💫 Here is editing steps:
272 | 1️⃣ Upload your image that needs to be edited.
273 | 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion.
274 | 3️⃣ Use the brush tool to draw your mask area.
275 | 4️⃣ Fill in your target prompt, then adjust the hyperparameters.
276 | 5️⃣ Click the "Edit" button to generate your edited image!
277 |
278 | 🔔🔔 [Important] Less skip steps, "re_init" and "attn_mask" will enhance the editing performance, making the results more aligned with your text but may lead to discontinuous images.
279 | If you fail because of these three, we recommend trying to increase "attn_scale" to increase attention between mask and background.
280 | """
281 | article = r"""
282 | If our work is helpful, please help to ⭐ the Github Repo. Thanks!
283 | """
284 |
285 | badge = r"""
286 | [](https://github.com/Xilluill/KV-Edit)
287 | """
288 |
289 | with gr.Blocks() as demo:
290 | gr.HTML(title)
291 | gr.Markdown(description)
292 |
293 | with gr.Row():
294 | with gr.Column():
295 | source_prompt = gr.Textbox(label="Source Prompt", value='' )
296 | inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps")
297 | target_prompt = gr.Textbox(label="Target Prompt", value='' )
298 | denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps")
299 | brush_canvas = gr.ImageEditor(label="Brush Canvas",
300 | sources=('upload'),
301 | brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'),
302 | interactive=True,
303 | transforms=[],
304 | container=True,
305 | format='png',scale=1)
306 |
307 | inv_btn = gr.Button("inverse")
308 | edit_btn = gr.Button("edit")
309 |
310 |
311 | with gr.Column():
312 | with gr.Accordion("Advanced Options", open=True):
313 |
314 | skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps")
315 | inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
316 | denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
317 | attn_scale = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale")
318 | seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
319 | with gr.Row():
320 | re_init = gr.Checkbox(label="re_init", value=False)
321 | attn_mask = gr.Checkbox(label="attn_mask", value=False)
322 |
323 |
324 | output_image = gr.Image(label="Generated Image")
325 | gr.Markdown(article)
326 | inv_btn.click(
327 | fn=editor.inverse,
328 | inputs=[brush_canvas,
329 | source_prompt, target_prompt,
330 | inversion_num_steps, denoise_num_steps,
331 | skip_step,
332 | inversion_guidance,
333 | denoise_guidance,seed,
334 | re_init, attn_mask
335 | ],
336 | outputs=[output_image]
337 | )
338 | edit_btn.click(
339 | fn=editor.edit,
340 | inputs=[brush_canvas,
341 | source_prompt, target_prompt,
342 | inversion_num_steps, denoise_num_steps,
343 | skip_step,
344 | inversion_guidance,
345 | denoise_guidance,seed,
346 | re_init, attn_mask,attn_scale
347 | ],
348 | outputs=[output_image]
349 | )
350 | return demo
351 |
352 | if __name__ == "__main__":
353 | import argparse
354 | parser = argparse.ArgumentParser(description="Flux")
355 | parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
356 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
357 | parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
358 | parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
359 | parser.add_argument("--port", type=int, default=41032)
360 | args = parser.parse_args()
361 |
362 | demo = create_demo(args.name)
363 |
364 | demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
--------------------------------------------------------------------------------
/gradio_kv_edit_gpu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 | from dataclasses import dataclass
5 | from glob import iglob
6 | import argparse
7 | from einops import rearrange
8 | from PIL import ExifTags, Image
9 | import torch
10 | import gradio as gr
11 | import numpy as np
12 | from flux.sampling import prepare
13 | from flux.util import (configs, load_ae, load_clip, load_t5)
14 | from models.kv_edit import Flux_kv_edit
15 |
16 | @dataclass
17 | class SamplingOptions:
18 | source_prompt: str = ''
19 | target_prompt: str = ''
20 | width: int = 1366
21 | height: int = 768
22 | inversion_num_steps: int = 0
23 | denoise_num_steps: int = 0
24 | skip_step: int = 0
25 | inversion_guidance: float = 1.0
26 | denoise_guidance: float = 1.0
27 | seed: int = 42
28 | re_init: bool = False
29 | attn_mask: bool = False
30 | attn_scale: float = 1.0
31 |
32 | def resize_image(image_array, max_width=1360, max_height=768):
33 | # 将numpy数组转换为PIL图像
34 | if image_array.shape[-1] == 4:
35 | mode = 'RGBA'
36 | else:
37 | mode = 'RGB'
38 |
39 | pil_image = Image.fromarray(image_array, mode=mode)
40 |
41 | # 获取原始图像的宽度和高度
42 | original_width, original_height = pil_image.size
43 |
44 | # 计算缩放比例
45 | width_ratio = max_width / original_width
46 | height_ratio = max_height / original_height
47 |
48 | # 选择较小的缩放比例以确保图像不超过最大宽度和高度
49 | scale_ratio = min(width_ratio, height_ratio)
50 |
51 | # 如果图像已经小于或等于最大分辨率,则不进行缩放
52 | if scale_ratio >= 1:
53 | return image_array
54 |
55 | # 计算新的宽度和高度
56 | new_width = int(original_width * scale_ratio)
57 | new_height = int(original_height * scale_ratio)
58 |
59 | # 缩放图像
60 | resized_image = pil_image.resize((new_width, new_height))
61 |
62 | # 将PIL图像转换回numpy数组
63 | resized_array = np.array(resized_image)
64 |
65 | return resized_array
66 |
67 | class FluxEditor_kv_demo:
68 | def __init__(self, args):
69 | self.args = args
70 | self.gpus = args.gpus
71 | if self.gpus:
72 | self.device = [torch.device("cuda:0"), torch.device("cuda:1")]
73 | else:
74 | self.device = [torch.device(args.device), torch.device(args.device)]
75 |
76 | self.name = args.name
77 | self.is_schnell = args.name == "flux-schnell"
78 |
79 | self.output_dir = 'regress_result'
80 |
81 | self.t5 = load_t5(self.device[1], max_length=256 if self.name == "flux-schnell" else 512)
82 | self.clip = load_clip(self.device[1])
83 | self.model = Flux_kv_edit(self.device[0], name=self.name)
84 | self.ae = load_ae(self.name, device=self.device[1])
85 |
86 | self.t5.eval()
87 | self.clip.eval()
88 | self.ae.eval()
89 | self.model.eval()
90 | self.info = {}
91 | @torch.inference_mode()
92 | def inverse(self, brush_canvas,
93 | source_prompt, target_prompt,
94 | inversion_num_steps, denoise_num_steps,
95 | skip_step,
96 | inversion_guidance, denoise_guidance,seed,
97 | re_init, attn_mask
98 | ):
99 | if hasattr(self, 'z0'):
100 | del self.z0
101 | del self.zt
102 | # self.info = {}
103 | # gc.collect()
104 |
105 | if 'feature' in self.info:
106 | key_list = list(self.info['feature'].keys())
107 | for key in key_list:
108 | del self.info['feature'][key]
109 | self.info = {}
110 |
111 | rgba_init_image = brush_canvas["background"]
112 | init_image = rgba_init_image[:,:,:3]
113 | # init_image = resize_image(init_image)
114 | shape = init_image.shape
115 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
116 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
117 | init_image = init_image[:height, :width, :]
118 | rgba_init_image = rgba_init_image[:height, :width, :]
119 |
120 | opts = SamplingOptions(
121 | source_prompt=source_prompt,
122 | target_prompt=target_prompt,
123 | width=width,
124 | height=height,
125 | inversion_num_steps=inversion_num_steps,
126 | denoise_num_steps=denoise_num_steps,
127 | skip_step=0, # no skip step in inverse leads chance to adjest skip_step in edit
128 | inversion_guidance=inversion_guidance,
129 | denoise_guidance=denoise_guidance,
130 | seed=seed,
131 | re_init=re_init,
132 | attn_mask=attn_mask
133 | )
134 | torch.manual_seed(opts.seed)
135 | if torch.cuda.is_available():
136 | torch.cuda.manual_seed_all(opts.seed)
137 | torch.cuda.empty_cache()
138 |
139 | if opts.attn_mask:
140 | # rgba_mask = resize_image(brush_canvas["layers"][0])[:height, :width, :]
141 | rgba_mask = brush_canvas["layers"][0][:height, :width, :]
142 | mask = rgba_mask[:,:,3]/255
143 | mask = mask.astype(int)
144 |
145 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device[0])
146 | else:
147 | mask = None
148 |
149 | self.init_image = self.encode(init_image, self.device[1]).to(self.device[0])
150 |
151 | t0 = time.perf_counter()
152 |
153 | with torch.no_grad():
154 | inp = prepare(self.t5, self.clip,self.init_image, prompt=opts.source_prompt)
155 | self.z0,self.zt,self.info = self.model.inverse(inp,mask,opts)
156 | t1 = time.perf_counter()
157 | print(f"inversion Done in {t1 - t0:.1f}s.")
158 | return None
159 |
160 |
161 |
162 | @torch.inference_mode()
163 | def edit(self, brush_canvas,
164 | source_prompt, target_prompt,
165 | inversion_num_steps, denoise_num_steps,
166 | skip_step,
167 | inversion_guidance, denoise_guidance,seed,
168 | re_init, attn_mask,attn_scale
169 | ):
170 |
171 | torch.cuda.empty_cache()
172 |
173 | rgba_init_image = brush_canvas["background"]
174 | init_image = rgba_init_image[:,:,:3]
175 | shape = init_image.shape
176 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
177 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
178 | init_image = init_image[:height, :width, :]
179 | rgba_init_image = rgba_init_image[:height, :width, :]
180 | # brush_canvas = brush_canvas["composite"][:,:,:3][:height, :width, :]
181 |
182 | # if np.all(brush_canvas[:,:,0] == brush_canvas[:,:,1]) and np.all(brush_canvas[:,:,1] == brush_canvas[:,:,2]):
183 | rgba_mask = brush_canvas["layers"][0][:height, :width, :]
184 | mask = rgba_mask[:,:,3]/255
185 | mask = mask.astype(int)
186 |
187 |
188 | rgba_mask[:,:,3] = rgba_mask[:,:,3]//2
189 | masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA'))
190 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device[0])
191 |
192 | seed = int(seed)
193 | if seed == -1:
194 | seed = torch.randint(0, 2**32, (1,)).item()
195 | opts = SamplingOptions(
196 | source_prompt=source_prompt,
197 | target_prompt=target_prompt,
198 | width=width,
199 | height=height,
200 | inversion_num_steps=inversion_num_steps,
201 | denoise_num_steps=denoise_num_steps,
202 | skip_step=skip_step,
203 | inversion_guidance=inversion_guidance,
204 | denoise_guidance=denoise_guidance,
205 | seed=seed,
206 | re_init=re_init,
207 | attn_mask=attn_mask,
208 | attn_scale=attn_scale
209 | )
210 | torch.manual_seed(opts.seed)
211 | if torch.cuda.is_available():
212 | torch.cuda.manual_seed_all(opts.seed)
213 |
214 | t0 = time.perf_counter()
215 |
216 |
217 | with torch.no_grad():
218 | inp_target = prepare(self.t5, self.clip, self.init_image, prompt=opts.target_prompt)
219 |
220 | x = self.model.denoise(self.z0.clone(),self.zt,inp_target,mask,opts,self.info)
221 |
222 | with torch.autocast(device_type=self.device[1].type, dtype=torch.bfloat16):
223 | x = self.ae.decode(x.to(self.device[1]))
224 |
225 | x = x.clamp(-1, 1)
226 | x = x.float().cpu()
227 | x = rearrange(x[0], "c h w -> h w c")
228 |
229 | if torch.cuda.is_available():
230 | torch.cuda.synchronize()
231 |
232 | output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
233 | if not os.path.exists(self.output_dir):
234 | os.makedirs(self.output_dir)
235 | idx = 0
236 | else:
237 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
238 | if len(fns) > 0:
239 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
240 | else:
241 | idx = 0
242 |
243 |
244 | fn = output_name.format(idx=idx)
245 |
246 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
247 | exif_data = Image.Exif()
248 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
249 | exif_data[ExifTags.Base.Make] = "Black Forest Labs"
250 | exif_data[ExifTags.Base.Model] = self.name
251 | exif_data[ExifTags.Base.ImageDescription] = target_prompt
252 | img.save(fn, exif=exif_data, quality=95, subsampling=0)
253 | masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG')
254 | t1 = time.perf_counter()
255 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
256 |
257 | print("End Edit")
258 | return img
259 |
260 |
261 | @torch.inference_mode()
262 | def encode(self,init_image, torch_device):
263 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
264 | init_image = init_image.unsqueeze(0)
265 | init_image = init_image.to(torch_device)
266 | self.ae.encoder.to(torch_device)
267 |
268 | init_image = self.ae.encode(init_image).to(torch.bfloat16)
269 | return init_image
270 |
271 | def create_demo(model_name: str):
272 | editor = FluxEditor_kv_demo(args)
273 | is_schnell = model_name == "flux-schnell"
274 |
275 | title = r"""
276 | 🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation
277 | """
278 | one = r"""
279 | We recommend that you try our code locally, you can try several different edits after inverting the image only once!
280 | """
281 |
282 | description = r"""
283 | Official 🤗 Gradio demo for KV-Edit: Training-Free Image Editing for Precise Background Preservation.
284 |
285 | 💫💫 Here is editing steps:
286 | 1️⃣ Upload your image that needs to be edited.
287 | 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion.
288 | 3️⃣ Use the brush tool to draw your mask area.
289 | 4️⃣ Fill in your target prompt, then adjust the hyperparameters.
290 | 5️⃣ Click the "Edit" button to generate your edited image!
291 |
292 | 🔔🔔 [Important] Less skip steps, "re_init" and "attn_mask" will enhance the editing performance, making the results more aligned with your text but may lead to discontinuous images.
293 | If you fail because of these three, we recommend trying to increase "attn_scale" to increase attention between mask and background.
294 | """
295 | article = r"""
296 | If our work is helpful, please help to ⭐ the Github Repo. Thanks!
297 | """
298 |
299 | with gr.Blocks() as demo:
300 | gr.HTML(title)
301 | gr.Markdown(description)
302 |
303 | with gr.Row():
304 | with gr.Column():
305 | source_prompt = gr.Textbox(label="Source Prompt", value='' )
306 | inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps")
307 | target_prompt = gr.Textbox(label="Target Prompt", value='' )
308 | denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps")
309 | # init_image = gr.Image(label="Input Image", visible=True,scale=1)
310 | brush_canvas = gr.ImageEditor(label="Brush Canvas",
311 | sources=('upload'),
312 | brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'),
313 | interactive=True,
314 | transforms=[],
315 | container=True,
316 | format='png')
317 |
318 | inv_btn = gr.Button("inverse")
319 | edit_btn = gr.Button("edit")
320 |
321 |
322 | with gr.Column():
323 | with gr.Accordion("Advanced Options", open=True):
324 | skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps")
325 | inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
326 | denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
327 | attn_scale = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale")
328 | seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
329 | with gr.Row():
330 | re_init = gr.Checkbox(label="re_init", value=False)
331 | attn_mask = gr.Checkbox(label="attn_mask", value=False)
332 |
333 | output_image = gr.Image(label="Generated Image")
334 | gr.Markdown(article)
335 | inv_btn.click(
336 | fn=editor.inverse,
337 | inputs=[brush_canvas,
338 | source_prompt, target_prompt,
339 | inversion_num_steps, denoise_num_steps,
340 | skip_step,
341 | inversion_guidance,
342 | denoise_guidance,seed,
343 | re_init, attn_mask
344 | ],
345 | outputs=[output_image]
346 | )
347 | edit_btn.click(
348 | fn=editor.edit,
349 | inputs=[ brush_canvas,
350 | source_prompt, target_prompt,
351 | inversion_num_steps, denoise_num_steps,
352 | skip_step,
353 | inversion_guidance,
354 | denoise_guidance,seed,
355 | re_init, attn_mask,attn_scale
356 | ],
357 | outputs=[output_image]
358 | )
359 | return demo
360 |
361 | if __name__ == "__main__":
362 | import argparse
363 | parser = argparse.ArgumentParser(description="Flux")
364 | parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
365 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
366 | parser.add_argument("--gpus", action="store_true", help="2 gpu to use")
367 | parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
368 | parser.add_argument("--port", type=int, default=41032)
369 | args = parser.parse_args()
370 |
371 | demo = create_demo(args.name)
372 |
373 | demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
--------------------------------------------------------------------------------
/gradio_kv_edit_inf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 | from dataclasses import dataclass
5 | from glob import iglob
6 | import argparse
7 | from einops import rearrange
8 | from PIL import ExifTags, Image
9 | import torch
10 | import gradio as gr
11 | import numpy as np
12 | from flux.sampling import prepare, unpack
13 | from flux.util import (configs, load_ae, load_clip, load_t5)
14 | from models.kv_edit import Flux_kv_edit_inf
15 |
16 | @dataclass
17 | class SamplingOptions:
18 | source_prompt: str = ''
19 | target_prompt: str = ''
20 | width: int = 1366
21 | height: int = 768
22 | inversion_num_steps: int = 0
23 | denoise_num_steps: int = 0
24 | skip_step: int = 0
25 | inversion_guidance: float = 1.0
26 | denoise_guidance: float = 1.0
27 | seed: int = 42
28 | attn_mask: bool = False
29 | attn_scale: float = 1.0
30 |
31 | class FluxEditor_kv_demo:
32 | def __init__(self, args):
33 | self.args = args
34 | self.device = torch.device(args.device)
35 | self.offload = args.offload
36 | self.name = args.name
37 | self.is_schnell = args.name == "flux-schnell"
38 |
39 | self.output_dir = 'regress_result'
40 |
41 | self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
42 | self.clip = load_clip(self.device)
43 | self.model = Flux_kv_edit_inf(device="cpu" if self.offload else self.device, name=self.name)
44 | self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
45 |
46 | self.t5.eval()
47 | self.clip.eval()
48 | self.ae.eval()
49 | self.model.eval()
50 |
51 | if self.offload:
52 | self.model.cpu()
53 | torch.cuda.empty_cache()
54 | self.ae.encoder.to(self.device)
55 |
56 | @torch.inference_mode()
57 | def edit(self, brush_canvas,
58 | source_prompt, target_prompt,
59 | inversion_num_steps, denoise_num_steps,
60 | skip_step,
61 | inversion_guidance, denoise_guidance,seed,
62 | attn_mask,attn_scale
63 | ):
64 |
65 | torch.cuda.empty_cache()
66 |
67 | rgba_init_image = brush_canvas["background"]
68 | init_image = rgba_init_image[:,:,:3]
69 | shape = init_image.shape
70 | height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
71 | width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
72 | init_image = init_image[:height, :width, :]
73 | rgba_init_image = rgba_init_image[:height, :width, :]
74 |
75 | rgba_mask = brush_canvas["layers"][0][:height, :width, :]
76 | mask = rgba_mask[:,:,3]/255
77 | mask = mask.astype(int)
78 |
79 | rgba_mask[:,:,3] = rgba_mask[:,:,3]//2
80 | masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA'))
81 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device)
82 |
83 | init_image = self.encode(init_image, self.device).to(self.device)
84 |
85 | seed = int(seed)
86 | if seed == -1:
87 | seed = torch.randint(0, 2**32, (1,)).item()
88 | opts = SamplingOptions(
89 | source_prompt=source_prompt,
90 | target_prompt=target_prompt,
91 | width=width,
92 | height=height,
93 | inversion_num_steps=inversion_num_steps,
94 | denoise_num_steps=denoise_num_steps,
95 | skip_step=skip_step,
96 | inversion_guidance=inversion_guidance,
97 | denoise_guidance=denoise_guidance,
98 | seed=seed,
99 | attn_mask=attn_mask,
100 | attn_scale=attn_scale
101 | )
102 |
103 | if self.offload:
104 | self.ae = self.ae.cpu()
105 | torch.cuda.empty_cache()
106 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
107 |
108 | torch.manual_seed(opts.seed)
109 | if torch.cuda.is_available():
110 | torch.cuda.manual_seed_all(opts.seed)
111 |
112 | t0 = time.perf_counter()
113 |
114 | with torch.no_grad():
115 | inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt)
116 | inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
117 |
118 | if self.offload:
119 | self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
120 | torch.cuda.empty_cache()
121 | self.model = self.model.to(self.device)
122 |
123 | x = self.model(inp, inp_target, mask, opts)
124 |
125 | if self.offload:
126 | self.model.cpu()
127 | torch.cuda.empty_cache()
128 | self.ae.decoder.to(x.device)
129 |
130 | with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
131 | x = self.ae.decode(x)
132 |
133 | x = x.clamp(-1, 1)
134 | x = x.float().cpu()
135 | x = rearrange(x[0], "c h w -> h w c")
136 |
137 | if torch.cuda.is_available():
138 | torch.cuda.synchronize()
139 |
140 | output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
141 | if not os.path.exists(self.output_dir):
142 | os.makedirs(self.output_dir)
143 | idx = 0
144 | else:
145 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
146 | if len(fns) > 0:
147 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
148 | else:
149 | idx = 0
150 |
151 | fn = output_name.format(idx=idx)
152 |
153 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
154 | exif_data = Image.Exif()
155 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
156 | exif_data[ExifTags.Base.Make] = "Black Forest Labs"
157 | exif_data[ExifTags.Base.Model] = self.name
158 |
159 | exif_data[ExifTags.Base.ImageDescription] = target_prompt
160 | img.save(fn, exif=exif_data, quality=95, subsampling=0)
161 | masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG')
162 | t1 = time.perf_counter()
163 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
164 |
165 | print("End Edit")
166 | return img
167 |
168 | @torch.inference_mode()
169 | def encode(self,init_image, torch_device):
170 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
171 | init_image = init_image.unsqueeze(0)
172 | init_image = init_image.to(torch_device)
173 | self.ae.encoder.to(torch_device)
174 |
175 | init_image = self.ae.encode(init_image).to(torch.bfloat16)
176 | return init_image
177 |
178 | def create_demo(model_name: str):
179 | editor = FluxEditor_kv_demo(args)
180 | is_schnell = model_name == "flux-schnell"
181 |
182 | title = r"""
183 | 🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation
184 | """
185 |
186 | description = r"""
187 | Inversion free version for KV-Edit: Training-Free Image Editing for Precise Background Preservation.
188 |
189 | 💫💫 Here is editing steps:
190 | # 💫💫 Here is editing steps: (We highly recommend you run our code locally!😘 Only one inversion before multiple editing, very productive!)
191 | 1️⃣ Upload your image that needs to be edited.
192 | 2️⃣ Use the brush tool to draw your mask area.
193 | 3️⃣ Fill in your source prompt and target prompt, then adjust the hyperparameters.
194 | 4️⃣ Click the "Edit" button to generate your edited image!
195 |
196 | 🔔🔔 [Important] Less skip steps and "attn_mask" will enhance the editing performance, making the results more aligned with your text but may lead to discontinuous images.
197 | If you fail because of these two, we recommend trying to increase "attn_scale" to increase attention between mask and background.
198 | """
199 | article = r"""
200 | If our work is helpful, please help to ⭐ the Github Repo. Thanks!
201 | """
202 |
203 | badge = r"""
204 | [](https://github.com/Xilluill/KV-Edit)
205 | """
206 |
207 | with gr.Blocks() as demo:
208 | gr.HTML(title)
209 | gr.Markdown(description)
210 | # gr.Markdown(badge)
211 |
212 | with gr.Row():
213 | with gr.Column():
214 | source_prompt = gr.Textbox(label="Source Prompt", value='' )
215 | inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps")
216 | target_prompt = gr.Textbox(label="Target Prompt", value='' )
217 | denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps")
218 | brush_canvas = gr.ImageEditor(label="Brush Canvas",
219 | sources=('upload'),
220 | brush=gr.Brush(colors=["#ff0000"],color_mode='fixed'),
221 | interactive=True,
222 | transforms=[],
223 | container=True,
224 | format='png',scale=1)
225 | edit_btn = gr.Button("edit")
226 |
227 |
228 | with gr.Column():
229 | with gr.Accordion("Advanced Options", open=True):
230 |
231 | skip_step = gr.Slider(0, 30, 4, step=1, label="Number of skip steps")
232 | inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
233 | denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
234 | attn_scale = gr.Slider(0.0, 5.0, 1, step=0.1, label="attn_scale")
235 | seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
236 | with gr.Row():
237 | attn_mask = gr.Checkbox(label="attn_mask", value=False)
238 |
239 |
240 | output_image = gr.Image(label="Generated Image")
241 | gr.Markdown(article)
242 | edit_btn.click(
243 | fn=editor.edit,
244 | inputs=[brush_canvas,
245 | source_prompt, target_prompt,
246 | inversion_num_steps, denoise_num_steps,
247 | skip_step,
248 | inversion_guidance,
249 | denoise_guidance,seed,
250 | attn_mask,attn_scale
251 | ],
252 | outputs=[output_image]
253 | )
254 | return demo
255 |
256 | if __name__ == "__main__":
257 | import argparse
258 | parser = argparse.ArgumentParser(description="Flux")
259 | parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
260 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
261 | parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
262 | parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
263 | parser.add_argument("--port", type=int, default=41032)
264 | args = parser.parse_args()
265 |
266 | demo = create_demo(args.name)
267 |
268 | demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
--------------------------------------------------------------------------------
/models/kv_edit.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from einops import rearrange,repeat
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from flux.sampling import get_schedule, unpack,denoise_kv,denoise_kv_inf
9 | from flux.util import load_flow_model
10 | from flux.model import Flux_kv
11 |
12 | class only_Flux(torch.nn.Module):
13 | def __init__(self, device,name='flux-dev'):
14 | self.device = device
15 | self.name = name
16 | super().__init__()
17 | self.model = load_flow_model(self.name, device=self.device,flux_cls=Flux_kv)
18 |
19 | def create_attention_mask(self,seq_len, mask_indices, text_len=512, device='cuda'):
20 |
21 | attention_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
22 |
23 | text_indices = torch.arange(0, text_len, device=device)
24 |
25 | mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)
26 |
27 | all_indices = torch.arange(text_len, seq_len, device=device)
28 | background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])
29 |
30 | # text setting
31 | attention_mask[text_indices.unsqueeze(1).expand(-1, seq_len)] = True
32 | attention_mask[text_indices.unsqueeze(1), text_indices] = True
33 | attention_mask[text_indices.unsqueeze(1), background_token_indices] = True
34 |
35 |
36 | # mask setting
37 | # attention_mask[mask_token_indices.unsqueeze(1), background_token_indices] = True
38 | attention_mask[mask_token_indices.unsqueeze(1), text_indices] = True
39 | attention_mask[mask_token_indices.unsqueeze(1), mask_token_indices] = True
40 |
41 | # background setting
42 | # attention_mask[background_token_indices.unsqueeze(1), mask_token_indices] = True
43 | attention_mask[background_token_indices.unsqueeze(1), text_indices] = True
44 | attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True
45 |
46 | return attention_mask.unsqueeze(0)
47 |
48 | def create_attention_scale(self,seq_len, mask_indices, text_len=512, device='cuda',scale = 0):
49 |
50 | attention_scale = torch.zeros(1, seq_len, dtype=torch.bfloat16, device=device) # 相加时广播
51 |
52 |
53 | text_indices = torch.arange(0, text_len, device=device)
54 |
55 | mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)
56 |
57 | all_indices = torch.arange(text_len, seq_len, device=device)
58 | background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])
59 |
60 | attention_scale[0, background_token_indices] = scale #
61 |
62 | return attention_scale.unsqueeze(0)
63 |
64 | class Flux_kv_edit_inf(only_Flux):
65 | def __init__(self, device,name):
66 | super().__init__(device,name)
67 |
68 | @torch.inference_mode()
69 | def forward(self,inp,inp_target,mask:Tensor,opts):
70 |
71 | info = {}
72 | info['feature'] = {}
73 | bs, L, d = inp["img"].shape
74 | h = opts.height // 8
75 | w = opts.width // 8
76 | mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
77 | mask[mask > 0] = 1
78 |
79 | mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
80 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
81 | info['mask'] = mask
82 | bool_mask = (mask.sum(dim=2) > 0.5)
83 | info['mask_indices'] = torch.nonzero(bool_mask)[:,1]
84 | #单独分离inversion
85 | if opts.attn_mask and (~bool_mask).any():
86 | attention_mask = self.create_attention_mask(L+512, info['mask_indices'], device=self.device)
87 | else:
88 | attention_mask = None
89 | info['attention_mask'] = attention_mask
90 |
91 | if opts.attn_scale != 0 and (~bool_mask).any():
92 | attention_scale = self.create_attention_scale(L+512, info['mask_indices'], device=mask.device,scale = opts.attn_scale)
93 | else:
94 | attention_scale = None
95 | info['attention_scale'] = attention_scale
96 |
97 | denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
98 | denoise_timesteps = denoise_timesteps[opts.skip_step:]
99 |
100 | z0 = inp["img"]
101 |
102 | with torch.no_grad():
103 | info['inject'] = True
104 | z_fe, info = denoise_kv_inf(self.model, img=inp["img"], img_ids=inp['img_ids'],
105 | source_txt=inp['txt'], source_txt_ids=inp['txt_ids'], source_vec=inp['vec'],
106 | target_txt=inp_target['txt'], target_txt_ids=inp_target['txt_ids'], target_vec=inp_target['vec'],
107 | timesteps=denoise_timesteps, source_guidance=opts.inversion_guidance, target_guidance=opts.denoise_guidance,
108 | info=info)
109 | mask_indices = info['mask_indices']
110 |
111 | z0[:, mask_indices,...] = z_fe
112 |
113 | z0 = unpack(z0.float(), opts.height, opts.width)
114 | del info
115 | return z0
116 |
117 | class Flux_kv_edit(only_Flux):
118 | def __init__(self, device,name):
119 | super().__init__(device,name)
120 |
121 | @torch.inference_mode()
122 | def forward(self,inp,inp_target,mask:Tensor,opts):
123 | z0,zt,info = self.inverse(inp,mask,opts)
124 | z0 = self.denoise(z0,zt,inp_target,mask,opts,info)
125 | return z0
126 | @torch.inference_mode()
127 | def inverse(self,inp,mask,opts):
128 | info = {}
129 | info['feature'] = {}
130 | bs, L, d = inp["img"].shape
131 | h = opts.height // 8
132 | w = opts.width // 8
133 |
134 | if opts.attn_mask:
135 | mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
136 | mask[mask > 0] = 1
137 |
138 | mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
139 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
140 | bool_mask = (mask.sum(dim=2) > 0.5)
141 | mask_indices = torch.nonzero(bool_mask)[:,1]
142 |
143 | #单独分离inversion
144 | assert not (~bool_mask).all(), "mask is all false"
145 | assert not (bool_mask).all(), "mask is all true"
146 | attention_mask = self.create_attention_mask(L+512, mask_indices, device=mask.device)
147 | info['attention_mask'] = attention_mask
148 |
149 |
150 | denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
151 | denoise_timesteps = denoise_timesteps[opts.skip_step:]
152 |
153 | # 加噪过程
154 | z0 = inp["img"].clone()
155 | info['inverse'] = True
156 | zt, info = denoise_kv(self.model, **inp, timesteps=denoise_timesteps, guidance=opts.inversion_guidance, inverse=True, info=info)
157 | return z0,zt,info
158 |
159 | @torch.inference_mode()
160 | def denoise(self,z0,zt,inp_target,mask:Tensor,opts,info):
161 |
162 | h = opts.height // 8
163 | w = opts.width // 8
164 | L = h * w // 4
165 | mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
166 | mask[mask > 0] = 1
167 |
168 | mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
169 |
170 | mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
171 | info['mask'] = mask
172 | bool_mask = (mask.sum(dim=2) > 0.5)
173 | info['mask_indices'] = torch.nonzero(bool_mask)[:,1]
174 |
175 | denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
176 | denoise_timesteps = denoise_timesteps[opts.skip_step:]
177 |
178 | mask_indices = info['mask_indices']
179 | if opts.re_init:
180 | noise = torch.randn_like(zt)
181 | t = denoise_timesteps[0]
182 | zt_noise = z0 *(1 - t) + noise * t
183 | inp_target["img"] = zt_noise[:, mask_indices,...]
184 | else:
185 | img_name = str(info['t']) + '_' + 'img'
186 | zt = info['feature'][img_name].to(zt.device)
187 | inp_target["img"] = zt[:, mask_indices,...]
188 |
189 | if opts.attn_scale != 0 and (~bool_mask).any():
190 | attention_scale = self.create_attention_scale(L+512, mask_indices, device=mask.device,scale = opts.attn_scale)
191 | else:
192 | attention_scale = None
193 | info['attention_scale'] = attention_scale
194 |
195 | info['inverse'] = False
196 | x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
197 |
198 | z0[:, mask_indices,...] = z0[:, mask_indices,...] * (1 - info['mask'][:, mask_indices,...]) + x * info['mask'][:, mask_indices,...]
199 |
200 | z0 = unpack(z0.float(), opts.height, opts.width)
201 | del info
202 | return z0
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | einops
3 | fire
4 | gradio==5.17.1
5 | huggingface-hub
6 | invisible-watermark
7 | matplotlib
8 | numpy
9 | opencv-python
10 | Pillow
11 | Requests
12 | safetensors
13 | scikit-learn
14 | scipy
15 | scikit-image
16 | torchvision
17 | tqdm
18 | transformers==4.49.0
19 | sentencepiece
--------------------------------------------------------------------------------
/resources/example.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xilluill/KV-Edit/77129a23eb8eff7408df7d8c91f1f8c523ce56ac/resources/example.jpeg
--------------------------------------------------------------------------------
/resources/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xilluill/KV-Edit/77129a23eb8eff7408df7d8c91f1f8c523ce56ac/resources/pipeline.jpg
--------------------------------------------------------------------------------
/resources/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xilluill/KV-Edit/77129a23eb8eff7408df7d8c91f1f8c523ce56ac/resources/teaser.jpg
--------------------------------------------------------------------------------