├── .gitignore ├── LICENSE ├── README.md ├── assets ├── A man with a face of avocado, in the drawing style of Rene Magritte..png ├── A teddy bear on a skateboard, children drawing style..png ├── Goryeo celadon in the shape of bird.png ├── Photo of a business woman, silver hair.png ├── a black porcelain in the shape of pikachu.png ├── a portrait of an old monk, highly detailed.png ├── example.gif ├── improved_sr_arch.jpg ├── variation_A man with a face of avocado, in the drawing style of Rene Magritte..png └── variation_a black porcelain in the shape of pikachu.png ├── configs ├── decoder_900M_vit_l.yaml ├── improved_sr_64_256_1.4B.yaml └── prior_1B_vit_l.yaml ├── demo ├── components.py └── product_demo.py ├── example.py ├── karlo ├── __init__.py ├── models │ ├── __init__.py │ ├── clip.py │ ├── decoder_model.py │ ├── prior_model.py │ ├── sr_256_1k.py │ └── sr_64_256.py ├── modules │ ├── __init__.py │ ├── diffusion │ │ ├── gaussian_diffusion.py │ │ └── respace.py │ ├── nn.py │ ├── resample.py │ ├── unet.py │ └── xf.py ├── sampler │ ├── i2i.py │ ├── t2i.py │ └── template.py └── utils │ └── util.py ├── requirements.txt ├── setup.cfg └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .idea/ 3 | __pycache__/ 4 | 5 | results 6 | results.* 7 | notebooks/ 8 | outputs 9 | outputs/ 10 | configs/temp/ 11 | pytests/ 12 | _cache/ 13 | scripts/ 14 | 15 | *.ckpt 16 | core.* 17 | 18 | __pycache__/ 19 | **/__pycache__/ 20 | *.py[cod] 21 | **/*.py[cod] 22 | **/*.pyc 23 | result*/ 24 | results*/ 25 | backup*/ 26 | test.*/ 27 | .nfs* 28 | 29 | .ipynb_*/ 30 | 31 | # MacOSX 32 | **/*.DS_Store 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Kakao Brain 2 | 3 | CreativeML Open RAIL-M 4 | dated August 22, 2022 5 | 6 | Section I: PREAMBLE 7 | 8 | Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. 9 | 10 | Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. 11 | 12 | In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. 13 | 14 | Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. 15 | 16 | This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. 17 | 18 | NOW THEREFORE, You and Licensor agree as follows: 19 | 20 | 1. Definitions 21 | 22 | - "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. 23 | - "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. 24 | - "Output" means the results of operating a Model as embodied in informational content resulting therefrom. 25 | - "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. 26 | - "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. 27 | - "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. 28 | - "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. 29 | - "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. 30 | - "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. 31 | - "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. 32 | - "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 33 | - "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. 34 | 35 | Section II: INTELLECTUAL PROPERTY RIGHTS 36 | 37 | Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. 38 | 39 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. 40 | 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. 41 | 42 | Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION 43 | 44 | 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: 45 | Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. 46 | You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; 47 | You must cause any modified files to carry prominent notices stating that You changed the files; 48 | You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. 49 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. 50 | 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). 51 | 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. 52 | 53 | Section IV: OTHER PROVISIONS 54 | 55 | 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. 56 | 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. 57 | 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. 58 | 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 59 | 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 60 | 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. 61 | 62 | END OF TERMS AND CONDITIONS 63 | 64 | 65 | 66 | 67 | Attachment A 68 | 69 | Use Restrictions 70 | 71 | You agree not to use the Model or Derivatives of the Model: 72 | - In any way that violates any applicable national, federal, state, local or international law or regulation; 73 | - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; 74 | - To generate or disseminate verifiably false information and/or content with the purpose of harming others; 75 | - To generate or disseminate personal identifiable information that can be used to harm an individual; 76 | - To defame, disparage or otherwise harass others; 77 | - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; 78 | - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; 79 | - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; 80 | - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; 81 | - To provide medical advice and medical results interpretation; 82 | - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Karlo-v1.0.alpha on COYO-100M and CC15M 2 | 3 | Karlo is a text-conditional image generation model based on OpenAI's unCLIP architecture with the improvement over the standard super-resolution model from 64px to 256px, recovering high-frequency details only in the small number of denoising steps. 4 | 5 |

6 | 7 |

8 | 9 |
10 | "a portrait of an old monk, highly detailed." 11 |

12 | 13 |

14 |
15 |
16 | "Photo of a business woman, silver hair" 17 |

18 | 19 |

20 |
21 |
22 | "A teddy bear on a skateboard, children drawing style." 23 |

24 | 25 |

26 |
27 |
28 | "Goryeo celadon in the shape of bird" 29 |

30 | 31 |

