├── .github
├── example_cir.gif
├── example_cir_with_mask.gif
├── example_t2i.gif
└── teaser.png
├── LICENSE
├── NOTICE
├── README.md
├── compodiff
├── __init__.py
├── model_loader.py
└── models.py
├── demo_search.py
├── requirements.txt
└── setup.py
/.github/example_cir.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/example_cir.gif
--------------------------------------------------------------------------------
/.github/example_cir_with_mask.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/example_cir_with_mask.gif
--------------------------------------------------------------------------------
/.github/example_t2i.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/example_t2i.gif
--------------------------------------------------------------------------------
/.github/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/teaser.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023-present NAVER Corp.
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | CompoDiff
2 | Copyright 2023-present NAVER Corp.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 |
16 | --------------------------------------------------------------------------------------
17 |
18 | This project contains subcomponents with separate copyright notices and license terms.
19 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
20 |
21 | =====
22 |
23 | lucidrains/DALLE2-pytorch
24 | https://github.com/lucidrains/DALLE2-pytorch
25 |
26 |
27 | MIT License
28 |
29 | Copyright (c) 2021 Phil Wang
30 |
31 | Permission is hereby granted, free of charge, to any person obtaining a copy
32 | of this software and associated documentation files (the "Software"), to deal
33 | in the Software without restriction, including without limitation the rights
34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
35 | copies of the Software, and to permit persons to whom the Software is
36 | furnished to do so, subject to the following conditions:
37 |
38 | The above copyright notice and this permission notice shall be included in all
39 | copies or substantial portions of the Software.
40 |
41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 | SOFTWARE.
48 |
49 | =====
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # \[TMLR\] CompoDiff: Versatile Composed Image Retrieval With Latent Diffusion
2 |
3 |
4 |
5 | Official Pytorch implementation of CompoDiff
6 |
7 | [\[Paper\]](https://arxiv.org/abs/2303.11916) [\[OpenReview\]](https://openreview.net/forum?id=mKtlzW0bWc) [\[Demo🤗\]](https://huggingface.co/spaces/navervision/CompoDiff-Aesthetic)
8 |
9 | **[Geonmo Gu](https://geonm.github.io/)\*1, [Sanghyuk Chun](https://sanghyukchun.github.io/home/)\*2, [Wonjae Kim](https://wonjae.kim)2, HeeJae Jun1, Yoohoon Kang1, [Sangdoo Yun](https://sangdooyun.github.io)2**
10 |
11 | 1 NAVER Vision 2 NAVER AI Lab
12 |
13 | \* First two authors contributed equally.
14 |
15 | ## ⭐ Overview
16 |
17 | CompoDiff is a model that utilizes diffusion models for Composed Image Retrieval (CIR) for the first time.
18 |
19 | The use of diffusion models has enabled the introduction of negative text and inpainting into CIR for the first time.
20 |
21 | Moreover, since it operates based on the image feature space of the CLIP-L/14 model, it can be used with various text-to-image models (such as [Graphit](https://huggingface.co/navervision/Graphit-SD), [SD-unCLIP](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip), and [unCLIP (Karlo)](https://huggingface.co/kakaobrain/karlo-v1-alpha-image-variations)).
22 |
23 | We overcame the challenge of collecting training data for CIR by generating a new SynthTriplets18M dataset. We are eager to release it soon and hope it will benefit the CIR research community.
24 |
25 | We believe that CompoDiff has great potential to contribute to the advancement of CIR and look forward to seeing how it performs in the field.
26 |
27 | ## 🚀 NEWS
28 |
29 | - 12/14/2023 - SynthTriplets18M released!
30 | - 04/28/2023 - CompoDiff-Aesthetic released!
31 |
32 | ## 🖼️ CompoDiff-Aesthetic
33 |
34 | In this repository, we first release CompoDiff-Aesthetic model, which is one of the variations of CompoDiff.
35 |
36 | For text guidance, we used the textual embeddings of CLIP-G/14 trained by [open_clip](https://github.com/mlfoundations/open_clip), while the generated output feature is still the visual embeddings of CLIP-L/14.
37 |
38 | We found that the textual embedding extractor of CLIP-G/14 not only performs better than those of CLIP-L/14 but also does not significantly increase computational cost.
39 |
40 | CompoDiff-Aesthetic was trained on high-quality and high-resolution images, and we believe it has the potential to benefit not only CIR but also various generation models.
41 |
42 | We hope that CompoDiff-Aesthetic will be a useful addition to the research community and look forward to seeing how it performs in various applications.
43 |
44 | ## 📚 SynthTriplets18M
45 |
46 | Using text to image diffusion models, we have created a large-scale synthesized triplet dataset.
47 |
48 | https://huggingface.co/datasets/navervision/SynthTriplets18M
49 |
50 | ## Search and Image generation demo
51 | We have set up a demo that can be tested in a local computing environment.
52 |
53 | It can be executed with the following command:
54 |
55 | ```bash
56 | $ git clone https://github.com/navervision/compodiff
57 | $ cd compodiff
58 | $ python demo_search.py
59 | ```
60 |
61 | Demo will be hosted at https://0.0.0.0:8000
62 |
63 | The unCLIP model used for image generation is from https://huggingface.co/kakaobrain/karlo-v1-alpha-image-variations.
64 |
65 | ### How to use demo
66 | #### Usage 1. Project textual embeddings to visual embeddings
67 |
68 |
69 | #### Usage 2. Composed visual embeddings without mask for CIR
70 |
71 |
72 | #### Usage 3. Composed visual embeddings with mask for CIR
73 |
74 |
75 | ## 💡 Usage
76 |
77 | ### Install CompoDiff
78 | ```
79 | $ pip install git+https://github.com/navervision/compodiff.git
80 | ```
81 |
82 | ### Build CompoDiff and CLIP models
83 | ```python
84 | import compodiff
85 | import torch
86 | from PIL import Image
87 | import requests
88 |
89 | device = "cuda" if torch.cuda.is_available() else "cpu"
90 |
91 | # build models
92 | compodiff_model, clip_model, img_preprocess, tokenizer = compodiff.build_model()
93 |
94 | compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device)
95 |
96 | if device != 'cpu':
97 | clip_model = clip_model.half()
98 | ```
99 |
100 | ### Usage 1. Project textual embeddings to visual embeddings
101 | ```python
102 | cfg_image_scale = 0.0
103 | cfg_text_scale = 7.5
104 |
105 | cfg_scale = (cfg_image_scale, cfg_text_scale)
106 |
107 | input_text = "owl carved on the wooden wall"
108 | negative_text = "low quality"
109 |
110 | # tokenize the input_text first.
111 | text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True)
112 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
113 |
114 | negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
115 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
116 |
117 | with torch.no_grad():
118 | # In the case of Usage 1, we do not use an image cond and a mask at all.
119 | image_cond = torch.zeros([1,1,768]).to(device)
120 | mask = torch.zeros([64, 64]).to(device).unsqueeze(0)
121 |
122 | text_cond = clip_model.encode_texts(text_tokens, text_attention_mask)
123 | negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask)
124 |
125 | # do denoising steps here
126 | timesteps = 10
127 | sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=10, cond_scale=cfg_scale, num_samples_per_batch=2)
128 | # NOTE: "sampled_image_features" is not L2-normalized
129 | ```
130 |
131 | ### Usage 2. Composed visual embeddings without mask for CIR
132 | ```python
133 | cfg_image_scale = 1.5
134 | cfg_text_scale = 7.5
135 |
136 | cfg_scale = (cfg_image_scale, cfg_text_scale)
137 |
138 | input_text = "as pencil sketch"
139 | negative_text = "low quality"
140 |
141 | # tokenize the input_text first.
142 | text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True)
143 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
144 |
145 | negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
146 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
147 |
148 | # prepare a reference image
149 | url = "http://images.cocodataset.org/val2017/000000039769.jpg"
150 | image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
151 |
152 | with torch.no_grad():
153 | processed_image = img_preprocess(image, return_tensors='pt')['pixel_values'].to(device)
154 |
155 | # In the case of Usage 2, we do not use a mask at all.
156 | mask = torch.zeros([64, 64]).to(device).unsqueeze(0)
157 |
158 | image_cond = clip_model.encode_images(processed_image)
159 |
160 | text_cond = clip_model.encode_texts(text_tokens, text_attention_mask)
161 | negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask)
162 |
163 | timesteps = 10
164 | sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=timesteps, cond_scale=cfg_scale, num_samples_per_batch=2)
165 |
166 | # NOTE: If you want to apply more of the original image’s context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal.
167 | source_weight = 0.1
168 | sampled_image_features = (1 - source_weight) * sampled_image_features + source_weight * image_cond[0]
169 | ```
170 |
171 | ### Usage 3. Composed visual embeddings with mask for CIR
172 | ```python
173 | cfg_image_scale = 1.5
174 | cfg_text_scale = 7.5
175 |
176 | cfg_scale = (cfg_image_scale, cfg_text_scale)
177 |
178 | input_text = "as pencil sketch"
179 | negative_text = "low quality"
180 |
181 | # tokenize the input_text first.
182 | text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True)
183 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
184 |
185 | negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
186 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
187 |
188 | # prepare a reference image
189 | url = "http://images.cocodataset.org/val2017/000000039769.jpg"
190 | image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
191 |
192 | # prepare a mask image
193 | url = "mask_url"
194 | mask = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
195 |
196 | with torch.no_grad():
197 | processed_image = img_preprocess(image, return_tensors='pt')['pixel_values'].to(device)
198 | processed_mask = img_preprocess(mask, do_normalize=False, return_tensors='pt')['pixel_values'].to(device)
199 | processed_mask = processed_mask[:,:1,:,:]
200 |
201 | masked_processed_image = processed_image * (1 - (processed_mask > 0.5).float())
202 | mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
203 | mask = (mask > 0.5).float()
204 |
205 | image_cond = clip_model.encode_images(masked_processed_image)
206 |
207 | text_cond = clip_model.encode_texts(text_tokens, text_attention_mask)
208 | negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask)
209 |
210 | timesteps = 10
211 | sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=timesteps, cond_scale=cfg_scale, num_samples_per_batch=2)
212 |
213 | # NOTE: If you want to apply more of the original image’s context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal.
214 | source_weight = 0.05
215 | sampled_image_features = (1 - source_weight) * sampled_image_features + source_weight * image_cond[0]
216 | ```
217 |
218 | ### Shotout
219 | K-NN index for the retrieval results are entirely trained using the entire Laion-5B imageset. For this retrieval you do not need to download any images, this is made possible thanks to the great work of [rom1504](https://github.com/rom1504/clip-retrieval).
220 |
221 | ## Citing CompoDiff
222 | If you find this repository useful, please consider giving a start ⭐ and citation:
223 | ```
224 | @article{gu2024compodiff,
225 | title={CompoDiff: Versatile Composed Image Retrieval With Latent Diffusion},
226 | author={Geonmo Gu and Sanghyuk Chun and Wonjae Kim and HeeJae Jun and Yoohoon Kang and Sangdoo Yun},
227 | journal={Transactions on Machine Learning Research},
228 | issn={2835-8856},
229 | year={2024},
230 | url={https://openreview.net/forum?id=mKtlzW0bWc},
231 | note={Expert Certification}
232 | }
233 | ```
234 |
235 | ## License
236 | ```
237 | CompoDiff
238 | Copyright 2023-present NAVER Corp.
239 |
240 | Licensed under the Apache License, Version 2.0 (the "License");
241 | you may not use this file except in compliance with the License.
242 | You may obtain a copy of the License at
243 |
244 | http://www.apache.org/licenses/LICENSE-2.0
245 |
246 | Unless required by applicable law or agreed to in writing, software
247 | distributed under the License is distributed on an "AS IS" BASIS,
248 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
249 | See the License for the specific language governing permissions and
250 | limitations under the License.
251 | ```
252 |
--------------------------------------------------------------------------------
/compodiff/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | CompoDiff
3 | Copyright (c) 2023-present NAVER Corp.
4 | Apache-2.0
5 | """
6 | from .model_loader import *
7 |
--------------------------------------------------------------------------------
/compodiff/model_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | CompoDiff
3 | Copyright (c) 2023-present NAVER Corp.
4 | Apache-2.0
5 | """
6 | import torch
7 | from transformers import PreTrainedModel, PretrainedConfig, CLIPTokenizer, CLIPImageProcessor
8 | try:
9 | from .models import build_compodiff, build_clip
10 | except:
11 | from models import build_compodiff, build_clip
12 |
13 |
14 | class CompoDiffConfig(PretrainedConfig):
15 | model_type = "CompoDiff"
16 |
17 | def __init__(
18 | self,
19 | embed_dim: int = 768,
20 | model_depth: int = 12,
21 | model_dim: int = 64,
22 | model_heads: int = 16,
23 | timesteps: int = 1000,
24 | **kwargs,
25 | ):
26 | self.embed_dim = embed_dim
27 | self.model_depth = model_depth
28 | self.model_dim = model_dim
29 | self.model_heads = model_heads
30 | self.timesteps = timesteps
31 | super().__init__(**kwargs)
32 |
33 |
34 | class CompoDiffModel(PreTrainedModel):
35 | config_class = CompoDiffConfig
36 |
37 | def __init__(self, config):
38 | super().__init__(config)
39 | self.model = build_compodiff(
40 | config.embed_dim,
41 | config.model_depth,
42 | config.model_dim,
43 | config.model_heads,
44 | config.timesteps,
45 | )
46 |
47 | def _init_weights(self, module):
48 | pass
49 |
50 | def sample(self, image_cond, text_cond, negative_text_cond, input_mask, num_samples_per_batch=4, cond_scale=1., timesteps=None, random_seed=None):
51 | return self.model.sample(image_cond, text_cond, negative_text_cond, input_mask, num_samples_per_batch, cond_scale, timesteps, random_seed)
52 |
53 |
54 | def build_model(model_name='navervision/CompoDiff-Aesthetic'):
55 | tokenizer = CLIPTokenizer.from_pretrained('laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')
56 |
57 | size_cond = {'shortest_edge': 224}
58 | preprocess = CLIPImageProcessor(crop_size={'height': 224, 'width': 224},
59 | do_center_crop=True,
60 | do_convert_rgb=True,
61 | do_normalize=True,
62 | do_rescale=True,
63 | do_resize=True,
64 | image_mean=[0.48145466, 0.4578275, 0.40821073],
65 | image_std=[0.26862954, 0.26130258, 0.27577711],
66 | resample=3,
67 | size=size_cond,
68 | )
69 | compodiff = CompoDiffModel.from_pretrained(model_name)
70 |
71 | clip_model = build_clip()
72 |
73 | return compodiff, clip_model, preprocess, tokenizer
74 |
75 |
76 | if __name__ == '__main__':
77 | #''' # convert CompoDiff
78 | compodiff_config = CompoDiffConfig()
79 |
80 | compodiff = CompoDiffModel(compodiff_config)
81 | compodiff.model.load_state_dict(torch.load('/data/data_zoo/logs/stage2_arch.depth12-heads16_lr1e-4_text-bigG_add-art-datasets/checkpoints/model_000710000.pt')['ema_model'])
82 | compodiff_config.save_pretrained('/data/CompoDiff_HF')
83 | compodiff.save_pretrained('/data/CompoDiff_HF')
84 | #'''
85 | #compodiff, clip_model, preprocess_img, tokenizer = build_model()
86 |
--------------------------------------------------------------------------------
/compodiff/models.py:
--------------------------------------------------------------------------------
1 | """
2 | CompoDiff
3 | Copyright (c) 2023-present NAVER Corp.
4 | Apache-2.0
5 | """
6 | import math
7 | import random
8 | from tqdm.auto import tqdm
9 | from functools import partial, wraps
10 | from contextlib import contextmanager
11 | from collections import namedtuple
12 | from pathlib import Path
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from torch.utils.checkpoint import checkpoint
17 | from torch import nn, einsum
18 | import torchvision.transforms as T
19 |
20 | import einops
21 | from einops import rearrange, repeat, reduce
22 | from einops.layers.torch import Rearrange
23 | from einops_exts import rearrange_many, repeat_many, check_shape
24 | from einops_exts.torch import EinopsToAndFrom
25 |
26 | # rotary embeddings
27 | from rotary_embedding_torch import RotaryEmbedding
28 |
29 | from transformers import CLIPTextModel, CLIPVisionModelWithProjection, CLIPImageProcessor
30 |
31 |
32 | class CLIPHF(torch.nn.Module):
33 | def __init__(self, image_model_name = 'openai/clip-vit-large-patch14', text_model_name = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'):
34 | super().__init__()
35 | self.clip_text_model = CLIPTextModel.from_pretrained(text_model_name, torch_dtype=torch.float16)#.to(device).eval()
36 | self.clip_text_model = self.clip_text_model.to(torch.float32)
37 | self.clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(image_model_name, torch_dtype=torch.float16)#.to(device).eval()
38 | self.clip_vision_model = self.clip_vision_model.to(torch.float32)
39 |
40 | def encode_images(self, input_images):
41 | vision_outputs = self.clip_vision_model(pixel_values=input_images.to(self.clip_vision_model.dtype))
42 | return vision_outputs.image_embeds.unsqueeze(1).float()
43 |
44 | def encode_texts(self, tokens, attention_masks):
45 | text_outputs = self.clip_text_model(input_ids=tokens, attention_mask=attention_masks)
46 | return text_outputs.last_hidden_state.float()
47 |
48 | def forward(self, input_images, tokens, attention_masks):
49 | return self.encode_images(input_images), self.encode_texts(tokens, attention_masks)
50 |
51 |
52 | def exists(val):
53 | return val is not None
54 |
55 | def identity(t, *args, **kwargs):
56 | return t
57 |
58 | def first(arr, d = None):
59 | if len(arr) == 0:
60 | return d
61 | return arr[0]
62 |
63 | def maybe(fn):
64 | @wraps(fn)
65 | def inner(x, *args, **kwargs):
66 | if not exists(x):
67 | return x
68 | return fn(x, *args, **kwargs)
69 | return inner
70 |
71 | def default(val, d):
72 | if exists(val):
73 | return val
74 | return d() if callable(d) else d
75 |
76 | def cast_tuple(val, length = None, validate = True):
77 | if isinstance(val, list):
78 | val = tuple(val)
79 |
80 | out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
81 |
82 | if exists(length) and validate:
83 | assert len(out) == length
84 |
85 | return out
86 |
87 | def module_device(module):
88 | if isinstance(module, nn.Identity):
89 | return 'cpu' # It doesn't matter
90 | return next(module.parameters()).device
91 |
92 | def zero_init_(m):
93 | nn.init.zeros_(m.weight)
94 | if exists(m.bias):
95 | nn.init.zeros_(m.bias)
96 |
97 | @contextmanager
98 | def null_context(*args, **kwargs):
99 | yield
100 |
101 | def eval_decorator(fn):
102 | def inner(model, *args, **kwargs):
103 | was_training = model.training
104 | model.eval()
105 | out = fn(model, *args, **kwargs)
106 | model.train(was_training)
107 | return out
108 | return inner
109 |
110 | def is_float_dtype(dtype):
111 | return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
112 |
113 | def is_list_str(x):
114 | if not isinstance(x, (list, tuple)):
115 | return False
116 | return all([type(el) == str for el in x])
117 |
118 | def pad_tuple_to_length(t, length, fillvalue = None):
119 | remain_length = length - len(t)
120 | if remain_length <= 0:
121 | return t
122 | return (*t, *((fillvalue,) * remain_length))
123 |
124 | # checkpointing helper function
125 |
126 | def make_checkpointable(fn, **kwargs):
127 | if isinstance(fn, nn.ModuleList):
128 | return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
129 |
130 | condition = kwargs.pop('condition', None)
131 |
132 | if exists(condition) and not condition(fn):
133 | return fn
134 |
135 | @wraps(fn)
136 | def inner(*args):
137 | input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
138 |
139 | if not input_needs_grad:
140 | return fn(*args)
141 |
142 | return checkpoint(fn, *args)
143 |
144 | return inner
145 |
146 | # for controlling freezing of CLIP
147 |
148 | def set_module_requires_grad_(module, requires_grad):
149 | for param in module.parameters():
150 | param.requires_grad = requires_grad
151 |
152 | def freeze_all_layers_(module):
153 | set_module_requires_grad_(module, False)
154 |
155 | def unfreeze_all_layers_(module):
156 | set_module_requires_grad_(module, True)
157 |
158 | def freeze_model_and_make_eval_(model):
159 | model.eval()
160 | freeze_all_layers_(model)
161 |
162 | # tensor helpers
163 |
164 | def log(t, eps = 1e-12):
165 | return torch.log(t.clamp(min = eps))
166 |
167 | def l2norm(t):
168 | return F.normalize(t, dim = -1)
169 |
170 | # image normalization functions
171 | # ddpms expect images to be in the range of -1 to 1
172 | # but CLIP may otherwise
173 |
174 | def normalize_neg_one_to_one(img):
175 | return img * 2 - 1
176 |
177 | def unnormalize_zero_to_one(normed_img):
178 | return (normed_img + 1) * 0.5
179 |
180 | # gaussian diffusion helper functions
181 |
182 | def extract(a, t, x_shape):
183 | b, *_ = t.shape
184 | out = a.gather(-1, t)
185 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
186 |
187 | def meanflat(x):
188 | return x.mean(dim = tuple(range(1, len(x.shape))))
189 |
190 | def normal_kl(mean1, logvar1, mean2, logvar2):
191 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
192 |
193 | def approx_standard_normal_cdf(x):
194 | return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
195 |
196 | def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
197 | assert x.shape == means.shape == log_scales.shape
198 |
199 | # attempting to correct nan gradients when learned variance is turned on
200 | # in the setting of deepspeed fp16
201 | eps = 1e-12 if x.dtype == torch.float32 else 1e-3
202 |
203 | centered_x = x - means
204 | inv_stdv = torch.exp(-log_scales)
205 | plus_in = inv_stdv * (centered_x + 1. / 255.)
206 | cdf_plus = approx_standard_normal_cdf(plus_in)
207 | min_in = inv_stdv * (centered_x - 1. / 255.)
208 | cdf_min = approx_standard_normal_cdf(min_in)
209 | log_cdf_plus = log(cdf_plus, eps = eps)
210 | log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
211 | cdf_delta = cdf_plus - cdf_min
212 |
213 | log_probs = torch.where(x < -thres,
214 | log_cdf_plus,
215 | torch.where(x > thres,
216 | log_one_minus_cdf_min,
217 | log(cdf_delta, eps = eps)))
218 |
219 | return log_probs
220 |
221 | def cosine_beta_schedule(timesteps, s = 0.008):
222 | """
223 | cosine schedule
224 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
225 | """
226 | steps = timesteps + 1
227 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
228 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
229 | alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
230 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
231 | return torch.clip(betas, 0, 0.999)
232 |
233 |
234 | def linear_beta_schedule(timesteps):
235 | scale = 1000 / timesteps
236 | beta_start = scale * 0.0001
237 | beta_end = scale * 0.02
238 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
239 |
240 |
241 | def quadratic_beta_schedule(timesteps):
242 | scale = 1000 / timesteps
243 | beta_start = scale * 0.0001
244 | beta_end = scale * 0.02
245 | return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
246 |
247 |
248 | def sigmoid_beta_schedule(timesteps):
249 | scale = 1000 / timesteps
250 | beta_start = scale * 0.0001
251 | beta_end = scale * 0.02
252 | betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
253 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
254 |
255 |
256 | class NoiseScheduler(nn.Module):
257 | def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
258 | super().__init__()
259 |
260 | if beta_schedule == "cosine":
261 | betas = cosine_beta_schedule(timesteps)
262 | elif beta_schedule == "linear":
263 | betas = linear_beta_schedule(timesteps)
264 | elif beta_schedule == "quadratic":
265 | betas = quadratic_beta_schedule(timesteps)
266 | elif beta_schedule == "jsd":
267 | betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
268 | elif beta_schedule == "sigmoid":
269 | betas = sigmoid_beta_schedule(timesteps)
270 | else:
271 | raise NotImplementedError()
272 |
273 | alphas = 1. - betas
274 | alphas_cumprod = torch.cumprod(alphas, axis = 0)
275 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
276 |
277 | timesteps, = betas.shape
278 | self.num_timesteps = int(timesteps)
279 |
280 | if loss_type == 'l1':
281 | loss_fn = F.l1_loss
282 | elif loss_type == 'l2':
283 | loss_fn = F.mse_loss
284 | elif loss_type == 'huber':
285 | loss_fn = F.smooth_l1_loss
286 | else:
287 | raise NotImplementedError()
288 |
289 | self.loss_type = loss_type
290 | self.loss_fn = loss_fn
291 |
292 | # register buffer helper function to cast double back to float
293 |
294 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
295 |
296 | register_buffer('betas', betas)
297 | register_buffer('alphas_cumprod', alphas_cumprod)
298 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
299 |
300 | # calculations for diffusion q(x_t | x_{t-1}) and others
301 |
302 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
303 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
304 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
305 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
306 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
307 |
308 | # calculations for posterior q(x_{t-1} | x_t, x_0)
309 |
310 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
311 |
312 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
313 |
314 | register_buffer('posterior_variance', posterior_variance)
315 |
316 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
317 |
318 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
319 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
320 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
321 |
322 | # p2 loss reweighting
323 |
324 | self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
325 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
326 |
327 | def sample_random_times(self, batch):
328 | return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)
329 |
330 | def q_posterior(self, x_start, x_t, t):
331 | posterior_mean = (
332 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
333 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
334 | )
335 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
336 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
337 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
338 |
339 | def q_sample(self, x_start, t, noise = None):
340 | noise = default(noise, lambda: torch.randn_like(x_start))
341 |
342 | return (
343 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
344 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
345 | )
346 |
347 | def calculate_v(self, x_start, t, noise = None):
348 | return (
349 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
350 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
351 | )
352 |
353 | def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
354 | shape = x_from.shape
355 | noise = default(noise, lambda: torch.randn_like(x_from))
356 |
357 | alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
358 | sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
359 | alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
360 | sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
361 |
362 | return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
363 |
364 | def predict_start_from_v(self, x_t, t, v):
365 | return (
366 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
367 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
368 | )
369 |
370 | def predict_start_from_noise(self, x_t, t, noise):
371 | return (
372 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
373 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
374 | )
375 |
376 | def predict_noise_from_start(self, x_t, t, x0):
377 | return (
378 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
379 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
380 | )
381 |
382 | def p2_reweigh_loss(self, loss, times):
383 | if not self.has_p2_loss_reweighting:
384 | return loss
385 | return loss * extract(self.p2_loss_weight, times, loss.shape)
386 |
387 | # diffusion prior
388 |
389 | class LayerNorm(nn.Module):
390 | def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
391 | super().__init__()
392 | self.eps = eps
393 | self.fp16_eps = fp16_eps
394 | self.stable = stable
395 | self.g = nn.Parameter(torch.ones(dim))
396 |
397 | def forward(self, x):
398 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
399 |
400 | if self.stable:
401 | x = x / x.amax(dim = -1, keepdim = True).detach()
402 |
403 | var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
404 | mean = torch.mean(x, dim = -1, keepdim = True)
405 | return (x - mean) * (var + eps).rsqrt() * self.g
406 |
407 | class ChanLayerNorm(nn.Module):
408 | def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
409 | super().__init__()
410 | self.eps = eps
411 | self.fp16_eps = fp16_eps
412 | self.stable = stable
413 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
414 |
415 | def forward(self, x):
416 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
417 |
418 | if self.stable:
419 | x = x / x.amax(dim = 1, keepdim = True).detach()
420 |
421 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
422 | mean = torch.mean(x, dim = 1, keepdim = True)
423 | return (x - mean) * (var + eps).rsqrt() * self.g
424 |
425 | class Residual(nn.Module):
426 | def __init__(self, fn):
427 | super().__init__()
428 | self.fn = fn
429 |
430 | def forward(self, x, **kwargs):
431 | return self.fn(x, **kwargs) + x
432 |
433 | # mlp
434 |
435 | class MLP(nn.Module):
436 | def __init__(
437 | self,
438 | dim_in,
439 | dim_out,
440 | *,
441 | expansion_factor = 2.,
442 | depth = 2,
443 | norm = False,
444 | ):
445 | super().__init__()
446 | hidden_dim = int(expansion_factor * dim_out)
447 | norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
448 |
449 | layers = [nn.Sequential(
450 | nn.Linear(dim_in, hidden_dim),
451 | nn.SiLU(),
452 | norm_fn()
453 | )]
454 |
455 | for _ in range(depth - 1):
456 | layers.append(nn.Sequential(
457 | nn.Linear(hidden_dim, hidden_dim),
458 | nn.SiLU(),
459 | norm_fn()
460 | ))
461 |
462 | layers.append(nn.Linear(hidden_dim, dim_out))
463 | self.net = nn.Sequential(*layers)
464 |
465 | def forward(self, x):
466 | return self.net(x.float())
467 |
468 |
469 | class SinusoidalPosEmb(nn.Module):
470 | def __init__(self, dim):
471 | super().__init__()
472 | self.dim = dim
473 |
474 | def forward(self, x):
475 | dtype, device = x.dtype, x.device
476 | assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
477 |
478 | half_dim = self.dim // 2
479 | emb = math.log(10000) / (half_dim - 1)
480 | emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
481 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
482 | return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
483 |
484 | # relative positional bias for causal transformer
485 |
486 | class RelPosBias(nn.Module):
487 | def __init__(
488 | self,
489 | heads = 8,
490 | num_buckets = 32,
491 | max_distance = 128,
492 | ):
493 | super().__init__()
494 | self.num_buckets = num_buckets
495 | self.max_distance = max_distance
496 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
497 |
498 | @staticmethod
499 | def _relative_position_bucket(
500 | relative_position,
501 | num_buckets = 32,
502 | max_distance = 128
503 | ):
504 | n = -relative_position
505 | n = torch.max(n, torch.zeros_like(n))
506 |
507 | max_exact = num_buckets // 2
508 | is_small = n < max_exact
509 |
510 | val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
511 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
512 | return torch.where(is_small, n, val_if_large)
513 |
514 | def forward(self, i, j, *, device):
515 | q_pos = torch.arange(i, dtype = torch.long, device = device)
516 | k_pos = torch.arange(j, dtype = torch.long, device = device)
517 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
518 | rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
519 | values = self.relative_attention_bias(rp_bucket)
520 | return rearrange(values, 'i j h -> h i j')
521 |
522 | # feedforward
523 |
524 | class SwiGLU(nn.Module):
525 | """ used successfully in https://arxiv.org/abs/2204.0231 """
526 | def forward(self, x):
527 | x, gate = x.chunk(2, dim = -1)
528 | return x * F.silu(gate)
529 |
530 | def FeedForward(
531 | dim,
532 | mult = 4,
533 | dropout = 0.,
534 | post_activation_norm = False
535 | ):
536 | """ post-activation norm https://arxiv.org/abs/2110.09456 """
537 |
538 | inner_dim = int(mult * dim)
539 | return nn.Sequential(
540 | LayerNorm(dim),
541 | nn.Linear(dim, inner_dim * 2, bias = False),
542 | SwiGLU(),
543 | LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
544 | nn.Dropout(dropout),
545 | nn.Linear(inner_dim, dim, bias = False)
546 | )
547 |
548 | # attention
549 |
550 | class Attention(nn.Module):
551 | def __init__(
552 | self,
553 | dim,
554 | *,
555 | dim_head = 64,
556 | heads = 8,
557 | dropout = 0.,
558 | causal = False,
559 | rotary_emb = None,
560 | cosine_sim = True,
561 | cosine_sim_scale = 16
562 | ):
563 | super().__init__()
564 | self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
565 | self.cosine_sim = cosine_sim
566 |
567 | self.heads = heads
568 | inner_dim = dim_head * heads
569 |
570 | self.causal = causal
571 | self.norm = LayerNorm(dim)
572 | self.dropout = nn.Dropout(dropout)
573 |
574 | self.null_kv = nn.Parameter(torch.randn(2, dim_head))
575 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
576 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
577 |
578 | self.rotary_emb = rotary_emb
579 |
580 | self.to_out = nn.Sequential(
581 | nn.Linear(inner_dim, dim, bias = False),
582 | LayerNorm(dim)
583 | )
584 |
585 | def forward(self, x, mask = None, attn_bias = None):
586 | b, n, device = *x.shape[:2], x.device
587 |
588 | x = self.norm(x)
589 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
590 |
591 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
592 | q = q * self.scale
593 |
594 | # rotary embeddings
595 |
596 | if exists(self.rotary_emb):
597 | q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
598 |
599 | # add null key / value for classifier free guidance in prior net
600 |
601 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
602 | k = torch.cat((nk, k), dim = -2)
603 | v = torch.cat((nv, v), dim = -2)
604 |
605 | # whether to use cosine sim
606 |
607 | if self.cosine_sim:
608 | q, k = map(l2norm, (q, k))
609 |
610 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
611 |
612 | # calculate query / key similarities
613 |
614 | sim = einsum('b h i d, b j d -> b h i j', q, k)
615 |
616 | # relative positional encoding (T5 style)
617 |
618 | if exists(attn_bias):
619 | sim = sim + attn_bias
620 |
621 | # masking
622 |
623 | max_neg_value = -torch.finfo(sim.dtype).max
624 |
625 | if exists(mask):
626 | mask = F.pad(mask, (1, 0), value = True)
627 | mask = rearrange(mask, 'b j -> b 1 1 j')
628 | sim = sim.masked_fill(~mask, max_neg_value)
629 |
630 | if self.causal:
631 | i, j = sim.shape[-2:]
632 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
633 | sim = sim.masked_fill(causal_mask, max_neg_value)
634 |
635 | # attention
636 |
637 | attn = sim.softmax(dim = -1, dtype = torch.float32)
638 | attn = attn.type(sim.dtype)
639 |
640 | attn = self.dropout(attn)
641 |
642 | # aggregate values
643 |
644 | out = einsum('b h i j, b j d -> b h i d', attn, v)
645 |
646 | out = rearrange(out, 'b h n d -> b n (h d)')
647 | return self.to_out(out)
648 |
649 | class CrossAttention(nn.Module):
650 | def __init__(
651 | self,
652 | dim,
653 | *,
654 | context_dim = None,
655 | dim_head = 64,
656 | heads = 8,
657 | dropout = 0.,
658 | norm_context = False,
659 | cosine_sim = False,
660 | cosine_sim_scale = 16
661 | ):
662 | super().__init__()
663 | self.cosine_sim = cosine_sim
664 | self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
665 | self.heads = heads
666 | inner_dim = dim_head * heads
667 |
668 | context_dim = default(context_dim, dim)
669 |
670 | self.norm = LayerNorm(dim)
671 | self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
672 | self.dropout = nn.Dropout(dropout)
673 |
674 | self.null_kv = nn.Parameter(torch.randn(2, dim_head))
675 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
676 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
677 |
678 | self.to_out = nn.Sequential(
679 | nn.Linear(inner_dim, dim, bias = False),
680 | LayerNorm(dim)
681 | )
682 |
683 | def forward(self, x, context, mask = None):
684 | b, n, device = *x.shape[:2], x.device
685 |
686 | x = self.norm(x)
687 | context = self.norm_context(context)
688 |
689 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
690 |
691 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
692 |
693 | # add null key / value for classifier free guidance in prior net
694 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
695 |
696 | k = torch.cat((nk, k), dim = -2)
697 | v = torch.cat((nv, v), dim = -2)
698 |
699 | if self.cosine_sim:
700 | q, k = map(l2norm, (q, k))
701 |
702 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
703 |
704 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
705 | max_neg_value = -torch.finfo(sim.dtype).max
706 |
707 | if exists(mask):
708 | mask = F.pad(mask, (1, 0), value = True)
709 | mask = rearrange(mask, 'b j -> b 1 1 j')
710 | sim = sim.masked_fill(~mask, max_neg_value)
711 |
712 | attn = sim.softmax(dim = -1, dtype = torch.float32)
713 | attn = attn.type(sim.dtype)
714 |
715 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
716 | out = rearrange(out, 'b h n d -> b n (h d)')
717 | return self.to_out(out)
718 |
719 |
720 | class CrossTransformer(nn.Module):
721 | def __init__(
722 | self,
723 | *,
724 | dim,
725 | depth,
726 | dim_head = 64,
727 | heads = 8,
728 | ff_mult = 4,
729 | norm_in = False,
730 | norm_out = True,
731 | attn_dropout = 0.,
732 | ff_dropout = 0.,
733 | final_proj = True,
734 | normformer = False,
735 | rotary_emb = True,
736 | causal = False,
737 | context_dim = None,
738 | timesteps = None,
739 | ):
740 | super().__init__()
741 | self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
742 |
743 | self.rel_pos_bias = RelPosBias(heads = heads)
744 |
745 | rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
746 |
747 | self.layers = nn.ModuleList([])
748 | for _ in range(depth):
749 | self.layers.append(nn.ModuleList([
750 | Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
751 | CrossAttention(dim = dim, context_dim = context_dim, dim_head = dim_head, dropout = attn_dropout),
752 | FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
753 | ]))
754 |
755 | self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
756 | self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
757 |
758 | def forward(self, x, context, mask = None):
759 | n, device = x.shape[1], x.device
760 |
761 | x = self.init_norm(x)
762 |
763 | attn_bias = self.rel_pos_bias(n, n + 1, device = device)
764 |
765 | for attn, cross_attn, ff in self.layers:
766 | x = attn(x, attn_bias = attn_bias) + x
767 | x = cross_attn(x, context, mask) + x
768 | x = ff(x) + x
769 |
770 | out = self.norm(x)
771 | return self.project_out(out)
772 |
773 |
774 | class CompoDiffNetwork(nn.Module):
775 | def __init__(
776 | self,
777 | dim,
778 | num_timesteps = None,
779 | max_text_len = 77,
780 | cross = False,
781 | text_model_name = None,
782 | **kwargs
783 | ):
784 | super().__init__()
785 | self.dim = dim
786 |
787 | self.to_text_embeds = nn.Sequential(
788 | nn.Linear(1280, self.dim),
789 | )
790 |
791 | self.continuous_embedded_time = not exists(num_timesteps)
792 |
793 | num_time_embeds = 1
794 | self.to_time_embeds = nn.Sequential(
795 | nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
796 | Rearrange('b (n d) -> b n d', n = num_time_embeds)
797 | )
798 | self.to_mask_embeds = nn.Sequential(
799 | Rearrange('b h w -> b (h w)'),
800 | MLP(4096, dim),
801 | Rearrange('b (n d) -> b n d', n = 1)
802 | )
803 |
804 | self.transformer = CrossTransformer(dim = dim, **kwargs)
805 |
806 | def forward_with_cond_scale(
807 | self,
808 | image_embed,
809 | image_cond,
810 | text_cond,
811 | input_mask,
812 | diffusion_timesteps,
813 | text_cond_uc,
814 | cond_scale = 1.,
815 | ):
816 | if cond_scale == 1.:
817 | logits = self.forward(image_embed, image_cond, text_cond, input_mask, diffusion_timesteps, text_cond_uc)
818 | return logits
819 | else:
820 | # make it triple!
821 | '''
822 | logits, null_image_logits, null_text_logits
823 | '''
824 | image_embed = torch.cat([image_embed] * 3)
825 | image_cond = torch.cat([image_cond, image_cond, torch.zeros_like(image_cond)])
826 | text_cond = torch.cat([text_cond, text_cond_uc, text_cond_uc])
827 | input_mask = torch.cat([input_mask] * 3)
828 | diffusion_timesteps = torch.cat([diffusion_timesteps] * 3)
829 |
830 | logits, null_text_logits, null_all_logits = self.forward(image_embed, image_cond, text_cond, input_mask, diffusion_timesteps).chunk(3)
831 | return null_all_logits + (logits - null_text_logits) * cond_scale[1] + (null_text_logits - null_all_logits) * cond_scale[0]
832 |
833 | def forward(
834 | self,
835 | image_embed,
836 | image_cond,
837 | text_cond,
838 | input_mask,
839 | diffusion_timesteps,
840 | ):
841 | batch, n_image_embed, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
842 |
843 | text_cond = self.to_text_embeds(text_cond)
844 |
845 | if self.continuous_embedded_time:
846 | diffusion_timesteps = diffusion_timesteps.type(dtype)
847 |
848 | time_embed = self.to_time_embeds(diffusion_timesteps)
849 |
850 | mask_embed = self.to_mask_embeds(input_mask)
851 |
852 | tokens = torch.cat((
853 | image_embed,
854 | time_embed,
855 | ), dim = -2)
856 |
857 | context_embed = torch.cat([text_cond, image_cond, mask_embed], dim=1)
858 | tokens = self.transformer(tokens, context=context_embed)
859 |
860 | pred_image_embed = tokens[..., :1, :]
861 |
862 | return pred_image_embed
863 |
864 | class CompoDiff(nn.Module):
865 | def __init__(
866 | self,
867 | net,
868 | *,
869 | image_embed_dim = None,
870 | timesteps = 1000,
871 | predict_x_start = True,
872 | loss_type = "l2",
873 | beta_schedule = "cosine",
874 | condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
875 | sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
876 | sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
877 | training_clamp_l2norm = False,
878 | init_image_embed_l2norm = False,
879 | image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
880 | ):
881 | super().__init__()
882 |
883 | self.sample_timesteps = None
884 |
885 | self.noise_scheduler = NoiseScheduler(
886 | beta_schedule = beta_schedule,
887 | timesteps = timesteps,
888 | loss_type = loss_type
889 | )
890 |
891 | self.net = net
892 | self.image_embed_dim = image_embed_dim
893 |
894 | self.condition_on_text_encodings = condition_on_text_encodings
895 |
896 | self.predict_x_start = predict_x_start
897 |
898 | self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
899 |
900 | # whether to force an l2norm, similar to clipping denoised, when sampling
901 |
902 | self.sampling_clamp_l2norm = sampling_clamp_l2norm
903 | self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
904 |
905 | self.training_clamp_l2norm = training_clamp_l2norm
906 | self.init_image_embed_l2norm = init_image_embed_l2norm
907 |
908 | # device tracker
909 |
910 | self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
911 |
912 | @property
913 | def device(self):
914 | return self._dummy.device
915 |
916 | def l2norm_clamp_embed(self, image_embed):
917 | return l2norm(image_embed) * self.image_embed_scale
918 |
919 | @torch.no_grad()
920 | def p_sample_loop_ddim(self, shape, image_cond, text_cond, negative_text_cond, input_mask, timesteps, eta = 1., cond_scale = 1., random_seed=None):
921 | batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
922 |
923 | times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
924 |
925 | times = list(reversed(times.int().tolist()))
926 | time_pairs = list(zip(times[:-1], times[1:]))
927 |
928 | if random_seed is None:
929 | image_embed = noise = torch.randn(shape).to(device)
930 | else:
931 | image_embed = noise = torch.randn(shape, generator=torch.manual_seed(random_seed)).to(device)
932 |
933 | x_start = None # for self-conditioning
934 |
935 | if self.init_image_embed_l2norm:
936 | image_embed = l2norm(image_embed) * self.image_embed_scale
937 |
938 | for time, time_next in tqdm(time_pairs, desc = 'CompoDiff sampling loop'):
939 | alpha = alphas[time]
940 | alpha_next = alphas[time_next]
941 |
942 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
943 |
944 | pred = self.net.forward_with_cond_scale(image_embed, image_cond, text_cond, input_mask, time_cond, text_cond_uc=negative_text_cond, cond_scale = cond_scale)
945 |
946 | # derive x0
947 |
948 | if self.predict_x_start:
949 | x_start = pred
950 | else:
951 | x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
952 |
953 | # clip x0 before maybe predicting noise
954 |
955 | if not self.predict_x_start:
956 | x_start.clamp_(-1., 1.)
957 |
958 | if self.predict_x_start and self.sampling_clamp_l2norm:
959 | x_start = self.l2norm_clamp_embed(x_start)
960 |
961 | # predict noise
962 |
963 | if self.predict_x_start or self.predict_v:
964 | pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
965 | else:
966 | pred_noise = pred
967 |
968 | if time_next < 0:
969 | image_embed = x_start
970 | continue
971 |
972 | c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
973 | c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
974 | noise = torch.randn_like(image_embed) if time_next > 0 else 0.
975 |
976 | image_embed = x_start * alpha_next.sqrt() + \
977 | c1 * noise + \
978 | c2 * pred_noise
979 |
980 | if self.predict_x_start and self.sampling_final_clamp_l2norm:
981 | image_embed = self.l2norm_clamp_embed(image_embed)
982 |
983 | return image_embed
984 |
985 | @torch.no_grad()
986 | def p_sample_loop(self, shape, image_cond, text_cond, negative_text_cond, input_mask, cond_scale = 1., timesteps = None, random_seed = None):
987 | timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
988 | assert timesteps <= self.noise_scheduler.num_timesteps
989 | is_ddim = timesteps < self.noise_scheduler.num_timesteps
990 |
991 | normalized_image_embed = self.p_sample_loop_ddim(shape, image_cond, text_cond, negative_text_cond, input_mask, cond_scale=cond_scale, timesteps=timesteps, random_seed=random_seed)
992 |
993 | image_embed = normalized_image_embed / self.image_embed_scale
994 | return image_embed
995 |
996 | @torch.no_grad()
997 | @eval_decorator
998 | def sample(
999 | self,
1000 | image_cond,
1001 | text_cond,
1002 | negative_text_cond,
1003 | input_mask,
1004 | num_samples_per_batch=4,
1005 | cond_scale=1.,
1006 | timesteps=None,
1007 | random_seed=None):
1008 |
1009 | timesteps = default(timesteps, self.sample_timesteps)
1010 |
1011 | if image_cond is not None:
1012 | image_cond = repeat(image_cond, 'b ... -> (b r) ...', r = num_samples_per_batch)
1013 | text_cond = repeat(text_cond, 'b ... -> (b r) ...', r = num_samples_per_batch)
1014 | input_mask = repeat(input_mask, 'b ... -> (b r) ...', r = num_samples_per_batch)
1015 | negative_text_cond = repeat(negative_text_cond, 'b ... -> (b r) ...', r = num_samples_per_batch)
1016 |
1017 | batch_size = text_cond.shape[0]
1018 | image_embed_dim = self.image_embed_dim
1019 |
1020 | image_embeds = self.p_sample_loop((batch_size, 1, image_embed_dim), image_cond, text_cond, negative_text_cond, input_mask, cond_scale = cond_scale, timesteps = timesteps, random_seed = random_seed)
1021 |
1022 | image_embeds = rearrange(image_embeds, '(b r) 1 d -> b r d', r = num_samples_per_batch)
1023 |
1024 | return torch.mean(image_embeds, dim=1)
1025 |
1026 | def forward(
1027 | self,
1028 | input_image_embed = None,
1029 | target_image_embed = None,
1030 | image_cond = None,
1031 | text_cond = None,
1032 | input_mask = None,
1033 | text_cond_uc = None,
1034 | *args,
1035 | **kwargs
1036 | ):
1037 | # timestep conditioning from ddpm
1038 | batch, device = input_image_embed.shape[0], input_image_embed.device
1039 | times = self.noise_scheduler.sample_random_times(batch)
1040 |
1041 | # scale image embed (Katherine)
1042 | input_image_embed *= self.image_embed_scale
1043 |
1044 | # calculate forward loss
1045 |
1046 | loss = self.p_losses(input_image_embed = input_image_embed,
1047 | target_image_embed = target_image_embed,
1048 | image_cond = image_cond,
1049 | text_cond = text_cond,
1050 | input_mask = input_mask,
1051 | times = times,
1052 | text_cond_uc = text_cond_uc,
1053 | *args, **kwargs)
1054 | return loss
1055 |
1056 | def build_compodiff(embed_dim,
1057 | model_depth,
1058 | model_dim,
1059 | model_heads,
1060 | timesteps,
1061 | ):
1062 |
1063 | compodiff_network = CompoDiffNetwork(
1064 | dim = embed_dim,
1065 | depth = model_depth,
1066 | dim_head = model_dim,
1067 | heads = model_heads,
1068 | )
1069 |
1070 | compodiff = CompoDiff(
1071 | net = compodiff_network,
1072 | image_embed_dim=embed_dim,
1073 | timesteps = timesteps,
1074 | condition_on_text_encodings = True,
1075 | image_embed_scale = 1.0,
1076 | sampling_clamp_l2norm = False,
1077 | training_clamp_l2norm = False,
1078 | init_image_embed_l2norm = False,
1079 | predict_x_start = True,
1080 | )
1081 |
1082 | return compodiff
1083 |
1084 | def build_clip():
1085 | clip_model = CLIPHF(image_model_name='openai/clip-vit-large-patch14',
1086 | text_model_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
1087 | )
1088 | return clip_model
1089 |
1090 |
1091 |
--------------------------------------------------------------------------------
/demo_search.py:
--------------------------------------------------------------------------------
1 | """
2 | CompoDiff
3 | Copyright (c) 2023-present NAVER Corp.
4 | Apache-2.0
5 | """
6 | import os
7 | import numpy as np
8 | import base64
9 | import requests
10 | import json
11 | import time
12 | import torch
13 | import torch.nn.functional as F
14 | import gradio as gr
15 | from clip_retrieval.clip_client import ClipClient
16 | import types
17 | from typing import Union, List, Optional, Callable
18 | import torch
19 | from diffusers import UnCLIPImageVariationPipeline
20 | from torchvision import transforms
21 | from torchvision.transforms.functional import to_pil_image, pil_to_tensor
22 | from PIL import Image
23 | import compodiff
24 |
25 |
26 | def load_models():
27 | ### build model
28 | print("\tbuilding CompoDiff")
29 |
30 | compodiff_model, clip_model, img_preprocess, tokenizer = compodiff.build_model()
31 |
32 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
33 |
34 | compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device)
35 |
36 | if device != 'cpu':
37 | clip_model = clip_model.half()
38 |
39 | model_dict = {}
40 | model_dict['compodiff'] = compodiff_model
41 | model_dict['clip_model'] = clip_model
42 | model_dict['img_preprocess'] = img_preprocess
43 | model_dict['tokenizer'] = tokenizer
44 | model_dict['device'] = device
45 |
46 | return model_dict
47 |
48 |
49 | @torch.no_grad()
50 | def l2norm(features):
51 | return features / features.norm(p=2, dim=-1, keepdim=True)
52 |
53 |
54 | def predict(images, input_text, negative_text, step, cfg_image_scale, cfg_text_scale, do_generate, source_mixing_weight):
55 | '''
56 | image_source, text_input, negative_text_input, mask_text_input, steps_input, cfg_scale, do_generate, cfg_attn_target
57 | '''
58 | device = model_dict['device']
59 |
60 | step = int(step)
61 | step = step + 1 if step < 1000 else step
62 |
63 | cfg_scale = (cfg_image_scale, cfg_text_scale)
64 |
65 | text = input_text
66 |
67 | if images is None:
68 | # t2i
69 | cfg_scale = (1.0, cfg_text_scale)
70 | text_token_dict = model_dict['tokenizer'](text=text, return_tensors='pt', padding='max_length', truncation=True)
71 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
72 |
73 | negative_text_token_dict = model_dict['tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
74 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
75 |
76 | with torch.no_grad():
77 | image_cond = torch.zeros([1,1,768]).to(device)
78 | text_cond = model_dict['clip_model'].encode_texts(text_tokens, text_attention_mask)
79 | negative_text_cond = model_dict['clip_model'].encode_texts(negative_text_tokens, negative_text_attention_mask)
80 |
81 | sampling_start = time.time()
82 | mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0)
83 | sampled_image_features = model_dict['compodiff'].sample(image_cond, text_cond, negative_text_cond, mask, timesteps=step, cond_scale=cfg_scale, num_samples_per_batch=2)
84 | sampling_end = time.time()
85 |
86 | sampled_image_features_org = sampled_image_features
87 | sampled_image_features = l2norm(sampled_image_features)
88 | else:
89 | # CIR
90 | image_source = images['image'].resize((512, 512))
91 | mask = images['mask'].resize((512, 512))
92 | mask = model_dict['img_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
93 | mask = mask[:,:1,:,:]
94 |
95 | ## preprocess
96 | image_source = model_dict['img_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(device)
97 |
98 | mask = (mask > 0.5).float().to(device)
99 | image_source = image_source * (1 - mask)
100 |
101 | text_token_dict = model_dict['tokenizer'](text=text, return_tensors='pt', padding='max_length', truncation=True)
102 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
103 |
104 | negative_text_token_dict = model_dict['tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
105 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
106 |
107 | with torch.no_grad():
108 | image_cond = model_dict['clip_model'].encode_images(image_source)
109 |
110 | text_cond = model_dict['clip_model'].encode_texts(text_tokens, text_attention_mask)
111 |
112 | negative_text_cond = model_dict['clip_model'].encode_texts(negative_text_tokens, negative_text_attention_mask)
113 |
114 | sampling_start = time.time()
115 | mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
116 | mask = (mask > 0.5).float()
117 | if torch.sum(mask).item() == 0.0:
118 | mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0)
119 | sampled_image_features = model_dict['compodiff'].sample(image_cond, text_cond, negative_text_cond, mask, timesteps=step, cond_scale=cfg_scale, num_samples_per_batch=2)
120 | sampling_end = time.time()
121 |
122 | sampled_image_features_org = (1 - source_mixing_weight) * sampled_image_features + source_mixing_weight * image_cond[0]
123 | sampled_image_features = l2norm(sampled_image_features_org)
124 |
125 | if do_generate and image_decoder is not None:
126 | images = image_decoder(image_embeddings=sampled_image_features_org.half(), num_images_per_prompt=2).images
127 | else:
128 | images = [Image.fromarray(np.zeros([256,256,3], dtype='uint8'))]
129 |
130 | do_list = [['KNN results', sampled_image_features],
131 | ]
132 |
133 | output = ''
134 | top1_list = []
135 | search_start = time.time()
136 | for name, features in do_list:
137 | results = client.query(embedding_input=features[0].tolist())[:15]
138 | output += f'{name} outputs
\n\n'
139 | for idx, result in enumerate(results):
140 | image_url = result['url']
141 | if idx == 0:
142 | top1_list.append(f'{image_url}')
143 | output += f'\n'
144 |
145 | output += '\n \n\n'
146 |
147 | search_end = time.time()
148 |
149 | return output, images
150 |
151 |
152 | if __name__ == "__main__":
153 | global model_dict, client, image_decoder
154 |
155 | model_dict = load_models()
156 |
157 | if 'cuda' in model_dict['device']:
158 | image_decoder = UnCLIPImageVariationPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float16).to('cuda:0')
159 | else:
160 | image_decoder = None
161 |
162 | client = ClipClient(url="https://knn.laion.ai/knn-service",
163 | indice_name="laion5B-L-14",
164 | )
165 |
166 | ### define gradio demo
167 | title = 'CompoDiff demo'
168 |
169 | md_title = f'''# {title}
170 | Diffusion on {model_dict["device"]}, K-NN Retrieval using https://rom1504.github.io/clip-retrieval.
171 | '''
172 | md_below = f'''### Tips:
173 | Here are some tips for using the demo:
174 | + If you want to apply more of the original image's context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal.
175 | + If you want to exclude specific keywords, you can add them to the Negative text input.
176 | + Try using "generate image with unCLIP" to create images. You can see some interesting generated images that are as fascinating as search results.
177 | + If you only input text and no image, it will work like the prior of unCLIP.
178 | '''
179 |
180 |
181 | with gr.Blocks(title=title) as demo:
182 | gr.Markdown(md_title)
183 | with gr.Row():
184 | with gr.Column():
185 | image_source = gr.Image(type='pil', label='Source image', tool='sketch')
186 | with gr.Row():
187 | steps_input = gr.Radio(['2', '3', '5', '10'], value='10', label='denoising steps')
188 | if model_dict['device'] == 'cpu':
189 | do_generate = gr.Checkbox(value=False, label='generate image with unCLIP', visible=False)
190 | else:
191 | do_generate = gr.Checkbox(value=False, label='generate image with unCLIP', visible=True)
192 | with gr.Accordion('Advanced options', open=False):
193 | with gr.Row():
194 | cfg_image_scale = gr.Number(value=1.5, label='image condition scale')
195 | cfg_text_scale = gr.Number(value=7.5, label='text condition scale')
196 | source_mixing_weight = gr.Number(value=0.1, label='source weight (0.0~1.0)')
197 | text_input = gr.Textbox(value='', label='Input text guidance')
198 | negative_text_input = gr.Textbox(value='', label='Negative text') # low quality, text overlay, logo
199 | submit_button = gr.Button('Submit')
200 | gr.Markdown(md_below)
201 | with gr.Column():
202 | if model_dict['device'] == 'cpu':
203 | gallery = gr.Gallery(label='Generated images', visible=False).style(grid=[2])
204 | else:
205 | gallery = gr.Gallery(label='Generated images', visible=True).style(grid=[2])
206 | md_output = gr.Markdown(label='Output')
207 | submit_button.click(predict, inputs=[image_source, text_input, negative_text_input, steps_input, cfg_image_scale, cfg_text_scale, do_generate, source_mixing_weight], outputs=[md_output, gallery])
208 | demo.launch(server_name='0.0.0.0',
209 | server_port=8000)
210 |
211 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.13.0
2 | torchvision>=0.9
3 | numpy
4 | transformers
5 | diffusers
6 | rotary-embedding-torch
7 | einops
8 | einops_exts
9 | gradio
10 | clip-retrieval
11 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | CompoDiff
3 | Copyright (c) 2023-present NAVER Corp.
4 | Apache-2.0
5 | """
6 | import os
7 | import pkg_resources
8 | from setuptools import setup, find_packages
9 |
10 | setup(
11 | name='compodiff',
12 | version='0.1.1',
13 | description='Easy to use CompoDiff library',
14 | author='NAVER Corp.',
15 | author_email='dl_visionresearch@navercorp.com',
16 | url='https://github.com/navervision/CompoDiff',
17 | install_requires=[
18 | str(r)
19 | for r in pkg_resources.parse_requirements(
20 | open(os.path.join(os.path.dirname(__file__), 'requirements.txt'))
21 | )
22 | ],
23 | packages=find_packages(),
24 | )
25 |
--------------------------------------------------------------------------------