├── .dockerignore
├── .github
└── FUNDING.yml
├── Dockerfile
├── LICENSE
├── README.md
├── Stable_Diffusion_v1_Model_Card.md
├── assets
├── a-painting-of-a-fire.png
├── a-photograph-of-a-fire.png
├── a-shirt-with-a-fire-printed-on-it.png
├── a-shirt-with-the-inscription-'fire'.png
├── a-watercolor-painting-of-a-fire.png
├── birdhouse.png
├── fire.png
├── inpainting.png
├── modelfigure.png
├── rdm-preview.jpg
├── reconstruction1.png
├── reconstruction2.png
├── results.gif
├── stable-samples
│ ├── img2img
│ │ ├── mountains-1.png
│ │ ├── mountains-2.png
│ │ ├── mountains-3.png
│ │ ├── sketch-mountains-input.jpg
│ │ ├── upscaling-in.png
│ │ └── upscaling-out.png
│ └── txt2img
│ │ ├── 000002025.png
│ │ ├── 000002035.png
│ │ ├── merged-0005.png
│ │ ├── merged-0006.png
│ │ └── merged-0007.png
├── the-earth-is-on-fire,-oil-on-canvas.png
├── txt2img-convsample.png
├── txt2img-preview.png
└── v1-variants-scores.jpg
├── configs
├── autoencoder
│ ├── autoencoder_kl_16x16x16.yaml
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ └── autoencoder_kl_8x8x64.yaml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_bedrooms-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ └── txt2img-1p4B-eval.yaml
├── retrieval-augmented-diffusion
│ └── 768x768.yaml
└── stable-diffusion
│ └── v1-inference.yaml
├── data
├── DejaVuSans.ttf
├── example_conditioning
│ ├── superresolution
│ │ └── sample_0.jpg
│ └── text_conditional
│ │ └── sample_0.txt
├── imagenet_clsidx_to_label.txt
├── imagenet_train_hr_indices.p
├── imagenet_val_hr_indices.p
├── index_synset.yaml
└── inpainting_examples
│ ├── 6458524847_2f4c361183_k.png
│ ├── 6458524847_2f4c361183_k_mask.png
│ ├── 8399166846_f6fb4e4b8e_k.png
│ ├── 8399166846_f6fb4e4b8e_k_mask.png
│ ├── alex-iby-G_Pk4D9rMLs.png
│ ├── alex-iby-G_Pk4D9rMLs_mask.png
│ ├── bench2.png
│ ├── bench2_mask.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png
│ ├── billow926-12-Wc-Zgx6Y.png
│ ├── billow926-12-Wc-Zgx6Y_mask.png
│ ├── overture-creations-5sI6fQgYIuo.png
│ ├── overture-creations-5sI6fQgYIuo_mask.png
│ ├── photo-1583445095369-9c651e7e5d34.png
│ └── photo-1583445095369-9c651e7e5d34_mask.png
├── docker-bootstrap.sh
├── docker-compose.yml
├── environment.yaml
├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── imagenet.py
│ └── lsun.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ └── plms.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── main.py
├── models
├── first_stage_models
│ ├── kl-f16
│ │ └── config.yaml
│ ├── kl-f32
│ │ └── config.yaml
│ ├── kl-f4
│ │ └── config.yaml
│ ├── kl-f8
│ │ └── config.yaml
│ ├── vq-f16
│ │ └── config.yaml
│ ├── vq-f4-noattn
│ │ └── config.yaml
│ ├── vq-f4
│ │ └── config.yaml
│ ├── vq-f8-n256
│ │ └── config.yaml
│ └── vq-f8
│ │ └── config.yaml
└── ldm
│ ├── bsr_sr
│ └── config.yaml
│ ├── celeba256
│ └── config.yaml
│ ├── cin256
│ └── config.yaml
│ ├── ffhq256
│ └── config.yaml
│ ├── inpainting_big
│ └── config.yaml
│ ├── layout2img-openimages256
│ └── config.yaml
│ ├── lsun_beds256
│ └── config.yaml
│ ├── lsun_churches256
│ └── config.yaml
│ ├── semantic_synthesis256
│ └── config.yaml
│ ├── semantic_synthesis512
│ └── config.yaml
│ └── text2img256
│ └── config.yaml
├── notebook_helpers.py
├── optimizedSD
├── LICENSE
├── ddpm.py
├── diffusers_txt2img.py
├── img2img_gradio.py
├── inpaint_gradio.py
├── openaimodelSplit.py
├── optimUtils.py
├── optimized_img2img.py
├── optimized_txt2img.py
├── samplers.py
├── splitAttention.py
├── txt2img_gradio.py
└── v1-inference.yaml
├── scripts
├── download_first_stages.sh
├── download_models.sh
├── img2img.py
├── inpaint.py
├── knn2img.py
├── latent_imagenet_diffusion.ipynb
├── sample_diffusion.py
├── train_searcher.py
└── txt2img.py
└── setup.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | *Dockerfile*
2 | docker-compose.yml
3 | .git
4 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: basuj
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: ['https://paypal.me/basuj']
14 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM continuumio/miniconda3:4.12.0 AS build
2 |
3 | # Step for image utility dependencies.
4 | RUN apt update \
5 | && apt install --no-install-recommends -y git \
6 | && apt-get clean
7 |
8 | COPY . /root/stable-diffusion/
9 |
10 | # Step to install dependencies with conda
11 | RUN eval "$(conda shell.bash hook)" \
12 | && conda install -c conda-forge conda-pack \
13 | && conda env create -f /root/stable-diffusion/environment.yaml \
14 | && conda activate ldm \
15 | && pip install gradio==3.1.7 \
16 | && conda activate base
17 |
18 | # Step to zip and conda environment to "venv" folder
19 | RUN conda pack --ignore-missing-files --ignore-editable-packages -n ldm -o /tmp/env.tar \
20 | && mkdir /venv \
21 | && cd /venv \
22 | && tar xf /tmp/env.tar \
23 | && rm /tmp/env.tar
24 |
25 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as runtime
26 |
27 | ARG OPTIMIZED_FILE=txt2img_gradio.py
28 | WORKDIR /root/stable-diffusion
29 |
30 | COPY --from=build /venv /venv
31 | COPY --from=build /root/stable-diffusion /root/stable-diffusion
32 |
33 | RUN mkdir -p /output /root/stable-diffusion/outputs \
34 | && ln -s /data /root/stable-diffusion/models/ldm/stable-diffusion-v1 \
35 | && ln -s /output /root/stable-diffusion/outputs/txt2img-samples
36 |
37 | ENV PYTHONUNBUFFERED=1
38 | ENV GRADIO_SERVER_NAME=0.0.0.0
39 | ENV GRADIO_SERVER_PORT=7860
40 | ENV APP_MAIN_FILE=${OPTIMIZED_FILE}
41 | EXPOSE 7860
42 |
43 | VOLUME ["/root/.cache", "/data", "/output"]
44 |
45 | SHELL ["/bin/bash", "-c"]
46 | ENTRYPOINT ["/root/stable-diffusion/docker-bootstrap.sh"]
47 | CMD python optimizedSD/${APP_MAIN_FILE}
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | All rights reserved by the authors.
2 | You must not distribute the weights provided to you directly or indirectly without explicit consent of the authors.
3 | You must not distribute harmful, offensive, dehumanizing content or otherwise harmful representations of people or their environments, cultures, religions, etc. produced with the model weights
4 | or other generated content described in the "Misuse and Malicious Use" section in the model card.
5 | The model weights are provided for research purposes only.
6 |
7 |
8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
9 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
10 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
11 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
12 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
13 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
14 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Optimized Stable Diffusion
2 |
3 |
4 |
5 |
6 |
7 |
8 | This repo is a modified version of the Stable Diffusion repo, optimized to use less VRAM than the original by sacrificing inference speed.
9 |
10 | To reduce the VRAM usage, the following opimizations are used:
11 |
12 | - the stable diffusion model is fragmented into four parts which are sent to the GPU only when needed. After the calculation is done, they are moved back to the CPU.
13 | - The attention calculation is done in parts.
14 |
15 | Installation
16 |
17 | All the modified files are in the [optimizedSD](optimizedSD) folder, so if you have already cloned the original repository you can just download and copy this folder into the original instead of cloning the entire repo. You can also clone this repo and follow the same installation steps as the original (mainly creating the conda environment and placing the weights at the specified location).
18 |
19 | Alternatively, if you prefer to use Docker, you can do the following:
20 |
21 | 1. Install [Docker](https://docs.docker.com/engine/install/), [Docker Compose plugin](https://docs.docker.com/compose/install/), and [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
22 | 2. Clone this repo to, e.g., `~/stable-diffusion`
23 | 3. Put your downloaded `model.ckpt` file into `~/sd-data` (it's a relative path, you can change it in `docker-compose.yml`)
24 | 4. `cd` into `~/stable-diffusion` and execute `docker compose up --build`
25 |
26 | This will launch gradio on port 7860 with txt2img. You can also use `docker compose run` to execute other Python scripts.
27 |
28 | Usage
29 |
30 | ## img2img
31 |
32 | - `img2img` can generate _512x512 images from a prior image and prompt using under 2.4GB VRAM in under 20 seconds per image_ on an RTX 2060.
33 |
34 | - The maximum size that can fit on 6GB GPU (RTX 2060) is around 1152x1088.
35 |
36 | - For example, the following command will generate 10 512x512 images:
37 |
38 | `python optimizedSD/optimized_img2img.py --prompt "Austrian alps" --init-img ~/sketch-mountains-input.jpg --strength 0.8 --n_iter 2 --n_samples 5 --H 512 --W 512`
39 |
40 | ## txt2img
41 |
42 | - `txt2img` can generate _512x512 images from a prompt using under 2.4GB GPU VRAM in under 24 seconds per image_ on an RTX 2060.
43 |
44 | - For example, the following command will generate 10 512x512 images:
45 |
46 | `python optimizedSD/optimized_txt2img.py --prompt "Cyberpunk style image of a Tesla car reflection in rain" --H 512 --W 512 --seed 27 --n_iter 2 --n_samples 5 --ddim_steps 50`
47 |
48 | ## inpainting
49 |
50 | - `inpaint_gradio.py` can fill masked parts of an image based on a given prompt. It can inpaint 512x512 images while using under 2.5GB of VRAM.
51 |
52 | - To launch the gradio interface for inpainting, run `python optimizedSD/inpaint_gradio.py`. The mask for the image can be drawn on the selected image using the brush tool.
53 |
54 | - The results are not yet perfect but can be improved by using a combination of prompt weighting, prompt engineering and testing out multiple values of the `--strength` argument.
55 |
56 | - _Suggestions to improve the inpainting algorithm are most welcome_.
57 |
58 | Using the Gradio GUI
59 |
60 | - You can also use the built-in gradio interface for `img2img`, `txt2img` & `inpainting` instead of the command line interface. Activate the conda environment and install the latest version of gradio using `pip install gradio`,
61 |
62 | - Run img2img using `python optimizedSD/img2img_gradio.py`, txt2img using `python optimizedSD/txt2img_gradio.py` and inpainting using `python optimizedSD/inpaint_gradio.py`.
63 |
64 | - img2img_gradio.py has a feature to crop input images. Look for the pen symbol in the image box after selecting the image.
65 |
66 | Arguments
67 |
68 | ## `--seed`
69 |
70 | **Seed for image generation**, can be used to reproduce previously generated images. Defaults to a random seed if unspecified.
71 |
72 | - The code will give the seed number along with each generated image. To generate the same image again, just specify the seed using `--seed` argument. Images are saved with its seed number as its name by default.
73 |
74 | - For example if the seed number for an image is `1234` and it's the 55th image in the folder, the image name will be named `seed_1234_00055.png`.
75 |
76 | ## `--n_samples`
77 |
78 | **Batch size/amount of images to generate at once.**
79 |
80 | - To get the lowest inference time per image, use the maximum batch size `--n_samples` that can fit on the GPU. Inference time per image will reduce on increasing the batch size, but the required VRAM will increase.
81 |
82 | - If you get a CUDA out of memory error, try reducing the batch size `--n_samples`. If it doesn't work, the other option is to reduce the image width `--W` or height `--H` or both.
83 |
84 | ## `--n_iter`
85 |
86 | **Run _x_ amount of times**
87 |
88 | - Equivalent to running the script n_iter number of times. Only difference is that the model is loaded only once per n_iter iterations. Unlike `n_samples`, reducing it doesn't have an effect on VRAM required or inference time.
89 |
90 | ## `--H` & `--W`
91 |
92 | **Height & width of the generated image.**
93 |
94 | - Both height and width should be a multiple of 64.
95 |
96 | ## `--turbo`
97 |
98 | **Increases inference speed at the cost of extra VRAM usage.**
99 |
100 | - Using this argument increases the inference speed by using around 700MB of extra GPU VRAM. It is especially effective when generating a small batch of images (~ 1 to 4) images. It takes under 20 seconds for txt2img and 15 seconds for img2img (on an RTX 2060, excluding the time to load the model). Use it on larger batch sizes if GPU VRAM available.
101 |
102 | ## `--precision autocast` or `--precision full`
103 |
104 | **Whether to use `full` or `mixed` precision**
105 |
106 | - Mixed Precision is enabled by default. If you don't have a GPU with tensor cores (any GTX 10 series card), you may not be able use mixed precision. Use the `--precision full` argument to disable it.
107 |
108 | ## `--format png` or `--format jpg`
109 |
110 | **Output image format**
111 |
112 | - The default output format is `png`. While `png` is lossless, it takes up a lot of space (unless large portions of the image happen to be a single colour). Use lossy `jpg` to get smaller image file sizes.
113 |
114 | ## `--unet_bs`
115 |
116 | **Batch size for the unet model**
117 |
118 | - Takes up a lot of extra RAM for **very little improvement** in inference time. `unet_bs` > 1 is not recommended!
119 |
120 | - Should generally be a multiple of 2x(n_samples)
121 |
122 | Weighted Prompts
123 |
124 | - Prompts can also be weighted to put relative emphasis on certain words.
125 | eg. `--prompt tabby cat:0.25 white duck:0.75 hybrid`.
126 |
127 | - The number followed by the colon represents the weight given to the words before the colon. The weights can be both fractions or integers.
128 |
129 | ## Troubleshooting
130 |
131 | ### Green colored output images
132 |
133 | - If you have a Nvidia GTX series GPU, the output images maybe entirely green in color. This is because GTX series do not support half precision calculation, which is the default mode of calculation in this repository. To overcome the issue, use the `--precision full` argument. The downside is that it will lead to higher GPU VRAM usage.
134 |
135 | ###
136 |
137 | ## Changelog
138 |
139 | - v1.0: Added support for multiple samplers for txt2img. Based on [crowsonkb](https://github.com/crowsonkb/k-diffusion)
140 | - v0.9: Added support for calculating attention in parts. (Thanks to @neonsecret @Doggettx, @ryudrigo)
141 | - v0.8: Added gradio interface for inpainting.
142 | - v0.7: Added support for logging, jpg file format
143 | - v0.6: Added support for using weighted prompts. (based on @lstein's [repo](https://github.com/lstein/stable-diffusion))
144 | - v0.5: Added support for using gradio interface.
145 | - v0.4: Added support for specifying image seed.
146 | - v0.3: Added support for using mixed precision.
147 | - v0.2: Added support for generating images in batches.
148 | - v0.1: Split the model into multiple parts to run it on lower VRAM.
149 |
--------------------------------------------------------------------------------
/Stable_Diffusion_v1_Model_Card.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion v1 Model Card
2 | This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
3 |
4 | ## Model Details
5 | - **Developed by:** Robin Rombach, Patrick Esser
6 | - **Model type:** Diffusion-based text-to-image generation model
7 | - **Language(s):** English
8 | - **License:** [Proprietary](LICENSE)
9 | - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
10 | - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
11 | - **Cite as:**
12 |
13 | @InProceedings{Rombach_2022_CVPR,
14 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
15 | title = {High-Resolution Image Synthesis With Latent Diffusion Models},
16 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
17 | month = {June},
18 | year = {2022},
19 | pages = {10684-10695}
20 | }
21 |
22 | # Uses
23 |
24 | ## Direct Use
25 | The model is intended for research purposes only. Possible research areas and
26 | tasks include
27 |
28 | - Safe deployment of models which have the potential to generate harmful content.
29 | - Probing and understanding the limitations and biases of generative models.
30 | - Generation of artworks and use in design and other artistic processes.
31 | - Applications in educational or creative tools.
32 | - Research on generative models.
33 |
34 | Excluded uses are described below.
35 |
36 | ### Misuse, Malicious Use, and Out-of-Scope Use
37 | _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
38 |
39 |
40 | The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
41 | #### Out-of-Scope Use
42 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
43 | #### Misuse and Malicious Use
44 | Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
45 |
46 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
47 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
48 | - Impersonating individuals without their consent.
49 | - Sexual content without consent of the people who might see it.
50 | - Mis- and disinformation
51 | - Representations of egregious violence and gore
52 | - Sharing of copyrighted or licensed material in violation of its terms of use.
53 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
54 |
55 | ## Limitations and Bias
56 |
57 | ### Limitations
58 |
59 | - The model does not achieve perfect photorealism
60 | - The model cannot render legible text
61 | - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
62 | - Faces and people in general may not be generated properly.
63 | - The model was trained mainly with English captions and will not work as well in other languages.
64 | - The autoencoding part of the model is lossy
65 | - The model was trained on a large-scale dataset
66 | [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
67 | and is not fit for product use without additional safety mechanisms and
68 | considerations.
69 |
70 | ### Bias
71 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
72 | Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
73 | which consists of images that are primarily limited to English descriptions.
74 | Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
75 | This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
76 | ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
77 |
78 |
79 | ## Training
80 |
81 | **Training Data**
82 | The model developers used the following dataset for training the model:
83 |
84 | - LAION-2B (en) and subsets thereof (see next section)
85 |
86 | **Training Procedure**
87 | Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
88 |
89 | - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
90 | - Text prompts are encoded through a ViT-L/14 text-encoder.
91 | - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
92 | - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
93 |
94 | We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
95 | which were trained as follows,
96 |
97 | - `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
98 | 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
99 | - `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
100 | 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
101 | filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
102 | - `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
103 |
104 |
105 | - **Hardware:** 32 x 8 x A100 GPUs
106 | - **Optimizer:** AdamW
107 | - **Gradient Accumulations**: 2
108 | - **Batch:** 32 x 8 x 2 x 4 = 2048
109 | - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
110 |
111 | ## Evaluation Results
112 | Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
113 | 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
114 | steps show the relative improvements of the checkpoints:
115 |
116 | 
117 |
118 | Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
119 | ## Environmental Impact
120 |
121 | **Stable Diffusion v1** **Estimated Emissions**
122 | Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
123 |
124 | - **Hardware Type:** A100 PCIe 40GB
125 | - **Hours used:** 150000
126 | - **Cloud Provider:** AWS
127 | - **Compute Region:** US-east
128 | - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
129 | ## Citation
130 | @InProceedings{Rombach_2022_CVPR,
131 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
132 | title = {High-Resolution Image Synthesis With Latent Diffusion Models},
133 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
134 | month = {June},
135 | year = {2022},
136 | pages = {10684-10695}
137 | }
138 |
139 | *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
140 |
141 |
--------------------------------------------------------------------------------
/assets/a-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-painting-of-a-fire.png
--------------------------------------------------------------------------------
/assets/a-photograph-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-photograph-of-a-fire.png
--------------------------------------------------------------------------------
/assets/a-shirt-with-a-fire-printed-on-it.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-shirt-with-a-fire-printed-on-it.png
--------------------------------------------------------------------------------
/assets/a-shirt-with-the-inscription-'fire'.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-shirt-with-the-inscription-'fire'.png
--------------------------------------------------------------------------------
/assets/a-watercolor-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-watercolor-painting-of-a-fire.png
--------------------------------------------------------------------------------
/assets/birdhouse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/birdhouse.png
--------------------------------------------------------------------------------
/assets/fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/fire.png
--------------------------------------------------------------------------------
/assets/inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/inpainting.png
--------------------------------------------------------------------------------
/assets/modelfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/modelfigure.png
--------------------------------------------------------------------------------
/assets/rdm-preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/rdm-preview.jpg
--------------------------------------------------------------------------------
/assets/reconstruction1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/reconstruction1.png
--------------------------------------------------------------------------------
/assets/reconstruction2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/reconstruction2.png
--------------------------------------------------------------------------------
/assets/results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/results.gif
--------------------------------------------------------------------------------
/assets/stable-samples/img2img/mountains-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/mountains-1.png
--------------------------------------------------------------------------------
/assets/stable-samples/img2img/mountains-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/mountains-2.png
--------------------------------------------------------------------------------
/assets/stable-samples/img2img/mountains-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/mountains-3.png
--------------------------------------------------------------------------------
/assets/stable-samples/img2img/sketch-mountains-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/sketch-mountains-input.jpg
--------------------------------------------------------------------------------
/assets/stable-samples/img2img/upscaling-in.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/upscaling-in.png
--------------------------------------------------------------------------------
/assets/stable-samples/img2img/upscaling-out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/upscaling-out.png
--------------------------------------------------------------------------------
/assets/stable-samples/txt2img/000002025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/000002025.png
--------------------------------------------------------------------------------
/assets/stable-samples/txt2img/000002035.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/000002035.png
--------------------------------------------------------------------------------
/assets/stable-samples/txt2img/merged-0005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/merged-0005.png
--------------------------------------------------------------------------------
/assets/stable-samples/txt2img/merged-0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/merged-0006.png
--------------------------------------------------------------------------------
/assets/stable-samples/txt2img/merged-0007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/merged-0007.png
--------------------------------------------------------------------------------
/assets/the-earth-is-on-fire,-oil-on-canvas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/the-earth-is-on-fire,-oil-on-canvas.png
--------------------------------------------------------------------------------
/assets/txt2img-convsample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/txt2img-convsample.png
--------------------------------------------------------------------------------
/assets/txt2img-preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/txt2img-preview.png
--------------------------------------------------------------------------------
/assets/v1-variants-scores.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/v1-variants-scores.jpg
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/data/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/DejaVuSans.ttf
--------------------------------------------------------------------------------
/data/example_conditioning/superresolution/sample_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/example_conditioning/superresolution/sample_0.jpg
--------------------------------------------------------------------------------
/data/example_conditioning/text_conditional/sample_0.txt:
--------------------------------------------------------------------------------
1 | A basket of cerries
2 |
--------------------------------------------------------------------------------
/data/imagenet_train_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/imagenet_train_hr_indices.p
--------------------------------------------------------------------------------
/data/imagenet_val_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/imagenet_val_hr_indices.p
--------------------------------------------------------------------------------
/data/inpainting_examples/6458524847_2f4c361183_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/6458524847_2f4c361183_k.png
--------------------------------------------------------------------------------
/data/inpainting_examples/6458524847_2f4c361183_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/6458524847_2f4c361183_k_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png
--------------------------------------------------------------------------------
/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png
--------------------------------------------------------------------------------
/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bench2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bench2.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bench2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bench2_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png
--------------------------------------------------------------------------------
/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png
--------------------------------------------------------------------------------
/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png
--------------------------------------------------------------------------------
/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png
--------------------------------------------------------------------------------
/docker-bootstrap.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 | source /venv/bin/activate
4 | update-ca-certificates --fresh
5 | export SSL_CERT_DIR=/etc/ssl/certs
6 | exec "$@"
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.9"
2 | services:
3 | sd:
4 | build: .
5 | ports:
6 | - "7860:7860"
7 | volumes:
8 | - ../sd-data:/data
9 | - ../sd-output:/output
10 | - sd-cache:/root/.cache
11 | environment:
12 | - APP_MAIN_FILE=txt2img_gradio.py
13 | deploy:
14 | resources:
15 | reservations:
16 | devices:
17 | - driver: nvidia
18 | count: 1
19 | capabilities: [gpu]
20 | volumes:
21 | sd-cache:
22 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.20.3
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.4.2
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - torch-fidelity==0.3.0
24 | - transformers==4.19.2
25 | - torchmetrics==0.6.0
26 | - kornia==0.6
27 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
28 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
29 | - -e .
30 |
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 | class FrozenCLIPEmbedder(AbstractEncoder):
138 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140 | super().__init__()
141 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
142 | self.transformer = CLIPTextModel.from_pretrained(version)
143 | self.device = device
144 | self.max_length = max_length
145 | self.freeze()
146 |
147 | def freeze(self):
148 | self.transformer = self.transformer.eval()
149 | for param in self.parameters():
150 | param.requires_grad = False
151 |
152 | def forward(self, text):
153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155 | tokens = batch_encoding["input_ids"].to(self.device)
156 | outputs = self.transformer(input_ids=tokens)
157 |
158 | z = outputs.last_hidden_state
159 | return z
160 |
161 | def encode(self, text):
162 | return self(text)
163 |
164 |
165 | class FrozenCLIPTextEmbedder(nn.Module):
166 | """
167 | Uses the CLIP transformer encoder for text.
168 | """
169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
170 | super().__init__()
171 | self.model, _ = clip.load(version, jit=False, device="cpu")
172 | self.device = device
173 | self.max_length = max_length
174 | self.n_repeat = n_repeat
175 | self.normalize = normalize
176 |
177 | def freeze(self):
178 | self.model = self.model.eval()
179 | for param in self.parameters():
180 | param.requires_grad = False
181 |
182 | def forward(self, text):
183 | tokens = clip.tokenize(text).to(self.device)
184 | z = self.model.encode_text(tokens)
185 | if self.normalize:
186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187 | return z
188 |
189 | def encode(self, text):
190 | z = self(text)
191 | if z.ndim==2:
192 | z = z[:, None, :]
193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194 | return z
195 |
196 |
197 | class FrozenClipImageEmbedder(nn.Module):
198 | """
199 | Uses the CLIP image encoder.
200 | """
201 | def __init__(
202 | self,
203 | model,
204 | jit=False,
205 | device='cuda' if torch.cuda.is_available() else 'cpu',
206 | antialias=False,
207 | ):
208 | super().__init__()
209 | self.model, _ = clip.load(name=model, device=device, jit=jit)
210 |
211 | self.antialias = antialias
212 |
213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
215 |
216 | def preprocess(self, x):
217 | # normalize to [0,1]
218 | x = kornia.geometry.resize(x, (224, 224),
219 | interpolation='bicubic',align_corners=True,
220 | antialias=self.antialias)
221 | x = (x + 1.) / 2.
222 | # renormalize according to clip
223 | x = kornia.enhance.normalize(x, self.mean, self.std)
224 | return x
225 |
226 | def forward(self, x):
227 | # x is assumed to be in range [-1,1]
228 | return self.model.encode_image(self.preprocess(x))
229 |
230 |
231 | if __name__ == "__main__":
232 | from ldm.util import count_params
233 | model = FrozenCLIPEmbedder()
234 | count_params(model, verbose=True)
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/models/ldm/ffhq256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 42
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/models/ldm/lsun_beds256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.lsun.LSUNBedroomsTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.lsun.LSUNBedroomsValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/optimizedSD/diffusers_txt2img.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import LDMTextToImagePipeline
3 |
4 | pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
5 |
6 | prompt = "19th Century wooden engraving of Elon musk"
7 |
8 | seed = torch.manual_seed(1024)
9 | images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"]
10 |
11 | # save images
12 | for idx, image in enumerate(images):
13 | image.save(f"image-{idx}.png")
14 |
--------------------------------------------------------------------------------
/optimizedSD/img2img_gradio.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import numpy as np
3 | import torch
4 | from torchvision.utils import make_grid
5 | import os, re
6 | from PIL import Image
7 | import torch
8 | import numpy as np
9 | from random import randint
10 | from omegaconf import OmegaConf
11 | from PIL import Image
12 | from tqdm import tqdm, trange
13 | from itertools import islice
14 | from einops import rearrange
15 | from torchvision.utils import make_grid
16 | import time
17 | from pytorch_lightning import seed_everything
18 | from torch import autocast
19 | from einops import rearrange, repeat
20 | from contextlib import nullcontext
21 | from ldm.util import instantiate_from_config
22 | from transformers import logging
23 | import pandas as pd
24 | from optimUtils import split_weighted_subprompts, logger
25 | logging.set_verbosity_error()
26 | import mimetypes
27 | mimetypes.init()
28 | mimetypes.add_type("application/javascript", ".js")
29 |
30 |
31 | def chunk(it, size):
32 | it = iter(it)
33 | return iter(lambda: tuple(islice(it, size)), ())
34 |
35 |
36 | def load_model_from_config(ckpt, verbose=False):
37 | print(f"Loading model from {ckpt}")
38 | pl_sd = torch.load(ckpt, map_location="cpu")
39 | if "global_step" in pl_sd:
40 | print(f"Global Step: {pl_sd['global_step']}")
41 | sd = pl_sd["state_dict"]
42 | return sd
43 |
44 |
45 | def load_img(image, h0, w0):
46 |
47 | image = image.convert("RGB")
48 | w, h = image.size
49 | print(f"loaded input image of size ({w}, {h})")
50 | if h0 is not None and w0 is not None:
51 | h, w = h0, w0
52 |
53 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
54 |
55 | print(f"New image size ({w}, {h})")
56 | image = image.resize((w, h), resample=Image.LANCZOS)
57 | image = np.array(image).astype(np.float32) / 255.0
58 | image = image[None].transpose(0, 3, 1, 2)
59 | image = torch.from_numpy(image)
60 | return 2.0 * image - 1.0
61 |
62 | config = "optimizedSD/v1-inference.yaml"
63 | ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
64 | sd = load_model_from_config(f"{ckpt}")
65 | li, lo = [], []
66 | for key, v_ in sd.items():
67 | sp = key.split(".")
68 | if (sp[0]) == "model":
69 | if "input_blocks" in sp:
70 | li.append(key)
71 | elif "middle_block" in sp:
72 | li.append(key)
73 | elif "time_embed" in sp:
74 | li.append(key)
75 | else:
76 | lo.append(key)
77 | for key in li:
78 | sd["model1." + key[6:]] = sd.pop(key)
79 | for key in lo:
80 | sd["model2." + key[6:]] = sd.pop(key)
81 |
82 | config = OmegaConf.load(f"{config}")
83 |
84 | model = instantiate_from_config(config.modelUNet)
85 | _, _ = model.load_state_dict(sd, strict=False)
86 | model.eval()
87 |
88 | modelCS = instantiate_from_config(config.modelCondStage)
89 | _, _ = modelCS.load_state_dict(sd, strict=False)
90 | modelCS.eval()
91 |
92 | modelFS = instantiate_from_config(config.modelFirstStage)
93 | _, _ = modelFS.load_state_dict(sd, strict=False)
94 | modelFS.eval()
95 | del sd
96 |
97 | def generate(
98 | image,
99 | prompt,
100 | strength,
101 | ddim_steps,
102 | n_iter,
103 | batch_size,
104 | Height,
105 | Width,
106 | scale,
107 | ddim_eta,
108 | unet_bs,
109 | device,
110 | seed,
111 | outdir,
112 | img_format,
113 | turbo,
114 | full_precision,
115 | ):
116 |
117 | if seed == "":
118 | seed = randint(0, 1000000)
119 | seed = int(seed)
120 | seed_everything(seed)
121 |
122 | # Logging
123 | sampler = "ddim"
124 | logger(locals(), log_csv = "logs/img2img_gradio_logs.csv")
125 |
126 | init_image = load_img(image, Height, Width).to(device)
127 | model.unet_bs = unet_bs
128 | model.turbo = turbo
129 | model.cdevice = device
130 | modelCS.cond_stage_model.device = device
131 |
132 | if device != "cpu" and full_precision == False:
133 | model.half()
134 | modelCS.half()
135 | modelFS.half()
136 | init_image = init_image.half()
137 |
138 | tic = time.time()
139 | os.makedirs(outdir, exist_ok=True)
140 | outpath = outdir
141 | sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
142 | os.makedirs(sample_path, exist_ok=True)
143 | base_count = len(os.listdir(sample_path))
144 |
145 | # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
146 | assert prompt is not None
147 | data = [batch_size * [prompt]]
148 |
149 | modelFS.to(device)
150 |
151 | init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
152 | init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
153 |
154 | if device != "cpu":
155 | mem = torch.cuda.memory_allocated() / 1e6
156 | modelFS.to("cpu")
157 | while torch.cuda.memory_allocated() / 1e6 >= mem:
158 | time.sleep(1)
159 |
160 | assert 0.0 <= strength <= 1.0, "can only work with strength in [0.0, 1.0]"
161 | t_enc = int(strength * ddim_steps)
162 | print(f"target t_enc is {t_enc} steps")
163 |
164 | if full_precision == False and device != "cpu":
165 | precision_scope = autocast
166 | else:
167 | precision_scope = nullcontext
168 |
169 | all_samples = []
170 | seeds = ""
171 | with torch.no_grad():
172 | all_samples = list()
173 | for _ in trange(n_iter, desc="Sampling"):
174 | for prompts in tqdm(data, desc="data"):
175 | with precision_scope("cuda"):
176 | modelCS.to(device)
177 | uc = None
178 | if scale != 1.0:
179 | uc = modelCS.get_learned_conditioning(batch_size * [""])
180 | if isinstance(prompts, tuple):
181 | prompts = list(prompts)
182 |
183 | subprompts, weights = split_weighted_subprompts(prompts[0])
184 | if len(subprompts) > 1:
185 | c = torch.zeros_like(uc)
186 | totalWeight = sum(weights)
187 | # normalize each "sub prompt" and add it
188 | for i in range(len(subprompts)):
189 | weight = weights[i]
190 | # if not skip_normalize:
191 | weight = weight / totalWeight
192 | c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
193 | else:
194 | c = modelCS.get_learned_conditioning(prompts)
195 |
196 | if device != "cpu":
197 | mem = torch.cuda.memory_allocated() / 1e6
198 | modelCS.to("cpu")
199 | while torch.cuda.memory_allocated() / 1e6 >= mem:
200 | time.sleep(1)
201 |
202 | # encode (scaled latent)
203 | z_enc = model.stochastic_encode(
204 | init_latent, torch.tensor([t_enc] * batch_size).to(device), seed, ddim_eta, ddim_steps
205 | )
206 | # decode it
207 | samples_ddim = model.sample(
208 | t_enc,
209 | c,
210 | z_enc,
211 | unconditional_guidance_scale=scale,
212 | unconditional_conditioning=uc,
213 | sampler = sampler
214 | )
215 |
216 | modelFS.to(device)
217 | print("saving images")
218 | for i in range(batch_size):
219 |
220 | x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
221 | x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
222 | all_samples.append(x_sample.to("cpu"))
223 | x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
224 | Image.fromarray(x_sample.astype(np.uint8)).save(
225 | os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
226 | )
227 | seeds += str(seed) + ","
228 | seed += 1
229 | base_count += 1
230 |
231 | if device != "cpu":
232 | mem = torch.cuda.memory_allocated() / 1e6
233 | modelFS.to("cpu")
234 | while torch.cuda.memory_allocated() / 1e6 >= mem:
235 | time.sleep(1)
236 |
237 | del samples_ddim
238 | del x_sample
239 | del x_samples_ddim
240 | print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
241 |
242 | toc = time.time()
243 |
244 | time_taken = (toc - tic) / 60.0
245 | grid = torch.cat(all_samples, 0)
246 | grid = make_grid(grid, nrow=n_iter)
247 | grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
248 |
249 | txt = (
250 | "Samples finished in "
251 | + str(round(time_taken, 3))
252 | + " minutes and exported to \n"
253 | + sample_path
254 | + "\nSeeds used = "
255 | + seeds[:-1]
256 | )
257 | return Image.fromarray(grid.astype(np.uint8)), txt
258 |
259 |
260 | demo = gr.Interface(
261 | fn=generate,
262 | inputs=[
263 | gr.Image(tool="editor", type="pil"),
264 | "text",
265 | gr.Slider(0, 1, value=0.75),
266 | gr.Slider(1, 1000, value=50),
267 | gr.Slider(1, 100, step=1),
268 | gr.Slider(1, 100, step=1),
269 | gr.Slider(64, 4096, value=512, step=64),
270 | gr.Slider(64, 4096, value=512, step=64),
271 | gr.Slider(0, 50, value=7.5, step=0.1),
272 | gr.Slider(0, 1, step=0.01),
273 | gr.Slider(1, 2, value=1, step=1),
274 | gr.Text(value="cuda"),
275 | "text",
276 | gr.Text(value="outputs/img2img-samples"),
277 | gr.Radio(["png", "jpg"], value='png'),
278 | "checkbox",
279 | "checkbox",
280 | ],
281 | outputs=["image", "text"],
282 | )
283 | demo.launch()
284 |
--------------------------------------------------------------------------------
/optimizedSD/optimUtils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 |
4 |
5 | def split_weighted_subprompts(text):
6 | """
7 | grabs all text up to the first occurrence of ':'
8 | uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
9 | if ':' has no value defined, defaults to 1.0
10 | repeats until no text remaining
11 | """
12 | remaining = len(text)
13 | prompts = []
14 | weights = []
15 | while remaining > 0:
16 | if ":" in text:
17 | idx = text.index(":") # first occurrence from start
18 | # grab up to index as sub-prompt
19 | prompt = text[:idx]
20 | remaining -= idx
21 | # remove from main text
22 | text = text[idx+1:]
23 | # find value for weight
24 | if " " in text:
25 | idx = text.index(" ") # first occurence
26 | else: # no space, read to end
27 | idx = len(text)
28 | if idx != 0:
29 | try:
30 | weight = float(text[:idx])
31 | except: # couldn't treat as float
32 | print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
33 | weight = 1.0
34 | else: # no value found
35 | weight = 1.0
36 | # remove from main text
37 | remaining -= idx
38 | text = text[idx+1:]
39 | # append the sub-prompt and its weight
40 | prompts.append(prompt)
41 | weights.append(weight)
42 | else: # no : found
43 | if len(text) > 0: # there is still text though
44 | # take remainder as weight 1
45 | prompts.append(text)
46 | weights.append(1.0)
47 | remaining = 0
48 | return prompts, weights
49 |
50 | def logger(params, log_csv):
51 | os.makedirs('logs', exist_ok=True)
52 | cols = [arg for arg, _ in params.items()]
53 | if not os.path.exists(log_csv):
54 | df = pd.DataFrame(columns=cols)
55 | df.to_csv(log_csv, index=False)
56 |
57 | df = pd.read_csv(log_csv)
58 | for arg in cols:
59 | if arg not in df.columns:
60 | df[arg] = ""
61 | df.to_csv(log_csv, index = False)
62 |
63 | li = {}
64 | cols = [col for col in df.columns]
65 | data = {arg:value for arg, value in params.items()}
66 | for col in cols:
67 | if col in data:
68 | li[col] = data[col]
69 | else:
70 | li[col] = ''
71 |
72 | df = pd.DataFrame(li,index = [0])
73 | df.to_csv(log_csv,index=False, mode='a', header=False)
--------------------------------------------------------------------------------
/optimizedSD/splitAttention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 | self.att_step = att_step
161 |
162 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165 |
166 | self.to_out = nn.Sequential(
167 | nn.Linear(inner_dim, query_dim),
168 | nn.Dropout(dropout)
169 | )
170 |
171 | def forward(self, x, context=None, mask=None):
172 | h = self.heads
173 |
174 | q = self.to_q(x)
175 | context = default(context, x)
176 | k = self.to_k(context)
177 | v = self.to_v(context)
178 | del context, x
179 |
180 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
181 |
182 |
183 | limit = k.shape[0]
184 | att_step = self.att_step
185 | q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
186 | k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
187 | v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
188 |
189 | q_chunks.reverse()
190 | k_chunks.reverse()
191 | v_chunks.reverse()
192 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
193 | del k, q, v
194 | for i in range (0, limit, att_step):
195 |
196 | q_buffer = q_chunks.pop()
197 | k_buffer = k_chunks.pop()
198 | v_buffer = v_chunks.pop()
199 | sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
200 |
201 | del k_buffer, q_buffer
202 | # attention, what we cannot get enough of, by chunks
203 |
204 | sim_buffer = sim_buffer.softmax(dim=-1)
205 |
206 | sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
207 | del v_buffer
208 | sim[i:i+att_step,:,:] = sim_buffer
209 |
210 | del sim_buffer
211 | sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
212 | return self.to_out(sim)
213 |
214 |
215 | class BasicTransformerBlock(nn.Module):
216 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
217 | super().__init__()
218 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
219 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
220 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
221 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
222 | self.norm1 = nn.LayerNorm(dim)
223 | self.norm2 = nn.LayerNorm(dim)
224 | self.norm3 = nn.LayerNorm(dim)
225 | self.checkpoint = checkpoint
226 |
227 | def forward(self, x, context=None):
228 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
229 |
230 | def _forward(self, x, context=None):
231 | x = self.attn1(self.norm1(x)) + x
232 | x = self.attn2(self.norm2(x), context=context) + x
233 | x = self.ff(self.norm3(x)) + x
234 | return x
235 |
236 |
237 | class SpatialTransformer(nn.Module):
238 | """
239 | Transformer block for image-like data.
240 | First, project the input (aka embedding)
241 | and reshape to b, t, d.
242 | Then apply standard transformer action.
243 | Finally, reshape to image
244 | """
245 | def __init__(self, in_channels, n_heads, d_head,
246 | depth=1, dropout=0., context_dim=None):
247 | super().__init__()
248 | self.in_channels = in_channels
249 | inner_dim = n_heads * d_head
250 | self.norm = Normalize(in_channels)
251 |
252 | self.proj_in = nn.Conv2d(in_channels,
253 | inner_dim,
254 | kernel_size=1,
255 | stride=1,
256 | padding=0)
257 |
258 | self.transformer_blocks = nn.ModuleList(
259 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
260 | for d in range(depth)]
261 | )
262 |
263 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
264 | in_channels,
265 | kernel_size=1,
266 | stride=1,
267 | padding=0))
268 |
269 | def forward(self, x, context=None):
270 | # note: if no context is given, cross-attention defaults to self-attention
271 | b, c, h, w = x.shape
272 | x_in = x
273 | x = self.norm(x)
274 | x = self.proj_in(x)
275 | x = rearrange(x, 'b c h w -> b (h w) c')
276 | for block in self.transformer_blocks:
277 | x = block(x, context=context)
278 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
279 | x = self.proj_out(x)
280 | return x + x_in
281 |
--------------------------------------------------------------------------------
/optimizedSD/txt2img_gradio.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import numpy as np
3 | import torch
4 | from torchvision.utils import make_grid
5 | from einops import rearrange
6 | import os, re
7 | from PIL import Image
8 | import torch
9 | import pandas as pd
10 | import numpy as np
11 | from random import randint
12 | from omegaconf import OmegaConf
13 | from PIL import Image
14 | from tqdm import tqdm, trange
15 | from itertools import islice
16 | from einops import rearrange
17 | from torchvision.utils import make_grid
18 | import time
19 | from pytorch_lightning import seed_everything
20 | from torch import autocast
21 | from contextlib import nullcontext
22 | from ldm.util import instantiate_from_config
23 | from optimUtils import split_weighted_subprompts, logger
24 | from transformers import logging
25 | logging.set_verbosity_error()
26 | import mimetypes
27 | mimetypes.init()
28 | mimetypes.add_type("application/javascript", ".js")
29 |
30 |
31 | def chunk(it, size):
32 | it = iter(it)
33 | return iter(lambda: tuple(islice(it, size)), ())
34 |
35 |
36 | def load_model_from_config(ckpt, verbose=False):
37 | print(f"Loading model from {ckpt}")
38 | pl_sd = torch.load(ckpt, map_location="cpu")
39 | if "global_step" in pl_sd:
40 | print(f"Global Step: {pl_sd['global_step']}")
41 | sd = pl_sd["state_dict"]
42 | return sd
43 |
44 | config = "optimizedSD/v1-inference.yaml"
45 | ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
46 | sd = load_model_from_config(f"{ckpt}")
47 | li, lo = [], []
48 | for key, v_ in sd.items():
49 | sp = key.split(".")
50 | if (sp[0]) == "model":
51 | if "input_blocks" in sp:
52 | li.append(key)
53 | elif "middle_block" in sp:
54 | li.append(key)
55 | elif "time_embed" in sp:
56 | li.append(key)
57 | else:
58 | lo.append(key)
59 | for key in li:
60 | sd["model1." + key[6:]] = sd.pop(key)
61 | for key in lo:
62 | sd["model2." + key[6:]] = sd.pop(key)
63 |
64 | config = OmegaConf.load(f"{config}")
65 |
66 | model = instantiate_from_config(config.modelUNet)
67 | _, _ = model.load_state_dict(sd, strict=False)
68 | model.eval()
69 |
70 | modelCS = instantiate_from_config(config.modelCondStage)
71 | _, _ = modelCS.load_state_dict(sd, strict=False)
72 | modelCS.eval()
73 |
74 | modelFS = instantiate_from_config(config.modelFirstStage)
75 | _, _ = modelFS.load_state_dict(sd, strict=False)
76 | modelFS.eval()
77 | del sd
78 |
79 |
80 | def generate(
81 | prompt,
82 | ddim_steps,
83 | n_iter,
84 | batch_size,
85 | Height,
86 | Width,
87 | scale,
88 | ddim_eta,
89 | unet_bs,
90 | device,
91 | seed,
92 | outdir,
93 | img_format,
94 | turbo,
95 | full_precision,
96 | sampler,
97 | ):
98 |
99 | C = 4
100 | f = 8
101 | start_code = None
102 | model.unet_bs = unet_bs
103 | model.turbo = turbo
104 | model.cdevice = device
105 | modelCS.cond_stage_model.device = device
106 |
107 | if seed == "":
108 | seed = randint(0, 1000000)
109 | seed = int(seed)
110 | seed_everything(seed)
111 | # Logging
112 | logger(locals(), "logs/txt2img_gradio_logs.csv")
113 |
114 | if device != "cpu" and full_precision == False:
115 | model.half()
116 | modelFS.half()
117 | modelCS.half()
118 |
119 | tic = time.time()
120 | os.makedirs(outdir, exist_ok=True)
121 | outpath = outdir
122 | sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
123 | os.makedirs(sample_path, exist_ok=True)
124 | base_count = len(os.listdir(sample_path))
125 |
126 | # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
127 | assert prompt is not None
128 | data = [batch_size * [prompt]]
129 |
130 | if full_precision == False and device != "cpu":
131 | precision_scope = autocast
132 | else:
133 | precision_scope = nullcontext
134 |
135 | all_samples = []
136 | seeds = ""
137 | with torch.no_grad():
138 |
139 | all_samples = list()
140 | for _ in trange(n_iter, desc="Sampling"):
141 | for prompts in tqdm(data, desc="data"):
142 | with precision_scope("cuda"):
143 | modelCS.to(device)
144 | uc = None
145 | if scale != 1.0:
146 | uc = modelCS.get_learned_conditioning(batch_size * [""])
147 | if isinstance(prompts, tuple):
148 | prompts = list(prompts)
149 |
150 | subprompts, weights = split_weighted_subprompts(prompts[0])
151 | if len(subprompts) > 1:
152 | c = torch.zeros_like(uc)
153 | totalWeight = sum(weights)
154 | # normalize each "sub prompt" and add it
155 | for i in range(len(subprompts)):
156 | weight = weights[i]
157 | # if not skip_normalize:
158 | weight = weight / totalWeight
159 | c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
160 | else:
161 | c = modelCS.get_learned_conditioning(prompts)
162 |
163 | shape = [batch_size, C, Height // f, Width // f]
164 |
165 | if device != "cpu":
166 | mem = torch.cuda.memory_allocated() / 1e6
167 | modelCS.to("cpu")
168 | while torch.cuda.memory_allocated() / 1e6 >= mem:
169 | time.sleep(1)
170 |
171 | samples_ddim = model.sample(
172 | S=ddim_steps,
173 | conditioning=c,
174 | seed=seed,
175 | shape=shape,
176 | verbose=False,
177 | unconditional_guidance_scale=scale,
178 | unconditional_conditioning=uc,
179 | eta=ddim_eta,
180 | x_T=start_code,
181 | sampler = sampler,
182 | )
183 |
184 | modelFS.to(device)
185 | print("saving images")
186 | for i in range(batch_size):
187 |
188 | x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
189 | x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
190 | all_samples.append(x_sample.to("cpu"))
191 | x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
192 | Image.fromarray(x_sample.astype(np.uint8)).save(
193 | os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
194 | )
195 | seeds += str(seed) + ","
196 | seed += 1
197 | base_count += 1
198 |
199 | if device != "cpu":
200 | mem = torch.cuda.memory_allocated() / 1e6
201 | modelFS.to("cpu")
202 | while torch.cuda.memory_allocated() / 1e6 >= mem:
203 | time.sleep(1)
204 |
205 | del samples_ddim
206 | del x_sample
207 | del x_samples_ddim
208 | print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
209 |
210 | toc = time.time()
211 |
212 | time_taken = (toc - tic) / 60.0
213 | grid = torch.cat(all_samples, 0)
214 | grid = make_grid(grid, nrow=n_iter)
215 | grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
216 |
217 | txt = (
218 | "Samples finished in "
219 | + str(round(time_taken, 3))
220 | + " minutes and exported to "
221 | + sample_path
222 | + "\nSeeds used = "
223 | + seeds[:-1]
224 | )
225 | return Image.fromarray(grid.astype(np.uint8)), txt
226 |
227 |
228 | demo = gr.Interface(
229 | fn=generate,
230 | inputs=[
231 | "text",
232 | gr.Slider(1, 1000, value=50),
233 | gr.Slider(1, 100, step=1),
234 | gr.Slider(1, 100, step=1),
235 | gr.Slider(64, 4096, value=512, step=64),
236 | gr.Slider(64, 4096, value=512, step=64),
237 | gr.Slider(0, 50, value=7.5, step=0.1),
238 | gr.Slider(0, 1, step=0.01),
239 | gr.Slider(1, 2, value=1, step=1),
240 | gr.Text(value="cuda"),
241 | "text",
242 | gr.Text(value="outputs/txt2img-samples"),
243 | gr.Radio(["png", "jpg"], value='png'),
244 | "checkbox",
245 | "checkbox",
246 | gr.Radio(["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"], value="plms"),
247 | ],
248 | outputs=["image", "text"],
249 | )
250 | demo.launch()
251 |
--------------------------------------------------------------------------------
/optimizedSD/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | modelUNet:
2 | base_learning_rate: 1.0e-04
3 | target: optimizedSD.ddpm.UNet
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unetConfigEncode:
21 | target: optimizedSD.openaimodelSplit.UNetModelEncode
22 | params:
23 | image_size: 32 # unused
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions: [4, 2, 1]
28 | num_res_blocks: 2
29 | channel_mult: [1, 2, 4, 4]
30 | num_heads: 8
31 | use_spatial_transformer: True
32 | transformer_depth: 1
33 | context_dim: 768
34 | use_checkpoint: True
35 | legacy: False
36 |
37 | unetConfigDecode:
38 | target: optimizedSD.openaimodelSplit.UNetModelDecode
39 | params:
40 | image_size: 32 # unused
41 | in_channels: 4
42 | out_channels: 4
43 | model_channels: 320
44 | attention_resolutions: [4, 2, 1]
45 | num_res_blocks: 2
46 | channel_mult: [1, 2, 4, 4]
47 | num_heads: 8
48 | use_spatial_transformer: True
49 | transformer_depth: 1
50 | context_dim: 768
51 | use_checkpoint: True
52 | legacy: False
53 |
54 | modelFirstStage:
55 | target: optimizedSD.ddpm.FirstStage
56 | params:
57 | linear_start: 0.00085
58 | linear_end: 0.0120
59 | num_timesteps_cond: 1
60 | log_every_t: 200
61 | timesteps: 1000
62 | first_stage_key: "jpg"
63 | cond_stage_key: "txt"
64 | image_size: 64
65 | channels: 4
66 | cond_stage_trainable: false # Note: different from the one we trained before
67 | conditioning_key: crossattn
68 | monitor: val/loss_simple_ema
69 | scale_factor: 0.18215
70 | use_ema: False
71 | first_stage_config:
72 | target: ldm.models.autoencoder.AutoencoderKL
73 | params:
74 | embed_dim: 4
75 | monitor: val/rec_loss
76 | ddconfig:
77 | double_z: true
78 | z_channels: 4
79 | resolution: 256
80 | in_channels: 3
81 | out_ch: 3
82 | ch: 128
83 | ch_mult:
84 | - 1
85 | - 2
86 | - 4
87 | - 4
88 | num_res_blocks: 2
89 | attn_resolutions: []
90 | dropout: 0.0
91 | lossconfig:
92 | target: torch.nn.Identity
93 |
94 | modelCondStage:
95 | target: optimizedSD.ddpm.CondStage
96 | params:
97 | linear_start: 0.00085
98 | linear_end: 0.0120
99 | num_timesteps_cond: 1
100 | log_every_t: 200
101 | timesteps: 1000
102 | first_stage_key: "jpg"
103 | cond_stage_key: "txt"
104 | image_size: 64
105 | channels: 4
106 | cond_stage_trainable: false # Note: different from the one we trained before
107 | conditioning_key: crossattn
108 | monitor: val/loss_simple_ema
109 | scale_factor: 0.18215
110 | use_ema: False
111 | cond_stage_config:
112 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
113 | params:
114 | device: cpu
115 |
--------------------------------------------------------------------------------
/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11 |
12 |
13 |
14 | cd models/first_stage_models/kl-f4
15 | unzip -o model.zip
16 |
17 | cd ../kl-f8
18 | unzip -o model.zip
19 |
20 | cd ../kl-f16
21 | unzip -o model.zip
22 |
23 | cd ../kl-f32
24 | unzip -o model.zip
25 |
26 | cd ../vq-f4
27 | unzip -o model.zip
28 |
29 | cd ../vq-f4-noattn
30 | unzip -o model.zip
31 |
32 | cd ../vq-f8
33 | unzip -o model.zip
34 |
35 | cd ../vq-f8-n256
36 | unzip -o model.zip
37 |
38 | cd ../vq-f16
39 | unzip -o model.zip
40 |
41 | cd ../..
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13 |
14 |
15 |
16 | cd models/ldm/celeba256
17 | unzip -o celeba-256.zip
18 |
19 | cd ../ffhq256
20 | unzip -o ffhq-256.zip
21 |
22 | cd ../lsun_churches256
23 | unzip -o lsun_churches-256.zip
24 |
25 | cd ../lsun_beds256
26 | unzip -o lsun_beds-256.zip
27 |
28 | cd ../text2img256
29 | unzip -o model.zip
30 |
31 | cd ../cin256
32 | unzip -o model.zip
33 |
34 | cd ../semantic_synthesis512
35 | unzip -o model.zip
36 |
37 | cd ../semantic_synthesis256
38 | unzip -o model.zip
39 |
40 | cd ../bsr_sr
41 | unzip -o model.zip
42 |
43 | cd ../layout2img-openimages256
44 | unzip -o model.zip
45 |
46 | cd ../inpainting_big
47 | unzip -o model.zip
48 |
49 | cd ../..
50 |
--------------------------------------------------------------------------------
/scripts/img2img.py:
--------------------------------------------------------------------------------
1 | """make variations of input image"""
2 |
3 | import argparse, os, sys, glob
4 | import PIL
5 | import torch
6 | import numpy as np
7 | from omegaconf import OmegaConf
8 | from PIL import Image
9 | from tqdm import tqdm, trange
10 | from itertools import islice
11 | from einops import rearrange, repeat
12 | from torchvision.utils import make_grid
13 | from torch import autocast
14 | from contextlib import nullcontext
15 | import time
16 | from pytorch_lightning import seed_everything
17 |
18 | from ldm.util import instantiate_from_config
19 | from ldm.models.diffusion.ddim import DDIMSampler
20 | from ldm.models.diffusion.plms import PLMSSampler
21 |
22 |
23 | def chunk(it, size):
24 | it = iter(it)
25 | return iter(lambda: tuple(islice(it, size)), ())
26 |
27 |
28 | def load_model_from_config(config, ckpt, verbose=False):
29 | print(f"Loading model from {ckpt}")
30 | pl_sd = torch.load(ckpt, map_location="cpu")
31 | if "global_step" in pl_sd:
32 | print(f"Global Step: {pl_sd['global_step']}")
33 | sd = pl_sd["state_dict"]
34 | model = instantiate_from_config(config.model)
35 | m, u = model.load_state_dict(sd, strict=False)
36 | if len(m) > 0 and verbose:
37 | print("missing keys:")
38 | print(m)
39 | if len(u) > 0 and verbose:
40 | print("unexpected keys:")
41 | print(u)
42 |
43 | model.cuda()
44 | model.eval()
45 | return model
46 |
47 |
48 | def load_img(path):
49 | image = Image.open(path).convert("RGB")
50 | w, h = image.size
51 | print(f"loaded input image of size ({w}, {h}) from {path}")
52 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
53 | image = image.resize((w, h), resample=PIL.Image.LANCZOS)
54 | image = np.array(image).astype(np.float32) / 255.0
55 | image = image[None].transpose(0, 3, 1, 2)
56 | image = torch.from_numpy(image)
57 | return 2.*image - 1.
58 |
59 |
60 | def main():
61 | parser = argparse.ArgumentParser()
62 |
63 | parser.add_argument(
64 | "--prompt",
65 | type=str,
66 | nargs="?",
67 | default="a painting of a virus monster playing guitar",
68 | help="the prompt to render"
69 | )
70 |
71 | parser.add_argument(
72 | "--init-img",
73 | type=str,
74 | nargs="?",
75 | help="path to the input image"
76 | )
77 |
78 | parser.add_argument(
79 | "--outdir",
80 | type=str,
81 | nargs="?",
82 | help="dir to write results to",
83 | default="outputs/img2img-samples"
84 | )
85 |
86 | parser.add_argument(
87 | "--skip_grid",
88 | action='store_true',
89 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
90 | )
91 |
92 | parser.add_argument(
93 | "--skip_save",
94 | action='store_true',
95 | help="do not save indiviual samples. For speed measurements.",
96 | )
97 |
98 | parser.add_argument(
99 | "--ddim_steps",
100 | type=int,
101 | default=50,
102 | help="number of ddim sampling steps",
103 | )
104 |
105 | parser.add_argument(
106 | "--plms",
107 | action='store_true',
108 | help="use plms sampling",
109 | )
110 | parser.add_argument(
111 | "--fixed_code",
112 | action='store_true',
113 | help="if enabled, uses the same starting code across all samples ",
114 | )
115 |
116 | parser.add_argument(
117 | "--ddim_eta",
118 | type=float,
119 | default=0.0,
120 | help="ddim eta (eta=0.0 corresponds to deterministic sampling",
121 | )
122 | parser.add_argument(
123 | "--n_iter",
124 | type=int,
125 | default=1,
126 | help="sample this often",
127 | )
128 | parser.add_argument(
129 | "--C",
130 | type=int,
131 | default=4,
132 | help="latent channels",
133 | )
134 | parser.add_argument(
135 | "--f",
136 | type=int,
137 | default=8,
138 | help="downsampling factor, most often 8 or 16",
139 | )
140 | parser.add_argument(
141 | "--n_samples",
142 | type=int,
143 | default=2,
144 | help="how many samples to produce for each given prompt. A.k.a batch size",
145 | )
146 | parser.add_argument(
147 | "--n_rows",
148 | type=int,
149 | default=0,
150 | help="rows in the grid (default: n_samples)",
151 | )
152 | parser.add_argument(
153 | "--scale",
154 | type=float,
155 | default=5.0,
156 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
157 | )
158 |
159 | parser.add_argument(
160 | "--strength",
161 | type=float,
162 | default=0.75,
163 | help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
164 | )
165 | parser.add_argument(
166 | "--from-file",
167 | type=str,
168 | help="if specified, load prompts from this file",
169 | )
170 | parser.add_argument(
171 | "--config",
172 | type=str,
173 | default="configs/stable-diffusion/v1-inference.yaml",
174 | help="path to config which constructs model",
175 | )
176 | parser.add_argument(
177 | "--ckpt",
178 | type=str,
179 | default="models/ldm/stable-diffusion-v1/model.ckpt",
180 | help="path to checkpoint of model",
181 | )
182 | parser.add_argument(
183 | "--seed",
184 | type=int,
185 | default=42,
186 | help="the seed (for reproducible sampling)",
187 | )
188 | parser.add_argument(
189 | "--precision",
190 | type=str,
191 | help="evaluate at this precision",
192 | choices=["full", "autocast"],
193 | default="autocast"
194 | )
195 |
196 | opt = parser.parse_args()
197 | seed_everything(opt.seed)
198 |
199 | config = OmegaConf.load(f"{opt.config}")
200 | model = load_model_from_config(config, f"{opt.ckpt}")
201 |
202 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
203 | model = model.to(device)
204 |
205 | if opt.plms:
206 | raise NotImplementedError("PLMS sampler not (yet) supported")
207 | sampler = PLMSSampler(model)
208 | else:
209 | sampler = DDIMSampler(model)
210 |
211 | os.makedirs(opt.outdir, exist_ok=True)
212 | outpath = opt.outdir
213 |
214 | batch_size = opt.n_samples
215 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
216 | if not opt.from_file:
217 | prompt = opt.prompt
218 | assert prompt is not None
219 | data = [batch_size * [prompt]]
220 |
221 | else:
222 | print(f"reading prompts from {opt.from_file}")
223 | with open(opt.from_file, "r") as f:
224 | data = f.read().splitlines()
225 | data = list(chunk(data, batch_size))
226 |
227 | sample_path = os.path.join(outpath, "samples")
228 | os.makedirs(sample_path, exist_ok=True)
229 | base_count = len(os.listdir(sample_path))
230 | grid_count = len(os.listdir(outpath)) - 1
231 |
232 | assert os.path.isfile(opt.init_img)
233 | init_image = load_img(opt.init_img).to(device)
234 | init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
235 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
236 |
237 | sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
238 |
239 | assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
240 | t_enc = int(opt.strength * opt.ddim_steps)
241 | print(f"target t_enc is {t_enc} steps")
242 |
243 | precision_scope = autocast if opt.precision == "autocast" else nullcontext
244 | with torch.no_grad():
245 | with precision_scope("cuda"):
246 | with model.ema_scope():
247 | tic = time.time()
248 | all_samples = list()
249 | for n in trange(opt.n_iter, desc="Sampling"):
250 | for prompts in tqdm(data, desc="data"):
251 | uc = None
252 | if opt.scale != 1.0:
253 | uc = model.get_learned_conditioning(batch_size * [""])
254 | if isinstance(prompts, tuple):
255 | prompts = list(prompts)
256 | c = model.get_learned_conditioning(prompts)
257 |
258 | # encode (scaled latent)
259 | z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
260 | # decode it
261 | samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
262 | unconditional_conditioning=uc,)
263 |
264 | x_samples = model.decode_first_stage(samples)
265 | x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
266 |
267 | if not opt.skip_save:
268 | for x_sample in x_samples:
269 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
270 | Image.fromarray(x_sample.astype(np.uint8)).save(
271 | os.path.join(sample_path, f"{base_count:05}.png"))
272 | base_count += 1
273 | all_samples.append(x_samples)
274 |
275 | if not opt.skip_grid:
276 | # additionally, save as grid
277 | grid = torch.stack(all_samples, 0)
278 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
279 | grid = make_grid(grid, nrow=n_rows)
280 |
281 | # to image
282 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
283 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
284 | grid_count += 1
285 |
286 | toc = time.time()
287 |
288 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
289 | f" \nEnjoy.")
290 |
291 |
292 | if __name__ == "__main__":
293 | main()
294 |
--------------------------------------------------------------------------------
/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | opt = parser.parse_args()
54 |
55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
56 | images = [x.replace("_mask.png", ".png") for x in masks]
57 | print(f"Found {len(masks)} inputs.")
58 |
59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
60 | model = instantiate_from_config(config.model)
61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
62 | strict=False)
63 |
64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
65 | model = model.to(device)
66 | sampler = DDIMSampler(model)
67 |
68 | os.makedirs(opt.outdir, exist_ok=True)
69 | with torch.no_grad():
70 | with model.ema_scope():
71 | for image, mask in tqdm(zip(images, masks)):
72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
73 | batch = make_batch(image, mask, device=device)
74 |
75 | # encode masked image and concat downsampled mask
76 | c = model.cond_stage_model.encode(batch["masked_image"])
77 | cc = torch.nn.functional.interpolate(batch["mask"],
78 | size=c.shape[-2:])
79 | c = torch.cat((c, cc), dim=1)
80 |
81 | shape = (c.shape[1]-1,)+c.shape[2:]
82 | samples_ddim, _ = sampler.sample(S=opt.steps,
83 | conditioning=c,
84 | batch_size=c.shape[0],
85 | shape=shape,
86 | verbose=False)
87 | x_samples_ddim = model.decode_first_stage(samples_ddim)
88 |
89 | image = torch.clamp((batch["image"]+1.0)/2.0,
90 | min=0.0, max=1.0)
91 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
92 | min=0.0, max=1.0)
93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
94 | min=0.0, max=1.0)
95 |
96 | inpainted = (1-mask)*image+mask*predicted_image
97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
99 |
--------------------------------------------------------------------------------
/scripts/train_searcher.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import scann
4 | import argparse
5 | import glob
6 | from multiprocessing import cpu_count
7 | from tqdm import tqdm
8 |
9 | from ldm.util import parallel_data_prefetch
10 |
11 |
12 | def search_bruteforce(searcher):
13 | return searcher.score_brute_force().build()
14 |
15 |
16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17 | partioning_trainsize, num_leaves, num_leaves_to_search):
18 | return searcher.tree(num_leaves=num_leaves,
19 | num_leaves_to_search=num_leaves_to_search,
20 | training_sample_size=partioning_trainsize). \
21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22 |
23 |
24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26 | reorder_k).build()
27 |
28 | def load_datapool(dpath):
29 |
30 |
31 | def load_single_file(saved_embeddings):
32 | compressed = np.load(saved_embeddings)
33 | database = {key: compressed[key] for key in compressed.files}
34 | return database
35 |
36 | def load_multi_files(data_archive):
37 | database = {key: [] for key in data_archive[0].files}
38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39 | for key in d.files:
40 | database[key].append(d[key])
41 |
42 | return database
43 |
44 | print(f'Load saved patch embedding from "{dpath}"')
45 | file_content = glob.glob(os.path.join(dpath, '*.npz'))
46 |
47 | if len(file_content) == 1:
48 | data_pool = load_single_file(file_content[0])
49 | elif len(file_content) > 1:
50 | data = [np.load(f) for f in file_content]
51 | prefetched_data = parallel_data_prefetch(load_multi_files, data,
52 | n_proc=min(len(data), cpu_count()), target_data_type='dict')
53 |
54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55 | else:
56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57 |
58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59 | return data_pool
60 |
61 |
62 | def train_searcher(opt,
63 | metric='dot_product',
64 | partioning_trainsize=None,
65 | reorder_k=None,
66 | # todo tune
67 | aiq_thld=0.2,
68 | dims_per_block=2,
69 | num_leaves=None,
70 | num_leaves_to_search=None,):
71 |
72 | data_pool = load_datapool(opt.database)
73 | k = opt.knn
74 |
75 | if not reorder_k:
76 | reorder_k = 2 * k
77 |
78 | # normalize
79 | # embeddings =
80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81 | pool_size = data_pool['embedding'].shape[0]
82 |
83 | print(*(['#'] * 100))
84 | print('Initializing scaNN searcher with the following values:')
85 | print(f'k: {k}')
86 | print(f'metric: {metric}')
87 | print(f'reorder_k: {reorder_k}')
88 | print(f'anisotropic_quantization_threshold: {aiq_thld}')
89 | print(f'dims_per_block: {dims_per_block}')
90 | print(*(['#'] * 100))
91 | print('Start training searcher....')
92 | print(f'N samples in pool is {pool_size}')
93 |
94 | # this reflects the recommended design choices proposed at
95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96 | if pool_size < 2e4:
97 | print('Using brute force search.')
98 | searcher = search_bruteforce(searcher)
99 | elif 2e4 <= pool_size and pool_size < 1e5:
100 | print('Using asymmetric hashing search and reordering.')
101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102 | else:
103 | print('Using using partioning, asymmetric hashing search and reordering.')
104 |
105 | if not partioning_trainsize:
106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10
107 | if not num_leaves:
108 | num_leaves = int(np.sqrt(pool_size))
109 |
110 | if not num_leaves_to_search:
111 | num_leaves_to_search = max(num_leaves // 20, 1)
112 |
113 | print('Partitioning params:')
114 | print(f'num_leaves: {num_leaves}')
115 | print(f'num_leaves_to_search: {num_leaves_to_search}')
116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118 | partioning_trainsize, num_leaves, num_leaves_to_search)
119 |
120 | print('Finish training searcher')
121 | searcher_savedir = opt.target_path
122 | os.makedirs(searcher_savedir, exist_ok=True)
123 | searcher.serialize(searcher_savedir)
124 | print(f'Saved trained searcher under "{searcher_savedir}"')
125 |
126 | if __name__ == '__main__':
127 | sys.path.append(os.getcwd())
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--database',
130 | '-d',
131 | default='data/rdm/retrieval_databases/openimages',
132 | type=str,
133 | help='path to folder containing the clip feature of the database')
134 | parser.add_argument('--target_path',
135 | '-t',
136 | default='data/rdm/searchers/openimages',
137 | type=str,
138 | help='path to the target folder where the searcher shall be stored.')
139 | parser.add_argument('--knn',
140 | '-k',
141 | default=20,
142 | type=int,
143 | help='number of nearest neighbors, for which the searcher shall be optimized')
144 |
145 | opt, _ = parser.parse_known_args()
146 |
147 | train_searcher(opt,)
--------------------------------------------------------------------------------
/scripts/txt2img.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | import torch
3 | import numpy as np
4 | from omegaconf import OmegaConf
5 | from PIL import Image
6 | from tqdm import tqdm, trange
7 | from itertools import islice
8 | from einops import rearrange
9 | from torchvision.utils import make_grid
10 | import time
11 | from pytorch_lightning import seed_everything
12 | from torch import autocast
13 | from contextlib import contextmanager, nullcontext
14 |
15 | from ldm.util import instantiate_from_config
16 | from ldm.models.diffusion.ddim import DDIMSampler
17 | from ldm.models.diffusion.plms import PLMSSampler
18 |
19 |
20 | def chunk(it, size):
21 | it = iter(it)
22 | return iter(lambda: tuple(islice(it, size)), ())
23 |
24 |
25 | def load_model_from_config(config, ckpt, verbose=False):
26 | print(f"Loading model from {ckpt}")
27 | pl_sd = torch.load(ckpt, map_location="cpu")
28 | if "global_step" in pl_sd:
29 | print(f"Global Step: {pl_sd['global_step']}")
30 | sd = pl_sd["state_dict"]
31 | model = instantiate_from_config(config.model)
32 | m, u = model.load_state_dict(sd, strict=False)
33 | if len(m) > 0 and verbose:
34 | print("missing keys:")
35 | print(m)
36 | if len(u) > 0 and verbose:
37 | print("unexpected keys:")
38 | print(u)
39 |
40 | model.cuda()
41 | model.eval()
42 | return model
43 |
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 |
48 | parser.add_argument(
49 | "--prompt",
50 | type=str,
51 | nargs="?",
52 | default="a painting of a virus monster playing guitar",
53 | help="the prompt to render"
54 | )
55 | parser.add_argument(
56 | "--outdir",
57 | type=str,
58 | nargs="?",
59 | help="dir to write results to",
60 | default="outputs/txt2img-samples"
61 | )
62 | parser.add_argument(
63 | "--skip_grid",
64 | action='store_true',
65 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
66 | )
67 | parser.add_argument(
68 | "--skip_save",
69 | action='store_true',
70 | help="do not save individual samples. For speed measurements.",
71 | )
72 | parser.add_argument(
73 | "--ddim_steps",
74 | type=int,
75 | default=50,
76 | help="number of ddim sampling steps",
77 | )
78 | parser.add_argument(
79 | "--plms",
80 | action='store_true',
81 | help="use plms sampling",
82 | )
83 | parser.add_argument(
84 | "--laion400m",
85 | action='store_true',
86 | help="uses the LAION400M model",
87 | )
88 | parser.add_argument(
89 | "--fixed_code",
90 | action='store_true',
91 | help="if enabled, uses the same starting code across samples ",
92 | )
93 | parser.add_argument(
94 | "--ddim_eta",
95 | type=float,
96 | default=0.0,
97 | help="ddim eta (eta=0.0 corresponds to deterministic sampling",
98 | )
99 | parser.add_argument(
100 | "--n_iter",
101 | type=int,
102 | default=2,
103 | help="sample this often",
104 | )
105 | parser.add_argument(
106 | "--H",
107 | type=int,
108 | default=512,
109 | help="image height, in pixel space",
110 | )
111 | parser.add_argument(
112 | "--W",
113 | type=int,
114 | default=512,
115 | help="image width, in pixel space",
116 | )
117 | parser.add_argument(
118 | "--C",
119 | type=int,
120 | default=4,
121 | help="latent channels",
122 | )
123 | parser.add_argument(
124 | "--f",
125 | type=int,
126 | default=8,
127 | help="downsampling factor",
128 | )
129 | parser.add_argument(
130 | "--n_samples",
131 | type=int,
132 | default=3,
133 | help="how many samples to produce for each given prompt. A.k.a. batch size",
134 | )
135 | parser.add_argument(
136 | "--n_rows",
137 | type=int,
138 | default=0,
139 | help="rows in the grid (default: n_samples)",
140 | )
141 | parser.add_argument(
142 | "--scale",
143 | type=float,
144 | default=7.5,
145 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
146 | )
147 | parser.add_argument(
148 | "--from-file",
149 | type=str,
150 | help="if specified, load prompts from this file",
151 | )
152 | parser.add_argument(
153 | "--config",
154 | type=str,
155 | default="configs/stable-diffusion/v1-inference.yaml",
156 | help="path to config which constructs model",
157 | )
158 | parser.add_argument(
159 | "--ckpt",
160 | type=str,
161 | default="models/ldm/stable-diffusion-v1/model.ckpt",
162 | help="path to checkpoint of model",
163 | )
164 | parser.add_argument(
165 | "--seed",
166 | type=int,
167 | default=42,
168 | help="the seed (for reproducible sampling)",
169 | )
170 | parser.add_argument(
171 | "--precision",
172 | type=str,
173 | help="evaluate at this precision",
174 | choices=["full", "autocast"],
175 | default="autocast"
176 | )
177 | opt = parser.parse_args()
178 |
179 | if opt.laion400m:
180 | print("Falling back to LAION 400M model...")
181 | opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
182 | opt.ckpt = "models/ldm/text2img-large/model.ckpt"
183 | opt.outdir = "outputs/txt2img-samples-laion400m"
184 |
185 | seed_everything(opt.seed)
186 |
187 | config = OmegaConf.load(f"{opt.config}")
188 | model = load_model_from_config(config, f"{opt.ckpt}")
189 |
190 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
191 | model = model.to(device)
192 |
193 | if opt.plms:
194 | sampler = PLMSSampler(model)
195 | else:
196 | sampler = DDIMSampler(model)
197 |
198 | os.makedirs(opt.outdir, exist_ok=True)
199 | outpath = opt.outdir
200 |
201 | batch_size = opt.n_samples
202 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
203 | if not opt.from_file:
204 | prompt = opt.prompt
205 | assert prompt is not None
206 | data = [batch_size * [prompt]]
207 |
208 | else:
209 | print(f"reading prompts from {opt.from_file}")
210 | with open(opt.from_file, "r") as f:
211 | data = f.read().splitlines()
212 | data = list(chunk(data, batch_size))
213 |
214 | sample_path = os.path.join(outpath, "samples")
215 | os.makedirs(sample_path, exist_ok=True)
216 | base_count = len(os.listdir(sample_path))
217 | grid_count = len(os.listdir(outpath)) - 1
218 |
219 | start_code = None
220 | if opt.fixed_code:
221 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
222 |
223 | precision_scope = autocast if opt.precision=="autocast" else nullcontext
224 | with torch.no_grad():
225 | with precision_scope("cuda"):
226 | with model.ema_scope():
227 | tic = time.time()
228 | all_samples = list()
229 | for n in trange(opt.n_iter, desc="Sampling"):
230 | for prompts in tqdm(data, desc="data"):
231 | uc = None
232 | if opt.scale != 1.0:
233 | uc = model.get_learned_conditioning(batch_size * [""])
234 | if isinstance(prompts, tuple):
235 | prompts = list(prompts)
236 | c = model.get_learned_conditioning(prompts)
237 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
238 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
239 | conditioning=c,
240 | batch_size=opt.n_samples,
241 | shape=shape,
242 | verbose=False,
243 | unconditional_guidance_scale=opt.scale,
244 | unconditional_conditioning=uc,
245 | eta=opt.ddim_eta,
246 | x_T=start_code)
247 |
248 | x_samples_ddim = model.decode_first_stage(samples_ddim)
249 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
250 |
251 | if not opt.skip_save:
252 | for x_sample in x_samples_ddim:
253 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
254 | Image.fromarray(x_sample.astype(np.uint8)).save(
255 | os.path.join(sample_path, f"{base_count:05}.png"))
256 | base_count += 1
257 |
258 | if not opt.skip_grid:
259 | all_samples.append(x_samples_ddim)
260 |
261 | if not opt.skip_grid:
262 | # additionally, save as grid
263 | grid = torch.stack(all_samples, 0)
264 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
265 | grid = make_grid(grid, nrow=n_rows)
266 |
267 | # to image
268 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
269 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
270 | grid_count += 1
271 |
272 | toc = time.time()
273 |
274 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
275 | f" \nEnjoy.")
276 |
277 |
278 | if __name__ == "__main__":
279 | main()
280 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------