32 |
33 | 34 | This alpha version of Karlo is trained on 115M image-text pairs, including [COYO](https://github.com/kakaobrain/coyo-dataset)-100M high-quality subset, CC3M, and CC12M. For those who are interested in a better version of Karlo trained on more large-scale high-quality datasets, please visit the landing page of our application [B^DISCOVER](https://bdiscover.kakaobrain.com/). 35 | 36 | ### Updates 37 | * [2022-12-01] Karlo-v1.0.alpha is released! 38 | * [2022-12-19] Karlo-v1.0.alpha was [integrated into the 🧨 diffusers library](#-diffusers-integration) 39 | * [2022-12-20] Karlo-v1.0.alpha was integrated into Huggingface Spaces 🤗 using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/kakaobrain/karlo) 40 | 41 | ## Model Architecture 42 | 43 | ### Overview 44 | Karlo is a text-conditional diffusion model based on unCLIP, composed of prior, decoder, and super-resolution modules. In this repository, we include the improved version of the standard super-resolution module for upscaling 64px to 256px only in 7 reverse steps, as illustrated in the figure below: 45 | 46 |

47 | 48 |

49 | 50 | In specific, the standard SR module trained by DDPM objective upscales 64px to 256px in the first 6 denoising steps based on the respacing technique. Then, the additional fine-tuned SR module trained by [VQ-GAN](https://compvis.github.io/taming-transformers/)-style loss performs the final reverse step to recover high-frequency details. We observe that this approach is very effective to upscale the low-resolution in a small number of reverse steps. 51 | 52 | ### Details 53 | We train all components from scratch on 115M image-text pairs including COYO-100M, CC3M, and CC12M. In the case of Prior and Decoder, we use ViT-L/14 provided by OpenAI’s [CLIP repository](https://github.com/openai/CLIP). Unlike the original implementation of unCLIP, we replace the trainable transformer in the decoder into the text encoder in ViT-L/14 for efficiency. In the case of the SR module, we first train the model using the DDPM objective in 1M steps, followed by additional 234K steps to fine-tune the additional component. The table below summarizes the important statistics of our components: 54 | 55 | | | Prior | Decoder | SR | 56 | |:------|----:|----:|----:| 57 | | CLIP | ViT-L/14 | ViT-L/14 | - | 58 | | #param | 1B | 900M | 700M + 700M | 59 | | #optimization steps | 1M | 1M | 1M + 0.2M | 60 | | #sampling steps | 25 | 50 (default), 25 (fast) | 7 | 61 | |Checkpoint links| [ViT-L-14](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt), [ViT-L-14 stats](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th), [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt) | [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt)| [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt) | 62 | 63 | In the checkpoint links, ViT-L-14 is equivalent to the original version, but we include it for convenience. We also remark that ViT-L-14-stats is required to normalize the outputs of the prior module. 64 | 65 | ### Evaluation 66 | We quantitatively measure the performance of Karlo-v1.0.alpha in the validation split of CC3M and MS-COCO. The table below presents CLIP-score and FID. To measure FID, we resize the image of the shorter side to 256px, followed by cropping it at the center. We set classifier-free guidance scales for prior and decoder to 4 and 8 in all cases. We observe that our model achieves reasonable performance even with 25 sampling steps of decoder. 67 | 68 | CC3M 69 | | Sampling step | CLIP-s (ViT-B/16) | FID (13k from val)| 70 | |:------|----:|----:| 71 | | Prior (25) + Decoder (25) + SR (7) | 0.3081 | 14.37 | 72 | | Prior (25) + Decoder (50) + SR (7) | 0.3086 | 13.95 | 73 | 74 | MS-COCO 75 | | Sampling step | CLIP-s (ViT-B/16) | FID (30k from val)| 76 | |:------|----:|----:| 77 | | Prior (25) + Decoder (25) + SR (7) | 0.3192 | 15.24 | 78 | | Prior (25) + Decoder (50) + SR (7) | 0.3192 | 14.43 | 79 | 80 | 81 | For more information, please refer to the upcoming technical report. 82 | 83 | ## 🧨 Diffusers integration 84 | Our unCLIP implemenetation is officially integrated in the [🧨 diffusers library](https://huggingface.co/docs/diffusers/api/pipelines/unclip) 85 | ``` 86 | #Requisits to run Karlo unCLIP on diffusers 87 | pip install diffusers transformers accelerate safetensors 88 | ``` 89 | 90 | ```py 91 | from diffusers import UnCLIPPipeline 92 | import torch 93 | 94 | pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) 95 | pipe = pipe.to('cuda') 96 | 97 | prompt = "a high-resolution photograph of a big red frog on a green leaf." 98 | image = pipe(prompt).images[0] 99 | image.save("./frog.png") 100 | ``` 101 | Check out the [diffusers docs](https://huggingface.co/docs/diffusers/api/pipelines/unclip) for the full usage of the `unCLIPPipeline` 102 | 103 | ## Environment Setup 104 | We use a single V100 of 32GB VRAM for sampling under PyTorch >= 1.10 and CUDA >= 11. The following commands install additional python packages and get pretrained model checkpoints. Or, you can simply install the package and download the weights via [setup.sh](setup.sh) 105 | - Additional python packages 106 | ``` 107 | pip install -r requirements.txt 108 | ``` 109 | - Model checkpoints 110 | ``` 111 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt -P $KARLO_ROOT_DIR # same with the official ViT-L/14 from OpenAI 112 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th -P $KARLO_ROOT_DIR 113 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR 114 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR 115 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt -P $KARLO_ROOT_DIR 116 | ``` 117 | 118 | ## Sampling 119 | 120 | ### Gradio demo (T2I and Image variation) 121 | The following command launches gradio demo for text-to-image generation and image variation. We notice that the second run in the gradio is unexpectedly slower than the usual case in PyTorch>=1.12. We guess that this happens because launching the cuda kernels takes some time, usually up to 2 minutes. 122 | ``` 123 | python demo/product_demo.py --host 0.0.0.0 --port $PORT --root-dir $KARLO_ROOT_DIR 124 | ``` 125 | 126 | Samples below are non-cherry picked T2I and image variation examples of random seed 0. 127 | In each case, the first row shows T2I samples and the second shows the image variation samples of the leftmost image in the first row. 128 | 129 |
130 | [T2I + Image variation] "A man with a face of avocado, in the drawing style of Rene Magritte." 131 |

132 | 133 | 134 |

135 |
136 | 137 |
138 | [T2I + Image variation] "a black porcelain in the shape of pikachu" 139 |

140 | 141 | 142 |

143 |
144 | 145 | 146 | ### T2I command line example 147 | Here, we include the command line example of T2I. For image variation, you can refer to [karlo/sampler/i2i.py](karlo/sampler/i2i.py) on how to replace the prior into the clip image feature. 148 | ```python 149 | python example.py --root-dir=$KARLO_ROOT_DIR \ 150 | --prompt="A man with a face of avocado, in the drawing style of Rene Magritte" \ 151 | --output-dir=$OUTPUT_DIR \ 152 | --max-bsz=2 \ 153 | --sampling-type=fast 154 | ``` 155 | 156 | ## Licence and Disclaimer 157 | This project including the weights are distributed under [CreativeML Open RAIL-M license](LICENSE), equivalent version of [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion/blob/main/LICENSE). You may use this model in commercial applications, but it is highly recommended to adopt a powerful safe checker as a post-processing. We also remark that we are not responsible for any kinds of use of the generated images. 158 | 159 | ## BibTex 160 | If you find this repository useful in your research, please cite: 161 | ``` 162 | @misc{kakaobrain2022karlo-v1-alpha, 163 | title = {Karlo-v1.0.alpha on COYO-100M and CC15M}, 164 | author = {Donghoon Lee, Jiseob Kim, Jisu Choi, Jongmin Kim, Minwoo Byeon, Woonhyuk Baek and Saehoon Kim}, 165 | year = {2022}, 166 | howpublished = {\url{https://github.com/kakaobrain/karlo}}, 167 | } 168 | ``` 169 | 170 | ## Acknowledgement 171 | * We deeply appreciate all the contributors to OpenAI’s [Guided-Diffusion](https://github.com/openai/guided-diffusion) project. 172 | * We also greatly appreciate [Apolinário Passos](https://github.com/apolinario) and [Will Berman](https://github.com/williamberman) from Huggingface for integrating this model to [diffusers](https://github.com/huggingface/diffusers). 173 | 174 | ## Contact 175 | If you would like to collaborate with us or share a feedback, please e-mail to us, contact@kakaobrain.com 176 | -------------------------------------------------------------------------------- /assets/A man with a face of avocado, in the drawing style of Rene Magritte..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/A man with a face of avocado, in the drawing style of Rene Magritte..png -------------------------------------------------------------------------------- /assets/A teddy bear on a skateboard, children drawing style..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/A teddy bear on a skateboard, children drawing style..png -------------------------------------------------------------------------------- /assets/Goryeo celadon in the shape of bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/Goryeo celadon in the shape of bird.png -------------------------------------------------------------------------------- /assets/Photo of a business woman, silver hair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/Photo of a business woman, silver hair.png -------------------------------------------------------------------------------- /assets/a black porcelain in the shape of pikachu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/a black porcelain in the shape of pikachu.png -------------------------------------------------------------------------------- /assets/a portrait of an old monk, highly detailed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/a portrait of an old monk, highly detailed.png -------------------------------------------------------------------------------- /assets/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/example.gif -------------------------------------------------------------------------------- /assets/improved_sr_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/improved_sr_arch.jpg -------------------------------------------------------------------------------- /assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png -------------------------------------------------------------------------------- /assets/variation_a black porcelain in the shape of pikachu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/variation_a black porcelain in the shape of pikachu.png -------------------------------------------------------------------------------- /configs/decoder_900M_vit_l.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: t2i-decoder 3 | diffusion_sampler: uniform 4 | hparams: 5 | image_size: 64 6 | num_channels: 320 7 | num_res_blocks: 3 8 | channel_mult: '' 9 | attention_resolutions: 32,16,8 10 | num_heads: -1 11 | num_head_channels: 64 12 | num_heads_upsample: -1 13 | use_scale_shift_norm: true 14 | dropout: 0.1 15 | clip_dim: 768 16 | clip_emb_mult: 4 17 | text_ctx: 77 18 | xf_width: 1536 19 | xf_layers: 0 20 | xf_heads: 0 21 | xf_final_ln: false 22 | resblock_updown: true 23 | learn_sigma: true 24 | text_drop: 0.3 25 | clip_emb_type: image 26 | clip_emb_drop: 0.1 27 | use_plm: true 28 | 29 | diffusion: 30 | steps: 1000 31 | learn_sigma: true 32 | sigma_small: false 33 | noise_schedule: squaredcos_cap_v2 34 | use_kl: false 35 | predict_xstart: false 36 | rescale_learned_sigmas: true 37 | timestep_respacing: '' 38 | -------------------------------------------------------------------------------- /configs/improved_sr_64_256_1.4B.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: improved_sr_64_256 3 | diffusion_sampler: uniform 4 | hparams: 5 | channels: 320 6 | depth: 3 7 | channels_multiple: 8 | - 1 9 | - 2 10 | - 3 11 | - 4 12 | dropout: 0.0 13 | 14 | diffusion: 15 | steps: 1000 16 | learn_sigma: false 17 | sigma_small: true 18 | noise_schedule: squaredcos_cap_v2 19 | use_kl: false 20 | predict_xstart: false 21 | rescale_learned_sigmas: true 22 | timestep_respacing: '7' 23 | 24 | 25 | sampling: 26 | timestep_respacing: '7' # fix 27 | clip_denoise: true 28 | -------------------------------------------------------------------------------- /configs/prior_1B_vit_l.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: prior 3 | diffusion_sampler: uniform 4 | hparams: 5 | text_ctx: 77 6 | xf_width: 2048 7 | xf_layers: 20 8 | xf_heads: 32 9 | xf_final_ln: true 10 | text_drop: 0.2 11 | clip_dim: 768 12 | 13 | diffusion: 14 | steps: 1000 15 | learn_sigma: false 16 | sigma_small: true 17 | noise_schedule: squaredcos_cap_v2 18 | use_kl: false 19 | predict_xstart: true 20 | rescale_learned_sigmas: false 21 | timestep_respacing: '' 22 | -------------------------------------------------------------------------------- /demo/components.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import time 7 | import sys 8 | import os 9 | import threading 10 | import logging 11 | from queue import Queue 12 | from PIL import Image 13 | 14 | import gradio as gr 15 | import numpy as np 16 | import torch 17 | 18 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 19 | 20 | from karlo.sampler.template import CKPT_PATH, BaseSampler 21 | from karlo.sampler.t2i import T2ISampler 22 | from karlo.sampler.i2i import I2ISampler 23 | from karlo.utils.util import set_seed 24 | 25 | 26 | def tensor_to_images(tensor: torch.Tensor, output_res=(1024, 1024)): 27 | assert tensor.ndim == 4 28 | tensor = torch.clone(tensor) 29 | # NCHW -> NHWC 30 | images = torch.permute(tensor * 255.0, [0, 2, 3, 1]).type(torch.uint8).cpu().numpy() 31 | concat_image = np.concatenate(images, axis=1) 32 | target_size = (output_res[1] * tensor.shape[0], output_res[0]) 33 | concat_image = Image.fromarray(concat_image).resize( 34 | target_size, resample=Image.NEAREST 35 | ) 36 | return images, concat_image 37 | 38 | 39 | class GradioSampler: 40 | def __init__( 41 | self, 42 | root_dir, 43 | max_bsz, 44 | progressive, 45 | sampling_type: str, 46 | ): 47 | self._root_dir = root_dir 48 | self._max_bsz = max_bsz 49 | self._progressive = progressive 50 | self._sampling_type = sampling_type 51 | 52 | self.load_ckpt() 53 | self.set_options_from_sampler() 54 | 55 | self.result_queue = Queue() 56 | 57 | def load_ckpt(self): 58 | base_sampler = BaseSampler(root_dir=self._root_dir) 59 | base_sampler.load_clip(clip_path="ViT-L-14.pt") 60 | base_sampler.load_prior( 61 | f"{CKPT_PATH['prior']}", 62 | clip_stat_path="ViT-L-14_stats.th", 63 | ) 64 | base_sampler.load_decoder(f"{CKPT_PATH['decoder']}") 65 | base_sampler.load_sr_64_256(f"{CKPT_PATH['sr_256']}") 66 | 67 | self.t2i_sampler = T2ISampler( 68 | root_dir=self._root_dir, sampling_type=self._sampling_type 69 | ) 70 | self.i2i_sampler = I2ISampler( 71 | root_dir=self._root_dir, sampling_type=self._sampling_type 72 | ) 73 | 74 | self.t2i_sampler._clip = base_sampler._clip 75 | self.t2i_sampler._tokenizer = base_sampler._tokenizer 76 | self.t2i_sampler._prior = base_sampler._prior 77 | self.t2i_sampler._decoder = base_sampler._decoder 78 | self.t2i_sampler._sr_64_256 = base_sampler._sr_64_256 79 | 80 | self.i2i_sampler._clip = base_sampler._clip 81 | self.i2i_sampler._tokenizer = base_sampler._tokenizer 82 | self.i2i_sampler._prior = base_sampler._prior 83 | self.i2i_sampler._decoder = base_sampler._decoder 84 | self.i2i_sampler._sr_64_256 = base_sampler._sr_64_256 85 | 86 | self.ckpt_info = f""" 87 | * **prior**: `{self._root_dir}/{CKPT_PATH['prior']}` 88 | * **decoder**: `{self._root_dir}/{CKPT_PATH['decoder']}` 89 | * **sr_64_256**: `{self._root_dir}/{CKPT_PATH['sr_256']}` 90 | """ 91 | 92 | def set_options_from_sampler(self): 93 | self.global_options = {"seed": 0, "max_bsz": self._max_bsz} 94 | 95 | self.prior_options = { 96 | "sm": self.t2i_sampler._prior_sm, 97 | "cf_scale": self.t2i_sampler._prior_cf_scale, 98 | } 99 | self.decoder_options = { 100 | "sm": self.t2i_sampler._decoder_sm, 101 | "cf_scale": self.t2i_sampler._decoder_cf_scale, 102 | } 103 | self.sr_64_256_options = { 104 | "sm": self.t2i_sampler._sr_sm, 105 | } 106 | 107 | def make_global_options(self): 108 | gr.Markdown("Global Options") 109 | with gr.Row(): 110 | return [ 111 | gr.Slider( 112 | label="seed", 113 | value=self.global_options["seed"], 114 | minimum=np.iinfo(np.uint32).min, 115 | maximum=np.iinfo(np.uint32).max, 116 | step=1, 117 | ), 118 | gr.Slider( 119 | label="maximum batch size", 120 | value=self.global_options["max_bsz"], 121 | minimum=1, 122 | maximum=4, 123 | step=1, 124 | ), 125 | ] 126 | 127 | def make_prior_options(self): 128 | gr.Markdown("Prior Options") 129 | return [ 130 | gr.Textbox( 131 | label="sampling method", 132 | value=self.prior_options["sm"], 133 | ), 134 | gr.Slider( 135 | label="Classifier-free guidance scales", 136 | value=self.prior_options["cf_scale"], 137 | minimum=0.1, 138 | maximum=24, 139 | ), 140 | ] 141 | 142 | def make_decoder_options(self): 143 | gr.Markdown("Decoder Options") 144 | with gr.Row(): 145 | return [ 146 | gr.Textbox( 147 | label="sampling method", 148 | value=self.decoder_options["sm"], 149 | ), 150 | gr.Slider( 151 | label="Classifier-free guidance scales", 152 | value=self.decoder_options["cf_scale"], 153 | minimum=0.1, 154 | maximum=24, 155 | ), 156 | ] 157 | 158 | def make_sr_64_256_options(self): 159 | return [gr.Variable(self.sr_64_256_options["sm"])] 160 | 161 | def make_basic_options(self): 162 | self.global_options_gr = self.make_global_options() 163 | self.prior_optios_gr = self.make_prior_options() 164 | self.decoder_options_gr = self.make_decoder_options() 165 | self.sr_64_256_options_gr = self.make_sr_64_256_options() 166 | 167 | def seed(self, seed): 168 | set_seed(seed) 169 | 170 | def _sample(self, output_generator): 171 | for k, out in enumerate(output_generator): 172 | self.result_queue.put((out, False)) 173 | self.result_queue.put((None, True)) 174 | 175 | def t2i_sample( 176 | self, 177 | text_input, 178 | prior_sm, 179 | prior_cf_scale, 180 | decoder_sm, 181 | decoder_cf_scale, 182 | sr_sm, 183 | seed, 184 | max_bsz, 185 | ): 186 | t0 = time.time() 187 | assert hasattr(self.t2i_sampler, "_prior_sm") 188 | assert hasattr(self.t2i_sampler, "_prior_cf_scale") 189 | assert hasattr(self.t2i_sampler, "_decoder_sm") 190 | assert hasattr(self.t2i_sampler, "_decoder_cf_scale") 191 | assert hasattr(self.t2i_sampler, "_sr_sm") 192 | 193 | print("-" * 100) 194 | print(f"text_input: {text_input}") 195 | print(f"prior_sm: {prior_sm}") 196 | print(f"prior_cf_scale: {prior_cf_scale}") 197 | print(f"decoder_sm: {decoder_sm}") 198 | print(f"decoder_cf_scale: {decoder_cf_scale}") 199 | print(f"sr_sm: {sr_sm}") 200 | print(f"seed: {seed}") 201 | print(f"max_bsz: {max_bsz}") 202 | 203 | self.t2i_sampler._prior_sm = prior_sm 204 | self.t2i_sampler._prior_cf_scale = prior_cf_scale 205 | 206 | self.t2i_sampler._decoder_sm = decoder_sm 207 | self.t2i_sampler._decoder_cf_scale = decoder_cf_scale 208 | 209 | self.t2i_sampler._sr_sm = sr_sm 210 | 211 | self.seed(seed) 212 | 213 | output_generator = self.t2i_sampler( 214 | prompt=text_input, 215 | bsz=max_bsz, 216 | progressive_mode=self._progressive, 217 | ) 218 | 219 | thread = threading.Thread(target=self._sample, args=(output_generator,)) 220 | thread.start() 221 | done = False 222 | 223 | while not done: 224 | if self.result_queue.empty(): 225 | time.sleep(0.1) 226 | else: 227 | while not self.result_queue.empty(): 228 | _out, done = self.result_queue.get(0) # get last item to display 229 | if not done: 230 | out = _out 231 | images, concat_image = tensor_to_images(out, (256, 256)) 232 | yield (text_input, images), concat_image 233 | 234 | thread.join() 235 | yield (text_input, images), concat_image 236 | 237 | t1 = time.time() 238 | execution_time = t1 - t0 239 | logging.info(f"Generation done. {text_input} -- {execution_time:.6f}secs") 240 | print("-" * 100) 241 | 242 | def i2i_sample( 243 | self, 244 | image_input, 245 | decoder_sm, 246 | decoder_cf_scale, 247 | sr_sm, 248 | seed, 249 | max_bsz, 250 | ): 251 | t0 = time.time() 252 | assert hasattr(self.i2i_sampler, "_decoder_sm") 253 | assert hasattr(self.i2i_sampler, "_decoder_cf_scale") 254 | assert hasattr(self.i2i_sampler, "_sr_sm") 255 | 256 | print("-" * 100) 257 | print(f"decoder_sm: {decoder_sm}") 258 | print(f"decoder_cf_scale: {decoder_cf_scale}") 259 | print(f"sr_sm: {sr_sm}") 260 | print(f"seed: {seed}") 261 | print(f"max_bsz: {max_bsz}") 262 | 263 | self.i2i_sampler._decoder_sm = decoder_sm 264 | self.i2i_sampler._decoder_cf_scale = decoder_cf_scale 265 | 266 | self.i2i_sampler._sr_sm = sr_sm 267 | 268 | self.seed(seed) 269 | 270 | output_generator = self.i2i_sampler( 271 | image=image_input, 272 | bsz=max_bsz, 273 | progressive_mode=self._progressive, 274 | ) 275 | 276 | thread = threading.Thread(target=self._sample, args=(output_generator,)) 277 | thread.start() 278 | done = False 279 | 280 | while not done: 281 | if self.result_queue.empty(): 282 | time.sleep(0.1) 283 | else: 284 | while not self.result_queue.empty(): 285 | _out, done = self.result_queue.get(0) # get last item to display 286 | if not done: 287 | out = _out 288 | images, concat_image = tensor_to_images(out, (256, 256)) 289 | yield ("", images), concat_image 290 | 291 | thread.join() 292 | yield ("", images), concat_image 293 | 294 | t1 = time.time() 295 | execution_time = t1 - t0 296 | logging.info(f"Variation done. {execution_time:.6f}secs") 297 | print("-" * 100) 298 | 299 | 300 | class ImageSelecter: 301 | @classmethod 302 | def make_basic_ui(cls, max_bsz): 303 | with gr.Box(): 304 | i2i_select_idx = gr.Radio( 305 | choices=[str(i) for i in range(0, max_bsz)], 306 | value="0", 307 | label="Image index", 308 | ) 309 | i2i_select_button = gr.Button( 310 | "Select for Image Variation", variant="primary" 311 | ) 312 | return { 313 | "i2i_select_idx": i2i_select_idx, 314 | "i2i_select_button": i2i_select_button, 315 | } 316 | 317 | @classmethod 318 | def select_fn(cls, stash, idx): 319 | if stash is not None: 320 | return Image.fromarray(stash[1][int(idx)].copy()) 321 | 322 | @classmethod 323 | def setup_button_click( 324 | cls, 325 | selector_ui, 326 | stash, 327 | i2i_input_images, 328 | ): 329 | selector_ui["i2i_select_button"].click( 330 | fn=cls.select_fn, 331 | inputs=[stash, selector_ui["i2i_select_idx"]], 332 | outputs=[i2i_input_images], 333 | ) 334 | -------------------------------------------------------------------------------- /demo/product_demo.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import argparse 7 | import logging 8 | import gradio as gr 9 | import os 10 | import sys 11 | 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | 14 | from karlo import __version__ as karlo_ver 15 | from demo.components import GradioSampler, ImageSelecter 16 | 17 | 18 | class GradioDemo: 19 | def __init__( 20 | self, 21 | root_dir: str, 22 | max_bsz: int, 23 | progressive: str, 24 | sampling_type: str, 25 | ): 26 | sampler = GradioSampler( 27 | root_dir=root_dir, 28 | max_bsz=max_bsz, 29 | progressive=progressive, 30 | sampling_type=sampling_type, 31 | ) 32 | 33 | demo = gr.Blocks() 34 | with demo: 35 | gr.Markdown(f"# Karlo Demo {karlo_ver}") 36 | with gr.Box(): 37 | gr.Markdown("## Generate 64px images + Upscaling to 256px") 38 | 39 | with gr.Tabs(): 40 | with gr.TabItem("Image Generation"): 41 | t2i_text_input = gr.Textbox( 42 | lines=1, 43 | placeholder="Type text prompt...", 44 | label="Text prompts", 45 | ) 46 | t2i_button = gr.Button("Generate", variant="primary") 47 | with gr.TabItem("Image Variation"): 48 | i2i_img_input = gr.Image(label="Image input", type="pil") 49 | i2i_button = gr.Button("Generate", variant="primary") 50 | 51 | with gr.Box(): 52 | outputs = gr.Image(label="Generated", type="pil") 53 | stash = gr.Variable() 54 | with gr.Row(): 55 | selector_ui = ImageSelecter.make_basic_ui(max_bsz=max_bsz) 56 | 57 | with gr.Box(): 58 | with gr.Accordion(label="Advanced Options", open=False): 59 | sampler.make_basic_options() 60 | 61 | with gr.Box(): 62 | with gr.Accordion(label="Checkpoint Information", open=False): 63 | gr.Markdown(sampler.ckpt_info) 64 | 65 | t2i_button.click( 66 | fn=sampler.t2i_sample, 67 | inputs=[t2i_text_input] 68 | + sampler.prior_optios_gr 69 | + sampler.decoder_options_gr 70 | + sampler.sr_64_256_options_gr 71 | + sampler.global_options_gr, 72 | outputs=[stash, outputs], 73 | ) 74 | i2i_button.click( 75 | fn=sampler.i2i_sample, 76 | inputs=[i2i_img_input] 77 | + sampler.decoder_options_gr 78 | + sampler.sr_64_256_options_gr 79 | + sampler.global_options_gr, 80 | outputs=[stash, outputs], 81 | ) 82 | 83 | ImageSelecter.setup_button_click(selector_ui, stash, i2i_img_input) 84 | 85 | demo.queue() 86 | self.demo = demo 87 | 88 | 89 | def default_parser(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument("--root-dir", type=str, default=None) 92 | parser.add_argument("--max_bsz", type=int, default=1) 93 | parser.add_argument( 94 | "--progressive", type=str, default="loop", choices=("loop", "stage", "final") 95 | ) 96 | parser.add_argument("--host", type=str, default="localhost") 97 | parser.add_argument("--port", type=int, default=6006) 98 | 99 | parser.add_argument( 100 | "--sampling-type", 101 | type=str, 102 | default="fast", 103 | choices=("fast", "default"), 104 | ) 105 | 106 | return parser 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = default_parser() 111 | args = parser.parse_args() 112 | logging.getLogger().setLevel(logging.INFO) 113 | 114 | assert ( 115 | args.root_dir is not None 116 | ), "--root-dir argument should be specified to load the pretrained ckpt" 117 | 118 | """Making Gradio""" 119 | gradio_demo = GradioDemo( 120 | root_dir=args.root_dir, 121 | max_bsz=args.max_bsz, 122 | progressive=args.progressive, 123 | sampling_type=args.sampling_type, 124 | ) 125 | gradio_demo.demo.launch(server_name=args.host, server_port=args.port) 126 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import argparse 8 | import logging 9 | import time 10 | from datetime import datetime 11 | 12 | import torch 13 | from PIL import Image 14 | 15 | from karlo.sampler.t2i import T2ISampler 16 | from karlo.utils.util import set_seed 17 | 18 | 19 | def default_parser(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--root-dir", type=str, required=True, help="path for model checkpoints" 23 | ) 24 | parser.add_argument("--max-bsz", type=int, default=1, help="#images to generate") 25 | parser.add_argument( 26 | "--output-dir", 27 | type=str, 28 | default="outputs", 29 | help="output path for generated images", 30 | ) 31 | parser.add_argument( 32 | "--sampling-type", 33 | type=str, 34 | default="fast", 35 | choices=("fast", "default"), 36 | ) 37 | parser.add_argument( 38 | "--prompt", type=str, default="A photo of a baby puppy waiting for her mom." 39 | ) 40 | parser.add_argument("--seed", type=int, default=0) 41 | 42 | return parser 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = default_parser() 47 | args = parser.parse_args() 48 | 49 | set_seed(args.seed) 50 | logging.getLogger().setLevel(logging.INFO) 51 | 52 | save_dir = os.path.join(args.output_dir, datetime.now().strftime("%d%m%Y_%H%M%S")) 53 | if not os.path.exists(save_dir): 54 | os.makedirs(save_dir) 55 | 56 | model = T2ISampler.from_pretrained( 57 | root_dir=args.root_dir, 58 | clip_model_path="ViT-L-14.pt", 59 | clip_stat_path="ViT-L-14_stats.th", 60 | sampling_type=args.sampling_type, 61 | ) 62 | 63 | for i in range(5): 64 | t1 = time.time() 65 | 66 | images = iter( 67 | model( 68 | prompt=args.prompt, 69 | bsz=args.max_bsz, 70 | progressive_mode="final", 71 | ) 72 | ).__next__() 73 | 74 | # NCHW, [0, 1], float32 -> NHWC, [0, 255], uint8 75 | images = ( 76 | torch.permute(images * 255.0, [0, 2, 3, 1]).type(torch.uint8).cpu().numpy() 77 | ) 78 | 79 | t2 = time.time() 80 | execution_time = t2 - t1 81 | logging.info(f"Iteration {i} -- {execution_time:.6f}secs") 82 | 83 | # Select the first one 84 | image = Image.fromarray(images[0]) 85 | image_name = "_".join(args.prompt.split(" ")) 86 | image.save(f"{save_dir}/{image_name}_{i:02d}.jpg") 87 | -------------------------------------------------------------------------------- /karlo/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.alpha" 2 | -------------------------------------------------------------------------------- /karlo/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/models/__init__.py -------------------------------------------------------------------------------- /karlo/models/clip.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | # ------------------------------------------------------------------------------------ 6 | # Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/) 7 | # ------------------------------------------------------------------------------------ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import clip 14 | 15 | from clip.model import CLIP, convert_weights 16 | from clip.simple_tokenizer import SimpleTokenizer, default_bpe 17 | 18 | 19 | """===== Monkey-Patching original CLIP for JIT compile =====""" 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """Subclass torch's LayerNorm to handle fp16.""" 24 | 25 | def forward(self, x: torch.Tensor): 26 | orig_type = x.dtype 27 | ret = F.layer_norm( 28 | x.type(torch.float32), 29 | self.normalized_shape, 30 | self.weight, 31 | self.bias, 32 | self.eps, 33 | ) 34 | return ret.type(orig_type) 35 | 36 | 37 | clip.model.LayerNorm = LayerNorm 38 | delattr(clip.model.CLIP, "forward") 39 | 40 | """===== End of Monkey-Patching =====""" 41 | 42 | 43 | class CustomizedCLIP(CLIP): 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | 47 | @torch.jit.export 48 | def encode_image(self, image): 49 | return self.visual(image) 50 | 51 | @torch.jit.export 52 | def encode_text(self, text): 53 | # re-define this function to return unpooled text features 54 | 55 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 56 | 57 | x = x + self.positional_embedding.type(self.dtype) 58 | x = x.permute(1, 0, 2) # NLD -> LND 59 | x = self.transformer(x) 60 | x = x.permute(1, 0, 2) # LND -> NLD 61 | x = self.ln_final(x).type(self.dtype) 62 | 63 | x_seq = x 64 | # x.shape = [batch_size, n_ctx, transformer.width] 65 | # take features from the eot embedding (eot_token is the highest number in each sequence) 66 | x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 67 | 68 | return x_out, x_seq 69 | 70 | @torch.jit.ignore 71 | def forward(self, image, text): 72 | super().forward(image, text) 73 | 74 | @classmethod 75 | def load_from_checkpoint(cls, ckpt_path: str): 76 | state_dict = torch.load(ckpt_path, map_location="cpu").state_dict() 77 | 78 | vit = "visual.proj" in state_dict 79 | if vit: 80 | vision_width = state_dict["visual.conv1.weight"].shape[0] 81 | vision_layers = len( 82 | [ 83 | k 84 | for k in state_dict.keys() 85 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") 86 | ] 87 | ) 88 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 89 | grid_size = round( 90 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 91 | ) 92 | image_resolution = vision_patch_size * grid_size 93 | else: 94 | counts: list = [ 95 | len( 96 | set( 97 | k.split(".")[2] 98 | for k in state_dict 99 | if k.startswith(f"visual.layer{b}") 100 | ) 101 | ) 102 | for b in [1, 2, 3, 4] 103 | ] 104 | vision_layers = tuple(counts) 105 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 106 | output_width = round( 107 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 108 | ) 109 | vision_patch_size = None 110 | assert ( 111 | output_width**2 + 1 112 | == state_dict["visual.attnpool.positional_embedding"].shape[0] 113 | ) 114 | image_resolution = output_width * 32 115 | 116 | embed_dim = state_dict["text_projection"].shape[1] 117 | context_length = state_dict["positional_embedding"].shape[0] 118 | vocab_size = state_dict["token_embedding.weight"].shape[0] 119 | transformer_width = state_dict["ln_final.weight"].shape[0] 120 | transformer_heads = transformer_width // 64 121 | transformer_layers = len( 122 | set( 123 | k.split(".")[2] 124 | for k in state_dict 125 | if k.startswith("transformer.resblocks") 126 | ) 127 | ) 128 | 129 | model = cls( 130 | embed_dim, 131 | image_resolution, 132 | vision_layers, 133 | vision_width, 134 | vision_patch_size, 135 | context_length, 136 | vocab_size, 137 | transformer_width, 138 | transformer_heads, 139 | transformer_layers, 140 | ) 141 | 142 | for key in ["input_resolution", "context_length", "vocab_size"]: 143 | if key in state_dict: 144 | del state_dict[key] 145 | 146 | convert_weights(model) 147 | model.load_state_dict(state_dict) 148 | model.eval() 149 | model.float() 150 | return model 151 | 152 | 153 | class CustomizedTokenizer(SimpleTokenizer): 154 | def __init__(self): 155 | super().__init__(bpe_path=default_bpe()) 156 | 157 | self.sot_token = self.encoder["<|startoftext|>"] 158 | self.eot_token = self.encoder["<|endoftext|>"] 159 | 160 | def padded_tokens_and_mask(self, texts, text_ctx): 161 | assert isinstance(texts, list) and all( 162 | isinstance(elem, str) for elem in texts 163 | ), "texts should be a list of strings" 164 | 165 | all_tokens = [ 166 | [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts 167 | ] 168 | 169 | mask = [ 170 | [True] * min(text_ctx, len(tokens)) 171 | + [False] * max(text_ctx - len(tokens), 0) 172 | for tokens in all_tokens 173 | ] 174 | mask = torch.tensor(mask, dtype=torch.bool) 175 | result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int) 176 | for i, tokens in enumerate(all_tokens): 177 | if len(tokens) > text_ctx: 178 | tokens = tokens[:text_ctx] 179 | tokens[-1] = self.eot_token 180 | result[i, : len(tokens)] = torch.tensor(tokens) 181 | 182 | return result, mask 183 | -------------------------------------------------------------------------------- /karlo/models/decoder_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ..modules import create_gaussian_diffusion 10 | from ..modules.unet import PLMImUNet 11 | 12 | 13 | class Text2ImProgressiveModel(torch.nn.Module): 14 | """ 15 | A decoder that generates 64x64px images based on the text prompt. 16 | 17 | :param config: yaml config to define the decoder. 18 | :param tokenizer: tokenizer used in clip. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | config, 24 | tokenizer, 25 | ): 26 | super().__init__() 27 | 28 | self._conf = config 29 | self._model_conf = config.model.hparams 30 | self._diffusion_kwargs = dict( 31 | steps=config.diffusion.steps, 32 | learn_sigma=config.diffusion.learn_sigma, 33 | sigma_small=config.diffusion.sigma_small, 34 | noise_schedule=config.diffusion.noise_schedule, 35 | use_kl=config.diffusion.use_kl, 36 | predict_xstart=config.diffusion.predict_xstart, 37 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 38 | timestep_respacing=config.diffusion.timestep_respacing, 39 | ) 40 | self._tokenizer = tokenizer 41 | 42 | self.model = self.create_plm_dec_model() 43 | 44 | cf_token, cf_mask = self.set_cf_text_tensor() 45 | self.register_buffer("cf_token", cf_token, persistent=False) 46 | self.register_buffer("cf_mask", cf_mask, persistent=False) 47 | 48 | @classmethod 49 | def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True): 50 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 51 | 52 | model = cls(config, tokenizer) 53 | model.load_state_dict(ckpt, strict=strict) 54 | return model 55 | 56 | def create_plm_dec_model(self): 57 | image_size = self._model_conf.image_size 58 | if self._model_conf.channel_mult == "": 59 | if image_size == 256: 60 | channel_mult = (1, 1, 2, 2, 4, 4) 61 | elif image_size == 128: 62 | channel_mult = (1, 1, 2, 3, 4) 63 | elif image_size == 64: 64 | channel_mult = (1, 2, 3, 4) 65 | else: 66 | raise ValueError(f"unsupported image size: {image_size}") 67 | else: 68 | channel_mult = tuple( 69 | int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",") 70 | ) 71 | assert 2 ** (len(channel_mult) + 2) == image_size 72 | 73 | attention_ds = [] 74 | for res in self._model_conf.attention_resolutions.split(","): 75 | attention_ds.append(image_size // int(res)) 76 | 77 | return PLMImUNet( 78 | text_ctx=self._model_conf.text_ctx, 79 | xf_width=self._model_conf.xf_width, 80 | in_channels=3, 81 | model_channels=self._model_conf.num_channels, 82 | out_channels=6 if self._model_conf.learn_sigma else 3, 83 | num_res_blocks=self._model_conf.num_res_blocks, 84 | attention_resolutions=tuple(attention_ds), 85 | dropout=self._model_conf.dropout, 86 | channel_mult=channel_mult, 87 | num_heads=self._model_conf.num_heads, 88 | num_head_channels=self._model_conf.num_head_channels, 89 | num_heads_upsample=self._model_conf.num_heads_upsample, 90 | use_scale_shift_norm=self._model_conf.use_scale_shift_norm, 91 | resblock_updown=self._model_conf.resblock_updown, 92 | clip_dim=self._model_conf.clip_dim, 93 | clip_emb_mult=self._model_conf.clip_emb_mult, 94 | clip_emb_type=self._model_conf.clip_emb_type, 95 | clip_emb_drop=self._model_conf.clip_emb_drop, 96 | ) 97 | 98 | def set_cf_text_tensor(self): 99 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) 100 | 101 | def get_sample_fn(self, timestep_respacing): 102 | use_ddim = timestep_respacing.startswith(("ddim", "fast")) 103 | 104 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 105 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 106 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 107 | sample_fn = ( 108 | diffusion.ddim_sample_loop_progressive 109 | if use_ddim 110 | else diffusion.p_sample_loop_progressive 111 | ) 112 | 113 | return sample_fn 114 | 115 | def forward( 116 | self, 117 | txt_feat, 118 | txt_feat_seq, 119 | tok, 120 | mask, 121 | img_feat=None, 122 | cf_guidance_scales=None, 123 | timestep_respacing=None, 124 | ): 125 | # cfg should be enabled in inference 126 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) 127 | assert img_feat is not None 128 | 129 | bsz = txt_feat.shape[0] 130 | img_sz = self._model_conf.image_size 131 | 132 | def guided_model_fn(x_t, ts, **kwargs): 133 | half = x_t[: len(x_t) // 2] 134 | combined = torch.cat([half, half], dim=0) 135 | model_out = self.model(combined, ts, **kwargs) 136 | eps, rest = model_out[:, :3], model_out[:, 3:] 137 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 138 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * ( 139 | cond_eps - uncond_eps 140 | ) 141 | eps = torch.cat([half_eps, half_eps], dim=0) 142 | return torch.cat([eps, rest], dim=1) 143 | 144 | cf_feat = self.model.cf_param.unsqueeze(0) 145 | cf_feat = cf_feat.expand(bsz // 2, -1) 146 | feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0) 147 | 148 | cond = { 149 | "y": feat, 150 | "txt_feat": txt_feat, 151 | "txt_feat_seq": txt_feat_seq, 152 | "mask": mask, 153 | } 154 | sample_fn = self.get_sample_fn(timestep_respacing) 155 | sample_outputs = sample_fn( 156 | guided_model_fn, 157 | (bsz, 3, img_sz, img_sz), 158 | noise=None, 159 | device=txt_feat.device, 160 | clip_denoised=True, 161 | model_kwargs=cond, 162 | ) 163 | 164 | for out in sample_outputs: 165 | sample = out["sample"] 166 | yield sample if cf_guidance_scales is None else sample[ 167 | : sample.shape[0] // 2 168 | ] 169 | 170 | 171 | class Text2ImModel(Text2ImProgressiveModel): 172 | def forward( 173 | self, 174 | txt_feat, 175 | txt_feat_seq, 176 | tok, 177 | mask, 178 | img_feat=None, 179 | cf_guidance_scales=None, 180 | timestep_respacing=None, 181 | ): 182 | last_out = None 183 | for out in super().forward( 184 | txt_feat, 185 | txt_feat_seq, 186 | tok, 187 | mask, 188 | img_feat, 189 | cf_guidance_scales, 190 | timestep_respacing, 191 | ): 192 | last_out = out 193 | return last_out 194 | -------------------------------------------------------------------------------- /karlo/models/prior_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ..modules import create_gaussian_diffusion 10 | from ..modules.xf import PriorTransformer 11 | 12 | 13 | class PriorDiffusionModel(torch.nn.Module): 14 | """ 15 | A prior that generates clip image feature based on the text prompt. 16 | 17 | :param config: yaml config to define the decoder. 18 | :param tokenizer: tokenizer used in clip. 19 | :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance). 20 | :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance). 21 | """ 22 | 23 | def __init__(self, config, tokenizer, clip_mean, clip_std): 24 | super().__init__() 25 | 26 | self._conf = config 27 | self._model_conf = config.model.hparams 28 | self._diffusion_kwargs = dict( 29 | steps=config.diffusion.steps, 30 | learn_sigma=config.diffusion.learn_sigma, 31 | sigma_small=config.diffusion.sigma_small, 32 | noise_schedule=config.diffusion.noise_schedule, 33 | use_kl=config.diffusion.use_kl, 34 | predict_xstart=config.diffusion.predict_xstart, 35 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 36 | timestep_respacing=config.diffusion.timestep_respacing, 37 | ) 38 | self._tokenizer = tokenizer 39 | 40 | self.register_buffer("clip_mean", clip_mean[None, :], persistent=False) 41 | self.register_buffer("clip_std", clip_std[None, :], persistent=False) 42 | 43 | causal_mask = self.get_causal_mask() 44 | self.register_buffer("causal_mask", causal_mask, persistent=False) 45 | 46 | self.model = PriorTransformer( 47 | text_ctx=self._model_conf.text_ctx, 48 | xf_width=self._model_conf.xf_width, 49 | xf_layers=self._model_conf.xf_layers, 50 | xf_heads=self._model_conf.xf_heads, 51 | xf_final_ln=self._model_conf.xf_final_ln, 52 | clip_dim=self._model_conf.clip_dim, 53 | ) 54 | 55 | cf_token, cf_mask = self.set_cf_text_tensor() 56 | self.register_buffer("cf_token", cf_token, persistent=False) 57 | self.register_buffer("cf_mask", cf_mask, persistent=False) 58 | 59 | @classmethod 60 | def load_from_checkpoint( 61 | cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True 62 | ): 63 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 64 | 65 | model = cls(config, tokenizer, clip_mean, clip_std) 66 | model.load_state_dict(ckpt, strict=strict) 67 | return model 68 | 69 | def set_cf_text_tensor(self): 70 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) 71 | 72 | def get_sample_fn(self, timestep_respacing): 73 | use_ddim = timestep_respacing.startswith(("ddim", "fast")) 74 | 75 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 76 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 77 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 78 | sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop 79 | 80 | return sample_fn 81 | 82 | def get_causal_mask(self): 83 | seq_len = self._model_conf.text_ctx + 4 84 | mask = torch.empty(seq_len, seq_len) 85 | mask.fill_(float("-inf")) 86 | mask.triu_(1) 87 | mask = mask[None, ...] 88 | return mask 89 | 90 | def forward( 91 | self, 92 | txt_feat, 93 | txt_feat_seq, 94 | mask, 95 | cf_guidance_scales=None, 96 | timestep_respacing=None, 97 | denoised_fn=True, 98 | ): 99 | # cfg should be enabled in inference 100 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) 101 | 102 | bsz_ = txt_feat.shape[0] 103 | bsz = bsz_ // 2 104 | 105 | def guided_model_fn(x_t, ts, **kwargs): 106 | half = x_t[: len(x_t) // 2] 107 | combined = torch.cat([half, half], dim=0) 108 | model_out = self.model(combined, ts, **kwargs) 109 | eps, rest = ( 110 | model_out[:, : int(x_t.shape[1])], 111 | model_out[:, int(x_t.shape[1]) :], 112 | ) 113 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 114 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * ( 115 | cond_eps - uncond_eps 116 | ) 117 | eps = torch.cat([half_eps, half_eps], dim=0) 118 | return torch.cat([eps, rest], dim=1) 119 | 120 | cond = { 121 | "text_emb": txt_feat, 122 | "text_enc": txt_feat_seq, 123 | "mask": mask, 124 | "causal_mask": self.causal_mask, 125 | } 126 | sample_fn = self.get_sample_fn(timestep_respacing) 127 | sample = sample_fn( 128 | guided_model_fn, 129 | (bsz_, self.model.clip_dim), 130 | noise=None, 131 | device=txt_feat.device, 132 | clip_denoised=False, 133 | denoised_fn=lambda x: torch.clamp(x, -10, 10), 134 | model_kwargs=cond, 135 | ) 136 | sample = (sample * self.clip_std) + self.clip_mean 137 | 138 | return sample[:bsz] 139 | -------------------------------------------------------------------------------- /karlo/models/sr_256_1k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from .sr_64_256 import SupRes64to256Progressive 7 | 8 | 9 | class SupRes256to1kProgressive(SupRes64to256Progressive): 10 | pass # no difference currently 11 | -------------------------------------------------------------------------------- /karlo/models/sr_64_256.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import copy 7 | import torch 8 | 9 | from ..modules.unet import SuperResUNetModel 10 | from ..modules import create_gaussian_diffusion 11 | 12 | 13 | class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module): 14 | """ 15 | ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses. 16 | In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model. 17 | In the following additional one step, a seperate fine-tuned model recovers high-frequency details. 18 | This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps. 19 | """ 20 | 21 | def __init__(self, config): 22 | super().__init__() 23 | 24 | self._config = config 25 | self._diffusion_kwargs = dict( 26 | steps=config.diffusion.steps, 27 | learn_sigma=config.diffusion.learn_sigma, 28 | sigma_small=config.diffusion.sigma_small, 29 | noise_schedule=config.diffusion.noise_schedule, 30 | use_kl=config.diffusion.use_kl, 31 | predict_xstart=config.diffusion.predict_xstart, 32 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, 33 | ) 34 | 35 | self.model_first_steps = SuperResUNetModel( 36 | in_channels=3, # auto-changed to 6 inside the model 37 | model_channels=config.model.hparams.channels, 38 | out_channels=3, 39 | num_res_blocks=config.model.hparams.depth, 40 | attention_resolutions=(), # no attention 41 | dropout=config.model.hparams.dropout, 42 | channel_mult=config.model.hparams.channels_multiple, 43 | resblock_updown=True, 44 | use_middle_attention=False, 45 | ) 46 | self.model_last_step = SuperResUNetModel( 47 | in_channels=3, # auto-changed to 6 inside the model 48 | model_channels=config.model.hparams.channels, 49 | out_channels=3, 50 | num_res_blocks=config.model.hparams.depth, 51 | attention_resolutions=(), # no attention 52 | dropout=config.model.hparams.dropout, 53 | channel_mult=config.model.hparams.channels_multiple, 54 | resblock_updown=True, 55 | use_middle_attention=False, 56 | ) 57 | 58 | @classmethod 59 | def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True): 60 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] 61 | 62 | model = cls(config) 63 | model.load_state_dict(ckpt, strict=strict) 64 | return model 65 | 66 | def get_sample_fn(self, timestep_respacing): 67 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) 68 | diffusion_kwargs.update(timestep_respacing=timestep_respacing) 69 | diffusion = create_gaussian_diffusion(**diffusion_kwargs) 70 | return diffusion.p_sample_loop_progressive_for_improved_sr 71 | 72 | def forward(self, low_res, timestep_respacing="7", **kwargs): 73 | assert ( 74 | timestep_respacing == "7" 75 | ), "different respacing method may work, but no guaranteed" 76 | 77 | sample_fn = self.get_sample_fn(timestep_respacing) 78 | sample_outputs = sample_fn( 79 | self.model_first_steps, 80 | self.model_last_step, 81 | shape=low_res.shape, 82 | clip_denoised=True, 83 | model_kwargs=dict(low_res=low_res), 84 | **kwargs, 85 | ) 86 | for x in sample_outputs: 87 | sample = x["sample"] 88 | yield sample 89 | -------------------------------------------------------------------------------- /karlo/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | from .diffusion import gaussian_diffusion as gd 7 | from .diffusion.respace import ( 8 | SpacedDiffusion, 9 | space_timesteps, 10 | ) 11 | 12 | 13 | def create_gaussian_diffusion( 14 | steps, 15 | learn_sigma, 16 | sigma_small, 17 | noise_schedule, 18 | use_kl, 19 | predict_xstart, 20 | rescale_learned_sigmas, 21 | timestep_respacing, 22 | ): 23 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 24 | if use_kl: 25 | loss_type = gd.LossType.RESCALED_KL 26 | elif rescale_learned_sigmas: 27 | loss_type = gd.LossType.RESCALED_MSE 28 | else: 29 | loss_type = gd.LossType.MSE 30 | if not timestep_respacing: 31 | timestep_respacing = [steps] 32 | 33 | return SpacedDiffusion( 34 | use_timesteps=space_timesteps(steps, timestep_respacing), 35 | betas=betas, 36 | model_mean_type=( 37 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 38 | ), 39 | model_var_type=( 40 | ( 41 | gd.ModelVarType.FIXED_LARGE 42 | if not sigma_small 43 | else gd.ModelVarType.FIXED_SMALL 44 | ) 45 | if not learn_sigma 46 | else gd.ModelVarType.LEARNED_RANGE 47 | ), 48 | loss_type=loss_type, 49 | ) 50 | -------------------------------------------------------------------------------- /karlo/modules/diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | import enum 6 | import math 7 | 8 | import numpy as np 9 | import torch as th 10 | 11 | 12 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 13 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 14 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 15 | betas[:warmup_time] = np.linspace( 16 | beta_start, beta_end, warmup_time, dtype=np.float64 17 | ) 18 | return betas 19 | 20 | 21 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 22 | """ 23 | This is the deprecated API for creating beta schedules. 24 | See get_named_beta_schedule() for the new library of schedules. 25 | """ 26 | if beta_schedule == "quad": 27 | betas = ( 28 | np.linspace( 29 | beta_start**0.5, 30 | beta_end**0.5, 31 | num_diffusion_timesteps, 32 | dtype=np.float64, 33 | ) 34 | ** 2 35 | ) 36 | elif beta_schedule == "linear": 37 | betas = np.linspace( 38 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 39 | ) 40 | elif beta_schedule == "warmup10": 41 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 42 | elif beta_schedule == "warmup50": 43 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 44 | elif beta_schedule == "const": 45 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 46 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 47 | betas = 1.0 / np.linspace( 48 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 49 | ) 50 | else: 51 | raise NotImplementedError(beta_schedule) 52 | assert betas.shape == (num_diffusion_timesteps,) 53 | return betas 54 | 55 | 56 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 57 | """ 58 | Get a pre-defined beta schedule for the given name. 59 | The beta schedule library consists of beta schedules which remain similar 60 | in the limit of num_diffusion_timesteps. 61 | Beta schedules may be added, but should not be removed or changed once 62 | they are committed to maintain backwards compatibility. 63 | """ 64 | if schedule_name == "linear": 65 | # Linear schedule from Ho et al, extended to work for any number of 66 | # diffusion steps. 67 | scale = 1000 / num_diffusion_timesteps 68 | return get_beta_schedule( 69 | "linear", 70 | beta_start=scale * 0.0001, 71 | beta_end=scale * 0.02, 72 | num_diffusion_timesteps=num_diffusion_timesteps, 73 | ) 74 | elif schedule_name == "squaredcos_cap_v2": 75 | return betas_for_alpha_bar( 76 | num_diffusion_timesteps, 77 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 78 | ) 79 | else: 80 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 81 | 82 | 83 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 84 | """ 85 | Create a beta schedule that discretizes the given alpha_t_bar function, 86 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 87 | :param num_diffusion_timesteps: the number of betas to produce. 88 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 89 | produces the cumulative product of (1-beta) up to that 90 | part of the diffusion process. 91 | :param max_beta: the maximum beta to use; use values lower than 1 to 92 | prevent singularities. 93 | """ 94 | betas = [] 95 | for i in range(num_diffusion_timesteps): 96 | t1 = i / num_diffusion_timesteps 97 | t2 = (i + 1) / num_diffusion_timesteps 98 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 99 | return np.array(betas) 100 | 101 | 102 | class ModelMeanType(enum.Enum): 103 | """ 104 | Which type of output the model predicts. 105 | """ 106 | 107 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 108 | START_X = enum.auto() # the model predicts x_0 109 | EPSILON = enum.auto() # the model predicts epsilon 110 | 111 | 112 | class ModelVarType(enum.Enum): 113 | """ 114 | What is used as the model's output variance. 115 | The LEARNED_RANGE option has been added to allow the model to predict 116 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 117 | """ 118 | 119 | LEARNED = enum.auto() 120 | FIXED_SMALL = enum.auto() 121 | FIXED_LARGE = enum.auto() 122 | LEARNED_RANGE = enum.auto() 123 | 124 | 125 | class LossType(enum.Enum): 126 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 127 | RESCALED_MSE = ( 128 | enum.auto() 129 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 130 | KL = enum.auto() # use the variational lower-bound 131 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 132 | 133 | def is_vb(self): 134 | return self == LossType.KL or self == LossType.RESCALED_KL 135 | 136 | 137 | class GaussianDiffusion(th.nn.Module): 138 | """ 139 | Utilities for training and sampling diffusion models. 140 | Original ported from this codebase: 141 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 142 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 143 | starting at T and going to 1. 144 | """ 145 | 146 | def __init__( 147 | self, 148 | *, 149 | betas, 150 | model_mean_type, 151 | model_var_type, 152 | loss_type, 153 | ): 154 | super(GaussianDiffusion, self).__init__() 155 | self.model_mean_type = model_mean_type 156 | self.model_var_type = model_var_type 157 | self.loss_type = loss_type 158 | 159 | # Use float64 for accuracy. 160 | betas = np.array(betas, dtype=np.float64) 161 | assert len(betas.shape) == 1, "betas must be 1-D" 162 | assert (betas > 0).all() and (betas <= 1).all() 163 | 164 | self.num_timesteps = int(betas.shape[0]) 165 | 166 | alphas = 1.0 - betas 167 | alphas_cumprod = np.cumprod(alphas, axis=0) 168 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 169 | alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0) 170 | assert alphas_cumprod_prev.shape == (self.num_timesteps,) 171 | 172 | # calculations for diffusion q(x_t | x_{t-1}) and others 173 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 174 | sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) 175 | log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod) 176 | sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) 177 | sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) 178 | 179 | # calculations for posterior q(x_{t-1} | x_t, x_0) 180 | posterior_variance = ( 181 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 182 | ) 183 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 184 | posterior_log_variance_clipped = np.log( 185 | np.append(posterior_variance[1], posterior_variance[1:]) 186 | ) 187 | posterior_mean_coef1 = ( 188 | betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) 189 | ) 190 | posterior_mean_coef2 = ( 191 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) 192 | ) 193 | 194 | self.register_buffer("betas", th.from_numpy(betas), persistent=False) 195 | self.register_buffer( 196 | "alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False 197 | ) 198 | self.register_buffer( 199 | "alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False 200 | ) 201 | self.register_buffer( 202 | "alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False 203 | ) 204 | 205 | self.register_buffer( 206 | "sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False 207 | ) 208 | self.register_buffer( 209 | "sqrt_one_minus_alphas_cumprod", 210 | th.from_numpy(sqrt_one_minus_alphas_cumprod), 211 | persistent=False, 212 | ) 213 | self.register_buffer( 214 | "log_one_minus_alphas_cumprod", 215 | th.from_numpy(log_one_minus_alphas_cumprod), 216 | persistent=False, 217 | ) 218 | self.register_buffer( 219 | "sqrt_recip_alphas_cumprod", 220 | th.from_numpy(sqrt_recip_alphas_cumprod), 221 | persistent=False, 222 | ) 223 | self.register_buffer( 224 | "sqrt_recipm1_alphas_cumprod", 225 | th.from_numpy(sqrt_recipm1_alphas_cumprod), 226 | persistent=False, 227 | ) 228 | 229 | self.register_buffer( 230 | "posterior_variance", th.from_numpy(posterior_variance), persistent=False 231 | ) 232 | self.register_buffer( 233 | "posterior_log_variance_clipped", 234 | th.from_numpy(posterior_log_variance_clipped), 235 | persistent=False, 236 | ) 237 | self.register_buffer( 238 | "posterior_mean_coef1", 239 | th.from_numpy(posterior_mean_coef1), 240 | persistent=False, 241 | ) 242 | self.register_buffer( 243 | "posterior_mean_coef2", 244 | th.from_numpy(posterior_mean_coef2), 245 | persistent=False, 246 | ) 247 | 248 | def q_mean_variance(self, x_start, t): 249 | """ 250 | Get the distribution q(x_t | x_0). 251 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 252 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 253 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 254 | """ 255 | mean = ( 256 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 257 | ) 258 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 259 | log_variance = _extract_into_tensor( 260 | self.log_one_minus_alphas_cumprod, t, x_start.shape 261 | ) 262 | return mean, variance, log_variance 263 | 264 | def q_sample(self, x_start, t, noise=None): 265 | """ 266 | Diffuse the data for a given number of diffusion steps. 267 | In other words, sample from q(x_t | x_0). 268 | :param x_start: the initial data batch. 269 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 270 | :param noise: if specified, the split-out normal noise. 271 | :return: A noisy version of x_start. 272 | """ 273 | if noise is None: 274 | noise = th.randn_like(x_start) 275 | assert noise.shape == x_start.shape 276 | return ( 277 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 278 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 279 | * noise 280 | ) 281 | 282 | def q_posterior_mean_variance(self, x_start, x_t, t): 283 | """ 284 | Compute the mean and variance of the diffusion posterior: 285 | q(x_{t-1} | x_t, x_0) 286 | """ 287 | assert x_start.shape == x_t.shape 288 | posterior_mean = ( 289 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 290 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 291 | ) 292 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 293 | posterior_log_variance_clipped = _extract_into_tensor( 294 | self.posterior_log_variance_clipped, t, x_t.shape 295 | ) 296 | assert ( 297 | posterior_mean.shape[0] 298 | == posterior_variance.shape[0] 299 | == posterior_log_variance_clipped.shape[0] 300 | == x_start.shape[0] 301 | ) 302 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 303 | 304 | def p_mean_variance( 305 | self, 306 | model, 307 | x, 308 | t, 309 | clip_denoised=True, 310 | denoised_fn=None, 311 | model_kwargs=None, 312 | **ignore_kwargs, 313 | ): 314 | """ 315 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 316 | the initial x, x_0. 317 | :param model: the model, which takes a signal and a batch of timesteps 318 | as input. 319 | :param x: the [N x C x ...] tensor at time t. 320 | :param t: a 1-D Tensor of timesteps. 321 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 322 | :param denoised_fn: if not None, a function which applies to the 323 | x_start prediction before it is used to sample. Applies before 324 | clip_denoised. 325 | :param model_kwargs: if not None, a dict of extra keyword arguments to 326 | pass to the model. This can be used for conditioning. 327 | :return: a dict with the following keys: 328 | - 'mean': the model mean output. 329 | - 'variance': the model variance output. 330 | - 'log_variance': the log of 'variance'. 331 | - 'pred_xstart': the prediction for x_0. 332 | """ 333 | if model_kwargs is None: 334 | model_kwargs = {} 335 | 336 | B, C = x.shape[:2] 337 | assert t.shape == (B,) 338 | model_output = model(x, t, **model_kwargs) 339 | if isinstance(model_output, tuple): 340 | model_output, extra = model_output 341 | else: 342 | extra = None 343 | 344 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 345 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 346 | model_output, model_var_values = th.split(model_output, C, dim=1) 347 | if self.model_var_type == ModelVarType.LEARNED: 348 | model_log_variance = model_var_values 349 | model_variance = th.exp(model_log_variance) 350 | else: 351 | min_log = _extract_into_tensor( 352 | self.posterior_log_variance_clipped, t, x.shape 353 | ) 354 | max_log = _extract_into_tensor(th.log(self.betas), t, x.shape) 355 | # The model_var_values is [-1, 1] for [min_var, max_var]. 356 | frac = (model_var_values + 1) / 2 357 | model_log_variance = frac * max_log + (1 - frac) * min_log 358 | model_variance = th.exp(model_log_variance) 359 | else: 360 | model_variance, model_log_variance = { 361 | # for fixedlarge, we set the initial (log-)variance like so 362 | # to get a better decoder log likelihood. 363 | ModelVarType.FIXED_LARGE: ( 364 | th.cat([self.posterior_variance[1][None], self.betas[1:]]), 365 | th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])), 366 | ), 367 | ModelVarType.FIXED_SMALL: ( 368 | self.posterior_variance, 369 | self.posterior_log_variance_clipped, 370 | ), 371 | }[self.model_var_type] 372 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 373 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 374 | 375 | def process_xstart(x): 376 | if denoised_fn is not None: 377 | x = denoised_fn(x) 378 | if clip_denoised: 379 | return x.clamp(-1, 1) 380 | return x 381 | 382 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 383 | pred_xstart = process_xstart( 384 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 385 | ) 386 | model_mean = model_output 387 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 388 | if self.model_mean_type == ModelMeanType.START_X: 389 | pred_xstart = process_xstart(model_output) 390 | else: 391 | pred_xstart = process_xstart( 392 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 393 | ) 394 | model_mean, _, _ = self.q_posterior_mean_variance( 395 | x_start=pred_xstart, x_t=x, t=t 396 | ) 397 | else: 398 | raise NotImplementedError(self.model_mean_type) 399 | 400 | assert ( 401 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 402 | ) 403 | return { 404 | "mean": model_mean, 405 | "variance": model_variance, 406 | "log_variance": model_log_variance, 407 | "pred_xstart": pred_xstart, 408 | } 409 | 410 | def _predict_xstart_from_eps(self, x_t, t, eps): 411 | assert x_t.shape == eps.shape 412 | return ( 413 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 414 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 415 | ) 416 | 417 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 418 | return ( 419 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 420 | - pred_xstart 421 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 422 | 423 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 424 | """ 425 | Compute the mean for the previous step, given a function cond_fn that 426 | computes the gradient of a conditional log probability with respect to 427 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 428 | condition on y. 429 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 430 | """ 431 | gradient = cond_fn(x, t, **model_kwargs) 432 | new_mean = ( 433 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 434 | ) 435 | return new_mean 436 | 437 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 438 | """ 439 | Compute what the p_mean_variance output would have been, should the 440 | model's score function be conditioned by cond_fn. 441 | See condition_mean() for details on cond_fn. 442 | Unlike condition_mean(), this instead uses the conditioning strategy 443 | from Song et al (2020). 444 | """ 445 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 446 | 447 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 448 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 449 | 450 | out = p_mean_var.copy() 451 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 452 | out["mean"], _, _ = self.q_posterior_mean_variance( 453 | x_start=out["pred_xstart"], x_t=x, t=t 454 | ) 455 | return out 456 | 457 | def p_sample( 458 | self, 459 | model, 460 | x, 461 | t, 462 | clip_denoised=True, 463 | denoised_fn=None, 464 | cond_fn=None, 465 | model_kwargs=None, 466 | ): 467 | """ 468 | Sample x_{t-1} from the model at the given timestep. 469 | :param model: the model to sample from. 470 | :param x: the current tensor at x_{t-1}. 471 | :param t: the value of t, starting at 0 for the first diffusion step. 472 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 473 | :param denoised_fn: if not None, a function which applies to the 474 | x_start prediction before it is used to sample. 475 | :param cond_fn: if not None, this is a gradient function that acts 476 | similarly to the model. 477 | :param model_kwargs: if not None, a dict of extra keyword arguments to 478 | pass to the model. This can be used for conditioning. 479 | :return: a dict containing the following keys: 480 | - 'sample': a random sample from the model. 481 | - 'pred_xstart': a prediction of x_0. 482 | """ 483 | out = self.p_mean_variance( 484 | model, 485 | x, 486 | t, 487 | clip_denoised=clip_denoised, 488 | denoised_fn=denoised_fn, 489 | model_kwargs=model_kwargs, 490 | ) 491 | noise = th.randn_like(x) 492 | nonzero_mask = ( 493 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 494 | ) # no noise when t == 0 495 | if cond_fn is not None: 496 | out["mean"] = self.condition_mean( 497 | cond_fn, out, x, t, model_kwargs=model_kwargs 498 | ) 499 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 500 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 501 | 502 | def p_sample_loop( 503 | self, 504 | model, 505 | shape, 506 | noise=None, 507 | clip_denoised=True, 508 | denoised_fn=None, 509 | cond_fn=None, 510 | model_kwargs=None, 511 | device=None, 512 | progress=False, 513 | ): 514 | """ 515 | Generate samples from the model. 516 | :param model: the model module. 517 | :param shape: the shape of the samples, (N, C, H, W). 518 | :param noise: if specified, the noise from the encoder to sample. 519 | Should be of the same shape as `shape`. 520 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 521 | :param denoised_fn: if not None, a function which applies to the 522 | x_start prediction before it is used to sample. 523 | :param cond_fn: if not None, this is a gradient function that acts 524 | similarly to the model. 525 | :param model_kwargs: if not None, a dict of extra keyword arguments to 526 | pass to the model. This can be used for conditioning. 527 | :param device: if specified, the device to create the samples on. 528 | If not specified, use a model parameter's device. 529 | :param progress: if True, show a tqdm progress bar. 530 | :return: a non-differentiable batch of samples. 531 | """ 532 | final = None 533 | for sample in self.p_sample_loop_progressive( 534 | model, 535 | shape, 536 | noise=noise, 537 | clip_denoised=clip_denoised, 538 | denoised_fn=denoised_fn, 539 | cond_fn=cond_fn, 540 | model_kwargs=model_kwargs, 541 | device=device, 542 | progress=progress, 543 | ): 544 | final = sample 545 | return final["sample"] 546 | 547 | def p_sample_loop_progressive( 548 | self, 549 | model, 550 | shape, 551 | noise=None, 552 | clip_denoised=True, 553 | denoised_fn=None, 554 | cond_fn=None, 555 | model_kwargs=None, 556 | device=None, 557 | progress=False, 558 | ): 559 | """ 560 | Generate samples from the model and yield intermediate samples from 561 | each timestep of diffusion. 562 | Arguments are the same as p_sample_loop(). 563 | Returns a generator over dicts, where each dict is the return value of 564 | p_sample(). 565 | """ 566 | if device is None: 567 | device = next(model.parameters()).device 568 | assert isinstance(shape, (tuple, list)) 569 | if noise is not None: 570 | img = noise 571 | else: 572 | img = th.randn(*shape, device=device) 573 | indices = list(range(self.num_timesteps))[::-1] 574 | 575 | if progress: 576 | # Lazy import so that we don't depend on tqdm. 577 | from tqdm.auto import tqdm 578 | 579 | indices = tqdm(indices) 580 | 581 | for idx, i in enumerate(indices): 582 | t = th.tensor([i] * shape[0], device=device) 583 | with th.no_grad(): 584 | out = self.p_sample( 585 | model, 586 | img, 587 | t, 588 | clip_denoised=clip_denoised, 589 | denoised_fn=denoised_fn, 590 | cond_fn=cond_fn, 591 | model_kwargs=model_kwargs, 592 | ) 593 | yield out 594 | img = out["sample"] 595 | 596 | def p_sample_loop_progressive_for_improved_sr( 597 | self, 598 | model, 599 | model_aux, 600 | shape, 601 | noise=None, 602 | clip_denoised=True, 603 | denoised_fn=None, 604 | cond_fn=None, 605 | model_kwargs=None, 606 | device=None, 607 | progress=False, 608 | ): 609 | """ 610 | Modified version of p_sample_loop_progressive for sampling from the improved sr model 611 | """ 612 | 613 | if device is None: 614 | device = next(model.parameters()).device 615 | assert isinstance(shape, (tuple, list)) 616 | if noise is not None: 617 | img = noise 618 | else: 619 | img = th.randn(*shape, device=device) 620 | indices = list(range(self.num_timesteps))[::-1] 621 | 622 | if progress: 623 | # Lazy import so that we don't depend on tqdm. 624 | from tqdm.auto import tqdm 625 | 626 | indices = tqdm(indices) 627 | 628 | for idx, i in enumerate(indices): 629 | t = th.tensor([i] * shape[0], device=device) 630 | with th.no_grad(): 631 | out = self.p_sample( 632 | model_aux if len(indices) - 1 == idx else model, 633 | img, 634 | t, 635 | clip_denoised=clip_denoised, 636 | denoised_fn=denoised_fn, 637 | cond_fn=cond_fn, 638 | model_kwargs=model_kwargs, 639 | ) 640 | yield out 641 | img = out["sample"] 642 | 643 | def ddim_sample( 644 | self, 645 | model, 646 | x, 647 | t, 648 | clip_denoised=True, 649 | denoised_fn=None, 650 | cond_fn=None, 651 | model_kwargs=None, 652 | eta=0.0, 653 | ): 654 | """ 655 | Sample x_{t-1} from the model using DDIM. 656 | Same usage as p_sample(). 657 | """ 658 | out = self.p_mean_variance( 659 | model, 660 | x, 661 | t, 662 | clip_denoised=clip_denoised, 663 | denoised_fn=denoised_fn, 664 | model_kwargs=model_kwargs, 665 | ) 666 | if cond_fn is not None: 667 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 668 | 669 | # Usually our model outputs epsilon, but we re-derive it 670 | # in case we used x_start or x_prev prediction. 671 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 672 | 673 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 674 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 675 | sigma = ( 676 | eta 677 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 678 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 679 | ) 680 | # Equation 12. 681 | noise = th.randn_like(x) 682 | mean_pred = ( 683 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 684 | + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps 685 | ) 686 | nonzero_mask = ( 687 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 688 | ) # no noise when t == 0 689 | sample = mean_pred + nonzero_mask * sigma * noise 690 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 691 | 692 | def ddim_reverse_sample( 693 | self, 694 | model, 695 | x, 696 | t, 697 | clip_denoised=True, 698 | denoised_fn=None, 699 | cond_fn=None, 700 | model_kwargs=None, 701 | eta=0.0, 702 | ): 703 | """ 704 | Sample x_{t+1} from the model using DDIM reverse ODE. 705 | """ 706 | assert eta == 0.0, "Reverse ODE only for deterministic path" 707 | out = self.p_mean_variance( 708 | model, 709 | x, 710 | t, 711 | clip_denoised=clip_denoised, 712 | denoised_fn=denoised_fn, 713 | model_kwargs=model_kwargs, 714 | ) 715 | if cond_fn is not None: 716 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 717 | # Usually our model outputs epsilon, but we re-derive it 718 | # in case we used x_start or x_prev prediction. 719 | eps = ( 720 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 721 | - out["pred_xstart"] 722 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 723 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 724 | 725 | # Equation 12. reversed 726 | mean_pred = ( 727 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 728 | + th.sqrt(1 - alpha_bar_next) * eps 729 | ) 730 | 731 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 732 | 733 | def ddim_sample_loop( 734 | self, 735 | model, 736 | shape, 737 | noise=None, 738 | clip_denoised=True, 739 | denoised_fn=None, 740 | cond_fn=None, 741 | model_kwargs=None, 742 | device=None, 743 | progress=False, 744 | eta=0.0, 745 | ): 746 | """ 747 | Generate samples from the model using DDIM. 748 | Same usage as p_sample_loop(). 749 | """ 750 | final = None 751 | for sample in self.ddim_sample_loop_progressive( 752 | model, 753 | shape, 754 | noise=noise, 755 | clip_denoised=clip_denoised, 756 | denoised_fn=denoised_fn, 757 | cond_fn=cond_fn, 758 | model_kwargs=model_kwargs, 759 | device=device, 760 | progress=progress, 761 | eta=eta, 762 | ): 763 | final = sample 764 | return final["sample"] 765 | 766 | def ddim_sample_loop_progressive( 767 | self, 768 | model, 769 | shape, 770 | noise=None, 771 | clip_denoised=True, 772 | denoised_fn=None, 773 | cond_fn=None, 774 | model_kwargs=None, 775 | device=None, 776 | progress=False, 777 | eta=0.0, 778 | ): 779 | """ 780 | Use DDIM to sample from the model and yield intermediate samples from 781 | each timestep of DDIM. 782 | Same usage as p_sample_loop_progressive(). 783 | """ 784 | if device is None: 785 | device = next(model.parameters()).device 786 | assert isinstance(shape, (tuple, list)) 787 | if noise is not None: 788 | img = noise 789 | else: 790 | img = th.randn(*shape, device=device) 791 | indices = list(range(self.num_timesteps))[::-1] 792 | 793 | if progress: 794 | # Lazy import so that we don't depend on tqdm. 795 | from tqdm.auto import tqdm 796 | 797 | indices = tqdm(indices) 798 | 799 | for i in indices: 800 | t = th.tensor([i] * shape[0], device=device) 801 | with th.no_grad(): 802 | out = self.ddim_sample( 803 | model, 804 | img, 805 | t, 806 | clip_denoised=clip_denoised, 807 | denoised_fn=denoised_fn, 808 | cond_fn=cond_fn, 809 | model_kwargs=model_kwargs, 810 | eta=eta, 811 | ) 812 | yield out 813 | img = out["sample"] 814 | 815 | 816 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 817 | """ 818 | Extract values from a 1-D numpy array for a batch of indices. 819 | :param arr: the 1-D numpy array. 820 | :param timesteps: a tensor of indices into the array to extract. 821 | :param broadcast_shape: a larger shape of K dimensions with the batch 822 | dimension equal to the length of timesteps. 823 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 824 | """ 825 | res = arr.to(device=timesteps.device)[timesteps].float() 826 | while len(res.shape) < len(broadcast_shape): 827 | res = res[..., None] 828 | return res + th.zeros(broadcast_shape, device=timesteps.device) 829 | -------------------------------------------------------------------------------- /karlo/modules/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | 6 | import torch as th 7 | 8 | from .gaussian_diffusion import GaussianDiffusion 9 | 10 | 11 | def space_timesteps(num_timesteps, section_counts): 12 | """ 13 | Create a list of timesteps to use from an original diffusion process, 14 | given the number of timesteps we want to take from equally-sized portions 15 | of the original process. 16 | 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim") :]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | elif section_counts == "fast27": 40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2") 41 | # Help reduce DDIM artifacts from noisiest timesteps. 42 | steps.remove(num_timesteps - 1) 43 | steps.add(num_timesteps - 3) 44 | return steps 45 | section_counts = [int(x) for x in section_counts.split(",")] 46 | size_per = num_timesteps // len(section_counts) 47 | extra = num_timesteps % len(section_counts) 48 | start_idx = 0 49 | all_steps = [] 50 | for i, section_count in enumerate(section_counts): 51 | size = size_per + (1 if i < extra else 0) 52 | if size < section_count: 53 | raise ValueError( 54 | f"cannot divide section of {size} steps into {section_count}" 55 | ) 56 | if section_count <= 1: 57 | frac_stride = 1 58 | else: 59 | frac_stride = (size - 1) / (section_count - 1) 60 | cur_idx = 0.0 61 | taken_steps = [] 62 | for _ in range(section_count): 63 | taken_steps.append(start_idx + round(cur_idx)) 64 | cur_idx += frac_stride 65 | all_steps += taken_steps 66 | start_idx += size 67 | return set(all_steps) 68 | 69 | 70 | class SpacedDiffusion(GaussianDiffusion): 71 | """ 72 | A diffusion process which can skip steps in a base diffusion process. 73 | 74 | :param use_timesteps: a collection (sequence or set) of timesteps from the 75 | original diffusion process to retain. 76 | :param kwargs: the kwargs to create the base diffusion process. 77 | """ 78 | 79 | def __init__(self, use_timesteps, **kwargs): 80 | self.use_timesteps = set(use_timesteps) 81 | self.original_num_steps = len(kwargs["betas"]) 82 | 83 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 84 | last_alpha_cumprod = 1.0 85 | new_betas = [] 86 | timestep_map = [] 87 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 88 | if i in self.use_timesteps: 89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 90 | last_alpha_cumprod = alpha_cumprod 91 | timestep_map.append(i) 92 | kwargs["betas"] = th.tensor(new_betas).numpy() 93 | super().__init__(**kwargs) 94 | self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False) 95 | 96 | def p_mean_variance(self, model, *args, **kwargs): 97 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | def wrapped(x, ts, **kwargs): 107 | ts_cpu = ts.detach().to("cpu") 108 | return model( 109 | x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs 110 | ) 111 | 112 | return wrapped 113 | -------------------------------------------------------------------------------- /karlo/modules/nn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class GroupNorm32(nn.GroupNorm): 13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5): 14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) 15 | self.swish = swish 16 | 17 | def forward(self, x): 18 | y = super().forward(x.float()).to(x.dtype) 19 | if self.swish == 1.0: 20 | y = F.silu(y) 21 | elif self.swish: 22 | y = y * F.sigmoid(y * float(self.swish)) 23 | return y 24 | 25 | 26 | def conv_nd(dims, *args, **kwargs): 27 | """ 28 | Create a 1D, 2D, or 3D convolution module. 29 | """ 30 | if dims == 1: 31 | return nn.Conv1d(*args, **kwargs) 32 | elif dims == 2: 33 | return nn.Conv2d(*args, **kwargs) 34 | elif dims == 3: 35 | return nn.Conv3d(*args, **kwargs) 36 | raise ValueError(f"unsupported dimensions: {dims}") 37 | 38 | 39 | def linear(*args, **kwargs): 40 | """ 41 | Create a linear module. 42 | """ 43 | return nn.Linear(*args, **kwargs) 44 | 45 | 46 | def avg_pool_nd(dims, *args, **kwargs): 47 | """ 48 | Create a 1D, 2D, or 3D average pooling module. 49 | """ 50 | if dims == 1: 51 | return nn.AvgPool1d(*args, **kwargs) 52 | elif dims == 2: 53 | return nn.AvgPool2d(*args, **kwargs) 54 | elif dims == 3: 55 | return nn.AvgPool3d(*args, **kwargs) 56 | raise ValueError(f"unsupported dimensions: {dims}") 57 | 58 | 59 | def zero_module(module): 60 | """ 61 | Zero out the parameters of a module and return it. 62 | """ 63 | for p in module.parameters(): 64 | p.detach().zero_() 65 | return module 66 | 67 | 68 | def scale_module(module, scale): 69 | """ 70 | Scale the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().mul_(scale) 74 | return module 75 | 76 | 77 | def normalization(channels, swish=0.0): 78 | """ 79 | Make a standard normalization layer, with an optional swish activation. 80 | 81 | :param channels: number of input channels. 82 | :return: an nn.Module for normalization. 83 | """ 84 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) 85 | 86 | 87 | def timestep_embedding(timesteps, dim, max_period=10000): 88 | """ 89 | Create sinusoidal timestep embeddings. 90 | 91 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 92 | These may be fractional. 93 | :param dim: the dimension of the output. 94 | :param max_period: controls the minimum frequency of the embeddings. 95 | :return: an [N x dim] Tensor of positional embeddings. 96 | """ 97 | half = dim // 2 98 | freqs = th.exp( 99 | -math.log(max_period) 100 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) 101 | / half 102 | ) 103 | args = timesteps[:, None].float() * freqs[None] 104 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 105 | if dim % 2: 106 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 107 | return embedding 108 | 109 | 110 | def mean_flat(tensor): 111 | """ 112 | Take the mean over all non-batch dimensions. 113 | """ 114 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 115 | -------------------------------------------------------------------------------- /karlo/modules/resample.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | from abc import abstractmethod 6 | 7 | import torch as th 8 | 9 | 10 | def create_named_schedule_sampler(name, diffusion): 11 | """ 12 | Create a ScheduleSampler from a library of pre-defined samplers. 13 | 14 | :param name: the name of the sampler. 15 | :param diffusion: the diffusion object to sample for. 16 | """ 17 | if name == "uniform": 18 | return UniformSampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(th.nn.Module): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / th.sum(w) 54 | indices = p.multinomial(batch_size, replacement=True) 55 | weights = 1 / (len(p) * p[indices]) 56 | return indices, weights 57 | 58 | 59 | class UniformSampler(ScheduleSampler): 60 | def __init__(self, diffusion): 61 | super(UniformSampler, self).__init__() 62 | self.diffusion = diffusion 63 | self.register_buffer( 64 | "_weights", th.ones([diffusion.num_timesteps]), persistent=False 65 | ) 66 | 67 | def weights(self): 68 | return self._weights 69 | -------------------------------------------------------------------------------- /karlo/modules/unet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) 3 | # ------------------------------------------------------------------------------------ 4 | 5 | import math 6 | from abc import abstractmethod 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .nn import ( 13 | avg_pool_nd, 14 | conv_nd, 15 | linear, 16 | normalization, 17 | timestep_embedding, 18 | zero_module, 19 | ) 20 | from .xf import LayerNorm 21 | 22 | 23 | class TimestepBlock(nn.Module): 24 | """ 25 | Any module where forward() takes timestep embeddings as a second argument. 26 | """ 27 | 28 | @abstractmethod 29 | def forward(self, x, emb): 30 | """ 31 | Apply the module to `x` given `emb` timestep embeddings. 32 | """ 33 | 34 | 35 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 36 | """ 37 | A sequential module that passes timestep embeddings to the children that 38 | support it as an extra input. 39 | """ 40 | 41 | def forward(self, x, emb, encoder_out=None, mask=None): 42 | for layer in self: 43 | if isinstance(layer, TimestepBlock): 44 | x = layer(x, emb) 45 | elif isinstance(layer, AttentionBlock): 46 | x = layer(x, encoder_out, mask=mask) 47 | else: 48 | x = layer(x) 49 | return x 50 | 51 | 52 | class Upsample(nn.Module): 53 | """ 54 | An upsampling layer with an optional convolution. 55 | 56 | :param channels: channels in the inputs and outputs. 57 | :param use_conv: a bool determining if a convolution is applied. 58 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 59 | upsampling occurs in the inner-two dimensions. 60 | """ 61 | 62 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 63 | super().__init__() 64 | self.channels = channels 65 | self.out_channels = out_channels or channels 66 | self.use_conv = use_conv 67 | self.dims = dims 68 | if use_conv: 69 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 70 | 71 | def forward(self, x): 72 | assert x.shape[1] == self.channels 73 | if self.dims == 3: 74 | x = F.interpolate( 75 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 76 | ) 77 | else: 78 | x = F.interpolate(x, scale_factor=2, mode="nearest") 79 | if self.use_conv: 80 | x = self.conv(x) 81 | return x 82 | 83 | 84 | class Downsample(nn.Module): 85 | """ 86 | A downsampling layer with an optional convolution. 87 | 88 | :param channels: channels in the inputs and outputs. 89 | :param use_conv: a bool determining if a convolution is applied. 90 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 91 | downsampling occurs in the inner-two dimensions. 92 | """ 93 | 94 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 95 | super().__init__() 96 | self.channels = channels 97 | self.out_channels = out_channels or channels 98 | self.use_conv = use_conv 99 | self.dims = dims 100 | stride = 2 if dims != 3 else (1, 2, 2) 101 | if use_conv: 102 | self.op = conv_nd( 103 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 104 | ) 105 | else: 106 | assert self.channels == self.out_channels 107 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 108 | 109 | def forward(self, x): 110 | assert x.shape[1] == self.channels 111 | return self.op(x) 112 | 113 | 114 | class ResBlock(TimestepBlock): 115 | """ 116 | A residual block that can optionally change the number of channels. 117 | 118 | :param channels: the number of input channels. 119 | :param emb_channels: the number of timestep embedding channels. 120 | :param dropout: the rate of dropout. 121 | :param out_channels: if specified, the number of out channels. 122 | :param use_conv: if True and out_channels is specified, use a spatial 123 | convolution instead of a smaller 1x1 convolution to change the 124 | channels in the skip connection. 125 | :param dims: determines if the signal is 1D, 2D, or 3D. 126 | :param use_checkpoint: if True, use gradient checkpointing on this module. 127 | :param up: if True, use this block for upsampling. 128 | :param down: if True, use this block for downsampling. 129 | """ 130 | 131 | def __init__( 132 | self, 133 | channels, 134 | emb_channels, 135 | dropout, 136 | out_channels=None, 137 | use_conv=False, 138 | use_scale_shift_norm=False, 139 | dims=2, 140 | use_checkpoint=False, 141 | up=False, 142 | down=False, 143 | ): 144 | super().__init__() 145 | self.channels = channels 146 | self.emb_channels = emb_channels 147 | self.dropout = dropout 148 | self.out_channels = out_channels or channels 149 | self.use_conv = use_conv 150 | self.use_checkpoint = use_checkpoint 151 | self.use_scale_shift_norm = use_scale_shift_norm 152 | 153 | self.in_layers = nn.Sequential( 154 | normalization(channels, swish=1.0), 155 | nn.Identity(), 156 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 157 | ) 158 | 159 | self.updown = up or down 160 | 161 | if up: 162 | self.h_upd = Upsample(channels, False, dims) 163 | self.x_upd = Upsample(channels, False, dims) 164 | elif down: 165 | self.h_upd = Downsample(channels, False, dims) 166 | self.x_upd = Downsample(channels, False, dims) 167 | else: 168 | self.h_upd = self.x_upd = nn.Identity() 169 | 170 | self.emb_layers = nn.Sequential( 171 | nn.SiLU(), 172 | linear( 173 | emb_channels, 174 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 175 | ), 176 | ) 177 | self.out_layers = nn.Sequential( 178 | normalization( 179 | self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0 180 | ), 181 | nn.SiLU() if use_scale_shift_norm else nn.Identity(), 182 | nn.Dropout(p=dropout), 183 | zero_module( 184 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 185 | ), 186 | ) 187 | 188 | if self.out_channels == channels: 189 | self.skip_connection = nn.Identity() 190 | elif use_conv: 191 | self.skip_connection = conv_nd( 192 | dims, channels, self.out_channels, 3, padding=1 193 | ) 194 | else: 195 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 196 | 197 | def forward(self, x, emb): 198 | """ 199 | Apply the block to a Tensor, conditioned on a timestep embedding. 200 | 201 | :param x: an [N x C x ...] Tensor of features. 202 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 203 | :return: an [N x C x ...] Tensor of outputs. 204 | """ 205 | if self.updown: 206 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 207 | h = in_rest(x) 208 | h = self.h_upd(h) 209 | x = self.x_upd(x) 210 | h = in_conv(h) 211 | else: 212 | h = self.in_layers(x) 213 | emb_out = self.emb_layers(emb) 214 | while len(emb_out.shape) < len(h.shape): 215 | emb_out = emb_out[..., None] 216 | if self.use_scale_shift_norm: 217 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 218 | scale, shift = th.chunk(emb_out, 2, dim=1) 219 | h = out_norm(h) * (1 + scale) + shift 220 | h = out_rest(h) 221 | else: 222 | h = h + emb_out 223 | h = self.out_layers(h) 224 | return self.skip_connection(x) + h 225 | 226 | 227 | class ResBlockNoTimeEmbedding(nn.Module): 228 | """ 229 | A residual block without time embedding 230 | 231 | :param channels: the number of input channels. 232 | :param emb_channels: the number of timestep embedding channels. 233 | :param dropout: the rate of dropout. 234 | :param out_channels: if specified, the number of out channels. 235 | :param use_conv: if True and out_channels is specified, use a spatial 236 | convolution instead of a smaller 1x1 convolution to change the 237 | channels in the skip connection. 238 | :param dims: determines if the signal is 1D, 2D, or 3D. 239 | :param use_checkpoint: if True, use gradient checkpointing on this module. 240 | :param up: if True, use this block for upsampling. 241 | :param down: if True, use this block for downsampling. 242 | """ 243 | 244 | def __init__( 245 | self, 246 | channels, 247 | emb_channels, 248 | dropout, 249 | out_channels=None, 250 | use_conv=False, 251 | dims=2, 252 | use_checkpoint=False, 253 | up=False, 254 | down=False, 255 | **kwargs, 256 | ): 257 | super().__init__() 258 | self.channels = channels 259 | self.emb_channels = emb_channels 260 | self.dropout = dropout 261 | self.out_channels = out_channels or channels 262 | self.use_conv = use_conv 263 | self.use_checkpoint = use_checkpoint 264 | 265 | self.in_layers = nn.Sequential( 266 | normalization(channels, swish=1.0), 267 | nn.Identity(), 268 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 269 | ) 270 | 271 | self.updown = up or down 272 | 273 | if up: 274 | self.h_upd = Upsample(channels, False, dims) 275 | self.x_upd = Upsample(channels, False, dims) 276 | elif down: 277 | self.h_upd = Downsample(channels, False, dims) 278 | self.x_upd = Downsample(channels, False, dims) 279 | else: 280 | self.h_upd = self.x_upd = nn.Identity() 281 | 282 | self.out_layers = nn.Sequential( 283 | normalization(self.out_channels, swish=1.0), 284 | nn.Dropout(p=dropout), 285 | zero_module( 286 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 287 | ), 288 | ) 289 | 290 | if self.out_channels == channels: 291 | self.skip_connection = nn.Identity() 292 | elif use_conv: 293 | self.skip_connection = conv_nd( 294 | dims, channels, self.out_channels, 3, padding=1 295 | ) 296 | else: 297 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 298 | 299 | def forward(self, x, emb=None): 300 | """ 301 | Apply the block to a Tensor, NOT conditioned on a timestep embedding. 302 | 303 | :param x: an [N x C x ...] Tensor of features. 304 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 305 | :return: an [N x C x ...] Tensor of outputs. 306 | """ 307 | assert emb is None 308 | 309 | if self.updown: 310 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 311 | h = in_rest(x) 312 | h = self.h_upd(h) 313 | x = self.x_upd(x) 314 | h = in_conv(h) 315 | else: 316 | h = self.in_layers(x) 317 | h = self.out_layers(h) 318 | return self.skip_connection(x) + h 319 | 320 | 321 | class AttentionBlock(nn.Module): 322 | """ 323 | An attention block that allows spatial positions to attend to each other. 324 | 325 | Originally ported from here, but adapted to the N-d case. 326 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 327 | """ 328 | 329 | def __init__( 330 | self, 331 | channels, 332 | num_heads=1, 333 | num_head_channels=-1, 334 | use_checkpoint=False, 335 | encoder_channels=None, 336 | ): 337 | super().__init__() 338 | self.channels = channels 339 | if num_head_channels == -1: 340 | self.num_heads = num_heads 341 | else: 342 | assert ( 343 | channels % num_head_channels == 0 344 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 345 | self.num_heads = channels // num_head_channels 346 | self.use_checkpoint = use_checkpoint 347 | self.norm = normalization(channels, swish=0.0) 348 | self.qkv = conv_nd(1, channels, channels * 3, 1) 349 | self.attention = QKVAttention(self.num_heads) 350 | 351 | if encoder_channels is not None: 352 | self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) 353 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 354 | 355 | def forward(self, x, encoder_out=None, mask=None): 356 | b, c, *spatial = x.shape 357 | qkv = self.qkv(self.norm(x).view(b, c, -1)) 358 | if encoder_out is not None: 359 | encoder_out = self.encoder_kv(encoder_out) 360 | h = self.attention(qkv, encoder_out, mask=mask) 361 | else: 362 | h = self.attention(qkv) 363 | h = self.proj_out(h) 364 | return x + h.reshape(b, c, *spatial) 365 | 366 | 367 | class QKVAttention(nn.Module): 368 | """ 369 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 370 | """ 371 | 372 | def __init__(self, n_heads): 373 | super().__init__() 374 | self.n_heads = n_heads 375 | 376 | def forward(self, qkv, encoder_kv=None, mask=None): 377 | """ 378 | Apply QKV attention. 379 | 380 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 381 | :return: an [N x (H * C) x T] tensor after attention. 382 | """ 383 | bs, width, length = qkv.shape 384 | assert width % (3 * self.n_heads) == 0 385 | ch = width // (3 * self.n_heads) 386 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 387 | if encoder_kv is not None: 388 | assert encoder_kv.shape[1] == self.n_heads * ch * 2 389 | ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) 390 | k = th.cat([ek, k], dim=-1) 391 | v = th.cat([ev, v], dim=-1) 392 | scale = 1 / math.sqrt(math.sqrt(ch)) 393 | weight = th.einsum("bct,bcs->bts", q * scale, k * scale) 394 | if mask is not None: 395 | mask = F.pad(mask, (0, length), value=0.0) 396 | mask = ( 397 | mask.unsqueeze(1) 398 | .expand(-1, self.n_heads, -1) 399 | .reshape(bs * self.n_heads, 1, -1) 400 | ) 401 | weight = weight + mask 402 | weight = th.softmax(weight, dim=-1) 403 | a = th.einsum("bts,bcs->bct", weight, v) 404 | return a.reshape(bs, -1, length) 405 | 406 | 407 | class UNetModel(nn.Module): 408 | """ 409 | The full UNet model with attention and timestep embedding. 410 | 411 | :param in_channels: channels in the input Tensor. 412 | :param model_channels: base channel count for the model. 413 | :param out_channels: channels in the output Tensor. 414 | :param num_res_blocks: number of residual blocks per downsample. 415 | :param attention_resolutions: a collection of downsample rates at which 416 | attention will take place. May be a set, list, or tuple. 417 | For example, if this contains 4, then at 4x downsampling, attention 418 | will be used. 419 | :param dropout: the dropout probability. 420 | :param channel_mult: channel multiplier for each level of the UNet. 421 | :param conv_resample: if True, use learned convolutions for upsampling and 422 | downsampling. 423 | :param dims: determines if the signal is 1D, 2D, or 3D. 424 | :param clip_dim: dimension of clip feature. 425 | :param num_classes: if specified (as an int), then this model will be 426 | class-conditional with `num_classes` classes. 427 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 428 | :param num_heads: the number of attention heads in each attention layer. 429 | :param num_heads_channels: if specified, ignore num_heads and instead use 430 | a fixed channel width per attention head. 431 | :param num_heads_upsample: works with num_heads to set a different number 432 | of heads for upsampling. Deprecated. 433 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 434 | :param resblock_updown: use residual blocks for up/downsampling. 435 | :param encoder_channels: use to make the dimension of query and kv same in AttentionBlock. 436 | :param use_time_embedding: use time embedding for condition. 437 | """ 438 | 439 | def __init__( 440 | self, 441 | in_channels, 442 | model_channels, 443 | out_channels, 444 | num_res_blocks, 445 | attention_resolutions, 446 | dropout=0, 447 | channel_mult=(1, 2, 4, 8), 448 | conv_resample=True, 449 | dims=2, 450 | clip_dim=None, 451 | use_checkpoint=False, 452 | num_heads=1, 453 | num_head_channels=-1, 454 | num_heads_upsample=-1, 455 | use_scale_shift_norm=False, 456 | use_middle_attention=True, 457 | resblock_updown=False, 458 | encoder_channels=None, 459 | use_time_embedding=True, 460 | ): 461 | super().__init__() 462 | 463 | if num_heads_upsample == -1: 464 | num_heads_upsample = num_heads 465 | 466 | self.in_channels = in_channels 467 | self.model_channels = model_channels 468 | self.out_channels = out_channels 469 | self.num_res_blocks = num_res_blocks 470 | self.attention_resolutions = attention_resolutions 471 | self.dropout = dropout 472 | self.channel_mult = channel_mult 473 | self.conv_resample = conv_resample 474 | self.clip_dim = clip_dim 475 | self.use_checkpoint = use_checkpoint 476 | self.num_heads = num_heads 477 | self.num_head_channels = num_head_channels 478 | self.num_heads_upsample = num_heads_upsample 479 | self.use_middle_attention = use_middle_attention 480 | self.use_time_embedding = use_time_embedding 481 | 482 | if self.use_time_embedding: 483 | time_embed_dim = model_channels * 4 484 | self.time_embed = nn.Sequential( 485 | linear(model_channels, time_embed_dim), 486 | nn.SiLU(), 487 | linear(time_embed_dim, time_embed_dim), 488 | ) 489 | 490 | if self.clip_dim is not None: 491 | self.clip_emb = nn.Linear(clip_dim, time_embed_dim) 492 | else: 493 | time_embed_dim = None 494 | 495 | CustomResidualBlock = ( 496 | ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding 497 | ) 498 | ch = input_ch = int(channel_mult[0] * model_channels) 499 | self.input_blocks = nn.ModuleList( 500 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 501 | ) 502 | self._feature_size = ch 503 | input_block_chans = [ch] 504 | ds = 1 505 | for level, mult in enumerate(channel_mult): 506 | for _ in range(num_res_blocks): 507 | layers = [ 508 | CustomResidualBlock( 509 | ch, 510 | time_embed_dim, 511 | dropout, 512 | out_channels=int(mult * model_channels), 513 | dims=dims, 514 | use_checkpoint=use_checkpoint, 515 | use_scale_shift_norm=use_scale_shift_norm, 516 | ) 517 | ] 518 | ch = int(mult * model_channels) 519 | if ds in attention_resolutions: 520 | layers.append( 521 | AttentionBlock( 522 | ch, 523 | use_checkpoint=use_checkpoint, 524 | num_heads=num_heads, 525 | num_head_channels=num_head_channels, 526 | encoder_channels=encoder_channels, 527 | ) 528 | ) 529 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 530 | self._feature_size += ch 531 | input_block_chans.append(ch) 532 | if level != len(channel_mult) - 1: 533 | out_ch = ch 534 | self.input_blocks.append( 535 | TimestepEmbedSequential( 536 | CustomResidualBlock( 537 | ch, 538 | time_embed_dim, 539 | dropout, 540 | out_channels=out_ch, 541 | dims=dims, 542 | use_checkpoint=use_checkpoint, 543 | use_scale_shift_norm=use_scale_shift_norm, 544 | down=True, 545 | ) 546 | if resblock_updown 547 | else Downsample( 548 | ch, conv_resample, dims=dims, out_channels=out_ch 549 | ) 550 | ) 551 | ) 552 | ch = out_ch 553 | input_block_chans.append(ch) 554 | ds *= 2 555 | self._feature_size += ch 556 | 557 | self.middle_block = TimestepEmbedSequential( 558 | CustomResidualBlock( 559 | ch, 560 | time_embed_dim, 561 | dropout, 562 | dims=dims, 563 | use_checkpoint=use_checkpoint, 564 | use_scale_shift_norm=use_scale_shift_norm, 565 | ), 566 | *( 567 | AttentionBlock( 568 | ch, 569 | use_checkpoint=use_checkpoint, 570 | num_heads=num_heads, 571 | num_head_channels=num_head_channels, 572 | encoder_channels=encoder_channels, 573 | ), 574 | ) 575 | if self.use_middle_attention 576 | else tuple(), # add AttentionBlock or not 577 | CustomResidualBlock( 578 | ch, 579 | time_embed_dim, 580 | dropout, 581 | dims=dims, 582 | use_checkpoint=use_checkpoint, 583 | use_scale_shift_norm=use_scale_shift_norm, 584 | ), 585 | ) 586 | self._feature_size += ch 587 | 588 | self.output_blocks = nn.ModuleList([]) 589 | for level, mult in list(enumerate(channel_mult))[::-1]: 590 | for i in range(num_res_blocks + 1): 591 | ich = input_block_chans.pop() 592 | layers = [ 593 | CustomResidualBlock( 594 | ch + ich, 595 | time_embed_dim, 596 | dropout, 597 | out_channels=int(model_channels * mult), 598 | dims=dims, 599 | use_checkpoint=use_checkpoint, 600 | use_scale_shift_norm=use_scale_shift_norm, 601 | ) 602 | ] 603 | ch = int(model_channels * mult) 604 | if ds in attention_resolutions: 605 | layers.append( 606 | AttentionBlock( 607 | ch, 608 | use_checkpoint=use_checkpoint, 609 | num_heads=num_heads_upsample, 610 | num_head_channels=num_head_channels, 611 | encoder_channels=encoder_channels, 612 | ) 613 | ) 614 | if level and i == num_res_blocks: 615 | out_ch = ch 616 | layers.append( 617 | CustomResidualBlock( 618 | ch, 619 | time_embed_dim, 620 | dropout, 621 | out_channels=out_ch, 622 | dims=dims, 623 | use_checkpoint=use_checkpoint, 624 | use_scale_shift_norm=use_scale_shift_norm, 625 | up=True, 626 | ) 627 | if resblock_updown 628 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 629 | ) 630 | ds //= 2 631 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 632 | self._feature_size += ch 633 | 634 | self.out = nn.Sequential( 635 | normalization(ch, swish=1.0), 636 | nn.Identity(), 637 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 638 | ) 639 | 640 | def forward(self, x, timesteps, y=None): 641 | """ 642 | Apply the model to an input batch. 643 | 644 | :param x: an [N x C x ...] Tensor of inputs. 645 | :param timesteps: a 1-D batch of timesteps. 646 | :param y: an [N] Tensor of labels, if class-conditional. 647 | :return: an [N x C x ...] Tensor of outputs. 648 | """ 649 | assert (y is not None) == ( 650 | self.clip_dim is not None 651 | ), "must specify y if and only if the model is clip-rep-conditional" 652 | 653 | hs = [] 654 | if self.use_time_embedding: 655 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 656 | if self.clip_dim is not None: 657 | emb = emb + self.clip_emb(y) 658 | else: 659 | emb = None 660 | 661 | h = x 662 | for module in self.input_blocks: 663 | h = module(h, emb) 664 | hs.append(h) 665 | h = self.middle_block(h, emb) 666 | for module in self.output_blocks: 667 | h = th.cat([h, hs.pop()], dim=1) 668 | h = module(h, emb) 669 | 670 | return self.out(h) 671 | 672 | 673 | class SuperResUNetModel(UNetModel): 674 | """ 675 | A UNetModel that performs super-resolution. 676 | 677 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 678 | Assumes that the shape of low-resolution and the input should be the same. 679 | """ 680 | 681 | def __init__(self, *args, **kwargs): 682 | if "in_channels" in kwargs: 683 | kwargs = dict(kwargs) 684 | kwargs["in_channels"] = kwargs["in_channels"] * 2 685 | else: 686 | # Curse you, Python. Or really, just curse positional arguments :|. 687 | args = list(args) 688 | args[1] = args[1] * 2 689 | super().__init__(*args, **kwargs) 690 | 691 | def forward(self, x, timesteps, low_res=None, **kwargs): 692 | _, _, new_height, new_width = x.shape 693 | assert new_height == low_res.shape[2] and new_width == low_res.shape[3] 694 | 695 | x = th.cat([x, low_res], dim=1) 696 | return super().forward(x, timesteps, **kwargs) 697 | 698 | 699 | class PLMImUNet(UNetModel): 700 | """ 701 | A UNetModel that conditions on text with a pretrained text encoder in CLIP. 702 | 703 | :param text_ctx: number of text tokens to expect. 704 | :param xf_width: width of the transformer. 705 | :param clip_emb_mult: #extra tokens by projecting clip text feature. 706 | :param clip_emb_type: type of condition (here, we fix clip image feature). 707 | :param clip_emb_drop: dropout rato of clip image feature for cfg. 708 | """ 709 | 710 | def __init__( 711 | self, 712 | text_ctx, 713 | xf_width, 714 | *args, 715 | clip_emb_mult=None, 716 | clip_emb_type="image", 717 | clip_emb_drop=0.0, 718 | **kwargs, 719 | ): 720 | self.text_ctx = text_ctx 721 | self.xf_width = xf_width 722 | self.clip_emb_mult = clip_emb_mult 723 | self.clip_emb_type = clip_emb_type 724 | self.clip_emb_drop = clip_emb_drop 725 | 726 | if not xf_width: 727 | super().__init__(*args, **kwargs, encoder_channels=None) 728 | else: 729 | super().__init__(*args, **kwargs, encoder_channels=xf_width) 730 | 731 | # Project text encoded feat seq from pre-trained text encoder in CLIP 732 | self.text_seq_proj = nn.Sequential( 733 | nn.Linear(self.clip_dim, xf_width), 734 | LayerNorm(xf_width), 735 | ) 736 | # Project CLIP text feat 737 | self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4) 738 | 739 | assert clip_emb_mult is not None 740 | assert clip_emb_type == "image" 741 | assert self.clip_dim is not None, "CLIP representation dim should be specified" 742 | 743 | self.clip_tok_proj = nn.Linear( 744 | self.clip_dim, self.xf_width * self.clip_emb_mult 745 | ) 746 | if self.clip_emb_drop > 0: 747 | self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32)) 748 | 749 | def proc_clip_emb_drop(self, feat): 750 | if self.clip_emb_drop > 0: 751 | bsz, feat_dim = feat.shape 752 | assert ( 753 | feat_dim == self.clip_dim 754 | ), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}" 755 | drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop 756 | feat = th.where( 757 | drop_idx[..., None], self.cf_param[None].type_as(feat), feat 758 | ) 759 | return feat 760 | 761 | def forward( 762 | self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None 763 | ): 764 | bsz = x.shape[0] 765 | hs = [] 766 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 767 | emb = emb + self.clip_emb(y) 768 | 769 | xf_out = self.text_seq_proj(txt_feat_seq) 770 | xf_out = xf_out.permute(0, 2, 1) 771 | emb = emb + self.text_feat_proj(txt_feat) 772 | xf_out = th.cat( 773 | [ 774 | self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult), 775 | xf_out, 776 | ], 777 | dim=2, 778 | ) 779 | mask = F.pad(mask, (self.clip_emb_mult, 0), value=True) 780 | mask = th.where(mask, 0.0, float("-inf")) 781 | 782 | h = x 783 | for module in self.input_blocks: 784 | h = module(h, emb, xf_out, mask=mask) 785 | hs.append(h) 786 | h = self.middle_block(h, emb, xf_out, mask=mask) 787 | for module in self.output_blocks: 788 | h = th.cat([h, hs.pop()], dim=1) 789 | h = module(h, emb, xf_out, mask=mask) 790 | h = self.out(h) 791 | 792 | return h 793 | -------------------------------------------------------------------------------- /karlo/modules/xf.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Adapted from the repos below: 3 | # (a) Guided-Diffusion (https://github.com/openai/guided-diffusion) 4 | # (b) CLIP ViT (https://github.com/openai/CLIP/) 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import math 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .nn import timestep_embedding 14 | 15 | 16 | def convert_module_to_f16(param): 17 | """ 18 | Convert primitive modules to float16. 19 | """ 20 | if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 21 | param.weight.data = param.weight.data.half() 22 | if param.bias is not None: 23 | param.bias.data = param.bias.data.half() 24 | 25 | 26 | class LayerNorm(nn.LayerNorm): 27 | """ 28 | Implementation that supports fp16 inputs but fp32 gains/biases. 29 | """ 30 | 31 | def forward(self, x: th.Tensor): 32 | return super().forward(x.float()).to(x.dtype) 33 | 34 | 35 | class MultiheadAttention(nn.Module): 36 | def __init__(self, n_ctx, width, heads): 37 | super().__init__() 38 | self.n_ctx = n_ctx 39 | self.width = width 40 | self.heads = heads 41 | self.c_qkv = nn.Linear(width, width * 3) 42 | self.c_proj = nn.Linear(width, width) 43 | self.attention = QKVMultiheadAttention(heads, n_ctx) 44 | 45 | def forward(self, x, mask=None): 46 | x = self.c_qkv(x) 47 | x = self.attention(x, mask=mask) 48 | x = self.c_proj(x) 49 | return x 50 | 51 | 52 | class MLP(nn.Module): 53 | def __init__(self, width): 54 | super().__init__() 55 | self.width = width 56 | self.c_fc = nn.Linear(width, width * 4) 57 | self.c_proj = nn.Linear(width * 4, width) 58 | self.gelu = nn.GELU() 59 | 60 | def forward(self, x): 61 | return self.c_proj(self.gelu(self.c_fc(x))) 62 | 63 | 64 | class QKVMultiheadAttention(nn.Module): 65 | def __init__(self, n_heads: int, n_ctx: int): 66 | super().__init__() 67 | self.n_heads = n_heads 68 | self.n_ctx = n_ctx 69 | 70 | def forward(self, qkv, mask=None): 71 | bs, n_ctx, width = qkv.shape 72 | attn_ch = width // self.n_heads // 3 73 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 74 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 75 | q, k, v = th.split(qkv, attn_ch, dim=-1) 76 | weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale) 77 | wdtype = weight.dtype 78 | if mask is not None: 79 | weight = weight + mask[:, None, ...] 80 | weight = th.softmax(weight, dim=-1).type(wdtype) 81 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 82 | 83 | 84 | class ResidualAttentionBlock(nn.Module): 85 | def __init__( 86 | self, 87 | n_ctx: int, 88 | width: int, 89 | heads: int, 90 | ): 91 | super().__init__() 92 | 93 | self.attn = MultiheadAttention( 94 | n_ctx, 95 | width, 96 | heads, 97 | ) 98 | self.ln_1 = LayerNorm(width) 99 | self.mlp = MLP(width) 100 | self.ln_2 = LayerNorm(width) 101 | 102 | def forward(self, x, mask=None): 103 | x = x + self.attn(self.ln_1(x), mask=mask) 104 | x = x + self.mlp(self.ln_2(x)) 105 | return x 106 | 107 | 108 | class Transformer(nn.Module): 109 | def __init__( 110 | self, 111 | n_ctx: int, 112 | width: int, 113 | layers: int, 114 | heads: int, 115 | ): 116 | super().__init__() 117 | self.n_ctx = n_ctx 118 | self.width = width 119 | self.layers = layers 120 | self.resblocks = nn.ModuleList( 121 | [ 122 | ResidualAttentionBlock( 123 | n_ctx, 124 | width, 125 | heads, 126 | ) 127 | for _ in range(layers) 128 | ] 129 | ) 130 | 131 | def forward(self, x, mask=None): 132 | for block in self.resblocks: 133 | x = block(x, mask=mask) 134 | return x 135 | 136 | 137 | class PriorTransformer(nn.Module): 138 | """ 139 | A Causal Transformer that conditions on CLIP text embedding, text. 140 | 141 | :param text_ctx: number of text tokens to expect. 142 | :param xf_width: width of the transformer. 143 | :param xf_layers: depth of the transformer. 144 | :param xf_heads: heads in the transformer. 145 | :param xf_final_ln: use a LayerNorm after the output layer. 146 | :param clip_dim: dimension of clip feature. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | text_ctx, 152 | xf_width, 153 | xf_layers, 154 | xf_heads, 155 | xf_final_ln, 156 | clip_dim, 157 | ): 158 | super().__init__() 159 | 160 | self.text_ctx = text_ctx 161 | self.xf_width = xf_width 162 | self.xf_layers = xf_layers 163 | self.xf_heads = xf_heads 164 | self.clip_dim = clip_dim 165 | self.ext_len = 4 166 | 167 | self.time_embed = nn.Sequential( 168 | nn.Linear(xf_width, xf_width), 169 | nn.SiLU(), 170 | nn.Linear(xf_width, xf_width), 171 | ) 172 | self.text_enc_proj = nn.Linear(clip_dim, xf_width) 173 | self.text_emb_proj = nn.Linear(clip_dim, xf_width) 174 | self.clip_img_proj = nn.Linear(clip_dim, xf_width) 175 | self.out_proj = nn.Linear(xf_width, clip_dim) 176 | self.transformer = Transformer( 177 | text_ctx + self.ext_len, 178 | xf_width, 179 | xf_layers, 180 | xf_heads, 181 | ) 182 | if xf_final_ln: 183 | self.final_ln = LayerNorm(xf_width) 184 | else: 185 | self.final_ln = None 186 | 187 | self.positional_embedding = nn.Parameter( 188 | th.empty(1, text_ctx + self.ext_len, xf_width) 189 | ) 190 | self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width))) 191 | 192 | nn.init.normal_(self.prd_emb, std=0.01) 193 | nn.init.normal_(self.positional_embedding, std=0.01) 194 | 195 | def forward( 196 | self, 197 | x, 198 | timesteps, 199 | text_emb=None, 200 | text_enc=None, 201 | mask=None, 202 | causal_mask=None, 203 | ): 204 | bsz = x.shape[0] 205 | mask = F.pad(mask, (0, self.ext_len), value=True) 206 | 207 | t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width)) 208 | text_enc = self.text_enc_proj(text_enc) 209 | text_emb = self.text_emb_proj(text_emb) 210 | x = self.clip_img_proj(x) 211 | 212 | input_seq = [ 213 | text_enc, 214 | text_emb[:, None, :], 215 | t_emb[:, None, :], 216 | x[:, None, :], 217 | self.prd_emb.to(x.dtype).expand(bsz, -1, -1), 218 | ] 219 | input = th.cat(input_seq, dim=1) 220 | input = input + self.positional_embedding.to(input.dtype) 221 | 222 | mask = th.where(mask, 0.0, float("-inf")) 223 | mask = (mask[:, None, :] + causal_mask).to(input.dtype) 224 | 225 | out = self.transformer(input, mask=mask) 226 | if self.final_ln is not None: 227 | out = self.final_ln(out) 228 | 229 | out = self.out_proj(out[:, -1]) 230 | 231 | return out 232 | -------------------------------------------------------------------------------- /karlo/sampler/i2i.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from typing import Iterator 7 | 8 | import torch 9 | import torchvision.transforms.functional as TVF 10 | from torchvision.transforms import InterpolationMode 11 | 12 | from .template import BaseSampler, CKPT_PATH 13 | 14 | 15 | class I2ISampler(BaseSampler): 16 | """ 17 | A sampler for image variation. In the original unclip paper, image variation transforms the noise obtained by DDIM inversion into a sample in RGB space. 18 | Here, we simply transform the white noise to image, conditioned on the clip image feature. 19 | 20 | :param root_dir: directory for model checkpoints. 21 | :param sampling_type: ["default", "fast"] 22 | """ 23 | 24 | def __init__( 25 | self, 26 | root_dir: str, 27 | sampling_type: str = "default", 28 | ): 29 | super().__init__(root_dir, sampling_type) 30 | 31 | @classmethod 32 | def from_pretrained( 33 | cls, 34 | root_dir: str, 35 | clip_model_path: str, 36 | clip_stat_path: str, 37 | sampling_type: str = "default", 38 | ): 39 | 40 | model = cls( 41 | root_dir=root_dir, 42 | sampling_type=sampling_type, 43 | ) 44 | model.load_clip(clip_model_path) 45 | model.load_decoder(f"{CKPT_PATH['decoder']}") 46 | model.load_sr_64_256(CKPT_PATH["sr_256"]) 47 | 48 | return model 49 | 50 | def preprocess( 51 | self, 52 | image, 53 | prompt: str, 54 | bsz: int, 55 | ): 56 | prompts_batch = [prompt for _ in range(bsz)] 57 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) 58 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") 59 | 60 | # preprocess input image 61 | image = TVF.normalize( 62 | TVF.to_tensor( 63 | TVF.resize( 64 | image, 65 | [224, 224], 66 | interpolation=InterpolationMode.BICUBIC, 67 | antialias=True, 68 | ) 69 | ), 70 | mean=[0.48145466, 0.4578275, 0.40821073], 71 | std=[0.26862954, 0.26130258, 0.27577711], 72 | ).unsqueeze(0) 73 | image_batch = image.repeat(bsz, 1, 1, 1).cuda() 74 | 75 | """ Get CLIP text and image features """ 76 | clip_model = self._clip 77 | tokenizer = self._tokenizer 78 | max_txt_length = 77 79 | 80 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) 81 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) 82 | if not (cf_token.shape == tok.shape): 83 | cf_token = cf_token.expand(tok.shape[0], -1) 84 | cf_mask = cf_mask.expand(tok.shape[0], -1) 85 | 86 | tok = torch.cat([tok, cf_token], dim=0) 87 | mask = torch.cat([mask, cf_mask], dim=0) 88 | 89 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda") 90 | txt_feat, txt_feat_seq = clip_model.encode_text(tok) 91 | img_feat = clip_model.encode_image(image_batch) 92 | 93 | return ( 94 | prompts_batch, 95 | decoder_cf_scales_batch, 96 | txt_feat, 97 | txt_feat_seq, 98 | tok, 99 | mask, 100 | img_feat, 101 | ) 102 | 103 | def __call__( 104 | self, 105 | image, 106 | bsz: int, 107 | progressive_mode=None, 108 | ) -> Iterator[torch.Tensor]: 109 | assert progressive_mode in ("loop", "stage", "final") 110 | with torch.no_grad(), torch.cuda.amp.autocast(): 111 | ( 112 | prompts_batch, 113 | decoder_cf_scales_batch, 114 | txt_feat, 115 | txt_feat_seq, 116 | tok, 117 | mask, 118 | img_feat, 119 | ) = self.preprocess( 120 | image=image, 121 | prompt="", 122 | bsz=bsz, 123 | ) 124 | 125 | """ Generate 64x64px images """ 126 | images_64_outputs = self._decoder( 127 | txt_feat, 128 | txt_feat_seq, 129 | tok, 130 | mask, 131 | img_feat, 132 | cf_guidance_scales=decoder_cf_scales_batch, 133 | timestep_respacing=self._decoder_sm, 134 | ) 135 | 136 | images_64 = None 137 | for k, out in enumerate(images_64_outputs): 138 | images_64 = out 139 | if progressive_mode == "loop": 140 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 141 | if progressive_mode == "stage": 142 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 143 | 144 | images_64 = torch.clamp(images_64, -1, 1) 145 | 146 | """ Upsample 64x64 to 256x256 """ 147 | images_256 = TVF.resize( 148 | images_64, 149 | [256, 256], 150 | interpolation=InterpolationMode.BICUBIC, 151 | antialias=True, 152 | ) 153 | images_256_outputs = self._sr_64_256( 154 | images_256, timestep_respacing=self._sr_sm 155 | ) 156 | 157 | for k, out in enumerate(images_256_outputs): 158 | images_256 = out 159 | if progressive_mode == "loop": 160 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 161 | if progressive_mode == "stage": 162 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 163 | 164 | yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0) 165 | -------------------------------------------------------------------------------- /karlo/sampler/t2i.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | from typing import Iterator 7 | 8 | import torch 9 | import torchvision.transforms.functional as TVF 10 | from torchvision.transforms import InterpolationMode 11 | 12 | from .template import BaseSampler, CKPT_PATH 13 | 14 | 15 | class T2ISampler(BaseSampler): 16 | """ 17 | A sampler for text-to-image generation. 18 | 19 | :param root_dir: directory for model checkpoints. 20 | :param sampling_type: ["default", "fast"] 21 | """ 22 | 23 | def __init__( 24 | self, 25 | root_dir: str, 26 | sampling_type: str = "default", 27 | ): 28 | super().__init__(root_dir, sampling_type) 29 | 30 | @classmethod 31 | def from_pretrained( 32 | cls, 33 | root_dir: str, 34 | clip_model_path: str, 35 | clip_stat_path: str, 36 | sampling_type: str = "default", 37 | ): 38 | 39 | model = cls( 40 | root_dir=root_dir, 41 | sampling_type=sampling_type, 42 | ) 43 | model.load_clip(clip_model_path) 44 | model.load_prior( 45 | f"{CKPT_PATH['prior']}", 46 | clip_stat_path=clip_stat_path, 47 | ) 48 | model.load_decoder(f"{CKPT_PATH['decoder']}") 49 | model.load_sr_64_256(CKPT_PATH["sr_256"]) 50 | 51 | return model 52 | 53 | def preprocess( 54 | self, 55 | prompt: str, 56 | bsz: int, 57 | ): 58 | """Setup prompts & cfg scales""" 59 | prompts_batch = [prompt for _ in range(bsz)] 60 | 61 | prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) 62 | prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") 63 | 64 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) 65 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") 66 | 67 | """ Get CLIP text feature """ 68 | clip_model = self._clip 69 | tokenizer = self._tokenizer 70 | max_txt_length = self._prior.model.text_ctx 71 | 72 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) 73 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) 74 | if not (cf_token.shape == tok.shape): 75 | cf_token = cf_token.expand(tok.shape[0], -1) 76 | cf_mask = cf_mask.expand(tok.shape[0], -1) 77 | 78 | tok = torch.cat([tok, cf_token], dim=0) 79 | mask = torch.cat([mask, cf_mask], dim=0) 80 | 81 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda") 82 | txt_feat, txt_feat_seq = clip_model.encode_text(tok) 83 | 84 | return ( 85 | prompts_batch, 86 | prior_cf_scales_batch, 87 | decoder_cf_scales_batch, 88 | txt_feat, 89 | txt_feat_seq, 90 | tok, 91 | mask, 92 | ) 93 | 94 | def __call__( 95 | self, 96 | prompt: str, 97 | bsz: int, 98 | progressive_mode=None, 99 | ) -> Iterator[torch.Tensor]: 100 | assert progressive_mode in ("loop", "stage", "final") 101 | with torch.no_grad(), torch.cuda.amp.autocast(): 102 | ( 103 | prompts_batch, 104 | prior_cf_scales_batch, 105 | decoder_cf_scales_batch, 106 | txt_feat, 107 | txt_feat_seq, 108 | tok, 109 | mask, 110 | ) = self.preprocess( 111 | prompt, 112 | bsz, 113 | ) 114 | 115 | """ Transform CLIP text feature into image feature """ 116 | img_feat = self._prior( 117 | txt_feat, 118 | txt_feat_seq, 119 | mask, 120 | prior_cf_scales_batch, 121 | timestep_respacing=self._prior_sm, 122 | ) 123 | 124 | """ Generate 64x64px images """ 125 | images_64_outputs = self._decoder( 126 | txt_feat, 127 | txt_feat_seq, 128 | tok, 129 | mask, 130 | img_feat, 131 | cf_guidance_scales=decoder_cf_scales_batch, 132 | timestep_respacing=self._decoder_sm, 133 | ) 134 | 135 | images_64 = None 136 | for k, out in enumerate(images_64_outputs): 137 | images_64 = out 138 | if progressive_mode == "loop": 139 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 140 | if progressive_mode == "stage": 141 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 142 | 143 | images_64 = torch.clamp(images_64, -1, 1) 144 | 145 | """ Upsample 64x64 to 256x256 """ 146 | images_256 = TVF.resize( 147 | images_64, 148 | [256, 256], 149 | interpolation=InterpolationMode.BICUBIC, 150 | antialias=True, 151 | ) 152 | images_256_outputs = self._sr_64_256( 153 | images_256, timestep_respacing=self._sr_sm 154 | ) 155 | 156 | for k, out in enumerate(images_256_outputs): 157 | images_256 = out 158 | if progressive_mode == "loop": 159 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 160 | if progressive_mode == "stage": 161 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) 162 | 163 | yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0) 164 | -------------------------------------------------------------------------------- /karlo/sampler/template.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Karlo-v1.0.alpha 3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import os 7 | import logging 8 | import torch 9 | 10 | from omegaconf import OmegaConf 11 | 12 | from ..models.clip import CustomizedCLIP, CustomizedTokenizer 13 | from ..models.prior_model import PriorDiffusionModel 14 | from ..models.decoder_model import Text2ImProgressiveModel 15 | from ..models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel 16 | 17 | 18 | SAMPLING_CONF = { 19 | "default": { 20 | "prior_sm": "25", 21 | "prior_n_samples": 1, 22 | "prior_cf_scale": 4.0, 23 | "decoder_sm": "50", 24 | "decoder_cf_scale": 8.0, 25 | "sr_sm": "7", 26 | }, 27 | "fast": { 28 | "prior_sm": "25", 29 | "prior_n_samples": 1, 30 | "prior_cf_scale": 4.0, 31 | "decoder_sm": "25", 32 | "decoder_cf_scale": 8.0, 33 | "sr_sm": "7", 34 | }, 35 | } 36 | 37 | CKPT_PATH = { 38 | "prior": "prior-ckpt-step=01000000-of-01000000.ckpt", 39 | "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt", 40 | "sr_256": "improved-sr-ckpt-step=1.2M.ckpt", 41 | } 42 | 43 | 44 | class BaseSampler: 45 | _PRIOR_CLASS = PriorDiffusionModel 46 | _DECODER_CLASS = Text2ImProgressiveModel 47 | _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel 48 | 49 | def __init__( 50 | self, 51 | root_dir: str, 52 | sampling_type: str = "fast", 53 | ): 54 | self._root_dir = root_dir 55 | 56 | sampling_type = SAMPLING_CONF[sampling_type] 57 | self._prior_sm = sampling_type["prior_sm"] 58 | self._prior_n_samples = sampling_type["prior_n_samples"] 59 | self._prior_cf_scale = sampling_type["prior_cf_scale"] 60 | 61 | assert self._prior_n_samples == 1 62 | 63 | self._decoder_sm = sampling_type["decoder_sm"] 64 | self._decoder_cf_scale = sampling_type["decoder_cf_scale"] 65 | 66 | self._sr_sm = sampling_type["sr_sm"] 67 | 68 | def __repr__(self): 69 | line = "" 70 | line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n" 71 | line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n" 72 | line += f"SR(64->256), sampling method: {self._sr_sm}" 73 | 74 | return line 75 | 76 | def load_clip(self, clip_path: str): 77 | clip = CustomizedCLIP.load_from_checkpoint( 78 | os.path.join(self._root_dir, clip_path) 79 | ) 80 | clip = torch.jit.script(clip) 81 | clip.cuda() 82 | clip.eval() 83 | 84 | self._clip = clip 85 | self._tokenizer = CustomizedTokenizer() 86 | 87 | def load_prior( 88 | self, 89 | ckpt_path: str, 90 | clip_stat_path: str, 91 | ): 92 | logging.info(f"Loading prior: {ckpt_path}") 93 | 94 | config = OmegaConf.load("configs/prior_1B_vit_l.yaml") 95 | clip_mean, clip_std = torch.load( 96 | os.path.join(self._root_dir, clip_stat_path), map_location="cpu" 97 | ) 98 | 99 | prior = self._PRIOR_CLASS.load_from_checkpoint( 100 | config, 101 | self._tokenizer, 102 | clip_mean, 103 | clip_std, 104 | os.path.join(self._root_dir, ckpt_path), 105 | strict=True, 106 | ) 107 | prior.cuda() 108 | prior.eval() 109 | logging.info("done.") 110 | 111 | self._prior = prior 112 | 113 | def load_decoder(self, ckpt_path: str): 114 | logging.info(f"Loading decoder: {ckpt_path}") 115 | 116 | config = OmegaConf.load("configs/decoder_900M_vit_l.yaml") 117 | decoder = self._DECODER_CLASS.load_from_checkpoint( 118 | config, 119 | self._tokenizer, 120 | os.path.join(self._root_dir, ckpt_path), 121 | strict=True, 122 | ) 123 | decoder.cuda() 124 | decoder.eval() 125 | logging.info("done.") 126 | 127 | self._decoder = decoder 128 | 129 | def load_sr_64_256(self, ckpt_path: str): 130 | logging.info(f"Loading SR(64->256): {ckpt_path}") 131 | 132 | config = OmegaConf.load("configs/improved_sr_64_256_1.4B.yaml") 133 | sr = self._SR256_CLASS.load_from_checkpoint( 134 | config, os.path.join(self._root_dir, ckpt_path), strict=True 135 | ) 136 | sr.cuda() 137 | sr.eval() 138 | logging.info("done.") 139 | 140 | self._sr_64_256 = sr 141 | -------------------------------------------------------------------------------- /karlo/utils/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def set_seed(seed): 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.10 2 | torchvision>=0.8.2 3 | black 4 | einops 5 | omegaconf 6 | matplotlib 7 | gradio>=3.5.0 8 | git+https://github.com/openai/CLIP.git 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203, E226, E402, E731, W503, W504 4 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install -r requirements.txt 4 | 5 | export KARLO_ROOT_DIR=$HOME/.cache/karlo/v1.0.alpha/ 6 | 7 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt -P $KARLO_ROOT_DIR 8 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th -P $KARLO_ROOT_DIR 9 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR 10 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR 11 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt -P $KARLO_ROOT_DIR 12 | --------------------------------------------------------------------------------