├── .github ├── example_cir.gif ├── example_cir_with_mask.gif ├── example_t2i.gif └── teaser.png ├── LICENSE ├── NOTICE ├── README.md ├── compodiff ├── __init__.py ├── model_loader.py └── models.py ├── demo_search.py ├── requirements.txt └── setup.py /.github/example_cir.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/example_cir.gif -------------------------------------------------------------------------------- /.github/example_cir_with_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/example_cir_with_mask.gif -------------------------------------------------------------------------------- /.github/example_t2i.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/example_t2i.gif -------------------------------------------------------------------------------- /.github/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/navervision/CompoDiff/860147c263f78e687e61748c228e8325b912b176/.github/teaser.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023-present NAVER Corp. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | CompoDiff 2 | Copyright 2023-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | -------------------------------------------------------------------------------------- 17 | 18 | This project contains subcomponents with separate copyright notices and license terms. 19 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 20 | 21 | ===== 22 | 23 | lucidrains/DALLE2-pytorch 24 | https://github.com/lucidrains/DALLE2-pytorch 25 | 26 | 27 | MIT License 28 | 29 | Copyright (c) 2021 Phil Wang 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. 48 | 49 | ===== 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # \[TMLR\] CompoDiff: Versatile Composed Image Retrieval With Latent Diffusion 2 | 3 | 4 | 5 | Official Pytorch implementation of CompoDiff 6 | 7 | [\[Paper\]](https://arxiv.org/abs/2303.11916) [\[OpenReview\]](https://openreview.net/forum?id=mKtlzW0bWc) [\[Demo🤗\]](https://huggingface.co/spaces/navervision/CompoDiff-Aesthetic) 8 | 9 | **[Geonmo Gu](https://geonm.github.io/)\*1, [Sanghyuk Chun](https://sanghyukchun.github.io/home/)\*2, [Wonjae Kim](https://wonjae.kim)2, HeeJae Jun1, Yoohoon Kang1, [Sangdoo Yun](https://sangdooyun.github.io)2** 10 | 11 | 1 NAVER Vision 2 NAVER AI Lab 12 | 13 | \* First two authors contributed equally. 14 | 15 | ## ⭐ Overview 16 | 17 | CompoDiff is a model that utilizes diffusion models for Composed Image Retrieval (CIR) for the first time. 18 | 19 | The use of diffusion models has enabled the introduction of negative text and inpainting into CIR for the first time. 20 | 21 | Moreover, since it operates based on the image feature space of the CLIP-L/14 model, it can be used with various text-to-image models (such as [Graphit](https://huggingface.co/navervision/Graphit-SD), [SD-unCLIP](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip), and [unCLIP (Karlo)](https://huggingface.co/kakaobrain/karlo-v1-alpha-image-variations)). 22 | 23 | We overcame the challenge of collecting training data for CIR by generating a new SynthTriplets18M dataset. We are eager to release it soon and hope it will benefit the CIR research community. 24 | 25 | We believe that CompoDiff has great potential to contribute to the advancement of CIR and look forward to seeing how it performs in the field. 26 | 27 | ## 🚀 NEWS 28 | 29 | - 12/14/2023 - SynthTriplets18M released! 30 | - 04/28/2023 - CompoDiff-Aesthetic released! 31 | 32 | ## 🖼️ CompoDiff-Aesthetic 33 | 34 | In this repository, we first release CompoDiff-Aesthetic model, which is one of the variations of CompoDiff. 35 | 36 | For text guidance, we used the textual embeddings of CLIP-G/14 trained by [open_clip](https://github.com/mlfoundations/open_clip), while the generated output feature is still the visual embeddings of CLIP-L/14. 37 | 38 | We found that the textual embedding extractor of CLIP-G/14 not only performs better than those of CLIP-L/14 but also does not significantly increase computational cost. 39 | 40 | CompoDiff-Aesthetic was trained on high-quality and high-resolution images, and we believe it has the potential to benefit not only CIR but also various generation models. 41 | 42 | We hope that CompoDiff-Aesthetic will be a useful addition to the research community and look forward to seeing how it performs in various applications. 43 | 44 | ## 📚 SynthTriplets18M 45 | 46 | Using text to image diffusion models, we have created a large-scale synthesized triplet dataset. 47 | 48 | https://huggingface.co/datasets/navervision/SynthTriplets18M 49 | 50 | ## Search and Image generation demo 51 | We have set up a demo that can be tested in a local computing environment. 52 | 53 | It can be executed with the following command: 54 | 55 | ```bash 56 | $ git clone https://github.com/navervision/compodiff 57 | $ cd compodiff 58 | $ python demo_search.py 59 | ``` 60 | 61 | Demo will be hosted at https://0.0.0.0:8000 62 | 63 | The unCLIP model used for image generation is from https://huggingface.co/kakaobrain/karlo-v1-alpha-image-variations. 64 | 65 | ### How to use demo 66 | #### Usage 1. Project textual embeddings to visual embeddings 67 | 68 | 69 | #### Usage 2. Composed visual embeddings without mask for CIR 70 | 71 | 72 | #### Usage 3. Composed visual embeddings with mask for CIR 73 | 74 | 75 | ## 💡 Usage 76 | 77 | ### Install CompoDiff 78 | ``` 79 | $ pip install git+https://github.com/navervision/compodiff.git 80 | ``` 81 | 82 | ### Build CompoDiff and CLIP models 83 | ```python 84 | import compodiff 85 | import torch 86 | from PIL import Image 87 | import requests 88 | 89 | device = "cuda" if torch.cuda.is_available() else "cpu" 90 | 91 | # build models 92 | compodiff_model, clip_model, img_preprocess, tokenizer = compodiff.build_model() 93 | 94 | compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device) 95 | 96 | if device != 'cpu': 97 | clip_model = clip_model.half() 98 | ``` 99 | 100 | ### Usage 1. Project textual embeddings to visual embeddings 101 | ```python 102 | cfg_image_scale = 0.0 103 | cfg_text_scale = 7.5 104 | 105 | cfg_scale = (cfg_image_scale, cfg_text_scale) 106 | 107 | input_text = "owl carved on the wooden wall" 108 | negative_text = "low quality" 109 | 110 | # tokenize the input_text first. 111 | text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True) 112 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 113 | 114 | negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True) 115 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 116 | 117 | with torch.no_grad(): 118 | # In the case of Usage 1, we do not use an image cond and a mask at all. 119 | image_cond = torch.zeros([1,1,768]).to(device) 120 | mask = torch.zeros([64, 64]).to(device).unsqueeze(0) 121 | 122 | text_cond = clip_model.encode_texts(text_tokens, text_attention_mask) 123 | negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask) 124 | 125 | # do denoising steps here 126 | timesteps = 10 127 | sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=10, cond_scale=cfg_scale, num_samples_per_batch=2) 128 | # NOTE: "sampled_image_features" is not L2-normalized 129 | ``` 130 | 131 | ### Usage 2. Composed visual embeddings without mask for CIR 132 | ```python 133 | cfg_image_scale = 1.5 134 | cfg_text_scale = 7.5 135 | 136 | cfg_scale = (cfg_image_scale, cfg_text_scale) 137 | 138 | input_text = "as pencil sketch" 139 | negative_text = "low quality" 140 | 141 | # tokenize the input_text first. 142 | text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True) 143 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 144 | 145 | negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True) 146 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 147 | 148 | # prepare a reference image 149 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 150 | image = Image.open(requests.get(url, stream=True).raw).resize((512, 512)) 151 | 152 | with torch.no_grad(): 153 | processed_image = img_preprocess(image, return_tensors='pt')['pixel_values'].to(device) 154 | 155 | # In the case of Usage 2, we do not use a mask at all. 156 | mask = torch.zeros([64, 64]).to(device).unsqueeze(0) 157 | 158 | image_cond = clip_model.encode_images(processed_image) 159 | 160 | text_cond = clip_model.encode_texts(text_tokens, text_attention_mask) 161 | negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask) 162 | 163 | timesteps = 10 164 | sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=timesteps, cond_scale=cfg_scale, num_samples_per_batch=2) 165 | 166 | # NOTE: If you want to apply more of the original image’s context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal. 167 | source_weight = 0.1 168 | sampled_image_features = (1 - source_weight) * sampled_image_features + source_weight * image_cond[0] 169 | ``` 170 | 171 | ### Usage 3. Composed visual embeddings with mask for CIR 172 | ```python 173 | cfg_image_scale = 1.5 174 | cfg_text_scale = 7.5 175 | 176 | cfg_scale = (cfg_image_scale, cfg_text_scale) 177 | 178 | input_text = "as pencil sketch" 179 | negative_text = "low quality" 180 | 181 | # tokenize the input_text first. 182 | text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True) 183 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 184 | 185 | negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True) 186 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 187 | 188 | # prepare a reference image 189 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 190 | image = Image.open(requests.get(url, stream=True).raw).resize((512, 512)) 191 | 192 | # prepare a mask image 193 | url = "mask_url" 194 | mask = Image.open(requests.get(url, stream=True).raw).resize((512, 512)) 195 | 196 | with torch.no_grad(): 197 | processed_image = img_preprocess(image, return_tensors='pt')['pixel_values'].to(device) 198 | processed_mask = img_preprocess(mask, do_normalize=False, return_tensors='pt')['pixel_values'].to(device) 199 | processed_mask = processed_mask[:,:1,:,:] 200 | 201 | masked_processed_image = processed_image * (1 - (processed_mask > 0.5).float()) 202 | mask = transforms.Resize([64, 64])(mask)[:,0,:,:] 203 | mask = (mask > 0.5).float() 204 | 205 | image_cond = clip_model.encode_images(masked_processed_image) 206 | 207 | text_cond = clip_model.encode_texts(text_tokens, text_attention_mask) 208 | negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask) 209 | 210 | timesteps = 10 211 | sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=timesteps, cond_scale=cfg_scale, num_samples_per_batch=2) 212 | 213 | # NOTE: If you want to apply more of the original image’s context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal. 214 | source_weight = 0.05 215 | sampled_image_features = (1 - source_weight) * sampled_image_features + source_weight * image_cond[0] 216 | ``` 217 | 218 | ### Shotout 219 | K-NN index for the retrieval results are entirely trained using the entire Laion-5B imageset. For this retrieval you do not need to download any images, this is made possible thanks to the great work of [rom1504](https://github.com/rom1504/clip-retrieval). 220 | 221 | ## Citing CompoDiff 222 | If you find this repository useful, please consider giving a start ⭐ and citation: 223 | ``` 224 | @article{gu2024compodiff, 225 | title={CompoDiff: Versatile Composed Image Retrieval With Latent Diffusion}, 226 | author={Geonmo Gu and Sanghyuk Chun and Wonjae Kim and HeeJae Jun and Yoohoon Kang and Sangdoo Yun}, 227 | journal={Transactions on Machine Learning Research}, 228 | issn={2835-8856}, 229 | year={2024}, 230 | url={https://openreview.net/forum?id=mKtlzW0bWc}, 231 | note={Expert Certification} 232 | } 233 | ``` 234 | 235 | ## License 236 | ``` 237 | CompoDiff 238 | Copyright 2023-present NAVER Corp. 239 | 240 | Licensed under the Apache License, Version 2.0 (the "License"); 241 | you may not use this file except in compliance with the License. 242 | You may obtain a copy of the License at 243 | 244 | http://www.apache.org/licenses/LICENSE-2.0 245 | 246 | Unless required by applicable law or agreed to in writing, software 247 | distributed under the License is distributed on an "AS IS" BASIS, 248 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 249 | See the License for the specific language governing permissions and 250 | limitations under the License. 251 | ``` 252 | -------------------------------------------------------------------------------- /compodiff/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | CompoDiff 3 | Copyright (c) 2023-present NAVER Corp. 4 | Apache-2.0 5 | """ 6 | from .model_loader import * 7 | -------------------------------------------------------------------------------- /compodiff/model_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | CompoDiff 3 | Copyright (c) 2023-present NAVER Corp. 4 | Apache-2.0 5 | """ 6 | import torch 7 | from transformers import PreTrainedModel, PretrainedConfig, CLIPTokenizer, CLIPImageProcessor 8 | try: 9 | from .models import build_compodiff, build_clip 10 | except: 11 | from models import build_compodiff, build_clip 12 | 13 | 14 | class CompoDiffConfig(PretrainedConfig): 15 | model_type = "CompoDiff" 16 | 17 | def __init__( 18 | self, 19 | embed_dim: int = 768, 20 | model_depth: int = 12, 21 | model_dim: int = 64, 22 | model_heads: int = 16, 23 | timesteps: int = 1000, 24 | **kwargs, 25 | ): 26 | self.embed_dim = embed_dim 27 | self.model_depth = model_depth 28 | self.model_dim = model_dim 29 | self.model_heads = model_heads 30 | self.timesteps = timesteps 31 | super().__init__(**kwargs) 32 | 33 | 34 | class CompoDiffModel(PreTrainedModel): 35 | config_class = CompoDiffConfig 36 | 37 | def __init__(self, config): 38 | super().__init__(config) 39 | self.model = build_compodiff( 40 | config.embed_dim, 41 | config.model_depth, 42 | config.model_dim, 43 | config.model_heads, 44 | config.timesteps, 45 | ) 46 | 47 | def _init_weights(self, module): 48 | pass 49 | 50 | def sample(self, image_cond, text_cond, negative_text_cond, input_mask, num_samples_per_batch=4, cond_scale=1., timesteps=None, random_seed=None): 51 | return self.model.sample(image_cond, text_cond, negative_text_cond, input_mask, num_samples_per_batch, cond_scale, timesteps, random_seed) 52 | 53 | 54 | def build_model(model_name='navervision/CompoDiff-Aesthetic'): 55 | tokenizer = CLIPTokenizer.from_pretrained('laion/CLIP-ViT-bigG-14-laion2B-39B-b160k') 56 | 57 | size_cond = {'shortest_edge': 224} 58 | preprocess = CLIPImageProcessor(crop_size={'height': 224, 'width': 224}, 59 | do_center_crop=True, 60 | do_convert_rgb=True, 61 | do_normalize=True, 62 | do_rescale=True, 63 | do_resize=True, 64 | image_mean=[0.48145466, 0.4578275, 0.40821073], 65 | image_std=[0.26862954, 0.26130258, 0.27577711], 66 | resample=3, 67 | size=size_cond, 68 | ) 69 | compodiff = CompoDiffModel.from_pretrained(model_name) 70 | 71 | clip_model = build_clip() 72 | 73 | return compodiff, clip_model, preprocess, tokenizer 74 | 75 | 76 | if __name__ == '__main__': 77 | #''' # convert CompoDiff 78 | compodiff_config = CompoDiffConfig() 79 | 80 | compodiff = CompoDiffModel(compodiff_config) 81 | compodiff.model.load_state_dict(torch.load('/data/data_zoo/logs/stage2_arch.depth12-heads16_lr1e-4_text-bigG_add-art-datasets/checkpoints/model_000710000.pt')['ema_model']) 82 | compodiff_config.save_pretrained('/data/CompoDiff_HF') 83 | compodiff.save_pretrained('/data/CompoDiff_HF') 84 | #''' 85 | #compodiff, clip_model, preprocess_img, tokenizer = build_model() 86 | -------------------------------------------------------------------------------- /compodiff/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | CompoDiff 3 | Copyright (c) 2023-present NAVER Corp. 4 | Apache-2.0 5 | """ 6 | import math 7 | import random 8 | from tqdm.auto import tqdm 9 | from functools import partial, wraps 10 | from contextlib import contextmanager 11 | from collections import namedtuple 12 | from pathlib import Path 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.utils.checkpoint import checkpoint 17 | from torch import nn, einsum 18 | import torchvision.transforms as T 19 | 20 | import einops 21 | from einops import rearrange, repeat, reduce 22 | from einops.layers.torch import Rearrange 23 | from einops_exts import rearrange_many, repeat_many, check_shape 24 | from einops_exts.torch import EinopsToAndFrom 25 | 26 | # rotary embeddings 27 | from rotary_embedding_torch import RotaryEmbedding 28 | 29 | from transformers import CLIPTextModel, CLIPVisionModelWithProjection, CLIPImageProcessor 30 | 31 | 32 | class CLIPHF(torch.nn.Module): 33 | def __init__(self, image_model_name = 'openai/clip-vit-large-patch14', text_model_name = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'): 34 | super().__init__() 35 | self.clip_text_model = CLIPTextModel.from_pretrained(text_model_name, torch_dtype=torch.float16)#.to(device).eval() 36 | self.clip_text_model = self.clip_text_model.to(torch.float32) 37 | self.clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(image_model_name, torch_dtype=torch.float16)#.to(device).eval() 38 | self.clip_vision_model = self.clip_vision_model.to(torch.float32) 39 | 40 | def encode_images(self, input_images): 41 | vision_outputs = self.clip_vision_model(pixel_values=input_images.to(self.clip_vision_model.dtype)) 42 | return vision_outputs.image_embeds.unsqueeze(1).float() 43 | 44 | def encode_texts(self, tokens, attention_masks): 45 | text_outputs = self.clip_text_model(input_ids=tokens, attention_mask=attention_masks) 46 | return text_outputs.last_hidden_state.float() 47 | 48 | def forward(self, input_images, tokens, attention_masks): 49 | return self.encode_images(input_images), self.encode_texts(tokens, attention_masks) 50 | 51 | 52 | def exists(val): 53 | return val is not None 54 | 55 | def identity(t, *args, **kwargs): 56 | return t 57 | 58 | def first(arr, d = None): 59 | if len(arr) == 0: 60 | return d 61 | return arr[0] 62 | 63 | def maybe(fn): 64 | @wraps(fn) 65 | def inner(x, *args, **kwargs): 66 | if not exists(x): 67 | return x 68 | return fn(x, *args, **kwargs) 69 | return inner 70 | 71 | def default(val, d): 72 | if exists(val): 73 | return val 74 | return d() if callable(d) else d 75 | 76 | def cast_tuple(val, length = None, validate = True): 77 | if isinstance(val, list): 78 | val = tuple(val) 79 | 80 | out = val if isinstance(val, tuple) else ((val,) * default(length, 1)) 81 | 82 | if exists(length) and validate: 83 | assert len(out) == length 84 | 85 | return out 86 | 87 | def module_device(module): 88 | if isinstance(module, nn.Identity): 89 | return 'cpu' # It doesn't matter 90 | return next(module.parameters()).device 91 | 92 | def zero_init_(m): 93 | nn.init.zeros_(m.weight) 94 | if exists(m.bias): 95 | nn.init.zeros_(m.bias) 96 | 97 | @contextmanager 98 | def null_context(*args, **kwargs): 99 | yield 100 | 101 | def eval_decorator(fn): 102 | def inner(model, *args, **kwargs): 103 | was_training = model.training 104 | model.eval() 105 | out = fn(model, *args, **kwargs) 106 | model.train(was_training) 107 | return out 108 | return inner 109 | 110 | def is_float_dtype(dtype): 111 | return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) 112 | 113 | def is_list_str(x): 114 | if not isinstance(x, (list, tuple)): 115 | return False 116 | return all([type(el) == str for el in x]) 117 | 118 | def pad_tuple_to_length(t, length, fillvalue = None): 119 | remain_length = length - len(t) 120 | if remain_length <= 0: 121 | return t 122 | return (*t, *((fillvalue,) * remain_length)) 123 | 124 | # checkpointing helper function 125 | 126 | def make_checkpointable(fn, **kwargs): 127 | if isinstance(fn, nn.ModuleList): 128 | return [maybe(make_checkpointable)(el, **kwargs) for el in fn] 129 | 130 | condition = kwargs.pop('condition', None) 131 | 132 | if exists(condition) and not condition(fn): 133 | return fn 134 | 135 | @wraps(fn) 136 | def inner(*args): 137 | input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args]) 138 | 139 | if not input_needs_grad: 140 | return fn(*args) 141 | 142 | return checkpoint(fn, *args) 143 | 144 | return inner 145 | 146 | # for controlling freezing of CLIP 147 | 148 | def set_module_requires_grad_(module, requires_grad): 149 | for param in module.parameters(): 150 | param.requires_grad = requires_grad 151 | 152 | def freeze_all_layers_(module): 153 | set_module_requires_grad_(module, False) 154 | 155 | def unfreeze_all_layers_(module): 156 | set_module_requires_grad_(module, True) 157 | 158 | def freeze_model_and_make_eval_(model): 159 | model.eval() 160 | freeze_all_layers_(model) 161 | 162 | # tensor helpers 163 | 164 | def log(t, eps = 1e-12): 165 | return torch.log(t.clamp(min = eps)) 166 | 167 | def l2norm(t): 168 | return F.normalize(t, dim = -1) 169 | 170 | # image normalization functions 171 | # ddpms expect images to be in the range of -1 to 1 172 | # but CLIP may otherwise 173 | 174 | def normalize_neg_one_to_one(img): 175 | return img * 2 - 1 176 | 177 | def unnormalize_zero_to_one(normed_img): 178 | return (normed_img + 1) * 0.5 179 | 180 | # gaussian diffusion helper functions 181 | 182 | def extract(a, t, x_shape): 183 | b, *_ = t.shape 184 | out = a.gather(-1, t) 185 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 186 | 187 | def meanflat(x): 188 | return x.mean(dim = tuple(range(1, len(x.shape)))) 189 | 190 | def normal_kl(mean1, logvar1, mean2, logvar2): 191 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)) 192 | 193 | def approx_standard_normal_cdf(x): 194 | return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3)))) 195 | 196 | def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): 197 | assert x.shape == means.shape == log_scales.shape 198 | 199 | # attempting to correct nan gradients when learned variance is turned on 200 | # in the setting of deepspeed fp16 201 | eps = 1e-12 if x.dtype == torch.float32 else 1e-3 202 | 203 | centered_x = x - means 204 | inv_stdv = torch.exp(-log_scales) 205 | plus_in = inv_stdv * (centered_x + 1. / 255.) 206 | cdf_plus = approx_standard_normal_cdf(plus_in) 207 | min_in = inv_stdv * (centered_x - 1. / 255.) 208 | cdf_min = approx_standard_normal_cdf(min_in) 209 | log_cdf_plus = log(cdf_plus, eps = eps) 210 | log_one_minus_cdf_min = log(1. - cdf_min, eps = eps) 211 | cdf_delta = cdf_plus - cdf_min 212 | 213 | log_probs = torch.where(x < -thres, 214 | log_cdf_plus, 215 | torch.where(x > thres, 216 | log_one_minus_cdf_min, 217 | log(cdf_delta, eps = eps))) 218 | 219 | return log_probs 220 | 221 | def cosine_beta_schedule(timesteps, s = 0.008): 222 | """ 223 | cosine schedule 224 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 225 | """ 226 | steps = timesteps + 1 227 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 228 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 229 | alphas_cumprod = alphas_cumprod / first(alphas_cumprod) 230 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 231 | return torch.clip(betas, 0, 0.999) 232 | 233 | 234 | def linear_beta_schedule(timesteps): 235 | scale = 1000 / timesteps 236 | beta_start = scale * 0.0001 237 | beta_end = scale * 0.02 238 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 239 | 240 | 241 | def quadratic_beta_schedule(timesteps): 242 | scale = 1000 / timesteps 243 | beta_start = scale * 0.0001 244 | beta_end = scale * 0.02 245 | return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2 246 | 247 | 248 | def sigmoid_beta_schedule(timesteps): 249 | scale = 1000 / timesteps 250 | beta_start = scale * 0.0001 251 | beta_end = scale * 0.02 252 | betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64) 253 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 254 | 255 | 256 | class NoiseScheduler(nn.Module): 257 | def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1): 258 | super().__init__() 259 | 260 | if beta_schedule == "cosine": 261 | betas = cosine_beta_schedule(timesteps) 262 | elif beta_schedule == "linear": 263 | betas = linear_beta_schedule(timesteps) 264 | elif beta_schedule == "quadratic": 265 | betas = quadratic_beta_schedule(timesteps) 266 | elif beta_schedule == "jsd": 267 | betas = 1.0 / torch.linspace(timesteps, 1, timesteps) 268 | elif beta_schedule == "sigmoid": 269 | betas = sigmoid_beta_schedule(timesteps) 270 | else: 271 | raise NotImplementedError() 272 | 273 | alphas = 1. - betas 274 | alphas_cumprod = torch.cumprod(alphas, axis = 0) 275 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 276 | 277 | timesteps, = betas.shape 278 | self.num_timesteps = int(timesteps) 279 | 280 | if loss_type == 'l1': 281 | loss_fn = F.l1_loss 282 | elif loss_type == 'l2': 283 | loss_fn = F.mse_loss 284 | elif loss_type == 'huber': 285 | loss_fn = F.smooth_l1_loss 286 | else: 287 | raise NotImplementedError() 288 | 289 | self.loss_type = loss_type 290 | self.loss_fn = loss_fn 291 | 292 | # register buffer helper function to cast double back to float 293 | 294 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 295 | 296 | register_buffer('betas', betas) 297 | register_buffer('alphas_cumprod', alphas_cumprod) 298 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 299 | 300 | # calculations for diffusion q(x_t | x_{t-1}) and others 301 | 302 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 303 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 304 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 305 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 306 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 307 | 308 | # calculations for posterior q(x_{t-1} | x_t, x_0) 309 | 310 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 311 | 312 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 313 | 314 | register_buffer('posterior_variance', posterior_variance) 315 | 316 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 317 | 318 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 319 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 320 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 321 | 322 | # p2 loss reweighting 323 | 324 | self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0. 325 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) 326 | 327 | def sample_random_times(self, batch): 328 | return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long) 329 | 330 | def q_posterior(self, x_start, x_t, t): 331 | posterior_mean = ( 332 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 333 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 334 | ) 335 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 336 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 337 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 338 | 339 | def q_sample(self, x_start, t, noise = None): 340 | noise = default(noise, lambda: torch.randn_like(x_start)) 341 | 342 | return ( 343 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 344 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 345 | ) 346 | 347 | def calculate_v(self, x_start, t, noise = None): 348 | return ( 349 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - 350 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start 351 | ) 352 | 353 | def q_sample_from_to(self, x_from, from_t, to_t, noise = None): 354 | shape = x_from.shape 355 | noise = default(noise, lambda: torch.randn_like(x_from)) 356 | 357 | alpha = extract(self.sqrt_alphas_cumprod, from_t, shape) 358 | sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape) 359 | alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape) 360 | sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape) 361 | 362 | return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha 363 | 364 | def predict_start_from_v(self, x_t, t, v): 365 | return ( 366 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 367 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 368 | ) 369 | 370 | def predict_start_from_noise(self, x_t, t, noise): 371 | return ( 372 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 373 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 374 | ) 375 | 376 | def predict_noise_from_start(self, x_t, t, x0): 377 | return ( 378 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ 379 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 380 | ) 381 | 382 | def p2_reweigh_loss(self, loss, times): 383 | if not self.has_p2_loss_reweighting: 384 | return loss 385 | return loss * extract(self.p2_loss_weight, times, loss.shape) 386 | 387 | # diffusion prior 388 | 389 | class LayerNorm(nn.Module): 390 | def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False): 391 | super().__init__() 392 | self.eps = eps 393 | self.fp16_eps = fp16_eps 394 | self.stable = stable 395 | self.g = nn.Parameter(torch.ones(dim)) 396 | 397 | def forward(self, x): 398 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps 399 | 400 | if self.stable: 401 | x = x / x.amax(dim = -1, keepdim = True).detach() 402 | 403 | var = torch.var(x, dim = -1, unbiased = False, keepdim = True) 404 | mean = torch.mean(x, dim = -1, keepdim = True) 405 | return (x - mean) * (var + eps).rsqrt() * self.g 406 | 407 | class ChanLayerNorm(nn.Module): 408 | def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False): 409 | super().__init__() 410 | self.eps = eps 411 | self.fp16_eps = fp16_eps 412 | self.stable = stable 413 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 414 | 415 | def forward(self, x): 416 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps 417 | 418 | if self.stable: 419 | x = x / x.amax(dim = 1, keepdim = True).detach() 420 | 421 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 422 | mean = torch.mean(x, dim = 1, keepdim = True) 423 | return (x - mean) * (var + eps).rsqrt() * self.g 424 | 425 | class Residual(nn.Module): 426 | def __init__(self, fn): 427 | super().__init__() 428 | self.fn = fn 429 | 430 | def forward(self, x, **kwargs): 431 | return self.fn(x, **kwargs) + x 432 | 433 | # mlp 434 | 435 | class MLP(nn.Module): 436 | def __init__( 437 | self, 438 | dim_in, 439 | dim_out, 440 | *, 441 | expansion_factor = 2., 442 | depth = 2, 443 | norm = False, 444 | ): 445 | super().__init__() 446 | hidden_dim = int(expansion_factor * dim_out) 447 | norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity() 448 | 449 | layers = [nn.Sequential( 450 | nn.Linear(dim_in, hidden_dim), 451 | nn.SiLU(), 452 | norm_fn() 453 | )] 454 | 455 | for _ in range(depth - 1): 456 | layers.append(nn.Sequential( 457 | nn.Linear(hidden_dim, hidden_dim), 458 | nn.SiLU(), 459 | norm_fn() 460 | )) 461 | 462 | layers.append(nn.Linear(hidden_dim, dim_out)) 463 | self.net = nn.Sequential(*layers) 464 | 465 | def forward(self, x): 466 | return self.net(x.float()) 467 | 468 | 469 | class SinusoidalPosEmb(nn.Module): 470 | def __init__(self, dim): 471 | super().__init__() 472 | self.dim = dim 473 | 474 | def forward(self, x): 475 | dtype, device = x.dtype, x.device 476 | assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type' 477 | 478 | half_dim = self.dim // 2 479 | emb = math.log(10000) / (half_dim - 1) 480 | emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) 481 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') 482 | return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype) 483 | 484 | # relative positional bias for causal transformer 485 | 486 | class RelPosBias(nn.Module): 487 | def __init__( 488 | self, 489 | heads = 8, 490 | num_buckets = 32, 491 | max_distance = 128, 492 | ): 493 | super().__init__() 494 | self.num_buckets = num_buckets 495 | self.max_distance = max_distance 496 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 497 | 498 | @staticmethod 499 | def _relative_position_bucket( 500 | relative_position, 501 | num_buckets = 32, 502 | max_distance = 128 503 | ): 504 | n = -relative_position 505 | n = torch.max(n, torch.zeros_like(n)) 506 | 507 | max_exact = num_buckets // 2 508 | is_small = n < max_exact 509 | 510 | val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() 511 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 512 | return torch.where(is_small, n, val_if_large) 513 | 514 | def forward(self, i, j, *, device): 515 | q_pos = torch.arange(i, dtype = torch.long, device = device) 516 | k_pos = torch.arange(j, dtype = torch.long, device = device) 517 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') 518 | rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) 519 | values = self.relative_attention_bias(rp_bucket) 520 | return rearrange(values, 'i j h -> h i j') 521 | 522 | # feedforward 523 | 524 | class SwiGLU(nn.Module): 525 | """ used successfully in https://arxiv.org/abs/2204.0231 """ 526 | def forward(self, x): 527 | x, gate = x.chunk(2, dim = -1) 528 | return x * F.silu(gate) 529 | 530 | def FeedForward( 531 | dim, 532 | mult = 4, 533 | dropout = 0., 534 | post_activation_norm = False 535 | ): 536 | """ post-activation norm https://arxiv.org/abs/2110.09456 """ 537 | 538 | inner_dim = int(mult * dim) 539 | return nn.Sequential( 540 | LayerNorm(dim), 541 | nn.Linear(dim, inner_dim * 2, bias = False), 542 | SwiGLU(), 543 | LayerNorm(inner_dim) if post_activation_norm else nn.Identity(), 544 | nn.Dropout(dropout), 545 | nn.Linear(inner_dim, dim, bias = False) 546 | ) 547 | 548 | # attention 549 | 550 | class Attention(nn.Module): 551 | def __init__( 552 | self, 553 | dim, 554 | *, 555 | dim_head = 64, 556 | heads = 8, 557 | dropout = 0., 558 | causal = False, 559 | rotary_emb = None, 560 | cosine_sim = True, 561 | cosine_sim_scale = 16 562 | ): 563 | super().__init__() 564 | self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5) 565 | self.cosine_sim = cosine_sim 566 | 567 | self.heads = heads 568 | inner_dim = dim_head * heads 569 | 570 | self.causal = causal 571 | self.norm = LayerNorm(dim) 572 | self.dropout = nn.Dropout(dropout) 573 | 574 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 575 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 576 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 577 | 578 | self.rotary_emb = rotary_emb 579 | 580 | self.to_out = nn.Sequential( 581 | nn.Linear(inner_dim, dim, bias = False), 582 | LayerNorm(dim) 583 | ) 584 | 585 | def forward(self, x, mask = None, attn_bias = None): 586 | b, n, device = *x.shape[:2], x.device 587 | 588 | x = self.norm(x) 589 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 590 | 591 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) 592 | q = q * self.scale 593 | 594 | # rotary embeddings 595 | 596 | if exists(self.rotary_emb): 597 | q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k)) 598 | 599 | # add null key / value for classifier free guidance in prior net 600 | 601 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) 602 | k = torch.cat((nk, k), dim = -2) 603 | v = torch.cat((nv, v), dim = -2) 604 | 605 | # whether to use cosine sim 606 | 607 | if self.cosine_sim: 608 | q, k = map(l2norm, (q, k)) 609 | 610 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k)) 611 | 612 | # calculate query / key similarities 613 | 614 | sim = einsum('b h i d, b j d -> b h i j', q, k) 615 | 616 | # relative positional encoding (T5 style) 617 | 618 | if exists(attn_bias): 619 | sim = sim + attn_bias 620 | 621 | # masking 622 | 623 | max_neg_value = -torch.finfo(sim.dtype).max 624 | 625 | if exists(mask): 626 | mask = F.pad(mask, (1, 0), value = True) 627 | mask = rearrange(mask, 'b j -> b 1 1 j') 628 | sim = sim.masked_fill(~mask, max_neg_value) 629 | 630 | if self.causal: 631 | i, j = sim.shape[-2:] 632 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) 633 | sim = sim.masked_fill(causal_mask, max_neg_value) 634 | 635 | # attention 636 | 637 | attn = sim.softmax(dim = -1, dtype = torch.float32) 638 | attn = attn.type(sim.dtype) 639 | 640 | attn = self.dropout(attn) 641 | 642 | # aggregate values 643 | 644 | out = einsum('b h i j, b j d -> b h i d', attn, v) 645 | 646 | out = rearrange(out, 'b h n d -> b n (h d)') 647 | return self.to_out(out) 648 | 649 | class CrossAttention(nn.Module): 650 | def __init__( 651 | self, 652 | dim, 653 | *, 654 | context_dim = None, 655 | dim_head = 64, 656 | heads = 8, 657 | dropout = 0., 658 | norm_context = False, 659 | cosine_sim = False, 660 | cosine_sim_scale = 16 661 | ): 662 | super().__init__() 663 | self.cosine_sim = cosine_sim 664 | self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5) 665 | self.heads = heads 666 | inner_dim = dim_head * heads 667 | 668 | context_dim = default(context_dim, dim) 669 | 670 | self.norm = LayerNorm(dim) 671 | self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity() 672 | self.dropout = nn.Dropout(dropout) 673 | 674 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 675 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 676 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) 677 | 678 | self.to_out = nn.Sequential( 679 | nn.Linear(inner_dim, dim, bias = False), 680 | LayerNorm(dim) 681 | ) 682 | 683 | def forward(self, x, context, mask = None): 684 | b, n, device = *x.shape[:2], x.device 685 | 686 | x = self.norm(x) 687 | context = self.norm_context(context) 688 | 689 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 690 | 691 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) 692 | 693 | # add null key / value for classifier free guidance in prior net 694 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) 695 | 696 | k = torch.cat((nk, k), dim = -2) 697 | v = torch.cat((nv, v), dim = -2) 698 | 699 | if self.cosine_sim: 700 | q, k = map(l2norm, (q, k)) 701 | 702 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k)) 703 | 704 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 705 | max_neg_value = -torch.finfo(sim.dtype).max 706 | 707 | if exists(mask): 708 | mask = F.pad(mask, (1, 0), value = True) 709 | mask = rearrange(mask, 'b j -> b 1 1 j') 710 | sim = sim.masked_fill(~mask, max_neg_value) 711 | 712 | attn = sim.softmax(dim = -1, dtype = torch.float32) 713 | attn = attn.type(sim.dtype) 714 | 715 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 716 | out = rearrange(out, 'b h n d -> b n (h d)') 717 | return self.to_out(out) 718 | 719 | 720 | class CrossTransformer(nn.Module): 721 | def __init__( 722 | self, 723 | *, 724 | dim, 725 | depth, 726 | dim_head = 64, 727 | heads = 8, 728 | ff_mult = 4, 729 | norm_in = False, 730 | norm_out = True, 731 | attn_dropout = 0., 732 | ff_dropout = 0., 733 | final_proj = True, 734 | normformer = False, 735 | rotary_emb = True, 736 | causal = False, 737 | context_dim = None, 738 | timesteps = None, 739 | ): 740 | super().__init__() 741 | self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM 742 | 743 | self.rel_pos_bias = RelPosBias(heads = heads) 744 | 745 | rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None 746 | 747 | self.layers = nn.ModuleList([]) 748 | for _ in range(depth): 749 | self.layers.append(nn.ModuleList([ 750 | Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb), 751 | CrossAttention(dim = dim, context_dim = context_dim, dim_head = dim_head, dropout = attn_dropout), 752 | FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) 753 | ])) 754 | 755 | self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options 756 | self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity() 757 | 758 | def forward(self, x, context, mask = None): 759 | n, device = x.shape[1], x.device 760 | 761 | x = self.init_norm(x) 762 | 763 | attn_bias = self.rel_pos_bias(n, n + 1, device = device) 764 | 765 | for attn, cross_attn, ff in self.layers: 766 | x = attn(x, attn_bias = attn_bias) + x 767 | x = cross_attn(x, context, mask) + x 768 | x = ff(x) + x 769 | 770 | out = self.norm(x) 771 | return self.project_out(out) 772 | 773 | 774 | class CompoDiffNetwork(nn.Module): 775 | def __init__( 776 | self, 777 | dim, 778 | num_timesteps = None, 779 | max_text_len = 77, 780 | cross = False, 781 | text_model_name = None, 782 | **kwargs 783 | ): 784 | super().__init__() 785 | self.dim = dim 786 | 787 | self.to_text_embeds = nn.Sequential( 788 | nn.Linear(1280, self.dim), 789 | ) 790 | 791 | self.continuous_embedded_time = not exists(num_timesteps) 792 | 793 | num_time_embeds = 1 794 | self.to_time_embeds = nn.Sequential( 795 | nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP 796 | Rearrange('b (n d) -> b n d', n = num_time_embeds) 797 | ) 798 | self.to_mask_embeds = nn.Sequential( 799 | Rearrange('b h w -> b (h w)'), 800 | MLP(4096, dim), 801 | Rearrange('b (n d) -> b n d', n = 1) 802 | ) 803 | 804 | self.transformer = CrossTransformer(dim = dim, **kwargs) 805 | 806 | def forward_with_cond_scale( 807 | self, 808 | image_embed, 809 | image_cond, 810 | text_cond, 811 | input_mask, 812 | diffusion_timesteps, 813 | text_cond_uc, 814 | cond_scale = 1., 815 | ): 816 | if cond_scale == 1.: 817 | logits = self.forward(image_embed, image_cond, text_cond, input_mask, diffusion_timesteps, text_cond_uc) 818 | return logits 819 | else: 820 | # make it triple! 821 | ''' 822 | logits, null_image_logits, null_text_logits 823 | ''' 824 | image_embed = torch.cat([image_embed] * 3) 825 | image_cond = torch.cat([image_cond, image_cond, torch.zeros_like(image_cond)]) 826 | text_cond = torch.cat([text_cond, text_cond_uc, text_cond_uc]) 827 | input_mask = torch.cat([input_mask] * 3) 828 | diffusion_timesteps = torch.cat([diffusion_timesteps] * 3) 829 | 830 | logits, null_text_logits, null_all_logits = self.forward(image_embed, image_cond, text_cond, input_mask, diffusion_timesteps).chunk(3) 831 | return null_all_logits + (logits - null_text_logits) * cond_scale[1] + (null_text_logits - null_all_logits) * cond_scale[0] 832 | 833 | def forward( 834 | self, 835 | image_embed, 836 | image_cond, 837 | text_cond, 838 | input_mask, 839 | diffusion_timesteps, 840 | ): 841 | batch, n_image_embed, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype 842 | 843 | text_cond = self.to_text_embeds(text_cond) 844 | 845 | if self.continuous_embedded_time: 846 | diffusion_timesteps = diffusion_timesteps.type(dtype) 847 | 848 | time_embed = self.to_time_embeds(diffusion_timesteps) 849 | 850 | mask_embed = self.to_mask_embeds(input_mask) 851 | 852 | tokens = torch.cat(( 853 | image_embed, 854 | time_embed, 855 | ), dim = -2) 856 | 857 | context_embed = torch.cat([text_cond, image_cond, mask_embed], dim=1) 858 | tokens = self.transformer(tokens, context=context_embed) 859 | 860 | pred_image_embed = tokens[..., :1, :] 861 | 862 | return pred_image_embed 863 | 864 | class CompoDiff(nn.Module): 865 | def __init__( 866 | self, 867 | net, 868 | *, 869 | image_embed_dim = None, 870 | timesteps = 1000, 871 | predict_x_start = True, 872 | loss_type = "l2", 873 | beta_schedule = "cosine", 874 | condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training 875 | sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs) 876 | sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm) 877 | training_clamp_l2norm = False, 878 | init_image_embed_l2norm = False, 879 | image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 880 | ): 881 | super().__init__() 882 | 883 | self.sample_timesteps = None 884 | 885 | self.noise_scheduler = NoiseScheduler( 886 | beta_schedule = beta_schedule, 887 | timesteps = timesteps, 888 | loss_type = loss_type 889 | ) 890 | 891 | self.net = net 892 | self.image_embed_dim = image_embed_dim 893 | 894 | self.condition_on_text_encodings = condition_on_text_encodings 895 | 896 | self.predict_x_start = predict_x_start 897 | 898 | self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5) 899 | 900 | # whether to force an l2norm, similar to clipping denoised, when sampling 901 | 902 | self.sampling_clamp_l2norm = sampling_clamp_l2norm 903 | self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm 904 | 905 | self.training_clamp_l2norm = training_clamp_l2norm 906 | self.init_image_embed_l2norm = init_image_embed_l2norm 907 | 908 | # device tracker 909 | 910 | self.register_buffer('_dummy', torch.tensor([True]), persistent = False) 911 | 912 | @property 913 | def device(self): 914 | return self._dummy.device 915 | 916 | def l2norm_clamp_embed(self, image_embed): 917 | return l2norm(image_embed) * self.image_embed_scale 918 | 919 | @torch.no_grad() 920 | def p_sample_loop_ddim(self, shape, image_cond, text_cond, negative_text_cond, input_mask, timesteps, eta = 1., cond_scale = 1., random_seed=None): 921 | batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps 922 | 923 | times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1] 924 | 925 | times = list(reversed(times.int().tolist())) 926 | time_pairs = list(zip(times[:-1], times[1:])) 927 | 928 | if random_seed is None: 929 | image_embed = noise = torch.randn(shape).to(device) 930 | else: 931 | image_embed = noise = torch.randn(shape, generator=torch.manual_seed(random_seed)).to(device) 932 | 933 | x_start = None # for self-conditioning 934 | 935 | if self.init_image_embed_l2norm: 936 | image_embed = l2norm(image_embed) * self.image_embed_scale 937 | 938 | for time, time_next in tqdm(time_pairs, desc = 'CompoDiff sampling loop'): 939 | alpha = alphas[time] 940 | alpha_next = alphas[time_next] 941 | 942 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long) 943 | 944 | pred = self.net.forward_with_cond_scale(image_embed, image_cond, text_cond, input_mask, time_cond, text_cond_uc=negative_text_cond, cond_scale = cond_scale) 945 | 946 | # derive x0 947 | 948 | if self.predict_x_start: 949 | x_start = pred 950 | else: 951 | x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise) 952 | 953 | # clip x0 before maybe predicting noise 954 | 955 | if not self.predict_x_start: 956 | x_start.clamp_(-1., 1.) 957 | 958 | if self.predict_x_start and self.sampling_clamp_l2norm: 959 | x_start = self.l2norm_clamp_embed(x_start) 960 | 961 | # predict noise 962 | 963 | if self.predict_x_start or self.predict_v: 964 | pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) 965 | else: 966 | pred_noise = pred 967 | 968 | if time_next < 0: 969 | image_embed = x_start 970 | continue 971 | 972 | c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 973 | c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() 974 | noise = torch.randn_like(image_embed) if time_next > 0 else 0. 975 | 976 | image_embed = x_start * alpha_next.sqrt() + \ 977 | c1 * noise + \ 978 | c2 * pred_noise 979 | 980 | if self.predict_x_start and self.sampling_final_clamp_l2norm: 981 | image_embed = self.l2norm_clamp_embed(image_embed) 982 | 983 | return image_embed 984 | 985 | @torch.no_grad() 986 | def p_sample_loop(self, shape, image_cond, text_cond, negative_text_cond, input_mask, cond_scale = 1., timesteps = None, random_seed = None): 987 | timesteps = default(timesteps, self.noise_scheduler.num_timesteps) 988 | assert timesteps <= self.noise_scheduler.num_timesteps 989 | is_ddim = timesteps < self.noise_scheduler.num_timesteps 990 | 991 | normalized_image_embed = self.p_sample_loop_ddim(shape, image_cond, text_cond, negative_text_cond, input_mask, cond_scale=cond_scale, timesteps=timesteps, random_seed=random_seed) 992 | 993 | image_embed = normalized_image_embed / self.image_embed_scale 994 | return image_embed 995 | 996 | @torch.no_grad() 997 | @eval_decorator 998 | def sample( 999 | self, 1000 | image_cond, 1001 | text_cond, 1002 | negative_text_cond, 1003 | input_mask, 1004 | num_samples_per_batch=4, 1005 | cond_scale=1., 1006 | timesteps=None, 1007 | random_seed=None): 1008 | 1009 | timesteps = default(timesteps, self.sample_timesteps) 1010 | 1011 | if image_cond is not None: 1012 | image_cond = repeat(image_cond, 'b ... -> (b r) ...', r = num_samples_per_batch) 1013 | text_cond = repeat(text_cond, 'b ... -> (b r) ...', r = num_samples_per_batch) 1014 | input_mask = repeat(input_mask, 'b ... -> (b r) ...', r = num_samples_per_batch) 1015 | negative_text_cond = repeat(negative_text_cond, 'b ... -> (b r) ...', r = num_samples_per_batch) 1016 | 1017 | batch_size = text_cond.shape[0] 1018 | image_embed_dim = self.image_embed_dim 1019 | 1020 | image_embeds = self.p_sample_loop((batch_size, 1, image_embed_dim), image_cond, text_cond, negative_text_cond, input_mask, cond_scale = cond_scale, timesteps = timesteps, random_seed = random_seed) 1021 | 1022 | image_embeds = rearrange(image_embeds, '(b r) 1 d -> b r d', r = num_samples_per_batch) 1023 | 1024 | return torch.mean(image_embeds, dim=1) 1025 | 1026 | def forward( 1027 | self, 1028 | input_image_embed = None, 1029 | target_image_embed = None, 1030 | image_cond = None, 1031 | text_cond = None, 1032 | input_mask = None, 1033 | text_cond_uc = None, 1034 | *args, 1035 | **kwargs 1036 | ): 1037 | # timestep conditioning from ddpm 1038 | batch, device = input_image_embed.shape[0], input_image_embed.device 1039 | times = self.noise_scheduler.sample_random_times(batch) 1040 | 1041 | # scale image embed (Katherine) 1042 | input_image_embed *= self.image_embed_scale 1043 | 1044 | # calculate forward loss 1045 | 1046 | loss = self.p_losses(input_image_embed = input_image_embed, 1047 | target_image_embed = target_image_embed, 1048 | image_cond = image_cond, 1049 | text_cond = text_cond, 1050 | input_mask = input_mask, 1051 | times = times, 1052 | text_cond_uc = text_cond_uc, 1053 | *args, **kwargs) 1054 | return loss 1055 | 1056 | def build_compodiff(embed_dim, 1057 | model_depth, 1058 | model_dim, 1059 | model_heads, 1060 | timesteps, 1061 | ): 1062 | 1063 | compodiff_network = CompoDiffNetwork( 1064 | dim = embed_dim, 1065 | depth = model_depth, 1066 | dim_head = model_dim, 1067 | heads = model_heads, 1068 | ) 1069 | 1070 | compodiff = CompoDiff( 1071 | net = compodiff_network, 1072 | image_embed_dim=embed_dim, 1073 | timesteps = timesteps, 1074 | condition_on_text_encodings = True, 1075 | image_embed_scale = 1.0, 1076 | sampling_clamp_l2norm = False, 1077 | training_clamp_l2norm = False, 1078 | init_image_embed_l2norm = False, 1079 | predict_x_start = True, 1080 | ) 1081 | 1082 | return compodiff 1083 | 1084 | def build_clip(): 1085 | clip_model = CLIPHF(image_model_name='openai/clip-vit-large-patch14', 1086 | text_model_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', 1087 | ) 1088 | return clip_model 1089 | 1090 | 1091 | -------------------------------------------------------------------------------- /demo_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | CompoDiff 3 | Copyright (c) 2023-present NAVER Corp. 4 | Apache-2.0 5 | """ 6 | import os 7 | import numpy as np 8 | import base64 9 | import requests 10 | import json 11 | import time 12 | import torch 13 | import torch.nn.functional as F 14 | import gradio as gr 15 | from clip_retrieval.clip_client import ClipClient 16 | import types 17 | from typing import Union, List, Optional, Callable 18 | import torch 19 | from diffusers import UnCLIPImageVariationPipeline 20 | from torchvision import transforms 21 | from torchvision.transforms.functional import to_pil_image, pil_to_tensor 22 | from PIL import Image 23 | import compodiff 24 | 25 | 26 | def load_models(): 27 | ### build model 28 | print("\tbuilding CompoDiff") 29 | 30 | compodiff_model, clip_model, img_preprocess, tokenizer = compodiff.build_model() 31 | 32 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 33 | 34 | compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device) 35 | 36 | if device != 'cpu': 37 | clip_model = clip_model.half() 38 | 39 | model_dict = {} 40 | model_dict['compodiff'] = compodiff_model 41 | model_dict['clip_model'] = clip_model 42 | model_dict['img_preprocess'] = img_preprocess 43 | model_dict['tokenizer'] = tokenizer 44 | model_dict['device'] = device 45 | 46 | return model_dict 47 | 48 | 49 | @torch.no_grad() 50 | def l2norm(features): 51 | return features / features.norm(p=2, dim=-1, keepdim=True) 52 | 53 | 54 | def predict(images, input_text, negative_text, step, cfg_image_scale, cfg_text_scale, do_generate, source_mixing_weight): 55 | ''' 56 | image_source, text_input, negative_text_input, mask_text_input, steps_input, cfg_scale, do_generate, cfg_attn_target 57 | ''' 58 | device = model_dict['device'] 59 | 60 | step = int(step) 61 | step = step + 1 if step < 1000 else step 62 | 63 | cfg_scale = (cfg_image_scale, cfg_text_scale) 64 | 65 | text = input_text 66 | 67 | if images is None: 68 | # t2i 69 | cfg_scale = (1.0, cfg_text_scale) 70 | text_token_dict = model_dict['tokenizer'](text=text, return_tensors='pt', padding='max_length', truncation=True) 71 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 72 | 73 | negative_text_token_dict = model_dict['tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True) 74 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 75 | 76 | with torch.no_grad(): 77 | image_cond = torch.zeros([1,1,768]).to(device) 78 | text_cond = model_dict['clip_model'].encode_texts(text_tokens, text_attention_mask) 79 | negative_text_cond = model_dict['clip_model'].encode_texts(negative_text_tokens, negative_text_attention_mask) 80 | 81 | sampling_start = time.time() 82 | mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0) 83 | sampled_image_features = model_dict['compodiff'].sample(image_cond, text_cond, negative_text_cond, mask, timesteps=step, cond_scale=cfg_scale, num_samples_per_batch=2) 84 | sampling_end = time.time() 85 | 86 | sampled_image_features_org = sampled_image_features 87 | sampled_image_features = l2norm(sampled_image_features) 88 | else: 89 | # CIR 90 | image_source = images['image'].resize((512, 512)) 91 | mask = images['mask'].resize((512, 512)) 92 | mask = model_dict['img_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values'] 93 | mask = mask[:,:1,:,:] 94 | 95 | ## preprocess 96 | image_source = model_dict['img_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(device) 97 | 98 | mask = (mask > 0.5).float().to(device) 99 | image_source = image_source * (1 - mask) 100 | 101 | text_token_dict = model_dict['tokenizer'](text=text, return_tensors='pt', padding='max_length', truncation=True) 102 | text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 103 | 104 | negative_text_token_dict = model_dict['tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True) 105 | negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device) 106 | 107 | with torch.no_grad(): 108 | image_cond = model_dict['clip_model'].encode_images(image_source) 109 | 110 | text_cond = model_dict['clip_model'].encode_texts(text_tokens, text_attention_mask) 111 | 112 | negative_text_cond = model_dict['clip_model'].encode_texts(negative_text_tokens, negative_text_attention_mask) 113 | 114 | sampling_start = time.time() 115 | mask = transforms.Resize([64, 64])(mask)[:,0,:,:] 116 | mask = (mask > 0.5).float() 117 | if torch.sum(mask).item() == 0.0: 118 | mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0) 119 | sampled_image_features = model_dict['compodiff'].sample(image_cond, text_cond, negative_text_cond, mask, timesteps=step, cond_scale=cfg_scale, num_samples_per_batch=2) 120 | sampling_end = time.time() 121 | 122 | sampled_image_features_org = (1 - source_mixing_weight) * sampled_image_features + source_mixing_weight * image_cond[0] 123 | sampled_image_features = l2norm(sampled_image_features_org) 124 | 125 | if do_generate and image_decoder is not None: 126 | images = image_decoder(image_embeddings=sampled_image_features_org.half(), num_images_per_prompt=2).images 127 | else: 128 | images = [Image.fromarray(np.zeros([256,256,3], dtype='uint8'))] 129 | 130 | do_list = [['KNN results', sampled_image_features], 131 | ] 132 | 133 | output = '' 134 | top1_list = [] 135 | search_start = time.time() 136 | for name, features in do_list: 137 | results = client.query(embedding_input=features[0].tolist())[:15] 138 | output += f'
{name} outputs\n\n' 139 | for idx, result in enumerate(results): 140 | image_url = result['url'] 141 | if idx == 0: 142 | top1_list.append(f'{image_url}') 143 | output += f'![image]({image_url})\n' 144 | 145 | output += '\n
\n\n' 146 | 147 | search_end = time.time() 148 | 149 | return output, images 150 | 151 | 152 | if __name__ == "__main__": 153 | global model_dict, client, image_decoder 154 | 155 | model_dict = load_models() 156 | 157 | if 'cuda' in model_dict['device']: 158 | image_decoder = UnCLIPImageVariationPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float16).to('cuda:0') 159 | else: 160 | image_decoder = None 161 | 162 | client = ClipClient(url="https://knn.laion.ai/knn-service", 163 | indice_name="laion5B-L-14", 164 | ) 165 | 166 | ### define gradio demo 167 | title = 'CompoDiff demo' 168 | 169 | md_title = f'''# {title} 170 | Diffusion on {model_dict["device"]}, K-NN Retrieval using https://rom1504.github.io/clip-retrieval. 171 | ''' 172 | md_below = f'''### Tips: 173 | Here are some tips for using the demo: 174 | + If you want to apply more of the original image's context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal. 175 | + If you want to exclude specific keywords, you can add them to the Negative text input. 176 | + Try using "generate image with unCLIP" to create images. You can see some interesting generated images that are as fascinating as search results. 177 | + If you only input text and no image, it will work like the prior of unCLIP. 178 | ''' 179 | 180 | 181 | with gr.Blocks(title=title) as demo: 182 | gr.Markdown(md_title) 183 | with gr.Row(): 184 | with gr.Column(): 185 | image_source = gr.Image(type='pil', label='Source image', tool='sketch') 186 | with gr.Row(): 187 | steps_input = gr.Radio(['2', '3', '5', '10'], value='10', label='denoising steps') 188 | if model_dict['device'] == 'cpu': 189 | do_generate = gr.Checkbox(value=False, label='generate image with unCLIP', visible=False) 190 | else: 191 | do_generate = gr.Checkbox(value=False, label='generate image with unCLIP', visible=True) 192 | with gr.Accordion('Advanced options', open=False): 193 | with gr.Row(): 194 | cfg_image_scale = gr.Number(value=1.5, label='image condition scale') 195 | cfg_text_scale = gr.Number(value=7.5, label='text condition scale') 196 | source_mixing_weight = gr.Number(value=0.1, label='source weight (0.0~1.0)') 197 | text_input = gr.Textbox(value='', label='Input text guidance') 198 | negative_text_input = gr.Textbox(value='', label='Negative text') # low quality, text overlay, logo 199 | submit_button = gr.Button('Submit') 200 | gr.Markdown(md_below) 201 | with gr.Column(): 202 | if model_dict['device'] == 'cpu': 203 | gallery = gr.Gallery(label='Generated images', visible=False).style(grid=[2]) 204 | else: 205 | gallery = gr.Gallery(label='Generated images', visible=True).style(grid=[2]) 206 | md_output = gr.Markdown(label='Output') 207 | submit_button.click(predict, inputs=[image_source, text_input, negative_text_input, steps_input, cfg_image_scale, cfg_text_scale, do_generate, source_mixing_weight], outputs=[md_output, gallery]) 208 | demo.launch(server_name='0.0.0.0', 209 | server_port=8000) 210 | 211 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.0 2 | torchvision>=0.9 3 | numpy 4 | transformers 5 | diffusers 6 | rotary-embedding-torch 7 | einops 8 | einops_exts 9 | gradio 10 | clip-retrieval 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | CompoDiff 3 | Copyright (c) 2023-present NAVER Corp. 4 | Apache-2.0 5 | """ 6 | import os 7 | import pkg_resources 8 | from setuptools import setup, find_packages 9 | 10 | setup( 11 | name='compodiff', 12 | version='0.1.1', 13 | description='Easy to use CompoDiff library', 14 | author='NAVER Corp.', 15 | author_email='dl_visionresearch@navercorp.com', 16 | url='https://github.com/navervision/CompoDiff', 17 | install_requires=[ 18 | str(r) 19 | for r in pkg_resources.parse_requirements( 20 | open(os.path.join(os.path.dirname(__file__), 'requirements.txt')) 21 | ) 22 | ], 23 | packages=find_packages(), 24 | ) 25 | --------------------------------------------------------------------------------