├── .gitignore
├── LICENSE
├── README.md
├── assets
├── A man with a face of avocado, in the drawing style of Rene Magritte..png
├── A teddy bear on a skateboard, children drawing style..png
├── Goryeo celadon in the shape of bird.png
├── Photo of a business woman, silver hair.png
├── a black porcelain in the shape of pikachu.png
├── a portrait of an old monk, highly detailed.png
├── example.gif
├── improved_sr_arch.jpg
├── variation_A man with a face of avocado, in the drawing style of Rene Magritte..png
└── variation_a black porcelain in the shape of pikachu.png
├── configs
├── decoder_900M_vit_l.yaml
├── improved_sr_64_256_1.4B.yaml
└── prior_1B_vit_l.yaml
├── demo
├── components.py
└── product_demo.py
├── example.py
├── karlo
├── __init__.py
├── models
│ ├── __init__.py
│ ├── clip.py
│ ├── decoder_model.py
│ ├── prior_model.py
│ ├── sr_256_1k.py
│ └── sr_64_256.py
├── modules
│ ├── __init__.py
│ ├── diffusion
│ │ ├── gaussian_diffusion.py
│ │ └── respace.py
│ ├── nn.py
│ ├── resample.py
│ ├── unet.py
│ └── xf.py
├── sampler
│ ├── i2i.py
│ ├── t2i.py
│ └── template.py
└── utils
│ └── util.py
├── requirements.txt
├── setup.cfg
└── setup.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | .idea/
3 | __pycache__/
4 |
5 | results
6 | results.*
7 | notebooks/
8 | outputs
9 | outputs/
10 | configs/temp/
11 | pytests/
12 | _cache/
13 | scripts/
14 |
15 | *.ckpt
16 | core.*
17 |
18 | __pycache__/
19 | **/__pycache__/
20 | *.py[cod]
21 | **/*.py[cod]
22 | **/*.pyc
23 | result*/
24 | results*/
25 | backup*/
26 | test.*/
27 | .nfs*
28 |
29 | .ipynb_*/
30 |
31 | # MacOSX
32 | **/*.DS_Store
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022 Kakao Brain
2 |
3 | CreativeML Open RAIL-M
4 | dated August 22, 2022
5 |
6 | Section I: PREAMBLE
7 |
8 | Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
9 |
10 | Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
11 |
12 | In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
13 |
14 | Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
15 |
16 | This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
17 |
18 | NOW THEREFORE, You and Licensor agree as follows:
19 |
20 | 1. Definitions
21 |
22 | - "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
23 | - "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
24 | - "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
25 | - "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
26 | - "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
27 | - "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
28 | - "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
29 | - "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
30 | - "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
31 | - "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
32 | - "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
33 | - "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
34 |
35 | Section II: INTELLECTUAL PROPERTY RIGHTS
36 |
37 | Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
38 |
39 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
40 | 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
41 |
42 | Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
43 |
44 | 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
45 | Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
46 | You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
47 | You must cause any modified files to carry prominent notices stating that You changed the files;
48 | You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
49 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
50 | 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
51 | 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
52 |
53 | Section IV: OTHER PROVISIONS
54 |
55 | 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
56 | 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
57 | 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
58 | 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
59 | 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
60 | 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
61 |
62 | END OF TERMS AND CONDITIONS
63 |
64 |
65 |
66 |
67 | Attachment A
68 |
69 | Use Restrictions
70 |
71 | You agree not to use the Model or Derivatives of the Model:
72 | - In any way that violates any applicable national, federal, state, local or international law or regulation;
73 | - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
74 | - To generate or disseminate verifiably false information and/or content with the purpose of harming others;
75 | - To generate or disseminate personal identifiable information that can be used to harm an individual;
76 | - To defame, disparage or otherwise harass others;
77 | - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
78 | - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
79 | - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
80 | - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
81 | - To provide medical advice and medical results interpretation;
82 | - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Karlo-v1.0.alpha on COYO-100M and CC15M
2 |
3 | Karlo is a text-conditional image generation model based on OpenAI's unCLIP architecture with the improvement over the standard super-resolution model from 64px to 256px, recovering high-frequency details only in the small number of denoising steps.
4 |
5 |
6 |
7 |
8 |
9 |
10 | "a portrait of an old monk, highly detailed."
11 |
12 |
13 |
14 |
15 |
16 | "Photo of a business woman, silver hair"
17 |
18 |
19 |
20 |
21 |
22 | "A teddy bear on a skateboard, children drawing style."
23 |
24 |
25 |
26 |
27 |
28 | "Goryeo celadon in the shape of bird"
29 |
30 |
31 |
32 |
33 |
34 | This alpha version of Karlo is trained on 115M image-text pairs, including [COYO](https://github.com/kakaobrain/coyo-dataset)-100M high-quality subset, CC3M, and CC12M. For those who are interested in a better version of Karlo trained on more large-scale high-quality datasets, please visit the landing page of our application [B^DISCOVER](https://bdiscover.kakaobrain.com/).
35 |
36 | ### Updates
37 | * [2022-12-01] Karlo-v1.0.alpha is released!
38 | * [2022-12-19] Karlo-v1.0.alpha was [integrated into the 🧨 diffusers library](#-diffusers-integration)
39 | * [2022-12-20] Karlo-v1.0.alpha was integrated into Huggingface Spaces 🤗 using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/kakaobrain/karlo)
40 |
41 | ## Model Architecture
42 |
43 | ### Overview
44 | Karlo is a text-conditional diffusion model based on unCLIP, composed of prior, decoder, and super-resolution modules. In this repository, we include the improved version of the standard super-resolution module for upscaling 64px to 256px only in 7 reverse steps, as illustrated in the figure below:
45 |
46 |
47 |
48 |
49 |
50 | In specific, the standard SR module trained by DDPM objective upscales 64px to 256px in the first 6 denoising steps based on the respacing technique. Then, the additional fine-tuned SR module trained by [VQ-GAN](https://compvis.github.io/taming-transformers/)-style loss performs the final reverse step to recover high-frequency details. We observe that this approach is very effective to upscale the low-resolution in a small number of reverse steps.
51 |
52 | ### Details
53 | We train all components from scratch on 115M image-text pairs including COYO-100M, CC3M, and CC12M. In the case of Prior and Decoder, we use ViT-L/14 provided by OpenAI’s [CLIP repository](https://github.com/openai/CLIP). Unlike the original implementation of unCLIP, we replace the trainable transformer in the decoder into the text encoder in ViT-L/14 for efficiency. In the case of the SR module, we first train the model using the DDPM objective in 1M steps, followed by additional 234K steps to fine-tune the additional component. The table below summarizes the important statistics of our components:
54 |
55 | | | Prior | Decoder | SR |
56 | |:------|----:|----:|----:|
57 | | CLIP | ViT-L/14 | ViT-L/14 | - |
58 | | #param | 1B | 900M | 700M + 700M |
59 | | #optimization steps | 1M | 1M | 1M + 0.2M |
60 | | #sampling steps | 25 | 50 (default), 25 (fast) | 7 |
61 | |Checkpoint links| [ViT-L-14](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt), [ViT-L-14 stats](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th), [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt) | [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt)| [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt) |
62 |
63 | In the checkpoint links, ViT-L-14 is equivalent to the original version, but we include it for convenience. We also remark that ViT-L-14-stats is required to normalize the outputs of the prior module.
64 |
65 | ### Evaluation
66 | We quantitatively measure the performance of Karlo-v1.0.alpha in the validation split of CC3M and MS-COCO. The table below presents CLIP-score and FID. To measure FID, we resize the image of the shorter side to 256px, followed by cropping it at the center. We set classifier-free guidance scales for prior and decoder to 4 and 8 in all cases. We observe that our model achieves reasonable performance even with 25 sampling steps of decoder.
67 |
68 | CC3M
69 | | Sampling step | CLIP-s (ViT-B/16) | FID (13k from val)|
70 | |:------|----:|----:|
71 | | Prior (25) + Decoder (25) + SR (7) | 0.3081 | 14.37 |
72 | | Prior (25) + Decoder (50) + SR (7) | 0.3086 | 13.95 |
73 |
74 | MS-COCO
75 | | Sampling step | CLIP-s (ViT-B/16) | FID (30k from val)|
76 | |:------|----:|----:|
77 | | Prior (25) + Decoder (25) + SR (7) | 0.3192 | 15.24 |
78 | | Prior (25) + Decoder (50) + SR (7) | 0.3192 | 14.43 |
79 |
80 |
81 | For more information, please refer to the upcoming technical report.
82 |
83 | ## 🧨 Diffusers integration
84 | Our unCLIP implemenetation is officially integrated in the [🧨 diffusers library](https://huggingface.co/docs/diffusers/api/pipelines/unclip)
85 | ```
86 | #Requisits to run Karlo unCLIP on diffusers
87 | pip install diffusers transformers accelerate safetensors
88 | ```
89 |
90 | ```py
91 | from diffusers import UnCLIPPipeline
92 | import torch
93 |
94 | pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
95 | pipe = pipe.to('cuda')
96 |
97 | prompt = "a high-resolution photograph of a big red frog on a green leaf."
98 | image = pipe(prompt).images[0]
99 | image.save("./frog.png")
100 | ```
101 | Check out the [diffusers docs](https://huggingface.co/docs/diffusers/api/pipelines/unclip) for the full usage of the `unCLIPPipeline`
102 |
103 | ## Environment Setup
104 | We use a single V100 of 32GB VRAM for sampling under PyTorch >= 1.10 and CUDA >= 11. The following commands install additional python packages and get pretrained model checkpoints. Or, you can simply install the package and download the weights via [setup.sh](setup.sh)
105 | - Additional python packages
106 | ```
107 | pip install -r requirements.txt
108 | ```
109 | - Model checkpoints
110 | ```
111 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt -P $KARLO_ROOT_DIR # same with the official ViT-L/14 from OpenAI
112 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th -P $KARLO_ROOT_DIR
113 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
114 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
115 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt -P $KARLO_ROOT_DIR
116 | ```
117 |
118 | ## Sampling
119 |
120 | ### Gradio demo (T2I and Image variation)
121 | The following command launches gradio demo for text-to-image generation and image variation. We notice that the second run in the gradio is unexpectedly slower than the usual case in PyTorch>=1.12. We guess that this happens because launching the cuda kernels takes some time, usually up to 2 minutes.
122 | ```
123 | python demo/product_demo.py --host 0.0.0.0 --port $PORT --root-dir $KARLO_ROOT_DIR
124 | ```
125 |
126 | Samples below are non-cherry picked T2I and image variation examples of random seed 0.
127 | In each case, the first row shows T2I samples and the second shows the image variation samples of the leftmost image in the first row.
128 |
129 |
130 | [T2I + Image variation] "A man with a face of avocado, in the drawing style of Rene Magritte."
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 | [T2I + Image variation] "a black porcelain in the shape of pikachu"
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 | ### T2I command line example
147 | Here, we include the command line example of T2I. For image variation, you can refer to [karlo/sampler/i2i.py](karlo/sampler/i2i.py) on how to replace the prior into the clip image feature.
148 | ```python
149 | python example.py --root-dir=$KARLO_ROOT_DIR \
150 | --prompt="A man with a face of avocado, in the drawing style of Rene Magritte" \
151 | --output-dir=$OUTPUT_DIR \
152 | --max-bsz=2 \
153 | --sampling-type=fast
154 | ```
155 |
156 | ## Licence and Disclaimer
157 | This project including the weights are distributed under [CreativeML Open RAIL-M license](LICENSE), equivalent version of [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion/blob/main/LICENSE). You may use this model in commercial applications, but it is highly recommended to adopt a powerful safe checker as a post-processing. We also remark that we are not responsible for any kinds of use of the generated images.
158 |
159 | ## BibTex
160 | If you find this repository useful in your research, please cite:
161 | ```
162 | @misc{kakaobrain2022karlo-v1-alpha,
163 | title = {Karlo-v1.0.alpha on COYO-100M and CC15M},
164 | author = {Donghoon Lee, Jiseob Kim, Jisu Choi, Jongmin Kim, Minwoo Byeon, Woonhyuk Baek and Saehoon Kim},
165 | year = {2022},
166 | howpublished = {\url{https://github.com/kakaobrain/karlo}},
167 | }
168 | ```
169 |
170 | ## Acknowledgement
171 | * We deeply appreciate all the contributors to OpenAI’s [Guided-Diffusion](https://github.com/openai/guided-diffusion) project.
172 | * We also greatly appreciate [Apolinário Passos](https://github.com/apolinario) and [Will Berman](https://github.com/williamberman) from Huggingface for integrating this model to [diffusers](https://github.com/huggingface/diffusers).
173 |
174 | ## Contact
175 | If you would like to collaborate with us or share a feedback, please e-mail to us, contact@kakaobrain.com
176 |
--------------------------------------------------------------------------------
/assets/A man with a face of avocado, in the drawing style of Rene Magritte..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/A man with a face of avocado, in the drawing style of Rene Magritte..png
--------------------------------------------------------------------------------
/assets/A teddy bear on a skateboard, children drawing style..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/A teddy bear on a skateboard, children drawing style..png
--------------------------------------------------------------------------------
/assets/Goryeo celadon in the shape of bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/Goryeo celadon in the shape of bird.png
--------------------------------------------------------------------------------
/assets/Photo of a business woman, silver hair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/Photo of a business woman, silver hair.png
--------------------------------------------------------------------------------
/assets/a black porcelain in the shape of pikachu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/a black porcelain in the shape of pikachu.png
--------------------------------------------------------------------------------
/assets/a portrait of an old monk, highly detailed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/a portrait of an old monk, highly detailed.png
--------------------------------------------------------------------------------
/assets/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/example.gif
--------------------------------------------------------------------------------
/assets/improved_sr_arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/improved_sr_arch.jpg
--------------------------------------------------------------------------------
/assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png
--------------------------------------------------------------------------------
/assets/variation_a black porcelain in the shape of pikachu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/assets/variation_a black porcelain in the shape of pikachu.png
--------------------------------------------------------------------------------
/configs/decoder_900M_vit_l.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: t2i-decoder
3 | diffusion_sampler: uniform
4 | hparams:
5 | image_size: 64
6 | num_channels: 320
7 | num_res_blocks: 3
8 | channel_mult: ''
9 | attention_resolutions: 32,16,8
10 | num_heads: -1
11 | num_head_channels: 64
12 | num_heads_upsample: -1
13 | use_scale_shift_norm: true
14 | dropout: 0.1
15 | clip_dim: 768
16 | clip_emb_mult: 4
17 | text_ctx: 77
18 | xf_width: 1536
19 | xf_layers: 0
20 | xf_heads: 0
21 | xf_final_ln: false
22 | resblock_updown: true
23 | learn_sigma: true
24 | text_drop: 0.3
25 | clip_emb_type: image
26 | clip_emb_drop: 0.1
27 | use_plm: true
28 |
29 | diffusion:
30 | steps: 1000
31 | learn_sigma: true
32 | sigma_small: false
33 | noise_schedule: squaredcos_cap_v2
34 | use_kl: false
35 | predict_xstart: false
36 | rescale_learned_sigmas: true
37 | timestep_respacing: ''
38 |
--------------------------------------------------------------------------------
/configs/improved_sr_64_256_1.4B.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: improved_sr_64_256
3 | diffusion_sampler: uniform
4 | hparams:
5 | channels: 320
6 | depth: 3
7 | channels_multiple:
8 | - 1
9 | - 2
10 | - 3
11 | - 4
12 | dropout: 0.0
13 |
14 | diffusion:
15 | steps: 1000
16 | learn_sigma: false
17 | sigma_small: true
18 | noise_schedule: squaredcos_cap_v2
19 | use_kl: false
20 | predict_xstart: false
21 | rescale_learned_sigmas: true
22 | timestep_respacing: '7'
23 |
24 |
25 | sampling:
26 | timestep_respacing: '7' # fix
27 | clip_denoise: true
28 |
--------------------------------------------------------------------------------
/configs/prior_1B_vit_l.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: prior
3 | diffusion_sampler: uniform
4 | hparams:
5 | text_ctx: 77
6 | xf_width: 2048
7 | xf_layers: 20
8 | xf_heads: 32
9 | xf_final_ln: true
10 | text_drop: 0.2
11 | clip_dim: 768
12 |
13 | diffusion:
14 | steps: 1000
15 | learn_sigma: false
16 | sigma_small: true
17 | noise_schedule: squaredcos_cap_v2
18 | use_kl: false
19 | predict_xstart: true
20 | rescale_learned_sigmas: false
21 | timestep_respacing: ''
22 |
--------------------------------------------------------------------------------
/demo/components.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import time
7 | import sys
8 | import os
9 | import threading
10 | import logging
11 | from queue import Queue
12 | from PIL import Image
13 |
14 | import gradio as gr
15 | import numpy as np
16 | import torch
17 |
18 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
19 |
20 | from karlo.sampler.template import CKPT_PATH, BaseSampler
21 | from karlo.sampler.t2i import T2ISampler
22 | from karlo.sampler.i2i import I2ISampler
23 | from karlo.utils.util import set_seed
24 |
25 |
26 | def tensor_to_images(tensor: torch.Tensor, output_res=(1024, 1024)):
27 | assert tensor.ndim == 4
28 | tensor = torch.clone(tensor)
29 | # NCHW -> NHWC
30 | images = torch.permute(tensor * 255.0, [0, 2, 3, 1]).type(torch.uint8).cpu().numpy()
31 | concat_image = np.concatenate(images, axis=1)
32 | target_size = (output_res[1] * tensor.shape[0], output_res[0])
33 | concat_image = Image.fromarray(concat_image).resize(
34 | target_size, resample=Image.NEAREST
35 | )
36 | return images, concat_image
37 |
38 |
39 | class GradioSampler:
40 | def __init__(
41 | self,
42 | root_dir,
43 | max_bsz,
44 | progressive,
45 | sampling_type: str,
46 | ):
47 | self._root_dir = root_dir
48 | self._max_bsz = max_bsz
49 | self._progressive = progressive
50 | self._sampling_type = sampling_type
51 |
52 | self.load_ckpt()
53 | self.set_options_from_sampler()
54 |
55 | self.result_queue = Queue()
56 |
57 | def load_ckpt(self):
58 | base_sampler = BaseSampler(root_dir=self._root_dir)
59 | base_sampler.load_clip(clip_path="ViT-L-14.pt")
60 | base_sampler.load_prior(
61 | f"{CKPT_PATH['prior']}",
62 | clip_stat_path="ViT-L-14_stats.th",
63 | )
64 | base_sampler.load_decoder(f"{CKPT_PATH['decoder']}")
65 | base_sampler.load_sr_64_256(f"{CKPT_PATH['sr_256']}")
66 |
67 | self.t2i_sampler = T2ISampler(
68 | root_dir=self._root_dir, sampling_type=self._sampling_type
69 | )
70 | self.i2i_sampler = I2ISampler(
71 | root_dir=self._root_dir, sampling_type=self._sampling_type
72 | )
73 |
74 | self.t2i_sampler._clip = base_sampler._clip
75 | self.t2i_sampler._tokenizer = base_sampler._tokenizer
76 | self.t2i_sampler._prior = base_sampler._prior
77 | self.t2i_sampler._decoder = base_sampler._decoder
78 | self.t2i_sampler._sr_64_256 = base_sampler._sr_64_256
79 |
80 | self.i2i_sampler._clip = base_sampler._clip
81 | self.i2i_sampler._tokenizer = base_sampler._tokenizer
82 | self.i2i_sampler._prior = base_sampler._prior
83 | self.i2i_sampler._decoder = base_sampler._decoder
84 | self.i2i_sampler._sr_64_256 = base_sampler._sr_64_256
85 |
86 | self.ckpt_info = f"""
87 | * **prior**: `{self._root_dir}/{CKPT_PATH['prior']}`
88 | * **decoder**: `{self._root_dir}/{CKPT_PATH['decoder']}`
89 | * **sr_64_256**: `{self._root_dir}/{CKPT_PATH['sr_256']}`
90 | """
91 |
92 | def set_options_from_sampler(self):
93 | self.global_options = {"seed": 0, "max_bsz": self._max_bsz}
94 |
95 | self.prior_options = {
96 | "sm": self.t2i_sampler._prior_sm,
97 | "cf_scale": self.t2i_sampler._prior_cf_scale,
98 | }
99 | self.decoder_options = {
100 | "sm": self.t2i_sampler._decoder_sm,
101 | "cf_scale": self.t2i_sampler._decoder_cf_scale,
102 | }
103 | self.sr_64_256_options = {
104 | "sm": self.t2i_sampler._sr_sm,
105 | }
106 |
107 | def make_global_options(self):
108 | gr.Markdown("Global Options")
109 | with gr.Row():
110 | return [
111 | gr.Slider(
112 | label="seed",
113 | value=self.global_options["seed"],
114 | minimum=np.iinfo(np.uint32).min,
115 | maximum=np.iinfo(np.uint32).max,
116 | step=1,
117 | ),
118 | gr.Slider(
119 | label="maximum batch size",
120 | value=self.global_options["max_bsz"],
121 | minimum=1,
122 | maximum=4,
123 | step=1,
124 | ),
125 | ]
126 |
127 | def make_prior_options(self):
128 | gr.Markdown("Prior Options")
129 | return [
130 | gr.Textbox(
131 | label="sampling method",
132 | value=self.prior_options["sm"],
133 | ),
134 | gr.Slider(
135 | label="Classifier-free guidance scales",
136 | value=self.prior_options["cf_scale"],
137 | minimum=0.1,
138 | maximum=24,
139 | ),
140 | ]
141 |
142 | def make_decoder_options(self):
143 | gr.Markdown("Decoder Options")
144 | with gr.Row():
145 | return [
146 | gr.Textbox(
147 | label="sampling method",
148 | value=self.decoder_options["sm"],
149 | ),
150 | gr.Slider(
151 | label="Classifier-free guidance scales",
152 | value=self.decoder_options["cf_scale"],
153 | minimum=0.1,
154 | maximum=24,
155 | ),
156 | ]
157 |
158 | def make_sr_64_256_options(self):
159 | return [gr.Variable(self.sr_64_256_options["sm"])]
160 |
161 | def make_basic_options(self):
162 | self.global_options_gr = self.make_global_options()
163 | self.prior_optios_gr = self.make_prior_options()
164 | self.decoder_options_gr = self.make_decoder_options()
165 | self.sr_64_256_options_gr = self.make_sr_64_256_options()
166 |
167 | def seed(self, seed):
168 | set_seed(seed)
169 |
170 | def _sample(self, output_generator):
171 | for k, out in enumerate(output_generator):
172 | self.result_queue.put((out, False))
173 | self.result_queue.put((None, True))
174 |
175 | def t2i_sample(
176 | self,
177 | text_input,
178 | prior_sm,
179 | prior_cf_scale,
180 | decoder_sm,
181 | decoder_cf_scale,
182 | sr_sm,
183 | seed,
184 | max_bsz,
185 | ):
186 | t0 = time.time()
187 | assert hasattr(self.t2i_sampler, "_prior_sm")
188 | assert hasattr(self.t2i_sampler, "_prior_cf_scale")
189 | assert hasattr(self.t2i_sampler, "_decoder_sm")
190 | assert hasattr(self.t2i_sampler, "_decoder_cf_scale")
191 | assert hasattr(self.t2i_sampler, "_sr_sm")
192 |
193 | print("-" * 100)
194 | print(f"text_input: {text_input}")
195 | print(f"prior_sm: {prior_sm}")
196 | print(f"prior_cf_scale: {prior_cf_scale}")
197 | print(f"decoder_sm: {decoder_sm}")
198 | print(f"decoder_cf_scale: {decoder_cf_scale}")
199 | print(f"sr_sm: {sr_sm}")
200 | print(f"seed: {seed}")
201 | print(f"max_bsz: {max_bsz}")
202 |
203 | self.t2i_sampler._prior_sm = prior_sm
204 | self.t2i_sampler._prior_cf_scale = prior_cf_scale
205 |
206 | self.t2i_sampler._decoder_sm = decoder_sm
207 | self.t2i_sampler._decoder_cf_scale = decoder_cf_scale
208 |
209 | self.t2i_sampler._sr_sm = sr_sm
210 |
211 | self.seed(seed)
212 |
213 | output_generator = self.t2i_sampler(
214 | prompt=text_input,
215 | bsz=max_bsz,
216 | progressive_mode=self._progressive,
217 | )
218 |
219 | thread = threading.Thread(target=self._sample, args=(output_generator,))
220 | thread.start()
221 | done = False
222 |
223 | while not done:
224 | if self.result_queue.empty():
225 | time.sleep(0.1)
226 | else:
227 | while not self.result_queue.empty():
228 | _out, done = self.result_queue.get(0) # get last item to display
229 | if not done:
230 | out = _out
231 | images, concat_image = tensor_to_images(out, (256, 256))
232 | yield (text_input, images), concat_image
233 |
234 | thread.join()
235 | yield (text_input, images), concat_image
236 |
237 | t1 = time.time()
238 | execution_time = t1 - t0
239 | logging.info(f"Generation done. {text_input} -- {execution_time:.6f}secs")
240 | print("-" * 100)
241 |
242 | def i2i_sample(
243 | self,
244 | image_input,
245 | decoder_sm,
246 | decoder_cf_scale,
247 | sr_sm,
248 | seed,
249 | max_bsz,
250 | ):
251 | t0 = time.time()
252 | assert hasattr(self.i2i_sampler, "_decoder_sm")
253 | assert hasattr(self.i2i_sampler, "_decoder_cf_scale")
254 | assert hasattr(self.i2i_sampler, "_sr_sm")
255 |
256 | print("-" * 100)
257 | print(f"decoder_sm: {decoder_sm}")
258 | print(f"decoder_cf_scale: {decoder_cf_scale}")
259 | print(f"sr_sm: {sr_sm}")
260 | print(f"seed: {seed}")
261 | print(f"max_bsz: {max_bsz}")
262 |
263 | self.i2i_sampler._decoder_sm = decoder_sm
264 | self.i2i_sampler._decoder_cf_scale = decoder_cf_scale
265 |
266 | self.i2i_sampler._sr_sm = sr_sm
267 |
268 | self.seed(seed)
269 |
270 | output_generator = self.i2i_sampler(
271 | image=image_input,
272 | bsz=max_bsz,
273 | progressive_mode=self._progressive,
274 | )
275 |
276 | thread = threading.Thread(target=self._sample, args=(output_generator,))
277 | thread.start()
278 | done = False
279 |
280 | while not done:
281 | if self.result_queue.empty():
282 | time.sleep(0.1)
283 | else:
284 | while not self.result_queue.empty():
285 | _out, done = self.result_queue.get(0) # get last item to display
286 | if not done:
287 | out = _out
288 | images, concat_image = tensor_to_images(out, (256, 256))
289 | yield ("", images), concat_image
290 |
291 | thread.join()
292 | yield ("", images), concat_image
293 |
294 | t1 = time.time()
295 | execution_time = t1 - t0
296 | logging.info(f"Variation done. {execution_time:.6f}secs")
297 | print("-" * 100)
298 |
299 |
300 | class ImageSelecter:
301 | @classmethod
302 | def make_basic_ui(cls, max_bsz):
303 | with gr.Box():
304 | i2i_select_idx = gr.Radio(
305 | choices=[str(i) for i in range(0, max_bsz)],
306 | value="0",
307 | label="Image index",
308 | )
309 | i2i_select_button = gr.Button(
310 | "Select for Image Variation", variant="primary"
311 | )
312 | return {
313 | "i2i_select_idx": i2i_select_idx,
314 | "i2i_select_button": i2i_select_button,
315 | }
316 |
317 | @classmethod
318 | def select_fn(cls, stash, idx):
319 | if stash is not None:
320 | return Image.fromarray(stash[1][int(idx)].copy())
321 |
322 | @classmethod
323 | def setup_button_click(
324 | cls,
325 | selector_ui,
326 | stash,
327 | i2i_input_images,
328 | ):
329 | selector_ui["i2i_select_button"].click(
330 | fn=cls.select_fn,
331 | inputs=[stash, selector_ui["i2i_select_idx"]],
332 | outputs=[i2i_input_images],
333 | )
334 |
--------------------------------------------------------------------------------
/demo/product_demo.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import argparse
7 | import logging
8 | import gradio as gr
9 | import os
10 | import sys
11 |
12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13 |
14 | from karlo import __version__ as karlo_ver
15 | from demo.components import GradioSampler, ImageSelecter
16 |
17 |
18 | class GradioDemo:
19 | def __init__(
20 | self,
21 | root_dir: str,
22 | max_bsz: int,
23 | progressive: str,
24 | sampling_type: str,
25 | ):
26 | sampler = GradioSampler(
27 | root_dir=root_dir,
28 | max_bsz=max_bsz,
29 | progressive=progressive,
30 | sampling_type=sampling_type,
31 | )
32 |
33 | demo = gr.Blocks()
34 | with demo:
35 | gr.Markdown(f"# Karlo Demo {karlo_ver}")
36 | with gr.Box():
37 | gr.Markdown("## Generate 64px images + Upscaling to 256px")
38 |
39 | with gr.Tabs():
40 | with gr.TabItem("Image Generation"):
41 | t2i_text_input = gr.Textbox(
42 | lines=1,
43 | placeholder="Type text prompt...",
44 | label="Text prompts",
45 | )
46 | t2i_button = gr.Button("Generate", variant="primary")
47 | with gr.TabItem("Image Variation"):
48 | i2i_img_input = gr.Image(label="Image input", type="pil")
49 | i2i_button = gr.Button("Generate", variant="primary")
50 |
51 | with gr.Box():
52 | outputs = gr.Image(label="Generated", type="pil")
53 | stash = gr.Variable()
54 | with gr.Row():
55 | selector_ui = ImageSelecter.make_basic_ui(max_bsz=max_bsz)
56 |
57 | with gr.Box():
58 | with gr.Accordion(label="Advanced Options", open=False):
59 | sampler.make_basic_options()
60 |
61 | with gr.Box():
62 | with gr.Accordion(label="Checkpoint Information", open=False):
63 | gr.Markdown(sampler.ckpt_info)
64 |
65 | t2i_button.click(
66 | fn=sampler.t2i_sample,
67 | inputs=[t2i_text_input]
68 | + sampler.prior_optios_gr
69 | + sampler.decoder_options_gr
70 | + sampler.sr_64_256_options_gr
71 | + sampler.global_options_gr,
72 | outputs=[stash, outputs],
73 | )
74 | i2i_button.click(
75 | fn=sampler.i2i_sample,
76 | inputs=[i2i_img_input]
77 | + sampler.decoder_options_gr
78 | + sampler.sr_64_256_options_gr
79 | + sampler.global_options_gr,
80 | outputs=[stash, outputs],
81 | )
82 |
83 | ImageSelecter.setup_button_click(selector_ui, stash, i2i_img_input)
84 |
85 | demo.queue()
86 | self.demo = demo
87 |
88 |
89 | def default_parser():
90 | parser = argparse.ArgumentParser()
91 | parser.add_argument("--root-dir", type=str, default=None)
92 | parser.add_argument("--max_bsz", type=int, default=1)
93 | parser.add_argument(
94 | "--progressive", type=str, default="loop", choices=("loop", "stage", "final")
95 | )
96 | parser.add_argument("--host", type=str, default="localhost")
97 | parser.add_argument("--port", type=int, default=6006)
98 |
99 | parser.add_argument(
100 | "--sampling-type",
101 | type=str,
102 | default="fast",
103 | choices=("fast", "default"),
104 | )
105 |
106 | return parser
107 |
108 |
109 | if __name__ == "__main__":
110 | parser = default_parser()
111 | args = parser.parse_args()
112 | logging.getLogger().setLevel(logging.INFO)
113 |
114 | assert (
115 | args.root_dir is not None
116 | ), "--root-dir argument should be specified to load the pretrained ckpt"
117 |
118 | """Making Gradio"""
119 | gradio_demo = GradioDemo(
120 | root_dir=args.root_dir,
121 | max_bsz=args.max_bsz,
122 | progressive=args.progressive,
123 | sampling_type=args.sampling_type,
124 | )
125 | gradio_demo.demo.launch(server_name=args.host, server_port=args.port)
126 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import os
7 | import argparse
8 | import logging
9 | import time
10 | from datetime import datetime
11 |
12 | import torch
13 | from PIL import Image
14 |
15 | from karlo.sampler.t2i import T2ISampler
16 | from karlo.utils.util import set_seed
17 |
18 |
19 | def default_parser():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | "--root-dir", type=str, required=True, help="path for model checkpoints"
23 | )
24 | parser.add_argument("--max-bsz", type=int, default=1, help="#images to generate")
25 | parser.add_argument(
26 | "--output-dir",
27 | type=str,
28 | default="outputs",
29 | help="output path for generated images",
30 | )
31 | parser.add_argument(
32 | "--sampling-type",
33 | type=str,
34 | default="fast",
35 | choices=("fast", "default"),
36 | )
37 | parser.add_argument(
38 | "--prompt", type=str, default="A photo of a baby puppy waiting for her mom."
39 | )
40 | parser.add_argument("--seed", type=int, default=0)
41 |
42 | return parser
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = default_parser()
47 | args = parser.parse_args()
48 |
49 | set_seed(args.seed)
50 | logging.getLogger().setLevel(logging.INFO)
51 |
52 | save_dir = os.path.join(args.output_dir, datetime.now().strftime("%d%m%Y_%H%M%S"))
53 | if not os.path.exists(save_dir):
54 | os.makedirs(save_dir)
55 |
56 | model = T2ISampler.from_pretrained(
57 | root_dir=args.root_dir,
58 | clip_model_path="ViT-L-14.pt",
59 | clip_stat_path="ViT-L-14_stats.th",
60 | sampling_type=args.sampling_type,
61 | )
62 |
63 | for i in range(5):
64 | t1 = time.time()
65 |
66 | images = iter(
67 | model(
68 | prompt=args.prompt,
69 | bsz=args.max_bsz,
70 | progressive_mode="final",
71 | )
72 | ).__next__()
73 |
74 | # NCHW, [0, 1], float32 -> NHWC, [0, 255], uint8
75 | images = (
76 | torch.permute(images * 255.0, [0, 2, 3, 1]).type(torch.uint8).cpu().numpy()
77 | )
78 |
79 | t2 = time.time()
80 | execution_time = t2 - t1
81 | logging.info(f"Iteration {i} -- {execution_time:.6f}secs")
82 |
83 | # Select the first one
84 | image = Image.fromarray(images[0])
85 | image_name = "_".join(args.prompt.split(" "))
86 | image.save(f"{save_dir}/{image_name}_{i:02d}.jpg")
87 |
--------------------------------------------------------------------------------
/karlo/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.alpha"
2 |
--------------------------------------------------------------------------------
/karlo/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/karlo/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/models/__init__.py
--------------------------------------------------------------------------------
/karlo/models/clip.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 | # ------------------------------------------------------------------------------------
6 | # Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
7 | # ------------------------------------------------------------------------------------
8 |
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import clip
14 |
15 | from clip.model import CLIP, convert_weights
16 | from clip.simple_tokenizer import SimpleTokenizer, default_bpe
17 |
18 |
19 | """===== Monkey-Patching original CLIP for JIT compile ====="""
20 |
21 |
22 | class LayerNorm(nn.LayerNorm):
23 | """Subclass torch's LayerNorm to handle fp16."""
24 |
25 | def forward(self, x: torch.Tensor):
26 | orig_type = x.dtype
27 | ret = F.layer_norm(
28 | x.type(torch.float32),
29 | self.normalized_shape,
30 | self.weight,
31 | self.bias,
32 | self.eps,
33 | )
34 | return ret.type(orig_type)
35 |
36 |
37 | clip.model.LayerNorm = LayerNorm
38 | delattr(clip.model.CLIP, "forward")
39 |
40 | """===== End of Monkey-Patching ====="""
41 |
42 |
43 | class CustomizedCLIP(CLIP):
44 | def __init__(self, *args, **kwargs):
45 | super().__init__(*args, **kwargs)
46 |
47 | @torch.jit.export
48 | def encode_image(self, image):
49 | return self.visual(image)
50 |
51 | @torch.jit.export
52 | def encode_text(self, text):
53 | # re-define this function to return unpooled text features
54 |
55 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
56 |
57 | x = x + self.positional_embedding.type(self.dtype)
58 | x = x.permute(1, 0, 2) # NLD -> LND
59 | x = self.transformer(x)
60 | x = x.permute(1, 0, 2) # LND -> NLD
61 | x = self.ln_final(x).type(self.dtype)
62 |
63 | x_seq = x
64 | # x.shape = [batch_size, n_ctx, transformer.width]
65 | # take features from the eot embedding (eot_token is the highest number in each sequence)
66 | x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
67 |
68 | return x_out, x_seq
69 |
70 | @torch.jit.ignore
71 | def forward(self, image, text):
72 | super().forward(image, text)
73 |
74 | @classmethod
75 | def load_from_checkpoint(cls, ckpt_path: str):
76 | state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
77 |
78 | vit = "visual.proj" in state_dict
79 | if vit:
80 | vision_width = state_dict["visual.conv1.weight"].shape[0]
81 | vision_layers = len(
82 | [
83 | k
84 | for k in state_dict.keys()
85 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
86 | ]
87 | )
88 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
89 | grid_size = round(
90 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
91 | )
92 | image_resolution = vision_patch_size * grid_size
93 | else:
94 | counts: list = [
95 | len(
96 | set(
97 | k.split(".")[2]
98 | for k in state_dict
99 | if k.startswith(f"visual.layer{b}")
100 | )
101 | )
102 | for b in [1, 2, 3, 4]
103 | ]
104 | vision_layers = tuple(counts)
105 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
106 | output_width = round(
107 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
108 | )
109 | vision_patch_size = None
110 | assert (
111 | output_width**2 + 1
112 | == state_dict["visual.attnpool.positional_embedding"].shape[0]
113 | )
114 | image_resolution = output_width * 32
115 |
116 | embed_dim = state_dict["text_projection"].shape[1]
117 | context_length = state_dict["positional_embedding"].shape[0]
118 | vocab_size = state_dict["token_embedding.weight"].shape[0]
119 | transformer_width = state_dict["ln_final.weight"].shape[0]
120 | transformer_heads = transformer_width // 64
121 | transformer_layers = len(
122 | set(
123 | k.split(".")[2]
124 | for k in state_dict
125 | if k.startswith("transformer.resblocks")
126 | )
127 | )
128 |
129 | model = cls(
130 | embed_dim,
131 | image_resolution,
132 | vision_layers,
133 | vision_width,
134 | vision_patch_size,
135 | context_length,
136 | vocab_size,
137 | transformer_width,
138 | transformer_heads,
139 | transformer_layers,
140 | )
141 |
142 | for key in ["input_resolution", "context_length", "vocab_size"]:
143 | if key in state_dict:
144 | del state_dict[key]
145 |
146 | convert_weights(model)
147 | model.load_state_dict(state_dict)
148 | model.eval()
149 | model.float()
150 | return model
151 |
152 |
153 | class CustomizedTokenizer(SimpleTokenizer):
154 | def __init__(self):
155 | super().__init__(bpe_path=default_bpe())
156 |
157 | self.sot_token = self.encoder["<|startoftext|>"]
158 | self.eot_token = self.encoder["<|endoftext|>"]
159 |
160 | def padded_tokens_and_mask(self, texts, text_ctx):
161 | assert isinstance(texts, list) and all(
162 | isinstance(elem, str) for elem in texts
163 | ), "texts should be a list of strings"
164 |
165 | all_tokens = [
166 | [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
167 | ]
168 |
169 | mask = [
170 | [True] * min(text_ctx, len(tokens))
171 | + [False] * max(text_ctx - len(tokens), 0)
172 | for tokens in all_tokens
173 | ]
174 | mask = torch.tensor(mask, dtype=torch.bool)
175 | result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
176 | for i, tokens in enumerate(all_tokens):
177 | if len(tokens) > text_ctx:
178 | tokens = tokens[:text_ctx]
179 | tokens[-1] = self.eot_token
180 | result[i, : len(tokens)] = torch.tensor(tokens)
181 |
182 | return result, mask
183 |
--------------------------------------------------------------------------------
/karlo/models/decoder_model.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import copy
7 | import torch
8 |
9 | from ..modules import create_gaussian_diffusion
10 | from ..modules.unet import PLMImUNet
11 |
12 |
13 | class Text2ImProgressiveModel(torch.nn.Module):
14 | """
15 | A decoder that generates 64x64px images based on the text prompt.
16 |
17 | :param config: yaml config to define the decoder.
18 | :param tokenizer: tokenizer used in clip.
19 | """
20 |
21 | def __init__(
22 | self,
23 | config,
24 | tokenizer,
25 | ):
26 | super().__init__()
27 |
28 | self._conf = config
29 | self._model_conf = config.model.hparams
30 | self._diffusion_kwargs = dict(
31 | steps=config.diffusion.steps,
32 | learn_sigma=config.diffusion.learn_sigma,
33 | sigma_small=config.diffusion.sigma_small,
34 | noise_schedule=config.diffusion.noise_schedule,
35 | use_kl=config.diffusion.use_kl,
36 | predict_xstart=config.diffusion.predict_xstart,
37 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
38 | timestep_respacing=config.diffusion.timestep_respacing,
39 | )
40 | self._tokenizer = tokenizer
41 |
42 | self.model = self.create_plm_dec_model()
43 |
44 | cf_token, cf_mask = self.set_cf_text_tensor()
45 | self.register_buffer("cf_token", cf_token, persistent=False)
46 | self.register_buffer("cf_mask", cf_mask, persistent=False)
47 |
48 | @classmethod
49 | def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
50 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
51 |
52 | model = cls(config, tokenizer)
53 | model.load_state_dict(ckpt, strict=strict)
54 | return model
55 |
56 | def create_plm_dec_model(self):
57 | image_size = self._model_conf.image_size
58 | if self._model_conf.channel_mult == "":
59 | if image_size == 256:
60 | channel_mult = (1, 1, 2, 2, 4, 4)
61 | elif image_size == 128:
62 | channel_mult = (1, 1, 2, 3, 4)
63 | elif image_size == 64:
64 | channel_mult = (1, 2, 3, 4)
65 | else:
66 | raise ValueError(f"unsupported image size: {image_size}")
67 | else:
68 | channel_mult = tuple(
69 | int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
70 | )
71 | assert 2 ** (len(channel_mult) + 2) == image_size
72 |
73 | attention_ds = []
74 | for res in self._model_conf.attention_resolutions.split(","):
75 | attention_ds.append(image_size // int(res))
76 |
77 | return PLMImUNet(
78 | text_ctx=self._model_conf.text_ctx,
79 | xf_width=self._model_conf.xf_width,
80 | in_channels=3,
81 | model_channels=self._model_conf.num_channels,
82 | out_channels=6 if self._model_conf.learn_sigma else 3,
83 | num_res_blocks=self._model_conf.num_res_blocks,
84 | attention_resolutions=tuple(attention_ds),
85 | dropout=self._model_conf.dropout,
86 | channel_mult=channel_mult,
87 | num_heads=self._model_conf.num_heads,
88 | num_head_channels=self._model_conf.num_head_channels,
89 | num_heads_upsample=self._model_conf.num_heads_upsample,
90 | use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
91 | resblock_updown=self._model_conf.resblock_updown,
92 | clip_dim=self._model_conf.clip_dim,
93 | clip_emb_mult=self._model_conf.clip_emb_mult,
94 | clip_emb_type=self._model_conf.clip_emb_type,
95 | clip_emb_drop=self._model_conf.clip_emb_drop,
96 | )
97 |
98 | def set_cf_text_tensor(self):
99 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
100 |
101 | def get_sample_fn(self, timestep_respacing):
102 | use_ddim = timestep_respacing.startswith(("ddim", "fast"))
103 |
104 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
105 | diffusion_kwargs.update(timestep_respacing=timestep_respacing)
106 | diffusion = create_gaussian_diffusion(**diffusion_kwargs)
107 | sample_fn = (
108 | diffusion.ddim_sample_loop_progressive
109 | if use_ddim
110 | else diffusion.p_sample_loop_progressive
111 | )
112 |
113 | return sample_fn
114 |
115 | def forward(
116 | self,
117 | txt_feat,
118 | txt_feat_seq,
119 | tok,
120 | mask,
121 | img_feat=None,
122 | cf_guidance_scales=None,
123 | timestep_respacing=None,
124 | ):
125 | # cfg should be enabled in inference
126 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
127 | assert img_feat is not None
128 |
129 | bsz = txt_feat.shape[0]
130 | img_sz = self._model_conf.image_size
131 |
132 | def guided_model_fn(x_t, ts, **kwargs):
133 | half = x_t[: len(x_t) // 2]
134 | combined = torch.cat([half, half], dim=0)
135 | model_out = self.model(combined, ts, **kwargs)
136 | eps, rest = model_out[:, :3], model_out[:, 3:]
137 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
138 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
139 | cond_eps - uncond_eps
140 | )
141 | eps = torch.cat([half_eps, half_eps], dim=0)
142 | return torch.cat([eps, rest], dim=1)
143 |
144 | cf_feat = self.model.cf_param.unsqueeze(0)
145 | cf_feat = cf_feat.expand(bsz // 2, -1)
146 | feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
147 |
148 | cond = {
149 | "y": feat,
150 | "txt_feat": txt_feat,
151 | "txt_feat_seq": txt_feat_seq,
152 | "mask": mask,
153 | }
154 | sample_fn = self.get_sample_fn(timestep_respacing)
155 | sample_outputs = sample_fn(
156 | guided_model_fn,
157 | (bsz, 3, img_sz, img_sz),
158 | noise=None,
159 | device=txt_feat.device,
160 | clip_denoised=True,
161 | model_kwargs=cond,
162 | )
163 |
164 | for out in sample_outputs:
165 | sample = out["sample"]
166 | yield sample if cf_guidance_scales is None else sample[
167 | : sample.shape[0] // 2
168 | ]
169 |
170 |
171 | class Text2ImModel(Text2ImProgressiveModel):
172 | def forward(
173 | self,
174 | txt_feat,
175 | txt_feat_seq,
176 | tok,
177 | mask,
178 | img_feat=None,
179 | cf_guidance_scales=None,
180 | timestep_respacing=None,
181 | ):
182 | last_out = None
183 | for out in super().forward(
184 | txt_feat,
185 | txt_feat_seq,
186 | tok,
187 | mask,
188 | img_feat,
189 | cf_guidance_scales,
190 | timestep_respacing,
191 | ):
192 | last_out = out
193 | return last_out
194 |
--------------------------------------------------------------------------------
/karlo/models/prior_model.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import copy
7 | import torch
8 |
9 | from ..modules import create_gaussian_diffusion
10 | from ..modules.xf import PriorTransformer
11 |
12 |
13 | class PriorDiffusionModel(torch.nn.Module):
14 | """
15 | A prior that generates clip image feature based on the text prompt.
16 |
17 | :param config: yaml config to define the decoder.
18 | :param tokenizer: tokenizer used in clip.
19 | :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
20 | :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
21 | """
22 |
23 | def __init__(self, config, tokenizer, clip_mean, clip_std):
24 | super().__init__()
25 |
26 | self._conf = config
27 | self._model_conf = config.model.hparams
28 | self._diffusion_kwargs = dict(
29 | steps=config.diffusion.steps,
30 | learn_sigma=config.diffusion.learn_sigma,
31 | sigma_small=config.diffusion.sigma_small,
32 | noise_schedule=config.diffusion.noise_schedule,
33 | use_kl=config.diffusion.use_kl,
34 | predict_xstart=config.diffusion.predict_xstart,
35 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
36 | timestep_respacing=config.diffusion.timestep_respacing,
37 | )
38 | self._tokenizer = tokenizer
39 |
40 | self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
41 | self.register_buffer("clip_std", clip_std[None, :], persistent=False)
42 |
43 | causal_mask = self.get_causal_mask()
44 | self.register_buffer("causal_mask", causal_mask, persistent=False)
45 |
46 | self.model = PriorTransformer(
47 | text_ctx=self._model_conf.text_ctx,
48 | xf_width=self._model_conf.xf_width,
49 | xf_layers=self._model_conf.xf_layers,
50 | xf_heads=self._model_conf.xf_heads,
51 | xf_final_ln=self._model_conf.xf_final_ln,
52 | clip_dim=self._model_conf.clip_dim,
53 | )
54 |
55 | cf_token, cf_mask = self.set_cf_text_tensor()
56 | self.register_buffer("cf_token", cf_token, persistent=False)
57 | self.register_buffer("cf_mask", cf_mask, persistent=False)
58 |
59 | @classmethod
60 | def load_from_checkpoint(
61 | cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
62 | ):
63 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
64 |
65 | model = cls(config, tokenizer, clip_mean, clip_std)
66 | model.load_state_dict(ckpt, strict=strict)
67 | return model
68 |
69 | def set_cf_text_tensor(self):
70 | return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
71 |
72 | def get_sample_fn(self, timestep_respacing):
73 | use_ddim = timestep_respacing.startswith(("ddim", "fast"))
74 |
75 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
76 | diffusion_kwargs.update(timestep_respacing=timestep_respacing)
77 | diffusion = create_gaussian_diffusion(**diffusion_kwargs)
78 | sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
79 |
80 | return sample_fn
81 |
82 | def get_causal_mask(self):
83 | seq_len = self._model_conf.text_ctx + 4
84 | mask = torch.empty(seq_len, seq_len)
85 | mask.fill_(float("-inf"))
86 | mask.triu_(1)
87 | mask = mask[None, ...]
88 | return mask
89 |
90 | def forward(
91 | self,
92 | txt_feat,
93 | txt_feat_seq,
94 | mask,
95 | cf_guidance_scales=None,
96 | timestep_respacing=None,
97 | denoised_fn=True,
98 | ):
99 | # cfg should be enabled in inference
100 | assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
101 |
102 | bsz_ = txt_feat.shape[0]
103 | bsz = bsz_ // 2
104 |
105 | def guided_model_fn(x_t, ts, **kwargs):
106 | half = x_t[: len(x_t) // 2]
107 | combined = torch.cat([half, half], dim=0)
108 | model_out = self.model(combined, ts, **kwargs)
109 | eps, rest = (
110 | model_out[:, : int(x_t.shape[1])],
111 | model_out[:, int(x_t.shape[1]) :],
112 | )
113 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
114 | half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
115 | cond_eps - uncond_eps
116 | )
117 | eps = torch.cat([half_eps, half_eps], dim=0)
118 | return torch.cat([eps, rest], dim=1)
119 |
120 | cond = {
121 | "text_emb": txt_feat,
122 | "text_enc": txt_feat_seq,
123 | "mask": mask,
124 | "causal_mask": self.causal_mask,
125 | }
126 | sample_fn = self.get_sample_fn(timestep_respacing)
127 | sample = sample_fn(
128 | guided_model_fn,
129 | (bsz_, self.model.clip_dim),
130 | noise=None,
131 | device=txt_feat.device,
132 | clip_denoised=False,
133 | denoised_fn=lambda x: torch.clamp(x, -10, 10),
134 | model_kwargs=cond,
135 | )
136 | sample = (sample * self.clip_std) + self.clip_mean
137 |
138 | return sample[:bsz]
139 |
--------------------------------------------------------------------------------
/karlo/models/sr_256_1k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | from .sr_64_256 import SupRes64to256Progressive
7 |
8 |
9 | class SupRes256to1kProgressive(SupRes64to256Progressive):
10 | pass # no difference currently
11 |
--------------------------------------------------------------------------------
/karlo/models/sr_64_256.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import copy
7 | import torch
8 |
9 | from ..modules.unet import SuperResUNetModel
10 | from ..modules import create_gaussian_diffusion
11 |
12 |
13 | class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
14 | """
15 | ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
16 | In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
17 | In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
18 | This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
19 | """
20 |
21 | def __init__(self, config):
22 | super().__init__()
23 |
24 | self._config = config
25 | self._diffusion_kwargs = dict(
26 | steps=config.diffusion.steps,
27 | learn_sigma=config.diffusion.learn_sigma,
28 | sigma_small=config.diffusion.sigma_small,
29 | noise_schedule=config.diffusion.noise_schedule,
30 | use_kl=config.diffusion.use_kl,
31 | predict_xstart=config.diffusion.predict_xstart,
32 | rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
33 | )
34 |
35 | self.model_first_steps = SuperResUNetModel(
36 | in_channels=3, # auto-changed to 6 inside the model
37 | model_channels=config.model.hparams.channels,
38 | out_channels=3,
39 | num_res_blocks=config.model.hparams.depth,
40 | attention_resolutions=(), # no attention
41 | dropout=config.model.hparams.dropout,
42 | channel_mult=config.model.hparams.channels_multiple,
43 | resblock_updown=True,
44 | use_middle_attention=False,
45 | )
46 | self.model_last_step = SuperResUNetModel(
47 | in_channels=3, # auto-changed to 6 inside the model
48 | model_channels=config.model.hparams.channels,
49 | out_channels=3,
50 | num_res_blocks=config.model.hparams.depth,
51 | attention_resolutions=(), # no attention
52 | dropout=config.model.hparams.dropout,
53 | channel_mult=config.model.hparams.channels_multiple,
54 | resblock_updown=True,
55 | use_middle_attention=False,
56 | )
57 |
58 | @classmethod
59 | def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
60 | ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
61 |
62 | model = cls(config)
63 | model.load_state_dict(ckpt, strict=strict)
64 | return model
65 |
66 | def get_sample_fn(self, timestep_respacing):
67 | diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
68 | diffusion_kwargs.update(timestep_respacing=timestep_respacing)
69 | diffusion = create_gaussian_diffusion(**diffusion_kwargs)
70 | return diffusion.p_sample_loop_progressive_for_improved_sr
71 |
72 | def forward(self, low_res, timestep_respacing="7", **kwargs):
73 | assert (
74 | timestep_respacing == "7"
75 | ), "different respacing method may work, but no guaranteed"
76 |
77 | sample_fn = self.get_sample_fn(timestep_respacing)
78 | sample_outputs = sample_fn(
79 | self.model_first_steps,
80 | self.model_last_step,
81 | shape=low_res.shape,
82 | clip_denoised=True,
83 | model_kwargs=dict(low_res=low_res),
84 | **kwargs,
85 | )
86 | for x in sample_outputs:
87 | sample = x["sample"]
88 | yield sample
89 |
--------------------------------------------------------------------------------
/karlo/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 |
6 | from .diffusion import gaussian_diffusion as gd
7 | from .diffusion.respace import (
8 | SpacedDiffusion,
9 | space_timesteps,
10 | )
11 |
12 |
13 | def create_gaussian_diffusion(
14 | steps,
15 | learn_sigma,
16 | sigma_small,
17 | noise_schedule,
18 | use_kl,
19 | predict_xstart,
20 | rescale_learned_sigmas,
21 | timestep_respacing,
22 | ):
23 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
24 | if use_kl:
25 | loss_type = gd.LossType.RESCALED_KL
26 | elif rescale_learned_sigmas:
27 | loss_type = gd.LossType.RESCALED_MSE
28 | else:
29 | loss_type = gd.LossType.MSE
30 | if not timestep_respacing:
31 | timestep_respacing = [steps]
32 |
33 | return SpacedDiffusion(
34 | use_timesteps=space_timesteps(steps, timestep_respacing),
35 | betas=betas,
36 | model_mean_type=(
37 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
38 | ),
39 | model_var_type=(
40 | (
41 | gd.ModelVarType.FIXED_LARGE
42 | if not sigma_small
43 | else gd.ModelVarType.FIXED_SMALL
44 | )
45 | if not learn_sigma
46 | else gd.ModelVarType.LEARNED_RANGE
47 | ),
48 | loss_type=loss_type,
49 | )
50 |
--------------------------------------------------------------------------------
/karlo/modules/diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 | import enum
6 | import math
7 |
8 | import numpy as np
9 | import torch as th
10 |
11 |
12 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
13 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
14 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
15 | betas[:warmup_time] = np.linspace(
16 | beta_start, beta_end, warmup_time, dtype=np.float64
17 | )
18 | return betas
19 |
20 |
21 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
22 | """
23 | This is the deprecated API for creating beta schedules.
24 | See get_named_beta_schedule() for the new library of schedules.
25 | """
26 | if beta_schedule == "quad":
27 | betas = (
28 | np.linspace(
29 | beta_start**0.5,
30 | beta_end**0.5,
31 | num_diffusion_timesteps,
32 | dtype=np.float64,
33 | )
34 | ** 2
35 | )
36 | elif beta_schedule == "linear":
37 | betas = np.linspace(
38 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
39 | )
40 | elif beta_schedule == "warmup10":
41 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
42 | elif beta_schedule == "warmup50":
43 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
44 | elif beta_schedule == "const":
45 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
46 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
47 | betas = 1.0 / np.linspace(
48 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
49 | )
50 | else:
51 | raise NotImplementedError(beta_schedule)
52 | assert betas.shape == (num_diffusion_timesteps,)
53 | return betas
54 |
55 |
56 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
57 | """
58 | Get a pre-defined beta schedule for the given name.
59 | The beta schedule library consists of beta schedules which remain similar
60 | in the limit of num_diffusion_timesteps.
61 | Beta schedules may be added, but should not be removed or changed once
62 | they are committed to maintain backwards compatibility.
63 | """
64 | if schedule_name == "linear":
65 | # Linear schedule from Ho et al, extended to work for any number of
66 | # diffusion steps.
67 | scale = 1000 / num_diffusion_timesteps
68 | return get_beta_schedule(
69 | "linear",
70 | beta_start=scale * 0.0001,
71 | beta_end=scale * 0.02,
72 | num_diffusion_timesteps=num_diffusion_timesteps,
73 | )
74 | elif schedule_name == "squaredcos_cap_v2":
75 | return betas_for_alpha_bar(
76 | num_diffusion_timesteps,
77 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
78 | )
79 | else:
80 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
81 |
82 |
83 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
84 | """
85 | Create a beta schedule that discretizes the given alpha_t_bar function,
86 | which defines the cumulative product of (1-beta) over time from t = [0,1].
87 | :param num_diffusion_timesteps: the number of betas to produce.
88 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
89 | produces the cumulative product of (1-beta) up to that
90 | part of the diffusion process.
91 | :param max_beta: the maximum beta to use; use values lower than 1 to
92 | prevent singularities.
93 | """
94 | betas = []
95 | for i in range(num_diffusion_timesteps):
96 | t1 = i / num_diffusion_timesteps
97 | t2 = (i + 1) / num_diffusion_timesteps
98 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
99 | return np.array(betas)
100 |
101 |
102 | class ModelMeanType(enum.Enum):
103 | """
104 | Which type of output the model predicts.
105 | """
106 |
107 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
108 | START_X = enum.auto() # the model predicts x_0
109 | EPSILON = enum.auto() # the model predicts epsilon
110 |
111 |
112 | class ModelVarType(enum.Enum):
113 | """
114 | What is used as the model's output variance.
115 | The LEARNED_RANGE option has been added to allow the model to predict
116 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
117 | """
118 |
119 | LEARNED = enum.auto()
120 | FIXED_SMALL = enum.auto()
121 | FIXED_LARGE = enum.auto()
122 | LEARNED_RANGE = enum.auto()
123 |
124 |
125 | class LossType(enum.Enum):
126 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
127 | RESCALED_MSE = (
128 | enum.auto()
129 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
130 | KL = enum.auto() # use the variational lower-bound
131 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
132 |
133 | def is_vb(self):
134 | return self == LossType.KL or self == LossType.RESCALED_KL
135 |
136 |
137 | class GaussianDiffusion(th.nn.Module):
138 | """
139 | Utilities for training and sampling diffusion models.
140 | Original ported from this codebase:
141 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
142 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
143 | starting at T and going to 1.
144 | """
145 |
146 | def __init__(
147 | self,
148 | *,
149 | betas,
150 | model_mean_type,
151 | model_var_type,
152 | loss_type,
153 | ):
154 | super(GaussianDiffusion, self).__init__()
155 | self.model_mean_type = model_mean_type
156 | self.model_var_type = model_var_type
157 | self.loss_type = loss_type
158 |
159 | # Use float64 for accuracy.
160 | betas = np.array(betas, dtype=np.float64)
161 | assert len(betas.shape) == 1, "betas must be 1-D"
162 | assert (betas > 0).all() and (betas <= 1).all()
163 |
164 | self.num_timesteps = int(betas.shape[0])
165 |
166 | alphas = 1.0 - betas
167 | alphas_cumprod = np.cumprod(alphas, axis=0)
168 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
169 | alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0)
170 | assert alphas_cumprod_prev.shape == (self.num_timesteps,)
171 |
172 | # calculations for diffusion q(x_t | x_{t-1}) and others
173 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
174 | sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
175 | log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod)
176 | sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
177 | sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
178 |
179 | # calculations for posterior q(x_{t-1} | x_t, x_0)
180 | posterior_variance = (
181 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
182 | )
183 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
184 | posterior_log_variance_clipped = np.log(
185 | np.append(posterior_variance[1], posterior_variance[1:])
186 | )
187 | posterior_mean_coef1 = (
188 | betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
189 | )
190 | posterior_mean_coef2 = (
191 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
192 | )
193 |
194 | self.register_buffer("betas", th.from_numpy(betas), persistent=False)
195 | self.register_buffer(
196 | "alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False
197 | )
198 | self.register_buffer(
199 | "alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False
200 | )
201 | self.register_buffer(
202 | "alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False
203 | )
204 |
205 | self.register_buffer(
206 | "sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False
207 | )
208 | self.register_buffer(
209 | "sqrt_one_minus_alphas_cumprod",
210 | th.from_numpy(sqrt_one_minus_alphas_cumprod),
211 | persistent=False,
212 | )
213 | self.register_buffer(
214 | "log_one_minus_alphas_cumprod",
215 | th.from_numpy(log_one_minus_alphas_cumprod),
216 | persistent=False,
217 | )
218 | self.register_buffer(
219 | "sqrt_recip_alphas_cumprod",
220 | th.from_numpy(sqrt_recip_alphas_cumprod),
221 | persistent=False,
222 | )
223 | self.register_buffer(
224 | "sqrt_recipm1_alphas_cumprod",
225 | th.from_numpy(sqrt_recipm1_alphas_cumprod),
226 | persistent=False,
227 | )
228 |
229 | self.register_buffer(
230 | "posterior_variance", th.from_numpy(posterior_variance), persistent=False
231 | )
232 | self.register_buffer(
233 | "posterior_log_variance_clipped",
234 | th.from_numpy(posterior_log_variance_clipped),
235 | persistent=False,
236 | )
237 | self.register_buffer(
238 | "posterior_mean_coef1",
239 | th.from_numpy(posterior_mean_coef1),
240 | persistent=False,
241 | )
242 | self.register_buffer(
243 | "posterior_mean_coef2",
244 | th.from_numpy(posterior_mean_coef2),
245 | persistent=False,
246 | )
247 |
248 | def q_mean_variance(self, x_start, t):
249 | """
250 | Get the distribution q(x_t | x_0).
251 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
252 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
253 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
254 | """
255 | mean = (
256 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
257 | )
258 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
259 | log_variance = _extract_into_tensor(
260 | self.log_one_minus_alphas_cumprod, t, x_start.shape
261 | )
262 | return mean, variance, log_variance
263 |
264 | def q_sample(self, x_start, t, noise=None):
265 | """
266 | Diffuse the data for a given number of diffusion steps.
267 | In other words, sample from q(x_t | x_0).
268 | :param x_start: the initial data batch.
269 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
270 | :param noise: if specified, the split-out normal noise.
271 | :return: A noisy version of x_start.
272 | """
273 | if noise is None:
274 | noise = th.randn_like(x_start)
275 | assert noise.shape == x_start.shape
276 | return (
277 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
278 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
279 | * noise
280 | )
281 |
282 | def q_posterior_mean_variance(self, x_start, x_t, t):
283 | """
284 | Compute the mean and variance of the diffusion posterior:
285 | q(x_{t-1} | x_t, x_0)
286 | """
287 | assert x_start.shape == x_t.shape
288 | posterior_mean = (
289 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
290 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
291 | )
292 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
293 | posterior_log_variance_clipped = _extract_into_tensor(
294 | self.posterior_log_variance_clipped, t, x_t.shape
295 | )
296 | assert (
297 | posterior_mean.shape[0]
298 | == posterior_variance.shape[0]
299 | == posterior_log_variance_clipped.shape[0]
300 | == x_start.shape[0]
301 | )
302 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
303 |
304 | def p_mean_variance(
305 | self,
306 | model,
307 | x,
308 | t,
309 | clip_denoised=True,
310 | denoised_fn=None,
311 | model_kwargs=None,
312 | **ignore_kwargs,
313 | ):
314 | """
315 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
316 | the initial x, x_0.
317 | :param model: the model, which takes a signal and a batch of timesteps
318 | as input.
319 | :param x: the [N x C x ...] tensor at time t.
320 | :param t: a 1-D Tensor of timesteps.
321 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
322 | :param denoised_fn: if not None, a function which applies to the
323 | x_start prediction before it is used to sample. Applies before
324 | clip_denoised.
325 | :param model_kwargs: if not None, a dict of extra keyword arguments to
326 | pass to the model. This can be used for conditioning.
327 | :return: a dict with the following keys:
328 | - 'mean': the model mean output.
329 | - 'variance': the model variance output.
330 | - 'log_variance': the log of 'variance'.
331 | - 'pred_xstart': the prediction for x_0.
332 | """
333 | if model_kwargs is None:
334 | model_kwargs = {}
335 |
336 | B, C = x.shape[:2]
337 | assert t.shape == (B,)
338 | model_output = model(x, t, **model_kwargs)
339 | if isinstance(model_output, tuple):
340 | model_output, extra = model_output
341 | else:
342 | extra = None
343 |
344 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
345 | assert model_output.shape == (B, C * 2, *x.shape[2:])
346 | model_output, model_var_values = th.split(model_output, C, dim=1)
347 | if self.model_var_type == ModelVarType.LEARNED:
348 | model_log_variance = model_var_values
349 | model_variance = th.exp(model_log_variance)
350 | else:
351 | min_log = _extract_into_tensor(
352 | self.posterior_log_variance_clipped, t, x.shape
353 | )
354 | max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
355 | # The model_var_values is [-1, 1] for [min_var, max_var].
356 | frac = (model_var_values + 1) / 2
357 | model_log_variance = frac * max_log + (1 - frac) * min_log
358 | model_variance = th.exp(model_log_variance)
359 | else:
360 | model_variance, model_log_variance = {
361 | # for fixedlarge, we set the initial (log-)variance like so
362 | # to get a better decoder log likelihood.
363 | ModelVarType.FIXED_LARGE: (
364 | th.cat([self.posterior_variance[1][None], self.betas[1:]]),
365 | th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])),
366 | ),
367 | ModelVarType.FIXED_SMALL: (
368 | self.posterior_variance,
369 | self.posterior_log_variance_clipped,
370 | ),
371 | }[self.model_var_type]
372 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
373 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
374 |
375 | def process_xstart(x):
376 | if denoised_fn is not None:
377 | x = denoised_fn(x)
378 | if clip_denoised:
379 | return x.clamp(-1, 1)
380 | return x
381 |
382 | if self.model_mean_type == ModelMeanType.PREVIOUS_X:
383 | pred_xstart = process_xstart(
384 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
385 | )
386 | model_mean = model_output
387 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
388 | if self.model_mean_type == ModelMeanType.START_X:
389 | pred_xstart = process_xstart(model_output)
390 | else:
391 | pred_xstart = process_xstart(
392 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
393 | )
394 | model_mean, _, _ = self.q_posterior_mean_variance(
395 | x_start=pred_xstart, x_t=x, t=t
396 | )
397 | else:
398 | raise NotImplementedError(self.model_mean_type)
399 |
400 | assert (
401 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
402 | )
403 | return {
404 | "mean": model_mean,
405 | "variance": model_variance,
406 | "log_variance": model_log_variance,
407 | "pred_xstart": pred_xstart,
408 | }
409 |
410 | def _predict_xstart_from_eps(self, x_t, t, eps):
411 | assert x_t.shape == eps.shape
412 | return (
413 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
414 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
415 | )
416 |
417 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
418 | return (
419 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
420 | - pred_xstart
421 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
422 |
423 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
424 | """
425 | Compute the mean for the previous step, given a function cond_fn that
426 | computes the gradient of a conditional log probability with respect to
427 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
428 | condition on y.
429 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
430 | """
431 | gradient = cond_fn(x, t, **model_kwargs)
432 | new_mean = (
433 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
434 | )
435 | return new_mean
436 |
437 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
438 | """
439 | Compute what the p_mean_variance output would have been, should the
440 | model's score function be conditioned by cond_fn.
441 | See condition_mean() for details on cond_fn.
442 | Unlike condition_mean(), this instead uses the conditioning strategy
443 | from Song et al (2020).
444 | """
445 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
446 |
447 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
448 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
449 |
450 | out = p_mean_var.copy()
451 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
452 | out["mean"], _, _ = self.q_posterior_mean_variance(
453 | x_start=out["pred_xstart"], x_t=x, t=t
454 | )
455 | return out
456 |
457 | def p_sample(
458 | self,
459 | model,
460 | x,
461 | t,
462 | clip_denoised=True,
463 | denoised_fn=None,
464 | cond_fn=None,
465 | model_kwargs=None,
466 | ):
467 | """
468 | Sample x_{t-1} from the model at the given timestep.
469 | :param model: the model to sample from.
470 | :param x: the current tensor at x_{t-1}.
471 | :param t: the value of t, starting at 0 for the first diffusion step.
472 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
473 | :param denoised_fn: if not None, a function which applies to the
474 | x_start prediction before it is used to sample.
475 | :param cond_fn: if not None, this is a gradient function that acts
476 | similarly to the model.
477 | :param model_kwargs: if not None, a dict of extra keyword arguments to
478 | pass to the model. This can be used for conditioning.
479 | :return: a dict containing the following keys:
480 | - 'sample': a random sample from the model.
481 | - 'pred_xstart': a prediction of x_0.
482 | """
483 | out = self.p_mean_variance(
484 | model,
485 | x,
486 | t,
487 | clip_denoised=clip_denoised,
488 | denoised_fn=denoised_fn,
489 | model_kwargs=model_kwargs,
490 | )
491 | noise = th.randn_like(x)
492 | nonzero_mask = (
493 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
494 | ) # no noise when t == 0
495 | if cond_fn is not None:
496 | out["mean"] = self.condition_mean(
497 | cond_fn, out, x, t, model_kwargs=model_kwargs
498 | )
499 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
500 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
501 |
502 | def p_sample_loop(
503 | self,
504 | model,
505 | shape,
506 | noise=None,
507 | clip_denoised=True,
508 | denoised_fn=None,
509 | cond_fn=None,
510 | model_kwargs=None,
511 | device=None,
512 | progress=False,
513 | ):
514 | """
515 | Generate samples from the model.
516 | :param model: the model module.
517 | :param shape: the shape of the samples, (N, C, H, W).
518 | :param noise: if specified, the noise from the encoder to sample.
519 | Should be of the same shape as `shape`.
520 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
521 | :param denoised_fn: if not None, a function which applies to the
522 | x_start prediction before it is used to sample.
523 | :param cond_fn: if not None, this is a gradient function that acts
524 | similarly to the model.
525 | :param model_kwargs: if not None, a dict of extra keyword arguments to
526 | pass to the model. This can be used for conditioning.
527 | :param device: if specified, the device to create the samples on.
528 | If not specified, use a model parameter's device.
529 | :param progress: if True, show a tqdm progress bar.
530 | :return: a non-differentiable batch of samples.
531 | """
532 | final = None
533 | for sample in self.p_sample_loop_progressive(
534 | model,
535 | shape,
536 | noise=noise,
537 | clip_denoised=clip_denoised,
538 | denoised_fn=denoised_fn,
539 | cond_fn=cond_fn,
540 | model_kwargs=model_kwargs,
541 | device=device,
542 | progress=progress,
543 | ):
544 | final = sample
545 | return final["sample"]
546 |
547 | def p_sample_loop_progressive(
548 | self,
549 | model,
550 | shape,
551 | noise=None,
552 | clip_denoised=True,
553 | denoised_fn=None,
554 | cond_fn=None,
555 | model_kwargs=None,
556 | device=None,
557 | progress=False,
558 | ):
559 | """
560 | Generate samples from the model and yield intermediate samples from
561 | each timestep of diffusion.
562 | Arguments are the same as p_sample_loop().
563 | Returns a generator over dicts, where each dict is the return value of
564 | p_sample().
565 | """
566 | if device is None:
567 | device = next(model.parameters()).device
568 | assert isinstance(shape, (tuple, list))
569 | if noise is not None:
570 | img = noise
571 | else:
572 | img = th.randn(*shape, device=device)
573 | indices = list(range(self.num_timesteps))[::-1]
574 |
575 | if progress:
576 | # Lazy import so that we don't depend on tqdm.
577 | from tqdm.auto import tqdm
578 |
579 | indices = tqdm(indices)
580 |
581 | for idx, i in enumerate(indices):
582 | t = th.tensor([i] * shape[0], device=device)
583 | with th.no_grad():
584 | out = self.p_sample(
585 | model,
586 | img,
587 | t,
588 | clip_denoised=clip_denoised,
589 | denoised_fn=denoised_fn,
590 | cond_fn=cond_fn,
591 | model_kwargs=model_kwargs,
592 | )
593 | yield out
594 | img = out["sample"]
595 |
596 | def p_sample_loop_progressive_for_improved_sr(
597 | self,
598 | model,
599 | model_aux,
600 | shape,
601 | noise=None,
602 | clip_denoised=True,
603 | denoised_fn=None,
604 | cond_fn=None,
605 | model_kwargs=None,
606 | device=None,
607 | progress=False,
608 | ):
609 | """
610 | Modified version of p_sample_loop_progressive for sampling from the improved sr model
611 | """
612 |
613 | if device is None:
614 | device = next(model.parameters()).device
615 | assert isinstance(shape, (tuple, list))
616 | if noise is not None:
617 | img = noise
618 | else:
619 | img = th.randn(*shape, device=device)
620 | indices = list(range(self.num_timesteps))[::-1]
621 |
622 | if progress:
623 | # Lazy import so that we don't depend on tqdm.
624 | from tqdm.auto import tqdm
625 |
626 | indices = tqdm(indices)
627 |
628 | for idx, i in enumerate(indices):
629 | t = th.tensor([i] * shape[0], device=device)
630 | with th.no_grad():
631 | out = self.p_sample(
632 | model_aux if len(indices) - 1 == idx else model,
633 | img,
634 | t,
635 | clip_denoised=clip_denoised,
636 | denoised_fn=denoised_fn,
637 | cond_fn=cond_fn,
638 | model_kwargs=model_kwargs,
639 | )
640 | yield out
641 | img = out["sample"]
642 |
643 | def ddim_sample(
644 | self,
645 | model,
646 | x,
647 | t,
648 | clip_denoised=True,
649 | denoised_fn=None,
650 | cond_fn=None,
651 | model_kwargs=None,
652 | eta=0.0,
653 | ):
654 | """
655 | Sample x_{t-1} from the model using DDIM.
656 | Same usage as p_sample().
657 | """
658 | out = self.p_mean_variance(
659 | model,
660 | x,
661 | t,
662 | clip_denoised=clip_denoised,
663 | denoised_fn=denoised_fn,
664 | model_kwargs=model_kwargs,
665 | )
666 | if cond_fn is not None:
667 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
668 |
669 | # Usually our model outputs epsilon, but we re-derive it
670 | # in case we used x_start or x_prev prediction.
671 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
672 |
673 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
674 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
675 | sigma = (
676 | eta
677 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
678 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
679 | )
680 | # Equation 12.
681 | noise = th.randn_like(x)
682 | mean_pred = (
683 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
684 | + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
685 | )
686 | nonzero_mask = (
687 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
688 | ) # no noise when t == 0
689 | sample = mean_pred + nonzero_mask * sigma * noise
690 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
691 |
692 | def ddim_reverse_sample(
693 | self,
694 | model,
695 | x,
696 | t,
697 | clip_denoised=True,
698 | denoised_fn=None,
699 | cond_fn=None,
700 | model_kwargs=None,
701 | eta=0.0,
702 | ):
703 | """
704 | Sample x_{t+1} from the model using DDIM reverse ODE.
705 | """
706 | assert eta == 0.0, "Reverse ODE only for deterministic path"
707 | out = self.p_mean_variance(
708 | model,
709 | x,
710 | t,
711 | clip_denoised=clip_denoised,
712 | denoised_fn=denoised_fn,
713 | model_kwargs=model_kwargs,
714 | )
715 | if cond_fn is not None:
716 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
717 | # Usually our model outputs epsilon, but we re-derive it
718 | # in case we used x_start or x_prev prediction.
719 | eps = (
720 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
721 | - out["pred_xstart"]
722 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
723 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
724 |
725 | # Equation 12. reversed
726 | mean_pred = (
727 | out["pred_xstart"] * th.sqrt(alpha_bar_next)
728 | + th.sqrt(1 - alpha_bar_next) * eps
729 | )
730 |
731 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
732 |
733 | def ddim_sample_loop(
734 | self,
735 | model,
736 | shape,
737 | noise=None,
738 | clip_denoised=True,
739 | denoised_fn=None,
740 | cond_fn=None,
741 | model_kwargs=None,
742 | device=None,
743 | progress=False,
744 | eta=0.0,
745 | ):
746 | """
747 | Generate samples from the model using DDIM.
748 | Same usage as p_sample_loop().
749 | """
750 | final = None
751 | for sample in self.ddim_sample_loop_progressive(
752 | model,
753 | shape,
754 | noise=noise,
755 | clip_denoised=clip_denoised,
756 | denoised_fn=denoised_fn,
757 | cond_fn=cond_fn,
758 | model_kwargs=model_kwargs,
759 | device=device,
760 | progress=progress,
761 | eta=eta,
762 | ):
763 | final = sample
764 | return final["sample"]
765 |
766 | def ddim_sample_loop_progressive(
767 | self,
768 | model,
769 | shape,
770 | noise=None,
771 | clip_denoised=True,
772 | denoised_fn=None,
773 | cond_fn=None,
774 | model_kwargs=None,
775 | device=None,
776 | progress=False,
777 | eta=0.0,
778 | ):
779 | """
780 | Use DDIM to sample from the model and yield intermediate samples from
781 | each timestep of DDIM.
782 | Same usage as p_sample_loop_progressive().
783 | """
784 | if device is None:
785 | device = next(model.parameters()).device
786 | assert isinstance(shape, (tuple, list))
787 | if noise is not None:
788 | img = noise
789 | else:
790 | img = th.randn(*shape, device=device)
791 | indices = list(range(self.num_timesteps))[::-1]
792 |
793 | if progress:
794 | # Lazy import so that we don't depend on tqdm.
795 | from tqdm.auto import tqdm
796 |
797 | indices = tqdm(indices)
798 |
799 | for i in indices:
800 | t = th.tensor([i] * shape[0], device=device)
801 | with th.no_grad():
802 | out = self.ddim_sample(
803 | model,
804 | img,
805 | t,
806 | clip_denoised=clip_denoised,
807 | denoised_fn=denoised_fn,
808 | cond_fn=cond_fn,
809 | model_kwargs=model_kwargs,
810 | eta=eta,
811 | )
812 | yield out
813 | img = out["sample"]
814 |
815 |
816 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
817 | """
818 | Extract values from a 1-D numpy array for a batch of indices.
819 | :param arr: the 1-D numpy array.
820 | :param timesteps: a tensor of indices into the array to extract.
821 | :param broadcast_shape: a larger shape of K dimensions with the batch
822 | dimension equal to the length of timesteps.
823 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
824 | """
825 | res = arr.to(device=timesteps.device)[timesteps].float()
826 | while len(res.shape) < len(broadcast_shape):
827 | res = res[..., None]
828 | return res + th.zeros(broadcast_shape, device=timesteps.device)
829 |
--------------------------------------------------------------------------------
/karlo/modules/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 |
6 | import torch as th
7 |
8 | from .gaussian_diffusion import GaussianDiffusion
9 |
10 |
11 | def space_timesteps(num_timesteps, section_counts):
12 | """
13 | Create a list of timesteps to use from an original diffusion process,
14 | given the number of timesteps we want to take from equally-sized portions
15 | of the original process.
16 |
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 |
21 | :param num_timesteps: the number of diffusion steps in the original
22 | process to divide up.
23 | :param section_counts: either a list of numbers, or a string containing
24 | comma-separated numbers, indicating the step count
25 | per section. As a special case, use "ddimN" where N
26 | is a number of steps to use the striding from the
27 | DDIM paper.
28 | :return: a set of diffusion steps from the original process to use.
29 | """
30 | if isinstance(section_counts, str):
31 | if section_counts.startswith("ddim"):
32 | desired_count = int(section_counts[len("ddim") :])
33 | for i in range(1, num_timesteps):
34 | if len(range(0, num_timesteps, i)) == desired_count:
35 | return set(range(0, num_timesteps, i))
36 | raise ValueError(
37 | f"cannot create exactly {num_timesteps} steps with an integer stride"
38 | )
39 | elif section_counts == "fast27":
40 | steps = space_timesteps(num_timesteps, "10,10,3,2,2")
41 | # Help reduce DDIM artifacts from noisiest timesteps.
42 | steps.remove(num_timesteps - 1)
43 | steps.add(num_timesteps - 3)
44 | return steps
45 | section_counts = [int(x) for x in section_counts.split(",")]
46 | size_per = num_timesteps // len(section_counts)
47 | extra = num_timesteps % len(section_counts)
48 | start_idx = 0
49 | all_steps = []
50 | for i, section_count in enumerate(section_counts):
51 | size = size_per + (1 if i < extra else 0)
52 | if size < section_count:
53 | raise ValueError(
54 | f"cannot divide section of {size} steps into {section_count}"
55 | )
56 | if section_count <= 1:
57 | frac_stride = 1
58 | else:
59 | frac_stride = (size - 1) / (section_count - 1)
60 | cur_idx = 0.0
61 | taken_steps = []
62 | for _ in range(section_count):
63 | taken_steps.append(start_idx + round(cur_idx))
64 | cur_idx += frac_stride
65 | all_steps += taken_steps
66 | start_idx += size
67 | return set(all_steps)
68 |
69 |
70 | class SpacedDiffusion(GaussianDiffusion):
71 | """
72 | A diffusion process which can skip steps in a base diffusion process.
73 |
74 | :param use_timesteps: a collection (sequence or set) of timesteps from the
75 | original diffusion process to retain.
76 | :param kwargs: the kwargs to create the base diffusion process.
77 | """
78 |
79 | def __init__(self, use_timesteps, **kwargs):
80 | self.use_timesteps = set(use_timesteps)
81 | self.original_num_steps = len(kwargs["betas"])
82 |
83 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
84 | last_alpha_cumprod = 1.0
85 | new_betas = []
86 | timestep_map = []
87 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
88 | if i in self.use_timesteps:
89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90 | last_alpha_cumprod = alpha_cumprod
91 | timestep_map.append(i)
92 | kwargs["betas"] = th.tensor(new_betas).numpy()
93 | super().__init__(**kwargs)
94 | self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
95 |
96 | def p_mean_variance(self, model, *args, **kwargs):
97 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | def wrapped(x, ts, **kwargs):
107 | ts_cpu = ts.detach().to("cpu")
108 | return model(
109 | x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
110 | )
111 |
112 | return wrapped
113 |
--------------------------------------------------------------------------------
/karlo/modules/nn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class GroupNorm32(nn.GroupNorm):
13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5):
14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
15 | self.swish = swish
16 |
17 | def forward(self, x):
18 | y = super().forward(x.float()).to(x.dtype)
19 | if self.swish == 1.0:
20 | y = F.silu(y)
21 | elif self.swish:
22 | y = y * F.sigmoid(y * float(self.swish))
23 | return y
24 |
25 |
26 | def conv_nd(dims, *args, **kwargs):
27 | """
28 | Create a 1D, 2D, or 3D convolution module.
29 | """
30 | if dims == 1:
31 | return nn.Conv1d(*args, **kwargs)
32 | elif dims == 2:
33 | return nn.Conv2d(*args, **kwargs)
34 | elif dims == 3:
35 | return nn.Conv3d(*args, **kwargs)
36 | raise ValueError(f"unsupported dimensions: {dims}")
37 |
38 |
39 | def linear(*args, **kwargs):
40 | """
41 | Create a linear module.
42 | """
43 | return nn.Linear(*args, **kwargs)
44 |
45 |
46 | def avg_pool_nd(dims, *args, **kwargs):
47 | """
48 | Create a 1D, 2D, or 3D average pooling module.
49 | """
50 | if dims == 1:
51 | return nn.AvgPool1d(*args, **kwargs)
52 | elif dims == 2:
53 | return nn.AvgPool2d(*args, **kwargs)
54 | elif dims == 3:
55 | return nn.AvgPool3d(*args, **kwargs)
56 | raise ValueError(f"unsupported dimensions: {dims}")
57 |
58 |
59 | def zero_module(module):
60 | """
61 | Zero out the parameters of a module and return it.
62 | """
63 | for p in module.parameters():
64 | p.detach().zero_()
65 | return module
66 |
67 |
68 | def scale_module(module, scale):
69 | """
70 | Scale the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().mul_(scale)
74 | return module
75 |
76 |
77 | def normalization(channels, swish=0.0):
78 | """
79 | Make a standard normalization layer, with an optional swish activation.
80 |
81 | :param channels: number of input channels.
82 | :return: an nn.Module for normalization.
83 | """
84 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
85 |
86 |
87 | def timestep_embedding(timesteps, dim, max_period=10000):
88 | """
89 | Create sinusoidal timestep embeddings.
90 |
91 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
92 | These may be fractional.
93 | :param dim: the dimension of the output.
94 | :param max_period: controls the minimum frequency of the embeddings.
95 | :return: an [N x dim] Tensor of positional embeddings.
96 | """
97 | half = dim // 2
98 | freqs = th.exp(
99 | -math.log(max_period)
100 | * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
101 | / half
102 | )
103 | args = timesteps[:, None].float() * freqs[None]
104 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
105 | if dim % 2:
106 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
107 | return embedding
108 |
109 |
110 | def mean_flat(tensor):
111 | """
112 | Take the mean over all non-batch dimensions.
113 | """
114 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
115 |
--------------------------------------------------------------------------------
/karlo/modules/resample.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 | from abc import abstractmethod
6 |
7 | import torch as th
8 |
9 |
10 | def create_named_schedule_sampler(name, diffusion):
11 | """
12 | Create a ScheduleSampler from a library of pre-defined samplers.
13 |
14 | :param name: the name of the sampler.
15 | :param diffusion: the diffusion object to sample for.
16 | """
17 | if name == "uniform":
18 | return UniformSampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(th.nn.Module):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / th.sum(w)
54 | indices = p.multinomial(batch_size, replacement=True)
55 | weights = 1 / (len(p) * p[indices])
56 | return indices, weights
57 |
58 |
59 | class UniformSampler(ScheduleSampler):
60 | def __init__(self, diffusion):
61 | super(UniformSampler, self).__init__()
62 | self.diffusion = diffusion
63 | self.register_buffer(
64 | "_weights", th.ones([diffusion.num_timesteps]), persistent=False
65 | )
66 |
67 | def weights(self):
68 | return self._weights
69 |
--------------------------------------------------------------------------------
/karlo/modules/unet.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
3 | # ------------------------------------------------------------------------------------
4 |
5 | import math
6 | from abc import abstractmethod
7 |
8 | import torch as th
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from .nn import (
13 | avg_pool_nd,
14 | conv_nd,
15 | linear,
16 | normalization,
17 | timestep_embedding,
18 | zero_module,
19 | )
20 | from .xf import LayerNorm
21 |
22 |
23 | class TimestepBlock(nn.Module):
24 | """
25 | Any module where forward() takes timestep embeddings as a second argument.
26 | """
27 |
28 | @abstractmethod
29 | def forward(self, x, emb):
30 | """
31 | Apply the module to `x` given `emb` timestep embeddings.
32 | """
33 |
34 |
35 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
36 | """
37 | A sequential module that passes timestep embeddings to the children that
38 | support it as an extra input.
39 | """
40 |
41 | def forward(self, x, emb, encoder_out=None, mask=None):
42 | for layer in self:
43 | if isinstance(layer, TimestepBlock):
44 | x = layer(x, emb)
45 | elif isinstance(layer, AttentionBlock):
46 | x = layer(x, encoder_out, mask=mask)
47 | else:
48 | x = layer(x)
49 | return x
50 |
51 |
52 | class Upsample(nn.Module):
53 | """
54 | An upsampling layer with an optional convolution.
55 |
56 | :param channels: channels in the inputs and outputs.
57 | :param use_conv: a bool determining if a convolution is applied.
58 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
59 | upsampling occurs in the inner-two dimensions.
60 | """
61 |
62 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
63 | super().__init__()
64 | self.channels = channels
65 | self.out_channels = out_channels or channels
66 | self.use_conv = use_conv
67 | self.dims = dims
68 | if use_conv:
69 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
70 |
71 | def forward(self, x):
72 | assert x.shape[1] == self.channels
73 | if self.dims == 3:
74 | x = F.interpolate(
75 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
76 | )
77 | else:
78 | x = F.interpolate(x, scale_factor=2, mode="nearest")
79 | if self.use_conv:
80 | x = self.conv(x)
81 | return x
82 |
83 |
84 | class Downsample(nn.Module):
85 | """
86 | A downsampling layer with an optional convolution.
87 |
88 | :param channels: channels in the inputs and outputs.
89 | :param use_conv: a bool determining if a convolution is applied.
90 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
91 | downsampling occurs in the inner-two dimensions.
92 | """
93 |
94 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
95 | super().__init__()
96 | self.channels = channels
97 | self.out_channels = out_channels or channels
98 | self.use_conv = use_conv
99 | self.dims = dims
100 | stride = 2 if dims != 3 else (1, 2, 2)
101 | if use_conv:
102 | self.op = conv_nd(
103 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1
104 | )
105 | else:
106 | assert self.channels == self.out_channels
107 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
108 |
109 | def forward(self, x):
110 | assert x.shape[1] == self.channels
111 | return self.op(x)
112 |
113 |
114 | class ResBlock(TimestepBlock):
115 | """
116 | A residual block that can optionally change the number of channels.
117 |
118 | :param channels: the number of input channels.
119 | :param emb_channels: the number of timestep embedding channels.
120 | :param dropout: the rate of dropout.
121 | :param out_channels: if specified, the number of out channels.
122 | :param use_conv: if True and out_channels is specified, use a spatial
123 | convolution instead of a smaller 1x1 convolution to change the
124 | channels in the skip connection.
125 | :param dims: determines if the signal is 1D, 2D, or 3D.
126 | :param use_checkpoint: if True, use gradient checkpointing on this module.
127 | :param up: if True, use this block for upsampling.
128 | :param down: if True, use this block for downsampling.
129 | """
130 |
131 | def __init__(
132 | self,
133 | channels,
134 | emb_channels,
135 | dropout,
136 | out_channels=None,
137 | use_conv=False,
138 | use_scale_shift_norm=False,
139 | dims=2,
140 | use_checkpoint=False,
141 | up=False,
142 | down=False,
143 | ):
144 | super().__init__()
145 | self.channels = channels
146 | self.emb_channels = emb_channels
147 | self.dropout = dropout
148 | self.out_channels = out_channels or channels
149 | self.use_conv = use_conv
150 | self.use_checkpoint = use_checkpoint
151 | self.use_scale_shift_norm = use_scale_shift_norm
152 |
153 | self.in_layers = nn.Sequential(
154 | normalization(channels, swish=1.0),
155 | nn.Identity(),
156 | conv_nd(dims, channels, self.out_channels, 3, padding=1),
157 | )
158 |
159 | self.updown = up or down
160 |
161 | if up:
162 | self.h_upd = Upsample(channels, False, dims)
163 | self.x_upd = Upsample(channels, False, dims)
164 | elif down:
165 | self.h_upd = Downsample(channels, False, dims)
166 | self.x_upd = Downsample(channels, False, dims)
167 | else:
168 | self.h_upd = self.x_upd = nn.Identity()
169 |
170 | self.emb_layers = nn.Sequential(
171 | nn.SiLU(),
172 | linear(
173 | emb_channels,
174 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
175 | ),
176 | )
177 | self.out_layers = nn.Sequential(
178 | normalization(
179 | self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0
180 | ),
181 | nn.SiLU() if use_scale_shift_norm else nn.Identity(),
182 | nn.Dropout(p=dropout),
183 | zero_module(
184 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
185 | ),
186 | )
187 |
188 | if self.out_channels == channels:
189 | self.skip_connection = nn.Identity()
190 | elif use_conv:
191 | self.skip_connection = conv_nd(
192 | dims, channels, self.out_channels, 3, padding=1
193 | )
194 | else:
195 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
196 |
197 | def forward(self, x, emb):
198 | """
199 | Apply the block to a Tensor, conditioned on a timestep embedding.
200 |
201 | :param x: an [N x C x ...] Tensor of features.
202 | :param emb: an [N x emb_channels] Tensor of timestep embeddings.
203 | :return: an [N x C x ...] Tensor of outputs.
204 | """
205 | if self.updown:
206 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
207 | h = in_rest(x)
208 | h = self.h_upd(h)
209 | x = self.x_upd(x)
210 | h = in_conv(h)
211 | else:
212 | h = self.in_layers(x)
213 | emb_out = self.emb_layers(emb)
214 | while len(emb_out.shape) < len(h.shape):
215 | emb_out = emb_out[..., None]
216 | if self.use_scale_shift_norm:
217 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
218 | scale, shift = th.chunk(emb_out, 2, dim=1)
219 | h = out_norm(h) * (1 + scale) + shift
220 | h = out_rest(h)
221 | else:
222 | h = h + emb_out
223 | h = self.out_layers(h)
224 | return self.skip_connection(x) + h
225 |
226 |
227 | class ResBlockNoTimeEmbedding(nn.Module):
228 | """
229 | A residual block without time embedding
230 |
231 | :param channels: the number of input channels.
232 | :param emb_channels: the number of timestep embedding channels.
233 | :param dropout: the rate of dropout.
234 | :param out_channels: if specified, the number of out channels.
235 | :param use_conv: if True and out_channels is specified, use a spatial
236 | convolution instead of a smaller 1x1 convolution to change the
237 | channels in the skip connection.
238 | :param dims: determines if the signal is 1D, 2D, or 3D.
239 | :param use_checkpoint: if True, use gradient checkpointing on this module.
240 | :param up: if True, use this block for upsampling.
241 | :param down: if True, use this block for downsampling.
242 | """
243 |
244 | def __init__(
245 | self,
246 | channels,
247 | emb_channels,
248 | dropout,
249 | out_channels=None,
250 | use_conv=False,
251 | dims=2,
252 | use_checkpoint=False,
253 | up=False,
254 | down=False,
255 | **kwargs,
256 | ):
257 | super().__init__()
258 | self.channels = channels
259 | self.emb_channels = emb_channels
260 | self.dropout = dropout
261 | self.out_channels = out_channels or channels
262 | self.use_conv = use_conv
263 | self.use_checkpoint = use_checkpoint
264 |
265 | self.in_layers = nn.Sequential(
266 | normalization(channels, swish=1.0),
267 | nn.Identity(),
268 | conv_nd(dims, channels, self.out_channels, 3, padding=1),
269 | )
270 |
271 | self.updown = up or down
272 |
273 | if up:
274 | self.h_upd = Upsample(channels, False, dims)
275 | self.x_upd = Upsample(channels, False, dims)
276 | elif down:
277 | self.h_upd = Downsample(channels, False, dims)
278 | self.x_upd = Downsample(channels, False, dims)
279 | else:
280 | self.h_upd = self.x_upd = nn.Identity()
281 |
282 | self.out_layers = nn.Sequential(
283 | normalization(self.out_channels, swish=1.0),
284 | nn.Dropout(p=dropout),
285 | zero_module(
286 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
287 | ),
288 | )
289 |
290 | if self.out_channels == channels:
291 | self.skip_connection = nn.Identity()
292 | elif use_conv:
293 | self.skip_connection = conv_nd(
294 | dims, channels, self.out_channels, 3, padding=1
295 | )
296 | else:
297 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
298 |
299 | def forward(self, x, emb=None):
300 | """
301 | Apply the block to a Tensor, NOT conditioned on a timestep embedding.
302 |
303 | :param x: an [N x C x ...] Tensor of features.
304 | :param emb: an [N x emb_channels] Tensor of timestep embeddings.
305 | :return: an [N x C x ...] Tensor of outputs.
306 | """
307 | assert emb is None
308 |
309 | if self.updown:
310 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
311 | h = in_rest(x)
312 | h = self.h_upd(h)
313 | x = self.x_upd(x)
314 | h = in_conv(h)
315 | else:
316 | h = self.in_layers(x)
317 | h = self.out_layers(h)
318 | return self.skip_connection(x) + h
319 |
320 |
321 | class AttentionBlock(nn.Module):
322 | """
323 | An attention block that allows spatial positions to attend to each other.
324 |
325 | Originally ported from here, but adapted to the N-d case.
326 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
327 | """
328 |
329 | def __init__(
330 | self,
331 | channels,
332 | num_heads=1,
333 | num_head_channels=-1,
334 | use_checkpoint=False,
335 | encoder_channels=None,
336 | ):
337 | super().__init__()
338 | self.channels = channels
339 | if num_head_channels == -1:
340 | self.num_heads = num_heads
341 | else:
342 | assert (
343 | channels % num_head_channels == 0
344 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
345 | self.num_heads = channels // num_head_channels
346 | self.use_checkpoint = use_checkpoint
347 | self.norm = normalization(channels, swish=0.0)
348 | self.qkv = conv_nd(1, channels, channels * 3, 1)
349 | self.attention = QKVAttention(self.num_heads)
350 |
351 | if encoder_channels is not None:
352 | self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
353 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
354 |
355 | def forward(self, x, encoder_out=None, mask=None):
356 | b, c, *spatial = x.shape
357 | qkv = self.qkv(self.norm(x).view(b, c, -1))
358 | if encoder_out is not None:
359 | encoder_out = self.encoder_kv(encoder_out)
360 | h = self.attention(qkv, encoder_out, mask=mask)
361 | else:
362 | h = self.attention(qkv)
363 | h = self.proj_out(h)
364 | return x + h.reshape(b, c, *spatial)
365 |
366 |
367 | class QKVAttention(nn.Module):
368 | """
369 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
370 | """
371 |
372 | def __init__(self, n_heads):
373 | super().__init__()
374 | self.n_heads = n_heads
375 |
376 | def forward(self, qkv, encoder_kv=None, mask=None):
377 | """
378 | Apply QKV attention.
379 |
380 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
381 | :return: an [N x (H * C) x T] tensor after attention.
382 | """
383 | bs, width, length = qkv.shape
384 | assert width % (3 * self.n_heads) == 0
385 | ch = width // (3 * self.n_heads)
386 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
387 | if encoder_kv is not None:
388 | assert encoder_kv.shape[1] == self.n_heads * ch * 2
389 | ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
390 | k = th.cat([ek, k], dim=-1)
391 | v = th.cat([ev, v], dim=-1)
392 | scale = 1 / math.sqrt(math.sqrt(ch))
393 | weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
394 | if mask is not None:
395 | mask = F.pad(mask, (0, length), value=0.0)
396 | mask = (
397 | mask.unsqueeze(1)
398 | .expand(-1, self.n_heads, -1)
399 | .reshape(bs * self.n_heads, 1, -1)
400 | )
401 | weight = weight + mask
402 | weight = th.softmax(weight, dim=-1)
403 | a = th.einsum("bts,bcs->bct", weight, v)
404 | return a.reshape(bs, -1, length)
405 |
406 |
407 | class UNetModel(nn.Module):
408 | """
409 | The full UNet model with attention and timestep embedding.
410 |
411 | :param in_channels: channels in the input Tensor.
412 | :param model_channels: base channel count for the model.
413 | :param out_channels: channels in the output Tensor.
414 | :param num_res_blocks: number of residual blocks per downsample.
415 | :param attention_resolutions: a collection of downsample rates at which
416 | attention will take place. May be a set, list, or tuple.
417 | For example, if this contains 4, then at 4x downsampling, attention
418 | will be used.
419 | :param dropout: the dropout probability.
420 | :param channel_mult: channel multiplier for each level of the UNet.
421 | :param conv_resample: if True, use learned convolutions for upsampling and
422 | downsampling.
423 | :param dims: determines if the signal is 1D, 2D, or 3D.
424 | :param clip_dim: dimension of clip feature.
425 | :param num_classes: if specified (as an int), then this model will be
426 | class-conditional with `num_classes` classes.
427 | :param use_checkpoint: use gradient checkpointing to reduce memory usage.
428 | :param num_heads: the number of attention heads in each attention layer.
429 | :param num_heads_channels: if specified, ignore num_heads and instead use
430 | a fixed channel width per attention head.
431 | :param num_heads_upsample: works with num_heads to set a different number
432 | of heads for upsampling. Deprecated.
433 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
434 | :param resblock_updown: use residual blocks for up/downsampling.
435 | :param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
436 | :param use_time_embedding: use time embedding for condition.
437 | """
438 |
439 | def __init__(
440 | self,
441 | in_channels,
442 | model_channels,
443 | out_channels,
444 | num_res_blocks,
445 | attention_resolutions,
446 | dropout=0,
447 | channel_mult=(1, 2, 4, 8),
448 | conv_resample=True,
449 | dims=2,
450 | clip_dim=None,
451 | use_checkpoint=False,
452 | num_heads=1,
453 | num_head_channels=-1,
454 | num_heads_upsample=-1,
455 | use_scale_shift_norm=False,
456 | use_middle_attention=True,
457 | resblock_updown=False,
458 | encoder_channels=None,
459 | use_time_embedding=True,
460 | ):
461 | super().__init__()
462 |
463 | if num_heads_upsample == -1:
464 | num_heads_upsample = num_heads
465 |
466 | self.in_channels = in_channels
467 | self.model_channels = model_channels
468 | self.out_channels = out_channels
469 | self.num_res_blocks = num_res_blocks
470 | self.attention_resolutions = attention_resolutions
471 | self.dropout = dropout
472 | self.channel_mult = channel_mult
473 | self.conv_resample = conv_resample
474 | self.clip_dim = clip_dim
475 | self.use_checkpoint = use_checkpoint
476 | self.num_heads = num_heads
477 | self.num_head_channels = num_head_channels
478 | self.num_heads_upsample = num_heads_upsample
479 | self.use_middle_attention = use_middle_attention
480 | self.use_time_embedding = use_time_embedding
481 |
482 | if self.use_time_embedding:
483 | time_embed_dim = model_channels * 4
484 | self.time_embed = nn.Sequential(
485 | linear(model_channels, time_embed_dim),
486 | nn.SiLU(),
487 | linear(time_embed_dim, time_embed_dim),
488 | )
489 |
490 | if self.clip_dim is not None:
491 | self.clip_emb = nn.Linear(clip_dim, time_embed_dim)
492 | else:
493 | time_embed_dim = None
494 |
495 | CustomResidualBlock = (
496 | ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding
497 | )
498 | ch = input_ch = int(channel_mult[0] * model_channels)
499 | self.input_blocks = nn.ModuleList(
500 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
501 | )
502 | self._feature_size = ch
503 | input_block_chans = [ch]
504 | ds = 1
505 | for level, mult in enumerate(channel_mult):
506 | for _ in range(num_res_blocks):
507 | layers = [
508 | CustomResidualBlock(
509 | ch,
510 | time_embed_dim,
511 | dropout,
512 | out_channels=int(mult * model_channels),
513 | dims=dims,
514 | use_checkpoint=use_checkpoint,
515 | use_scale_shift_norm=use_scale_shift_norm,
516 | )
517 | ]
518 | ch = int(mult * model_channels)
519 | if ds in attention_resolutions:
520 | layers.append(
521 | AttentionBlock(
522 | ch,
523 | use_checkpoint=use_checkpoint,
524 | num_heads=num_heads,
525 | num_head_channels=num_head_channels,
526 | encoder_channels=encoder_channels,
527 | )
528 | )
529 | self.input_blocks.append(TimestepEmbedSequential(*layers))
530 | self._feature_size += ch
531 | input_block_chans.append(ch)
532 | if level != len(channel_mult) - 1:
533 | out_ch = ch
534 | self.input_blocks.append(
535 | TimestepEmbedSequential(
536 | CustomResidualBlock(
537 | ch,
538 | time_embed_dim,
539 | dropout,
540 | out_channels=out_ch,
541 | dims=dims,
542 | use_checkpoint=use_checkpoint,
543 | use_scale_shift_norm=use_scale_shift_norm,
544 | down=True,
545 | )
546 | if resblock_updown
547 | else Downsample(
548 | ch, conv_resample, dims=dims, out_channels=out_ch
549 | )
550 | )
551 | )
552 | ch = out_ch
553 | input_block_chans.append(ch)
554 | ds *= 2
555 | self._feature_size += ch
556 |
557 | self.middle_block = TimestepEmbedSequential(
558 | CustomResidualBlock(
559 | ch,
560 | time_embed_dim,
561 | dropout,
562 | dims=dims,
563 | use_checkpoint=use_checkpoint,
564 | use_scale_shift_norm=use_scale_shift_norm,
565 | ),
566 | *(
567 | AttentionBlock(
568 | ch,
569 | use_checkpoint=use_checkpoint,
570 | num_heads=num_heads,
571 | num_head_channels=num_head_channels,
572 | encoder_channels=encoder_channels,
573 | ),
574 | )
575 | if self.use_middle_attention
576 | else tuple(), # add AttentionBlock or not
577 | CustomResidualBlock(
578 | ch,
579 | time_embed_dim,
580 | dropout,
581 | dims=dims,
582 | use_checkpoint=use_checkpoint,
583 | use_scale_shift_norm=use_scale_shift_norm,
584 | ),
585 | )
586 | self._feature_size += ch
587 |
588 | self.output_blocks = nn.ModuleList([])
589 | for level, mult in list(enumerate(channel_mult))[::-1]:
590 | for i in range(num_res_blocks + 1):
591 | ich = input_block_chans.pop()
592 | layers = [
593 | CustomResidualBlock(
594 | ch + ich,
595 | time_embed_dim,
596 | dropout,
597 | out_channels=int(model_channels * mult),
598 | dims=dims,
599 | use_checkpoint=use_checkpoint,
600 | use_scale_shift_norm=use_scale_shift_norm,
601 | )
602 | ]
603 | ch = int(model_channels * mult)
604 | if ds in attention_resolutions:
605 | layers.append(
606 | AttentionBlock(
607 | ch,
608 | use_checkpoint=use_checkpoint,
609 | num_heads=num_heads_upsample,
610 | num_head_channels=num_head_channels,
611 | encoder_channels=encoder_channels,
612 | )
613 | )
614 | if level and i == num_res_blocks:
615 | out_ch = ch
616 | layers.append(
617 | CustomResidualBlock(
618 | ch,
619 | time_embed_dim,
620 | dropout,
621 | out_channels=out_ch,
622 | dims=dims,
623 | use_checkpoint=use_checkpoint,
624 | use_scale_shift_norm=use_scale_shift_norm,
625 | up=True,
626 | )
627 | if resblock_updown
628 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
629 | )
630 | ds //= 2
631 | self.output_blocks.append(TimestepEmbedSequential(*layers))
632 | self._feature_size += ch
633 |
634 | self.out = nn.Sequential(
635 | normalization(ch, swish=1.0),
636 | nn.Identity(),
637 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
638 | )
639 |
640 | def forward(self, x, timesteps, y=None):
641 | """
642 | Apply the model to an input batch.
643 |
644 | :param x: an [N x C x ...] Tensor of inputs.
645 | :param timesteps: a 1-D batch of timesteps.
646 | :param y: an [N] Tensor of labels, if class-conditional.
647 | :return: an [N x C x ...] Tensor of outputs.
648 | """
649 | assert (y is not None) == (
650 | self.clip_dim is not None
651 | ), "must specify y if and only if the model is clip-rep-conditional"
652 |
653 | hs = []
654 | if self.use_time_embedding:
655 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
656 | if self.clip_dim is not None:
657 | emb = emb + self.clip_emb(y)
658 | else:
659 | emb = None
660 |
661 | h = x
662 | for module in self.input_blocks:
663 | h = module(h, emb)
664 | hs.append(h)
665 | h = self.middle_block(h, emb)
666 | for module in self.output_blocks:
667 | h = th.cat([h, hs.pop()], dim=1)
668 | h = module(h, emb)
669 |
670 | return self.out(h)
671 |
672 |
673 | class SuperResUNetModel(UNetModel):
674 | """
675 | A UNetModel that performs super-resolution.
676 |
677 | Expects an extra kwarg `low_res` to condition on a low-resolution image.
678 | Assumes that the shape of low-resolution and the input should be the same.
679 | """
680 |
681 | def __init__(self, *args, **kwargs):
682 | if "in_channels" in kwargs:
683 | kwargs = dict(kwargs)
684 | kwargs["in_channels"] = kwargs["in_channels"] * 2
685 | else:
686 | # Curse you, Python. Or really, just curse positional arguments :|.
687 | args = list(args)
688 | args[1] = args[1] * 2
689 | super().__init__(*args, **kwargs)
690 |
691 | def forward(self, x, timesteps, low_res=None, **kwargs):
692 | _, _, new_height, new_width = x.shape
693 | assert new_height == low_res.shape[2] and new_width == low_res.shape[3]
694 |
695 | x = th.cat([x, low_res], dim=1)
696 | return super().forward(x, timesteps, **kwargs)
697 |
698 |
699 | class PLMImUNet(UNetModel):
700 | """
701 | A UNetModel that conditions on text with a pretrained text encoder in CLIP.
702 |
703 | :param text_ctx: number of text tokens to expect.
704 | :param xf_width: width of the transformer.
705 | :param clip_emb_mult: #extra tokens by projecting clip text feature.
706 | :param clip_emb_type: type of condition (here, we fix clip image feature).
707 | :param clip_emb_drop: dropout rato of clip image feature for cfg.
708 | """
709 |
710 | def __init__(
711 | self,
712 | text_ctx,
713 | xf_width,
714 | *args,
715 | clip_emb_mult=None,
716 | clip_emb_type="image",
717 | clip_emb_drop=0.0,
718 | **kwargs,
719 | ):
720 | self.text_ctx = text_ctx
721 | self.xf_width = xf_width
722 | self.clip_emb_mult = clip_emb_mult
723 | self.clip_emb_type = clip_emb_type
724 | self.clip_emb_drop = clip_emb_drop
725 |
726 | if not xf_width:
727 | super().__init__(*args, **kwargs, encoder_channels=None)
728 | else:
729 | super().__init__(*args, **kwargs, encoder_channels=xf_width)
730 |
731 | # Project text encoded feat seq from pre-trained text encoder in CLIP
732 | self.text_seq_proj = nn.Sequential(
733 | nn.Linear(self.clip_dim, xf_width),
734 | LayerNorm(xf_width),
735 | )
736 | # Project CLIP text feat
737 | self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
738 |
739 | assert clip_emb_mult is not None
740 | assert clip_emb_type == "image"
741 | assert self.clip_dim is not None, "CLIP representation dim should be specified"
742 |
743 | self.clip_tok_proj = nn.Linear(
744 | self.clip_dim, self.xf_width * self.clip_emb_mult
745 | )
746 | if self.clip_emb_drop > 0:
747 | self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
748 |
749 | def proc_clip_emb_drop(self, feat):
750 | if self.clip_emb_drop > 0:
751 | bsz, feat_dim = feat.shape
752 | assert (
753 | feat_dim == self.clip_dim
754 | ), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}"
755 | drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop
756 | feat = th.where(
757 | drop_idx[..., None], self.cf_param[None].type_as(feat), feat
758 | )
759 | return feat
760 |
761 | def forward(
762 | self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None
763 | ):
764 | bsz = x.shape[0]
765 | hs = []
766 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
767 | emb = emb + self.clip_emb(y)
768 |
769 | xf_out = self.text_seq_proj(txt_feat_seq)
770 | xf_out = xf_out.permute(0, 2, 1)
771 | emb = emb + self.text_feat_proj(txt_feat)
772 | xf_out = th.cat(
773 | [
774 | self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
775 | xf_out,
776 | ],
777 | dim=2,
778 | )
779 | mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
780 | mask = th.where(mask, 0.0, float("-inf"))
781 |
782 | h = x
783 | for module in self.input_blocks:
784 | h = module(h, emb, xf_out, mask=mask)
785 | hs.append(h)
786 | h = self.middle_block(h, emb, xf_out, mask=mask)
787 | for module in self.output_blocks:
788 | h = th.cat([h, hs.pop()], dim=1)
789 | h = module(h, emb, xf_out, mask=mask)
790 | h = self.out(h)
791 |
792 | return h
793 |
--------------------------------------------------------------------------------
/karlo/modules/xf.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Adapted from the repos below:
3 | # (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
4 | # (b) CLIP ViT (https://github.com/openai/CLIP/)
5 | # ------------------------------------------------------------------------------------
6 |
7 | import math
8 |
9 | import torch as th
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .nn import timestep_embedding
14 |
15 |
16 | def convert_module_to_f16(param):
17 | """
18 | Convert primitive modules to float16.
19 | """
20 | if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
21 | param.weight.data = param.weight.data.half()
22 | if param.bias is not None:
23 | param.bias.data = param.bias.data.half()
24 |
25 |
26 | class LayerNorm(nn.LayerNorm):
27 | """
28 | Implementation that supports fp16 inputs but fp32 gains/biases.
29 | """
30 |
31 | def forward(self, x: th.Tensor):
32 | return super().forward(x.float()).to(x.dtype)
33 |
34 |
35 | class MultiheadAttention(nn.Module):
36 | def __init__(self, n_ctx, width, heads):
37 | super().__init__()
38 | self.n_ctx = n_ctx
39 | self.width = width
40 | self.heads = heads
41 | self.c_qkv = nn.Linear(width, width * 3)
42 | self.c_proj = nn.Linear(width, width)
43 | self.attention = QKVMultiheadAttention(heads, n_ctx)
44 |
45 | def forward(self, x, mask=None):
46 | x = self.c_qkv(x)
47 | x = self.attention(x, mask=mask)
48 | x = self.c_proj(x)
49 | return x
50 |
51 |
52 | class MLP(nn.Module):
53 | def __init__(self, width):
54 | super().__init__()
55 | self.width = width
56 | self.c_fc = nn.Linear(width, width * 4)
57 | self.c_proj = nn.Linear(width * 4, width)
58 | self.gelu = nn.GELU()
59 |
60 | def forward(self, x):
61 | return self.c_proj(self.gelu(self.c_fc(x)))
62 |
63 |
64 | class QKVMultiheadAttention(nn.Module):
65 | def __init__(self, n_heads: int, n_ctx: int):
66 | super().__init__()
67 | self.n_heads = n_heads
68 | self.n_ctx = n_ctx
69 |
70 | def forward(self, qkv, mask=None):
71 | bs, n_ctx, width = qkv.shape
72 | attn_ch = width // self.n_heads // 3
73 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
74 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
75 | q, k, v = th.split(qkv, attn_ch, dim=-1)
76 | weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
77 | wdtype = weight.dtype
78 | if mask is not None:
79 | weight = weight + mask[:, None, ...]
80 | weight = th.softmax(weight, dim=-1).type(wdtype)
81 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
82 |
83 |
84 | class ResidualAttentionBlock(nn.Module):
85 | def __init__(
86 | self,
87 | n_ctx: int,
88 | width: int,
89 | heads: int,
90 | ):
91 | super().__init__()
92 |
93 | self.attn = MultiheadAttention(
94 | n_ctx,
95 | width,
96 | heads,
97 | )
98 | self.ln_1 = LayerNorm(width)
99 | self.mlp = MLP(width)
100 | self.ln_2 = LayerNorm(width)
101 |
102 | def forward(self, x, mask=None):
103 | x = x + self.attn(self.ln_1(x), mask=mask)
104 | x = x + self.mlp(self.ln_2(x))
105 | return x
106 |
107 |
108 | class Transformer(nn.Module):
109 | def __init__(
110 | self,
111 | n_ctx: int,
112 | width: int,
113 | layers: int,
114 | heads: int,
115 | ):
116 | super().__init__()
117 | self.n_ctx = n_ctx
118 | self.width = width
119 | self.layers = layers
120 | self.resblocks = nn.ModuleList(
121 | [
122 | ResidualAttentionBlock(
123 | n_ctx,
124 | width,
125 | heads,
126 | )
127 | for _ in range(layers)
128 | ]
129 | )
130 |
131 | def forward(self, x, mask=None):
132 | for block in self.resblocks:
133 | x = block(x, mask=mask)
134 | return x
135 |
136 |
137 | class PriorTransformer(nn.Module):
138 | """
139 | A Causal Transformer that conditions on CLIP text embedding, text.
140 |
141 | :param text_ctx: number of text tokens to expect.
142 | :param xf_width: width of the transformer.
143 | :param xf_layers: depth of the transformer.
144 | :param xf_heads: heads in the transformer.
145 | :param xf_final_ln: use a LayerNorm after the output layer.
146 | :param clip_dim: dimension of clip feature.
147 | """
148 |
149 | def __init__(
150 | self,
151 | text_ctx,
152 | xf_width,
153 | xf_layers,
154 | xf_heads,
155 | xf_final_ln,
156 | clip_dim,
157 | ):
158 | super().__init__()
159 |
160 | self.text_ctx = text_ctx
161 | self.xf_width = xf_width
162 | self.xf_layers = xf_layers
163 | self.xf_heads = xf_heads
164 | self.clip_dim = clip_dim
165 | self.ext_len = 4
166 |
167 | self.time_embed = nn.Sequential(
168 | nn.Linear(xf_width, xf_width),
169 | nn.SiLU(),
170 | nn.Linear(xf_width, xf_width),
171 | )
172 | self.text_enc_proj = nn.Linear(clip_dim, xf_width)
173 | self.text_emb_proj = nn.Linear(clip_dim, xf_width)
174 | self.clip_img_proj = nn.Linear(clip_dim, xf_width)
175 | self.out_proj = nn.Linear(xf_width, clip_dim)
176 | self.transformer = Transformer(
177 | text_ctx + self.ext_len,
178 | xf_width,
179 | xf_layers,
180 | xf_heads,
181 | )
182 | if xf_final_ln:
183 | self.final_ln = LayerNorm(xf_width)
184 | else:
185 | self.final_ln = None
186 |
187 | self.positional_embedding = nn.Parameter(
188 | th.empty(1, text_ctx + self.ext_len, xf_width)
189 | )
190 | self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
191 |
192 | nn.init.normal_(self.prd_emb, std=0.01)
193 | nn.init.normal_(self.positional_embedding, std=0.01)
194 |
195 | def forward(
196 | self,
197 | x,
198 | timesteps,
199 | text_emb=None,
200 | text_enc=None,
201 | mask=None,
202 | causal_mask=None,
203 | ):
204 | bsz = x.shape[0]
205 | mask = F.pad(mask, (0, self.ext_len), value=True)
206 |
207 | t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
208 | text_enc = self.text_enc_proj(text_enc)
209 | text_emb = self.text_emb_proj(text_emb)
210 | x = self.clip_img_proj(x)
211 |
212 | input_seq = [
213 | text_enc,
214 | text_emb[:, None, :],
215 | t_emb[:, None, :],
216 | x[:, None, :],
217 | self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
218 | ]
219 | input = th.cat(input_seq, dim=1)
220 | input = input + self.positional_embedding.to(input.dtype)
221 |
222 | mask = th.where(mask, 0.0, float("-inf"))
223 | mask = (mask[:, None, :] + causal_mask).to(input.dtype)
224 |
225 | out = self.transformer(input, mask=mask)
226 | if self.final_ln is not None:
227 | out = self.final_ln(out)
228 |
229 | out = self.out_proj(out[:, -1])
230 |
231 | return out
232 |
--------------------------------------------------------------------------------
/karlo/sampler/i2i.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | from typing import Iterator
7 |
8 | import torch
9 | import torchvision.transforms.functional as TVF
10 | from torchvision.transforms import InterpolationMode
11 |
12 | from .template import BaseSampler, CKPT_PATH
13 |
14 |
15 | class I2ISampler(BaseSampler):
16 | """
17 | A sampler for image variation. In the original unclip paper, image variation transforms the noise obtained by DDIM inversion into a sample in RGB space.
18 | Here, we simply transform the white noise to image, conditioned on the clip image feature.
19 |
20 | :param root_dir: directory for model checkpoints.
21 | :param sampling_type: ["default", "fast"]
22 | """
23 |
24 | def __init__(
25 | self,
26 | root_dir: str,
27 | sampling_type: str = "default",
28 | ):
29 | super().__init__(root_dir, sampling_type)
30 |
31 | @classmethod
32 | def from_pretrained(
33 | cls,
34 | root_dir: str,
35 | clip_model_path: str,
36 | clip_stat_path: str,
37 | sampling_type: str = "default",
38 | ):
39 |
40 | model = cls(
41 | root_dir=root_dir,
42 | sampling_type=sampling_type,
43 | )
44 | model.load_clip(clip_model_path)
45 | model.load_decoder(f"{CKPT_PATH['decoder']}")
46 | model.load_sr_64_256(CKPT_PATH["sr_256"])
47 |
48 | return model
49 |
50 | def preprocess(
51 | self,
52 | image,
53 | prompt: str,
54 | bsz: int,
55 | ):
56 | prompts_batch = [prompt for _ in range(bsz)]
57 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
58 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
59 |
60 | # preprocess input image
61 | image = TVF.normalize(
62 | TVF.to_tensor(
63 | TVF.resize(
64 | image,
65 | [224, 224],
66 | interpolation=InterpolationMode.BICUBIC,
67 | antialias=True,
68 | )
69 | ),
70 | mean=[0.48145466, 0.4578275, 0.40821073],
71 | std=[0.26862954, 0.26130258, 0.27577711],
72 | ).unsqueeze(0)
73 | image_batch = image.repeat(bsz, 1, 1, 1).cuda()
74 |
75 | """ Get CLIP text and image features """
76 | clip_model = self._clip
77 | tokenizer = self._tokenizer
78 | max_txt_length = 77
79 |
80 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
81 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
82 | if not (cf_token.shape == tok.shape):
83 | cf_token = cf_token.expand(tok.shape[0], -1)
84 | cf_mask = cf_mask.expand(tok.shape[0], -1)
85 |
86 | tok = torch.cat([tok, cf_token], dim=0)
87 | mask = torch.cat([mask, cf_mask], dim=0)
88 |
89 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
90 | txt_feat, txt_feat_seq = clip_model.encode_text(tok)
91 | img_feat = clip_model.encode_image(image_batch)
92 |
93 | return (
94 | prompts_batch,
95 | decoder_cf_scales_batch,
96 | txt_feat,
97 | txt_feat_seq,
98 | tok,
99 | mask,
100 | img_feat,
101 | )
102 |
103 | def __call__(
104 | self,
105 | image,
106 | bsz: int,
107 | progressive_mode=None,
108 | ) -> Iterator[torch.Tensor]:
109 | assert progressive_mode in ("loop", "stage", "final")
110 | with torch.no_grad(), torch.cuda.amp.autocast():
111 | (
112 | prompts_batch,
113 | decoder_cf_scales_batch,
114 | txt_feat,
115 | txt_feat_seq,
116 | tok,
117 | mask,
118 | img_feat,
119 | ) = self.preprocess(
120 | image=image,
121 | prompt="",
122 | bsz=bsz,
123 | )
124 |
125 | """ Generate 64x64px images """
126 | images_64_outputs = self._decoder(
127 | txt_feat,
128 | txt_feat_seq,
129 | tok,
130 | mask,
131 | img_feat,
132 | cf_guidance_scales=decoder_cf_scales_batch,
133 | timestep_respacing=self._decoder_sm,
134 | )
135 |
136 | images_64 = None
137 | for k, out in enumerate(images_64_outputs):
138 | images_64 = out
139 | if progressive_mode == "loop":
140 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
141 | if progressive_mode == "stage":
142 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
143 |
144 | images_64 = torch.clamp(images_64, -1, 1)
145 |
146 | """ Upsample 64x64 to 256x256 """
147 | images_256 = TVF.resize(
148 | images_64,
149 | [256, 256],
150 | interpolation=InterpolationMode.BICUBIC,
151 | antialias=True,
152 | )
153 | images_256_outputs = self._sr_64_256(
154 | images_256, timestep_respacing=self._sr_sm
155 | )
156 |
157 | for k, out in enumerate(images_256_outputs):
158 | images_256 = out
159 | if progressive_mode == "loop":
160 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
161 | if progressive_mode == "stage":
162 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
163 |
164 | yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
165 |
--------------------------------------------------------------------------------
/karlo/sampler/t2i.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | from typing import Iterator
7 |
8 | import torch
9 | import torchvision.transforms.functional as TVF
10 | from torchvision.transforms import InterpolationMode
11 |
12 | from .template import BaseSampler, CKPT_PATH
13 |
14 |
15 | class T2ISampler(BaseSampler):
16 | """
17 | A sampler for text-to-image generation.
18 |
19 | :param root_dir: directory for model checkpoints.
20 | :param sampling_type: ["default", "fast"]
21 | """
22 |
23 | def __init__(
24 | self,
25 | root_dir: str,
26 | sampling_type: str = "default",
27 | ):
28 | super().__init__(root_dir, sampling_type)
29 |
30 | @classmethod
31 | def from_pretrained(
32 | cls,
33 | root_dir: str,
34 | clip_model_path: str,
35 | clip_stat_path: str,
36 | sampling_type: str = "default",
37 | ):
38 |
39 | model = cls(
40 | root_dir=root_dir,
41 | sampling_type=sampling_type,
42 | )
43 | model.load_clip(clip_model_path)
44 | model.load_prior(
45 | f"{CKPT_PATH['prior']}",
46 | clip_stat_path=clip_stat_path,
47 | )
48 | model.load_decoder(f"{CKPT_PATH['decoder']}")
49 | model.load_sr_64_256(CKPT_PATH["sr_256"])
50 |
51 | return model
52 |
53 | def preprocess(
54 | self,
55 | prompt: str,
56 | bsz: int,
57 | ):
58 | """Setup prompts & cfg scales"""
59 | prompts_batch = [prompt for _ in range(bsz)]
60 |
61 | prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
62 | prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
63 |
64 | decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
65 | decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
66 |
67 | """ Get CLIP text feature """
68 | clip_model = self._clip
69 | tokenizer = self._tokenizer
70 | max_txt_length = self._prior.model.text_ctx
71 |
72 | tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
73 | cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
74 | if not (cf_token.shape == tok.shape):
75 | cf_token = cf_token.expand(tok.shape[0], -1)
76 | cf_mask = cf_mask.expand(tok.shape[0], -1)
77 |
78 | tok = torch.cat([tok, cf_token], dim=0)
79 | mask = torch.cat([mask, cf_mask], dim=0)
80 |
81 | tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
82 | txt_feat, txt_feat_seq = clip_model.encode_text(tok)
83 |
84 | return (
85 | prompts_batch,
86 | prior_cf_scales_batch,
87 | decoder_cf_scales_batch,
88 | txt_feat,
89 | txt_feat_seq,
90 | tok,
91 | mask,
92 | )
93 |
94 | def __call__(
95 | self,
96 | prompt: str,
97 | bsz: int,
98 | progressive_mode=None,
99 | ) -> Iterator[torch.Tensor]:
100 | assert progressive_mode in ("loop", "stage", "final")
101 | with torch.no_grad(), torch.cuda.amp.autocast():
102 | (
103 | prompts_batch,
104 | prior_cf_scales_batch,
105 | decoder_cf_scales_batch,
106 | txt_feat,
107 | txt_feat_seq,
108 | tok,
109 | mask,
110 | ) = self.preprocess(
111 | prompt,
112 | bsz,
113 | )
114 |
115 | """ Transform CLIP text feature into image feature """
116 | img_feat = self._prior(
117 | txt_feat,
118 | txt_feat_seq,
119 | mask,
120 | prior_cf_scales_batch,
121 | timestep_respacing=self._prior_sm,
122 | )
123 |
124 | """ Generate 64x64px images """
125 | images_64_outputs = self._decoder(
126 | txt_feat,
127 | txt_feat_seq,
128 | tok,
129 | mask,
130 | img_feat,
131 | cf_guidance_scales=decoder_cf_scales_batch,
132 | timestep_respacing=self._decoder_sm,
133 | )
134 |
135 | images_64 = None
136 | for k, out in enumerate(images_64_outputs):
137 | images_64 = out
138 | if progressive_mode == "loop":
139 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
140 | if progressive_mode == "stage":
141 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
142 |
143 | images_64 = torch.clamp(images_64, -1, 1)
144 |
145 | """ Upsample 64x64 to 256x256 """
146 | images_256 = TVF.resize(
147 | images_64,
148 | [256, 256],
149 | interpolation=InterpolationMode.BICUBIC,
150 | antialias=True,
151 | )
152 | images_256_outputs = self._sr_64_256(
153 | images_256, timestep_respacing=self._sr_sm
154 | )
155 |
156 | for k, out in enumerate(images_256_outputs):
157 | images_256 = out
158 | if progressive_mode == "loop":
159 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
160 | if progressive_mode == "stage":
161 | yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
162 |
163 | yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
164 |
--------------------------------------------------------------------------------
/karlo/sampler/template.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Karlo-v1.0.alpha
3 | # Copyright (c) 2022 KakaoBrain. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import os
7 | import logging
8 | import torch
9 |
10 | from omegaconf import OmegaConf
11 |
12 | from ..models.clip import CustomizedCLIP, CustomizedTokenizer
13 | from ..models.prior_model import PriorDiffusionModel
14 | from ..models.decoder_model import Text2ImProgressiveModel
15 | from ..models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
16 |
17 |
18 | SAMPLING_CONF = {
19 | "default": {
20 | "prior_sm": "25",
21 | "prior_n_samples": 1,
22 | "prior_cf_scale": 4.0,
23 | "decoder_sm": "50",
24 | "decoder_cf_scale": 8.0,
25 | "sr_sm": "7",
26 | },
27 | "fast": {
28 | "prior_sm": "25",
29 | "prior_n_samples": 1,
30 | "prior_cf_scale": 4.0,
31 | "decoder_sm": "25",
32 | "decoder_cf_scale": 8.0,
33 | "sr_sm": "7",
34 | },
35 | }
36 |
37 | CKPT_PATH = {
38 | "prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
39 | "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
40 | "sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
41 | }
42 |
43 |
44 | class BaseSampler:
45 | _PRIOR_CLASS = PriorDiffusionModel
46 | _DECODER_CLASS = Text2ImProgressiveModel
47 | _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
48 |
49 | def __init__(
50 | self,
51 | root_dir: str,
52 | sampling_type: str = "fast",
53 | ):
54 | self._root_dir = root_dir
55 |
56 | sampling_type = SAMPLING_CONF[sampling_type]
57 | self._prior_sm = sampling_type["prior_sm"]
58 | self._prior_n_samples = sampling_type["prior_n_samples"]
59 | self._prior_cf_scale = sampling_type["prior_cf_scale"]
60 |
61 | assert self._prior_n_samples == 1
62 |
63 | self._decoder_sm = sampling_type["decoder_sm"]
64 | self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
65 |
66 | self._sr_sm = sampling_type["sr_sm"]
67 |
68 | def __repr__(self):
69 | line = ""
70 | line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
71 | line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
72 | line += f"SR(64->256), sampling method: {self._sr_sm}"
73 |
74 | return line
75 |
76 | def load_clip(self, clip_path: str):
77 | clip = CustomizedCLIP.load_from_checkpoint(
78 | os.path.join(self._root_dir, clip_path)
79 | )
80 | clip = torch.jit.script(clip)
81 | clip.cuda()
82 | clip.eval()
83 |
84 | self._clip = clip
85 | self._tokenizer = CustomizedTokenizer()
86 |
87 | def load_prior(
88 | self,
89 | ckpt_path: str,
90 | clip_stat_path: str,
91 | ):
92 | logging.info(f"Loading prior: {ckpt_path}")
93 |
94 | config = OmegaConf.load("configs/prior_1B_vit_l.yaml")
95 | clip_mean, clip_std = torch.load(
96 | os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
97 | )
98 |
99 | prior = self._PRIOR_CLASS.load_from_checkpoint(
100 | config,
101 | self._tokenizer,
102 | clip_mean,
103 | clip_std,
104 | os.path.join(self._root_dir, ckpt_path),
105 | strict=True,
106 | )
107 | prior.cuda()
108 | prior.eval()
109 | logging.info("done.")
110 |
111 | self._prior = prior
112 |
113 | def load_decoder(self, ckpt_path: str):
114 | logging.info(f"Loading decoder: {ckpt_path}")
115 |
116 | config = OmegaConf.load("configs/decoder_900M_vit_l.yaml")
117 | decoder = self._DECODER_CLASS.load_from_checkpoint(
118 | config,
119 | self._tokenizer,
120 | os.path.join(self._root_dir, ckpt_path),
121 | strict=True,
122 | )
123 | decoder.cuda()
124 | decoder.eval()
125 | logging.info("done.")
126 |
127 | self._decoder = decoder
128 |
129 | def load_sr_64_256(self, ckpt_path: str):
130 | logging.info(f"Loading SR(64->256): {ckpt_path}")
131 |
132 | config = OmegaConf.load("configs/improved_sr_64_256_1.4B.yaml")
133 | sr = self._SR256_CLASS.load_from_checkpoint(
134 | config, os.path.join(self._root_dir, ckpt_path), strict=True
135 | )
136 | sr.cuda()
137 | sr.eval()
138 | logging.info("done.")
139 |
140 | self._sr_64_256 = sr
141 |
--------------------------------------------------------------------------------
/karlo/utils/util.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def set_seed(seed):
7 | random.seed(seed)
8 | np.random.seed(seed)
9 | torch.manual_seed(seed)
10 | torch.cuda.manual_seed_all(seed)
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.10
2 | torchvision>=0.8.2
3 | black
4 | einops
5 | omegaconf
6 | matplotlib
7 | gradio>=3.5.0
8 | git+https://github.com/openai/CLIP.git
9 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = E203, E226, E402, E731, W503, W504
4 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | pip install -r requirements.txt
4 |
5 | export KARLO_ROOT_DIR=$HOME/.cache/karlo/v1.0.alpha/
6 |
7 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt -P $KARLO_ROOT_DIR
8 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th -P $KARLO_ROOT_DIR
9 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
10 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
11 | wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt -P $KARLO_ROOT_DIR
12 |
--------------------------------------------------------------------------------