├── .dockerignore ├── .github └── FUNDING.yml ├── Dockerfile ├── LICENSE ├── README.md ├── Stable_Diffusion_v1_Model_Card.md ├── assets ├── a-painting-of-a-fire.png ├── a-photograph-of-a-fire.png ├── a-shirt-with-a-fire-printed-on-it.png ├── a-shirt-with-the-inscription-'fire'.png ├── a-watercolor-painting-of-a-fire.png ├── birdhouse.png ├── fire.png ├── inpainting.png ├── modelfigure.png ├── rdm-preview.jpg ├── reconstruction1.png ├── reconstruction2.png ├── results.gif ├── stable-samples │ ├── img2img │ │ ├── mountains-1.png │ │ ├── mountains-2.png │ │ ├── mountains-3.png │ │ ├── sketch-mountains-input.jpg │ │ ├── upscaling-in.png │ │ └── upscaling-out.png │ └── txt2img │ │ ├── 000002025.png │ │ ├── 000002035.png │ │ ├── merged-0005.png │ │ ├── merged-0006.png │ │ └── merged-0007.png ├── the-earth-is-on-fire,-oil-on-canvas.png ├── txt2img-convsample.png ├── txt2img-preview.png └── v1-variants-scores.jpg ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml └── stable-diffusion │ └── v1-inference.yaml ├── data ├── DejaVuSans.ttf ├── example_conditioning │ ├── superresolution │ │ └── sample_0.jpg │ └── text_conditional │ │ └── sample_0.txt ├── imagenet_clsidx_to_label.txt ├── imagenet_train_hr_indices.p ├── imagenet_val_hr_indices.p ├── index_synset.yaml └── inpainting_examples │ ├── 6458524847_2f4c361183_k.png │ ├── 6458524847_2f4c361183_k_mask.png │ ├── 8399166846_f6fb4e4b8e_k.png │ ├── 8399166846_f6fb4e4b8e_k_mask.png │ ├── alex-iby-G_Pk4D9rMLs.png │ ├── alex-iby-G_Pk4D9rMLs_mask.png │ ├── bench2.png │ ├── bench2_mask.png │ ├── bertrand-gabioud-CpuFzIsHYJ0.png │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png │ ├── billow926-12-Wc-Zgx6Y.png │ ├── billow926-12-Wc-Zgx6Y_mask.png │ ├── overture-creations-5sI6fQgYIuo.png │ ├── overture-creations-5sI6fQgYIuo_mask.png │ ├── photo-1583445095369-9c651e7e5d34.png │ └── photo-1583445095369-9c651e7e5d34_mask.png ├── docker-bootstrap.sh ├── docker-compose.yml ├── environment.yaml ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── models ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml └── ldm │ ├── bsr_sr │ └── config.yaml │ ├── celeba256 │ └── config.yaml │ ├── cin256 │ └── config.yaml │ ├── ffhq256 │ └── config.yaml │ ├── inpainting_big │ └── config.yaml │ ├── layout2img-openimages256 │ └── config.yaml │ ├── lsun_beds256 │ └── config.yaml │ ├── lsun_churches256 │ └── config.yaml │ ├── semantic_synthesis256 │ └── config.yaml │ ├── semantic_synthesis512 │ └── config.yaml │ └── text2img256 │ └── config.yaml ├── notebook_helpers.py ├── optimizedSD ├── LICENSE ├── ddpm.py ├── diffusers_txt2img.py ├── img2img_gradio.py ├── inpaint_gradio.py ├── openaimodelSplit.py ├── optimUtils.py ├── optimized_img2img.py ├── optimized_txt2img.py ├── samplers.py ├── splitAttention.py ├── txt2img_gradio.py └── v1-inference.yaml ├── scripts ├── download_first_stages.sh ├── download_models.sh ├── img2img.py ├── inpaint.py ├── knn2img.py ├── latent_imagenet_diffusion.ipynb ├── sample_diffusion.py ├── train_searcher.py └── txt2img.py └── setup.py /.dockerignore: -------------------------------------------------------------------------------- 1 | *Dockerfile* 2 | docker-compose.yml 3 | .git 4 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: basuj 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: ['https://paypal.me/basuj'] 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3:4.12.0 AS build 2 | 3 | # Step for image utility dependencies. 4 | RUN apt update \ 5 | && apt install --no-install-recommends -y git \ 6 | && apt-get clean 7 | 8 | COPY . /root/stable-diffusion/ 9 | 10 | # Step to install dependencies with conda 11 | RUN eval "$(conda shell.bash hook)" \ 12 | && conda install -c conda-forge conda-pack \ 13 | && conda env create -f /root/stable-diffusion/environment.yaml \ 14 | && conda activate ldm \ 15 | && pip install gradio==3.1.7 \ 16 | && conda activate base 17 | 18 | # Step to zip and conda environment to "venv" folder 19 | RUN conda pack --ignore-missing-files --ignore-editable-packages -n ldm -o /tmp/env.tar \ 20 | && mkdir /venv \ 21 | && cd /venv \ 22 | && tar xf /tmp/env.tar \ 23 | && rm /tmp/env.tar 24 | 25 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as runtime 26 | 27 | ARG OPTIMIZED_FILE=txt2img_gradio.py 28 | WORKDIR /root/stable-diffusion 29 | 30 | COPY --from=build /venv /venv 31 | COPY --from=build /root/stable-diffusion /root/stable-diffusion 32 | 33 | RUN mkdir -p /output /root/stable-diffusion/outputs \ 34 | && ln -s /data /root/stable-diffusion/models/ldm/stable-diffusion-v1 \ 35 | && ln -s /output /root/stable-diffusion/outputs/txt2img-samples 36 | 37 | ENV PYTHONUNBUFFERED=1 38 | ENV GRADIO_SERVER_NAME=0.0.0.0 39 | ENV GRADIO_SERVER_PORT=7860 40 | ENV APP_MAIN_FILE=${OPTIMIZED_FILE} 41 | EXPOSE 7860 42 | 43 | VOLUME ["/root/.cache", "/data", "/output"] 44 | 45 | SHELL ["/bin/bash", "-c"] 46 | ENTRYPOINT ["/root/stable-diffusion/docker-bootstrap.sh"] 47 | CMD python optimizedSD/${APP_MAIN_FILE} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All rights reserved by the authors. 2 | You must not distribute the weights provided to you directly or indirectly without explicit consent of the authors. 3 | You must not distribute harmful, offensive, dehumanizing content or otherwise harmful representations of people or their environments, cultures, religions, etc. produced with the model weights 4 | or other generated content described in the "Misuse and Malicious Use" section in the model card. 5 | The model weights are provided for research purposes only. 6 | 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 9 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 10 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 11 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 12 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 13 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 14 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Optimized Stable Diffusion

2 |

3 | 4 | 5 | 6 |

7 | 8 | This repo is a modified version of the Stable Diffusion repo, optimized to use less VRAM than the original by sacrificing inference speed. 9 | 10 | To reduce the VRAM usage, the following opimizations are used: 11 | 12 | - the stable diffusion model is fragmented into four parts which are sent to the GPU only when needed. After the calculation is done, they are moved back to the CPU. 13 | - The attention calculation is done in parts. 14 | 15 |

Installation

16 | 17 | All the modified files are in the [optimizedSD](optimizedSD) folder, so if you have already cloned the original repository you can just download and copy this folder into the original instead of cloning the entire repo. You can also clone this repo and follow the same installation steps as the original (mainly creating the conda environment and placing the weights at the specified location). 18 | 19 | Alternatively, if you prefer to use Docker, you can do the following: 20 | 21 | 1. Install [Docker](https://docs.docker.com/engine/install/), [Docker Compose plugin](https://docs.docker.com/compose/install/), and [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) 22 | 2. Clone this repo to, e.g., `~/stable-diffusion` 23 | 3. Put your downloaded `model.ckpt` file into `~/sd-data` (it's a relative path, you can change it in `docker-compose.yml`) 24 | 4. `cd` into `~/stable-diffusion` and execute `docker compose up --build` 25 | 26 | This will launch gradio on port 7860 with txt2img. You can also use `docker compose run` to execute other Python scripts. 27 | 28 |

Usage

29 | 30 | ## img2img 31 | 32 | - `img2img` can generate _512x512 images from a prior image and prompt using under 2.4GB VRAM in under 20 seconds per image_ on an RTX 2060. 33 | 34 | - The maximum size that can fit on 6GB GPU (RTX 2060) is around 1152x1088. 35 | 36 | - For example, the following command will generate 10 512x512 images: 37 | 38 | `python optimizedSD/optimized_img2img.py --prompt "Austrian alps" --init-img ~/sketch-mountains-input.jpg --strength 0.8 --n_iter 2 --n_samples 5 --H 512 --W 512` 39 | 40 | ## txt2img 41 | 42 | - `txt2img` can generate _512x512 images from a prompt using under 2.4GB GPU VRAM in under 24 seconds per image_ on an RTX 2060. 43 | 44 | - For example, the following command will generate 10 512x512 images: 45 | 46 | `python optimizedSD/optimized_txt2img.py --prompt "Cyberpunk style image of a Tesla car reflection in rain" --H 512 --W 512 --seed 27 --n_iter 2 --n_samples 5 --ddim_steps 50` 47 | 48 | ## inpainting 49 | 50 | - `inpaint_gradio.py` can fill masked parts of an image based on a given prompt. It can inpaint 512x512 images while using under 2.5GB of VRAM. 51 | 52 | - To launch the gradio interface for inpainting, run `python optimizedSD/inpaint_gradio.py`. The mask for the image can be drawn on the selected image using the brush tool. 53 | 54 | - The results are not yet perfect but can be improved by using a combination of prompt weighting, prompt engineering and testing out multiple values of the `--strength` argument. 55 | 56 | - _Suggestions to improve the inpainting algorithm are most welcome_. 57 | 58 |

Using the Gradio GUI

59 | 60 | - You can also use the built-in gradio interface for `img2img`, `txt2img` & `inpainting` instead of the command line interface. Activate the conda environment and install the latest version of gradio using `pip install gradio`, 61 | 62 | - Run img2img using `python optimizedSD/img2img_gradio.py`, txt2img using `python optimizedSD/txt2img_gradio.py` and inpainting using `python optimizedSD/inpaint_gradio.py`. 63 | 64 | - img2img_gradio.py has a feature to crop input images. Look for the pen symbol in the image box after selecting the image. 65 | 66 |

Arguments

67 | 68 | ## `--seed` 69 | 70 | **Seed for image generation**, can be used to reproduce previously generated images. Defaults to a random seed if unspecified. 71 | 72 | - The code will give the seed number along with each generated image. To generate the same image again, just specify the seed using `--seed` argument. Images are saved with its seed number as its name by default. 73 | 74 | - For example if the seed number for an image is `1234` and it's the 55th image in the folder, the image name will be named `seed_1234_00055.png`. 75 | 76 | ## `--n_samples` 77 | 78 | **Batch size/amount of images to generate at once.** 79 | 80 | - To get the lowest inference time per image, use the maximum batch size `--n_samples` that can fit on the GPU. Inference time per image will reduce on increasing the batch size, but the required VRAM will increase. 81 | 82 | - If you get a CUDA out of memory error, try reducing the batch size `--n_samples`. If it doesn't work, the other option is to reduce the image width `--W` or height `--H` or both. 83 | 84 | ## `--n_iter` 85 | 86 | **Run _x_ amount of times** 87 | 88 | - Equivalent to running the script n_iter number of times. Only difference is that the model is loaded only once per n_iter iterations. Unlike `n_samples`, reducing it doesn't have an effect on VRAM required or inference time. 89 | 90 | ## `--H` & `--W` 91 | 92 | **Height & width of the generated image.** 93 | 94 | - Both height and width should be a multiple of 64. 95 | 96 | ## `--turbo` 97 | 98 | **Increases inference speed at the cost of extra VRAM usage.** 99 | 100 | - Using this argument increases the inference speed by using around 700MB of extra GPU VRAM. It is especially effective when generating a small batch of images (~ 1 to 4) images. It takes under 20 seconds for txt2img and 15 seconds for img2img (on an RTX 2060, excluding the time to load the model). Use it on larger batch sizes if GPU VRAM available. 101 | 102 | ## `--precision autocast` or `--precision full` 103 | 104 | **Whether to use `full` or `mixed` precision** 105 | 106 | - Mixed Precision is enabled by default. If you don't have a GPU with tensor cores (any GTX 10 series card), you may not be able use mixed precision. Use the `--precision full` argument to disable it. 107 | 108 | ## `--format png` or `--format jpg` 109 | 110 | **Output image format** 111 | 112 | - The default output format is `png`. While `png` is lossless, it takes up a lot of space (unless large portions of the image happen to be a single colour). Use lossy `jpg` to get smaller image file sizes. 113 | 114 | ## `--unet_bs` 115 | 116 | **Batch size for the unet model** 117 | 118 | - Takes up a lot of extra RAM for **very little improvement** in inference time. `unet_bs` > 1 is not recommended! 119 | 120 | - Should generally be a multiple of 2x(n_samples) 121 | 122 |

Weighted Prompts

123 | 124 | - Prompts can also be weighted to put relative emphasis on certain words. 125 | eg. `--prompt tabby cat:0.25 white duck:0.75 hybrid`. 126 | 127 | - The number followed by the colon represents the weight given to the words before the colon. The weights can be both fractions or integers. 128 | 129 | ## Troubleshooting 130 | 131 | ### Green colored output images 132 | 133 | - If you have a Nvidia GTX series GPU, the output images maybe entirely green in color. This is because GTX series do not support half precision calculation, which is the default mode of calculation in this repository. To overcome the issue, use the `--precision full` argument. The downside is that it will lead to higher GPU VRAM usage. 134 | 135 | ### 136 | 137 | ## Changelog 138 | 139 | - v1.0: Added support for multiple samplers for txt2img. Based on [crowsonkb](https://github.com/crowsonkb/k-diffusion) 140 | - v0.9: Added support for calculating attention in parts. (Thanks to @neonsecret @Doggettx, @ryudrigo) 141 | - v0.8: Added gradio interface for inpainting. 142 | - v0.7: Added support for logging, jpg file format 143 | - v0.6: Added support for using weighted prompts. (based on @lstein's [repo](https://github.com/lstein/stable-diffusion)) 144 | - v0.5: Added support for using gradio interface. 145 | - v0.4: Added support for specifying image seed. 146 | - v0.3: Added support for using mixed precision. 147 | - v0.2: Added support for generating images in batches. 148 | - v0.1: Split the model into multiple parts to run it on lower VRAM. 149 | -------------------------------------------------------------------------------- /Stable_Diffusion_v1_Model_Card.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion v1 Model Card 2 | This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion). 3 | 4 | ## Model Details 5 | - **Developed by:** Robin Rombach, Patrick Esser 6 | - **Model type:** Diffusion-based text-to-image generation model 7 | - **Language(s):** English 8 | - **License:** [Proprietary](LICENSE) 9 | - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487). 10 | - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752). 11 | - **Cite as:** 12 | 13 | @InProceedings{Rombach_2022_CVPR, 14 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, 15 | title = {High-Resolution Image Synthesis With Latent Diffusion Models}, 16 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 17 | month = {June}, 18 | year = {2022}, 19 | pages = {10684-10695} 20 | } 21 | 22 | # Uses 23 | 24 | ## Direct Use 25 | The model is intended for research purposes only. Possible research areas and 26 | tasks include 27 | 28 | - Safe deployment of models which have the potential to generate harmful content. 29 | - Probing and understanding the limitations and biases of generative models. 30 | - Generation of artworks and use in design and other artistic processes. 31 | - Applications in educational or creative tools. 32 | - Research on generative models. 33 | 34 | Excluded uses are described below. 35 | 36 | ### Misuse, Malicious Use, and Out-of-Scope Use 37 | _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_. 38 | 39 | 40 | The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 41 | #### Out-of-Scope Use 42 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. 43 | #### Misuse and Malicious Use 44 | Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to: 45 | 46 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. 47 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes. 48 | - Impersonating individuals without their consent. 49 | - Sexual content without consent of the people who might see it. 50 | - Mis- and disinformation 51 | - Representations of egregious violence and gore 52 | - Sharing of copyrighted or licensed material in violation of its terms of use. 53 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. 54 | 55 | ## Limitations and Bias 56 | 57 | ### Limitations 58 | 59 | - The model does not achieve perfect photorealism 60 | - The model cannot render legible text 61 | - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere” 62 | - Faces and people in general may not be generated properly. 63 | - The model was trained mainly with English captions and will not work as well in other languages. 64 | - The autoencoding part of the model is lossy 65 | - The model was trained on a large-scale dataset 66 | [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material 67 | and is not fit for product use without additional safety mechanisms and 68 | considerations. 69 | 70 | ### Bias 71 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. 72 | Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), 73 | which consists of images that are primarily limited to English descriptions. 74 | Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. 75 | This affects the overall output of the model, as white and western cultures are often set as the default. Further, the 76 | ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts. 77 | 78 | 79 | ## Training 80 | 81 | **Training Data** 82 | The model developers used the following dataset for training the model: 83 | 84 | - LAION-2B (en) and subsets thereof (see next section) 85 | 86 | **Training Procedure** 87 | Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training, 88 | 89 | - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4 90 | - Text prompts are encoded through a ViT-L/14 text-encoder. 91 | - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention. 92 | - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. 93 | 94 | We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`, 95 | which were trained as follows, 96 | 97 | - `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en). 98 | 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`). 99 | - `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`. 100 | 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en, 101 | filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)). 102 | - `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 103 | 104 | 105 | - **Hardware:** 32 x 8 x A100 GPUs 106 | - **Optimizer:** AdamW 107 | - **Gradient Accumulations**: 2 108 | - **Batch:** 32 x 8 x 2 x 4 = 2048 109 | - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant 110 | 111 | ## Evaluation Results 112 | Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, 113 | 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling 114 | steps show the relative improvements of the checkpoints: 115 | 116 | ![pareto](assets/v1-variants-scores.jpg) 117 | 118 | Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores. 119 | ## Environmental Impact 120 | 121 | **Stable Diffusion v1** **Estimated Emissions** 122 | Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact. 123 | 124 | - **Hardware Type:** A100 PCIe 40GB 125 | - **Hours used:** 150000 126 | - **Cloud Provider:** AWS 127 | - **Compute Region:** US-east 128 | - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq. 129 | ## Citation 130 | @InProceedings{Rombach_2022_CVPR, 131 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, 132 | title = {High-Resolution Image Synthesis With Latent Diffusion Models}, 133 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 134 | month = {June}, 135 | year = {2022}, 136 | pages = {10684-10695} 137 | } 138 | 139 | *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).* 140 | 141 | -------------------------------------------------------------------------------- /assets/a-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-painting-of-a-fire.png -------------------------------------------------------------------------------- /assets/a-photograph-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-photograph-of-a-fire.png -------------------------------------------------------------------------------- /assets/a-shirt-with-a-fire-printed-on-it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-shirt-with-a-fire-printed-on-it.png -------------------------------------------------------------------------------- /assets/a-shirt-with-the-inscription-'fire'.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-shirt-with-the-inscription-'fire'.png -------------------------------------------------------------------------------- /assets/a-watercolor-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/a-watercolor-painting-of-a-fire.png -------------------------------------------------------------------------------- /assets/birdhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/birdhouse.png -------------------------------------------------------------------------------- /assets/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/fire.png -------------------------------------------------------------------------------- /assets/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/inpainting.png -------------------------------------------------------------------------------- /assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/modelfigure.png -------------------------------------------------------------------------------- /assets/rdm-preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/rdm-preview.jpg -------------------------------------------------------------------------------- /assets/reconstruction1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/reconstruction1.png -------------------------------------------------------------------------------- /assets/reconstruction2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/reconstruction2.png -------------------------------------------------------------------------------- /assets/results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/results.gif -------------------------------------------------------------------------------- /assets/stable-samples/img2img/mountains-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/mountains-1.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/mountains-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/mountains-2.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/mountains-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/mountains-3.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/sketch-mountains-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/sketch-mountains-input.jpg -------------------------------------------------------------------------------- /assets/stable-samples/img2img/upscaling-in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/upscaling-in.png -------------------------------------------------------------------------------- /assets/stable-samples/img2img/upscaling-out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/img2img/upscaling-out.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/000002025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/000002025.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/000002035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/000002035.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/merged-0005.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/merged-0006.png -------------------------------------------------------------------------------- /assets/stable-samples/txt2img/merged-0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/stable-samples/txt2img/merged-0007.png -------------------------------------------------------------------------------- /assets/the-earth-is-on-fire,-oil-on-canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/the-earth-is-on-fire,-oil-on-canvas.png -------------------------------------------------------------------------------- /assets/txt2img-convsample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/txt2img-convsample.png -------------------------------------------------------------------------------- /assets/txt2img-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/txt2img-preview.png -------------------------------------------------------------------------------- /assets/v1-variants-scores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/assets/v1-variants-scores.jpg -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /data/example_conditioning/superresolution/sample_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/example_conditioning/superresolution/sample_0.jpg -------------------------------------------------------------------------------- /data/example_conditioning/text_conditional/sample_0.txt: -------------------------------------------------------------------------------- 1 | A basket of cerries 2 | -------------------------------------------------------------------------------- /data/imagenet_train_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/imagenet_train_hr_indices.p -------------------------------------------------------------------------------- /data/imagenet_val_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/imagenet_val_hr_indices.p -------------------------------------------------------------------------------- /data/inpainting_examples/6458524847_2f4c361183_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/6458524847_2f4c361183_k.png -------------------------------------------------------------------------------- /data/inpainting_examples/6458524847_2f4c361183_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/6458524847_2f4c361183_k_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/8399166846_f6fb4e4b8e_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png -------------------------------------------------------------------------------- /data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png -------------------------------------------------------------------------------- /data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bench2.png -------------------------------------------------------------------------------- /data/inpainting_examples/bench2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bench2_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png -------------------------------------------------------------------------------- /data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/billow926-12-Wc-Zgx6Y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png -------------------------------------------------------------------------------- /data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png -------------------------------------------------------------------------------- /data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png -------------------------------------------------------------------------------- /docker-bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | source /venv/bin/activate 4 | update-ca-certificates --fresh 5 | export SSL_CERT_DIR=/etc/ssl/certs 6 | exec "$@" -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | services: 3 | sd: 4 | build: . 5 | ports: 6 | - "7860:7860" 7 | volumes: 8 | - ../sd-data:/data 9 | - ../sd-output:/output 10 | - sd-cache:/root/.cache 11 | environment: 12 | - APP_MAIN_FILE=txt2img_gradio.py 13 | deploy: 14 | resources: 15 | reservations: 16 | devices: 17 | - driver: nvidia 18 | count: 1 19 | capabilities: [gpu] 20 | volumes: 21 | sd-cache: 22 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.20.3 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - torch-fidelity==0.3.0 24 | - transformers==4.19.2 25 | - torchmetrics==0.6.0 26 | - kornia==0.6 27 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 28 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 29 | - -e . 30 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import kornia 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | class FrozenCLIPEmbedder(AbstractEncoder): 138 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 140 | super().__init__() 141 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 142 | self.transformer = CLIPTextModel.from_pretrained(version) 143 | self.device = device 144 | self.max_length = max_length 145 | self.freeze() 146 | 147 | def freeze(self): 148 | self.transformer = self.transformer.eval() 149 | for param in self.parameters(): 150 | param.requires_grad = False 151 | 152 | def forward(self, text): 153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 155 | tokens = batch_encoding["input_ids"].to(self.device) 156 | outputs = self.transformer(input_ids=tokens) 157 | 158 | z = outputs.last_hidden_state 159 | return z 160 | 161 | def encode(self, text): 162 | return self(text) 163 | 164 | 165 | class FrozenCLIPTextEmbedder(nn.Module): 166 | """ 167 | Uses the CLIP transformer encoder for text. 168 | """ 169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 170 | super().__init__() 171 | self.model, _ = clip.load(version, jit=False, device="cpu") 172 | self.device = device 173 | self.max_length = max_length 174 | self.n_repeat = n_repeat 175 | self.normalize = normalize 176 | 177 | def freeze(self): 178 | self.model = self.model.eval() 179 | for param in self.parameters(): 180 | param.requires_grad = False 181 | 182 | def forward(self, text): 183 | tokens = clip.tokenize(text).to(self.device) 184 | z = self.model.encode_text(tokens) 185 | if self.normalize: 186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 187 | return z 188 | 189 | def encode(self, text): 190 | z = self(text) 191 | if z.ndim==2: 192 | z = z[:, None, :] 193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 194 | return z 195 | 196 | 197 | class FrozenClipImageEmbedder(nn.Module): 198 | """ 199 | Uses the CLIP image encoder. 200 | """ 201 | def __init__( 202 | self, 203 | model, 204 | jit=False, 205 | device='cuda' if torch.cuda.is_available() else 'cpu', 206 | antialias=False, 207 | ): 208 | super().__init__() 209 | self.model, _ = clip.load(name=model, device=device, jit=jit) 210 | 211 | self.antialias = antialias 212 | 213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 215 | 216 | def preprocess(self, x): 217 | # normalize to [0,1] 218 | x = kornia.geometry.resize(x, (224, 224), 219 | interpolation='bicubic',align_corners=True, 220 | antialias=self.antialias) 221 | x = (x + 1.) / 2. 222 | # renormalize according to clip 223 | x = kornia.enhance.normalize(x, self.mean, self.std) 224 | return x 225 | 226 | def forward(self, x): 227 | # x is assumed to be in range [-1,1] 228 | return self.model.encode_image(self.preprocess(x)) 229 | 230 | 231 | if __name__ == "__main__": 232 | from ldm.util import count_params 233 | model = FrozenCLIPEmbedder() 234 | count_params(model, verbose=True) -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basujindal/stable-diffusion/54b4a91633e38f6060d29c4e0194efe8f66eeedc/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /optimizedSD/diffusers_txt2img.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import LDMTextToImagePipeline 3 | 4 | pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True) 5 | 6 | prompt = "19th Century wooden engraving of Elon musk" 7 | 8 | seed = torch.manual_seed(1024) 9 | images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"] 10 | 11 | # save images 12 | for idx, image in enumerate(images): 13 | image.save(f"image-{idx}.png") 14 | -------------------------------------------------------------------------------- /optimizedSD/img2img_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | import torch 4 | from torchvision.utils import make_grid 5 | import os, re 6 | from PIL import Image 7 | import torch 8 | import numpy as np 9 | from random import randint 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | from tqdm import tqdm, trange 13 | from itertools import islice 14 | from einops import rearrange 15 | from torchvision.utils import make_grid 16 | import time 17 | from pytorch_lightning import seed_everything 18 | from torch import autocast 19 | from einops import rearrange, repeat 20 | from contextlib import nullcontext 21 | from ldm.util import instantiate_from_config 22 | from transformers import logging 23 | import pandas as pd 24 | from optimUtils import split_weighted_subprompts, logger 25 | logging.set_verbosity_error() 26 | import mimetypes 27 | mimetypes.init() 28 | mimetypes.add_type("application/javascript", ".js") 29 | 30 | 31 | def chunk(it, size): 32 | it = iter(it) 33 | return iter(lambda: tuple(islice(it, size)), ()) 34 | 35 | 36 | def load_model_from_config(ckpt, verbose=False): 37 | print(f"Loading model from {ckpt}") 38 | pl_sd = torch.load(ckpt, map_location="cpu") 39 | if "global_step" in pl_sd: 40 | print(f"Global Step: {pl_sd['global_step']}") 41 | sd = pl_sd["state_dict"] 42 | return sd 43 | 44 | 45 | def load_img(image, h0, w0): 46 | 47 | image = image.convert("RGB") 48 | w, h = image.size 49 | print(f"loaded input image of size ({w}, {h})") 50 | if h0 is not None and w0 is not None: 51 | h, w = h0, w0 52 | 53 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 54 | 55 | print(f"New image size ({w}, {h})") 56 | image = image.resize((w, h), resample=Image.LANCZOS) 57 | image = np.array(image).astype(np.float32) / 255.0 58 | image = image[None].transpose(0, 3, 1, 2) 59 | image = torch.from_numpy(image) 60 | return 2.0 * image - 1.0 61 | 62 | config = "optimizedSD/v1-inference.yaml" 63 | ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" 64 | sd = load_model_from_config(f"{ckpt}") 65 | li, lo = [], [] 66 | for key, v_ in sd.items(): 67 | sp = key.split(".") 68 | if (sp[0]) == "model": 69 | if "input_blocks" in sp: 70 | li.append(key) 71 | elif "middle_block" in sp: 72 | li.append(key) 73 | elif "time_embed" in sp: 74 | li.append(key) 75 | else: 76 | lo.append(key) 77 | for key in li: 78 | sd["model1." + key[6:]] = sd.pop(key) 79 | for key in lo: 80 | sd["model2." + key[6:]] = sd.pop(key) 81 | 82 | config = OmegaConf.load(f"{config}") 83 | 84 | model = instantiate_from_config(config.modelUNet) 85 | _, _ = model.load_state_dict(sd, strict=False) 86 | model.eval() 87 | 88 | modelCS = instantiate_from_config(config.modelCondStage) 89 | _, _ = modelCS.load_state_dict(sd, strict=False) 90 | modelCS.eval() 91 | 92 | modelFS = instantiate_from_config(config.modelFirstStage) 93 | _, _ = modelFS.load_state_dict(sd, strict=False) 94 | modelFS.eval() 95 | del sd 96 | 97 | def generate( 98 | image, 99 | prompt, 100 | strength, 101 | ddim_steps, 102 | n_iter, 103 | batch_size, 104 | Height, 105 | Width, 106 | scale, 107 | ddim_eta, 108 | unet_bs, 109 | device, 110 | seed, 111 | outdir, 112 | img_format, 113 | turbo, 114 | full_precision, 115 | ): 116 | 117 | if seed == "": 118 | seed = randint(0, 1000000) 119 | seed = int(seed) 120 | seed_everything(seed) 121 | 122 | # Logging 123 | sampler = "ddim" 124 | logger(locals(), log_csv = "logs/img2img_gradio_logs.csv") 125 | 126 | init_image = load_img(image, Height, Width).to(device) 127 | model.unet_bs = unet_bs 128 | model.turbo = turbo 129 | model.cdevice = device 130 | modelCS.cond_stage_model.device = device 131 | 132 | if device != "cpu" and full_precision == False: 133 | model.half() 134 | modelCS.half() 135 | modelFS.half() 136 | init_image = init_image.half() 137 | 138 | tic = time.time() 139 | os.makedirs(outdir, exist_ok=True) 140 | outpath = outdir 141 | sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150] 142 | os.makedirs(sample_path, exist_ok=True) 143 | base_count = len(os.listdir(sample_path)) 144 | 145 | # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 146 | assert prompt is not None 147 | data = [batch_size * [prompt]] 148 | 149 | modelFS.to(device) 150 | 151 | init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) 152 | init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space 153 | 154 | if device != "cpu": 155 | mem = torch.cuda.memory_allocated() / 1e6 156 | modelFS.to("cpu") 157 | while torch.cuda.memory_allocated() / 1e6 >= mem: 158 | time.sleep(1) 159 | 160 | assert 0.0 <= strength <= 1.0, "can only work with strength in [0.0, 1.0]" 161 | t_enc = int(strength * ddim_steps) 162 | print(f"target t_enc is {t_enc} steps") 163 | 164 | if full_precision == False and device != "cpu": 165 | precision_scope = autocast 166 | else: 167 | precision_scope = nullcontext 168 | 169 | all_samples = [] 170 | seeds = "" 171 | with torch.no_grad(): 172 | all_samples = list() 173 | for _ in trange(n_iter, desc="Sampling"): 174 | for prompts in tqdm(data, desc="data"): 175 | with precision_scope("cuda"): 176 | modelCS.to(device) 177 | uc = None 178 | if scale != 1.0: 179 | uc = modelCS.get_learned_conditioning(batch_size * [""]) 180 | if isinstance(prompts, tuple): 181 | prompts = list(prompts) 182 | 183 | subprompts, weights = split_weighted_subprompts(prompts[0]) 184 | if len(subprompts) > 1: 185 | c = torch.zeros_like(uc) 186 | totalWeight = sum(weights) 187 | # normalize each "sub prompt" and add it 188 | for i in range(len(subprompts)): 189 | weight = weights[i] 190 | # if not skip_normalize: 191 | weight = weight / totalWeight 192 | c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) 193 | else: 194 | c = modelCS.get_learned_conditioning(prompts) 195 | 196 | if device != "cpu": 197 | mem = torch.cuda.memory_allocated() / 1e6 198 | modelCS.to("cpu") 199 | while torch.cuda.memory_allocated() / 1e6 >= mem: 200 | time.sleep(1) 201 | 202 | # encode (scaled latent) 203 | z_enc = model.stochastic_encode( 204 | init_latent, torch.tensor([t_enc] * batch_size).to(device), seed, ddim_eta, ddim_steps 205 | ) 206 | # decode it 207 | samples_ddim = model.sample( 208 | t_enc, 209 | c, 210 | z_enc, 211 | unconditional_guidance_scale=scale, 212 | unconditional_conditioning=uc, 213 | sampler = sampler 214 | ) 215 | 216 | modelFS.to(device) 217 | print("saving images") 218 | for i in range(batch_size): 219 | 220 | x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) 221 | x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 222 | all_samples.append(x_sample.to("cpu")) 223 | x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") 224 | Image.fromarray(x_sample.astype(np.uint8)).save( 225 | os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}") 226 | ) 227 | seeds += str(seed) + "," 228 | seed += 1 229 | base_count += 1 230 | 231 | if device != "cpu": 232 | mem = torch.cuda.memory_allocated() / 1e6 233 | modelFS.to("cpu") 234 | while torch.cuda.memory_allocated() / 1e6 >= mem: 235 | time.sleep(1) 236 | 237 | del samples_ddim 238 | del x_sample 239 | del x_samples_ddim 240 | print("memory_final = ", torch.cuda.memory_allocated() / 1e6) 241 | 242 | toc = time.time() 243 | 244 | time_taken = (toc - tic) / 60.0 245 | grid = torch.cat(all_samples, 0) 246 | grid = make_grid(grid, nrow=n_iter) 247 | grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() 248 | 249 | txt = ( 250 | "Samples finished in " 251 | + str(round(time_taken, 3)) 252 | + " minutes and exported to \n" 253 | + sample_path 254 | + "\nSeeds used = " 255 | + seeds[:-1] 256 | ) 257 | return Image.fromarray(grid.astype(np.uint8)), txt 258 | 259 | 260 | demo = gr.Interface( 261 | fn=generate, 262 | inputs=[ 263 | gr.Image(tool="editor", type="pil"), 264 | "text", 265 | gr.Slider(0, 1, value=0.75), 266 | gr.Slider(1, 1000, value=50), 267 | gr.Slider(1, 100, step=1), 268 | gr.Slider(1, 100, step=1), 269 | gr.Slider(64, 4096, value=512, step=64), 270 | gr.Slider(64, 4096, value=512, step=64), 271 | gr.Slider(0, 50, value=7.5, step=0.1), 272 | gr.Slider(0, 1, step=0.01), 273 | gr.Slider(1, 2, value=1, step=1), 274 | gr.Text(value="cuda"), 275 | "text", 276 | gr.Text(value="outputs/img2img-samples"), 277 | gr.Radio(["png", "jpg"], value='png'), 278 | "checkbox", 279 | "checkbox", 280 | ], 281 | outputs=["image", "text"], 282 | ) 283 | demo.launch() 284 | -------------------------------------------------------------------------------- /optimizedSD/optimUtils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | 5 | def split_weighted_subprompts(text): 6 | """ 7 | grabs all text up to the first occurrence of ':' 8 | uses the grabbed text as a sub-prompt, and takes the value following ':' as weight 9 | if ':' has no value defined, defaults to 1.0 10 | repeats until no text remaining 11 | """ 12 | remaining = len(text) 13 | prompts = [] 14 | weights = [] 15 | while remaining > 0: 16 | if ":" in text: 17 | idx = text.index(":") # first occurrence from start 18 | # grab up to index as sub-prompt 19 | prompt = text[:idx] 20 | remaining -= idx 21 | # remove from main text 22 | text = text[idx+1:] 23 | # find value for weight 24 | if " " in text: 25 | idx = text.index(" ") # first occurence 26 | else: # no space, read to end 27 | idx = len(text) 28 | if idx != 0: 29 | try: 30 | weight = float(text[:idx]) 31 | except: # couldn't treat as float 32 | print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") 33 | weight = 1.0 34 | else: # no value found 35 | weight = 1.0 36 | # remove from main text 37 | remaining -= idx 38 | text = text[idx+1:] 39 | # append the sub-prompt and its weight 40 | prompts.append(prompt) 41 | weights.append(weight) 42 | else: # no : found 43 | if len(text) > 0: # there is still text though 44 | # take remainder as weight 1 45 | prompts.append(text) 46 | weights.append(1.0) 47 | remaining = 0 48 | return prompts, weights 49 | 50 | def logger(params, log_csv): 51 | os.makedirs('logs', exist_ok=True) 52 | cols = [arg for arg, _ in params.items()] 53 | if not os.path.exists(log_csv): 54 | df = pd.DataFrame(columns=cols) 55 | df.to_csv(log_csv, index=False) 56 | 57 | df = pd.read_csv(log_csv) 58 | for arg in cols: 59 | if arg not in df.columns: 60 | df[arg] = "" 61 | df.to_csv(log_csv, index = False) 62 | 63 | li = {} 64 | cols = [col for col in df.columns] 65 | data = {arg:value for arg, value in params.items()} 66 | for col in cols: 67 | if col in data: 68 | li[col] = data[col] 69 | else: 70 | li[col] = '' 71 | 72 | df = pd.DataFrame(li,index = [0]) 73 | df.to_csv(log_csv,index=False, mode='a', header=False) -------------------------------------------------------------------------------- /optimizedSD/splitAttention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | self.att_step = att_step 161 | 162 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 163 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 164 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 165 | 166 | self.to_out = nn.Sequential( 167 | nn.Linear(inner_dim, query_dim), 168 | nn.Dropout(dropout) 169 | ) 170 | 171 | def forward(self, x, context=None, mask=None): 172 | h = self.heads 173 | 174 | q = self.to_q(x) 175 | context = default(context, x) 176 | k = self.to_k(context) 177 | v = self.to_v(context) 178 | del context, x 179 | 180 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 181 | 182 | 183 | limit = k.shape[0] 184 | att_step = self.att_step 185 | q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0)) 186 | k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0)) 187 | v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0)) 188 | 189 | q_chunks.reverse() 190 | k_chunks.reverse() 191 | v_chunks.reverse() 192 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 193 | del k, q, v 194 | for i in range (0, limit, att_step): 195 | 196 | q_buffer = q_chunks.pop() 197 | k_buffer = k_chunks.pop() 198 | v_buffer = v_chunks.pop() 199 | sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 200 | 201 | del k_buffer, q_buffer 202 | # attention, what we cannot get enough of, by chunks 203 | 204 | sim_buffer = sim_buffer.softmax(dim=-1) 205 | 206 | sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 207 | del v_buffer 208 | sim[i:i+att_step,:,:] = sim_buffer 209 | 210 | del sim_buffer 211 | sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) 212 | return self.to_out(sim) 213 | 214 | 215 | class BasicTransformerBlock(nn.Module): 216 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 217 | super().__init__() 218 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 219 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 220 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 221 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 222 | self.norm1 = nn.LayerNorm(dim) 223 | self.norm2 = nn.LayerNorm(dim) 224 | self.norm3 = nn.LayerNorm(dim) 225 | self.checkpoint = checkpoint 226 | 227 | def forward(self, x, context=None): 228 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 229 | 230 | def _forward(self, x, context=None): 231 | x = self.attn1(self.norm1(x)) + x 232 | x = self.attn2(self.norm2(x), context=context) + x 233 | x = self.ff(self.norm3(x)) + x 234 | return x 235 | 236 | 237 | class SpatialTransformer(nn.Module): 238 | """ 239 | Transformer block for image-like data. 240 | First, project the input (aka embedding) 241 | and reshape to b, t, d. 242 | Then apply standard transformer action. 243 | Finally, reshape to image 244 | """ 245 | def __init__(self, in_channels, n_heads, d_head, 246 | depth=1, dropout=0., context_dim=None): 247 | super().__init__() 248 | self.in_channels = in_channels 249 | inner_dim = n_heads * d_head 250 | self.norm = Normalize(in_channels) 251 | 252 | self.proj_in = nn.Conv2d(in_channels, 253 | inner_dim, 254 | kernel_size=1, 255 | stride=1, 256 | padding=0) 257 | 258 | self.transformer_blocks = nn.ModuleList( 259 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 260 | for d in range(depth)] 261 | ) 262 | 263 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 264 | in_channels, 265 | kernel_size=1, 266 | stride=1, 267 | padding=0)) 268 | 269 | def forward(self, x, context=None): 270 | # note: if no context is given, cross-attention defaults to self-attention 271 | b, c, h, w = x.shape 272 | x_in = x 273 | x = self.norm(x) 274 | x = self.proj_in(x) 275 | x = rearrange(x, 'b c h w -> b (h w) c') 276 | for block in self.transformer_blocks: 277 | x = block(x, context=context) 278 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 279 | x = self.proj_out(x) 280 | return x + x_in 281 | -------------------------------------------------------------------------------- /optimizedSD/txt2img_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | import torch 4 | from torchvision.utils import make_grid 5 | from einops import rearrange 6 | import os, re 7 | from PIL import Image 8 | import torch 9 | import pandas as pd 10 | import numpy as np 11 | from random import randint 12 | from omegaconf import OmegaConf 13 | from PIL import Image 14 | from tqdm import tqdm, trange 15 | from itertools import islice 16 | from einops import rearrange 17 | from torchvision.utils import make_grid 18 | import time 19 | from pytorch_lightning import seed_everything 20 | from torch import autocast 21 | from contextlib import nullcontext 22 | from ldm.util import instantiate_from_config 23 | from optimUtils import split_weighted_subprompts, logger 24 | from transformers import logging 25 | logging.set_verbosity_error() 26 | import mimetypes 27 | mimetypes.init() 28 | mimetypes.add_type("application/javascript", ".js") 29 | 30 | 31 | def chunk(it, size): 32 | it = iter(it) 33 | return iter(lambda: tuple(islice(it, size)), ()) 34 | 35 | 36 | def load_model_from_config(ckpt, verbose=False): 37 | print(f"Loading model from {ckpt}") 38 | pl_sd = torch.load(ckpt, map_location="cpu") 39 | if "global_step" in pl_sd: 40 | print(f"Global Step: {pl_sd['global_step']}") 41 | sd = pl_sd["state_dict"] 42 | return sd 43 | 44 | config = "optimizedSD/v1-inference.yaml" 45 | ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" 46 | sd = load_model_from_config(f"{ckpt}") 47 | li, lo = [], [] 48 | for key, v_ in sd.items(): 49 | sp = key.split(".") 50 | if (sp[0]) == "model": 51 | if "input_blocks" in sp: 52 | li.append(key) 53 | elif "middle_block" in sp: 54 | li.append(key) 55 | elif "time_embed" in sp: 56 | li.append(key) 57 | else: 58 | lo.append(key) 59 | for key in li: 60 | sd["model1." + key[6:]] = sd.pop(key) 61 | for key in lo: 62 | sd["model2." + key[6:]] = sd.pop(key) 63 | 64 | config = OmegaConf.load(f"{config}") 65 | 66 | model = instantiate_from_config(config.modelUNet) 67 | _, _ = model.load_state_dict(sd, strict=False) 68 | model.eval() 69 | 70 | modelCS = instantiate_from_config(config.modelCondStage) 71 | _, _ = modelCS.load_state_dict(sd, strict=False) 72 | modelCS.eval() 73 | 74 | modelFS = instantiate_from_config(config.modelFirstStage) 75 | _, _ = modelFS.load_state_dict(sd, strict=False) 76 | modelFS.eval() 77 | del sd 78 | 79 | 80 | def generate( 81 | prompt, 82 | ddim_steps, 83 | n_iter, 84 | batch_size, 85 | Height, 86 | Width, 87 | scale, 88 | ddim_eta, 89 | unet_bs, 90 | device, 91 | seed, 92 | outdir, 93 | img_format, 94 | turbo, 95 | full_precision, 96 | sampler, 97 | ): 98 | 99 | C = 4 100 | f = 8 101 | start_code = None 102 | model.unet_bs = unet_bs 103 | model.turbo = turbo 104 | model.cdevice = device 105 | modelCS.cond_stage_model.device = device 106 | 107 | if seed == "": 108 | seed = randint(0, 1000000) 109 | seed = int(seed) 110 | seed_everything(seed) 111 | # Logging 112 | logger(locals(), "logs/txt2img_gradio_logs.csv") 113 | 114 | if device != "cpu" and full_precision == False: 115 | model.half() 116 | modelFS.half() 117 | modelCS.half() 118 | 119 | tic = time.time() 120 | os.makedirs(outdir, exist_ok=True) 121 | outpath = outdir 122 | sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150] 123 | os.makedirs(sample_path, exist_ok=True) 124 | base_count = len(os.listdir(sample_path)) 125 | 126 | # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 127 | assert prompt is not None 128 | data = [batch_size * [prompt]] 129 | 130 | if full_precision == False and device != "cpu": 131 | precision_scope = autocast 132 | else: 133 | precision_scope = nullcontext 134 | 135 | all_samples = [] 136 | seeds = "" 137 | with torch.no_grad(): 138 | 139 | all_samples = list() 140 | for _ in trange(n_iter, desc="Sampling"): 141 | for prompts in tqdm(data, desc="data"): 142 | with precision_scope("cuda"): 143 | modelCS.to(device) 144 | uc = None 145 | if scale != 1.0: 146 | uc = modelCS.get_learned_conditioning(batch_size * [""]) 147 | if isinstance(prompts, tuple): 148 | prompts = list(prompts) 149 | 150 | subprompts, weights = split_weighted_subprompts(prompts[0]) 151 | if len(subprompts) > 1: 152 | c = torch.zeros_like(uc) 153 | totalWeight = sum(weights) 154 | # normalize each "sub prompt" and add it 155 | for i in range(len(subprompts)): 156 | weight = weights[i] 157 | # if not skip_normalize: 158 | weight = weight / totalWeight 159 | c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) 160 | else: 161 | c = modelCS.get_learned_conditioning(prompts) 162 | 163 | shape = [batch_size, C, Height // f, Width // f] 164 | 165 | if device != "cpu": 166 | mem = torch.cuda.memory_allocated() / 1e6 167 | modelCS.to("cpu") 168 | while torch.cuda.memory_allocated() / 1e6 >= mem: 169 | time.sleep(1) 170 | 171 | samples_ddim = model.sample( 172 | S=ddim_steps, 173 | conditioning=c, 174 | seed=seed, 175 | shape=shape, 176 | verbose=False, 177 | unconditional_guidance_scale=scale, 178 | unconditional_conditioning=uc, 179 | eta=ddim_eta, 180 | x_T=start_code, 181 | sampler = sampler, 182 | ) 183 | 184 | modelFS.to(device) 185 | print("saving images") 186 | for i in range(batch_size): 187 | 188 | x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) 189 | x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 190 | all_samples.append(x_sample.to("cpu")) 191 | x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") 192 | Image.fromarray(x_sample.astype(np.uint8)).save( 193 | os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}") 194 | ) 195 | seeds += str(seed) + "," 196 | seed += 1 197 | base_count += 1 198 | 199 | if device != "cpu": 200 | mem = torch.cuda.memory_allocated() / 1e6 201 | modelFS.to("cpu") 202 | while torch.cuda.memory_allocated() / 1e6 >= mem: 203 | time.sleep(1) 204 | 205 | del samples_ddim 206 | del x_sample 207 | del x_samples_ddim 208 | print("memory_final = ", torch.cuda.memory_allocated() / 1e6) 209 | 210 | toc = time.time() 211 | 212 | time_taken = (toc - tic) / 60.0 213 | grid = torch.cat(all_samples, 0) 214 | grid = make_grid(grid, nrow=n_iter) 215 | grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() 216 | 217 | txt = ( 218 | "Samples finished in " 219 | + str(round(time_taken, 3)) 220 | + " minutes and exported to " 221 | + sample_path 222 | + "\nSeeds used = " 223 | + seeds[:-1] 224 | ) 225 | return Image.fromarray(grid.astype(np.uint8)), txt 226 | 227 | 228 | demo = gr.Interface( 229 | fn=generate, 230 | inputs=[ 231 | "text", 232 | gr.Slider(1, 1000, value=50), 233 | gr.Slider(1, 100, step=1), 234 | gr.Slider(1, 100, step=1), 235 | gr.Slider(64, 4096, value=512, step=64), 236 | gr.Slider(64, 4096, value=512, step=64), 237 | gr.Slider(0, 50, value=7.5, step=0.1), 238 | gr.Slider(0, 1, step=0.01), 239 | gr.Slider(1, 2, value=1, step=1), 240 | gr.Text(value="cuda"), 241 | "text", 242 | gr.Text(value="outputs/txt2img-samples"), 243 | gr.Radio(["png", "jpg"], value='png'), 244 | "checkbox", 245 | "checkbox", 246 | gr.Radio(["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"], value="plms"), 247 | ], 248 | outputs=["image", "text"], 249 | ) 250 | demo.launch() 251 | -------------------------------------------------------------------------------- /optimizedSD/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | modelUNet: 2 | base_learning_rate: 1.0e-04 3 | target: optimizedSD.ddpm.UNet 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unetConfigEncode: 21 | target: optimizedSD.openaimodelSplit.UNetModelEncode 22 | params: 23 | image_size: 32 # unused 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: [4, 2, 1] 28 | num_res_blocks: 2 29 | channel_mult: [1, 2, 4, 4] 30 | num_heads: 8 31 | use_spatial_transformer: True 32 | transformer_depth: 1 33 | context_dim: 768 34 | use_checkpoint: True 35 | legacy: False 36 | 37 | unetConfigDecode: 38 | target: optimizedSD.openaimodelSplit.UNetModelDecode 39 | params: 40 | image_size: 32 # unused 41 | in_channels: 4 42 | out_channels: 4 43 | model_channels: 320 44 | attention_resolutions: [4, 2, 1] 45 | num_res_blocks: 2 46 | channel_mult: [1, 2, 4, 4] 47 | num_heads: 8 48 | use_spatial_transformer: True 49 | transformer_depth: 1 50 | context_dim: 768 51 | use_checkpoint: True 52 | legacy: False 53 | 54 | modelFirstStage: 55 | target: optimizedSD.ddpm.FirstStage 56 | params: 57 | linear_start: 0.00085 58 | linear_end: 0.0120 59 | num_timesteps_cond: 1 60 | log_every_t: 200 61 | timesteps: 1000 62 | first_stage_key: "jpg" 63 | cond_stage_key: "txt" 64 | image_size: 64 65 | channels: 4 66 | cond_stage_trainable: false # Note: different from the one we trained before 67 | conditioning_key: crossattn 68 | monitor: val/loss_simple_ema 69 | scale_factor: 0.18215 70 | use_ema: False 71 | first_stage_config: 72 | target: ldm.models.autoencoder.AutoencoderKL 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | double_z: true 78 | z_channels: 4 79 | resolution: 256 80 | in_channels: 3 81 | out_ch: 3 82 | ch: 128 83 | ch_mult: 84 | - 1 85 | - 2 86 | - 4 87 | - 4 88 | num_res_blocks: 2 89 | attn_resolutions: [] 90 | dropout: 0.0 91 | lossconfig: 92 | target: torch.nn.Identity 93 | 94 | modelCondStage: 95 | target: optimizedSD.ddpm.CondStage 96 | params: 97 | linear_start: 0.00085 98 | linear_end: 0.0120 99 | num_timesteps_cond: 1 100 | log_every_t: 200 101 | timesteps: 1000 102 | first_stage_key: "jpg" 103 | cond_stage_key: "txt" 104 | image_size: 64 105 | channels: 4 106 | cond_stage_trainable: false # Note: different from the one we trained before 107 | conditioning_key: crossattn 108 | monitor: val/loss_simple_ema 109 | scale_factor: 0.18215 110 | use_ema: False 111 | cond_stage_config: 112 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 113 | params: 114 | device: cpu 115 | -------------------------------------------------------------------------------- /scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /scripts/img2img.py: -------------------------------------------------------------------------------- 1 | """make variations of input image""" 2 | 3 | import argparse, os, sys, glob 4 | import PIL 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from tqdm import tqdm, trange 10 | from itertools import islice 11 | from einops import rearrange, repeat 12 | from torchvision.utils import make_grid 13 | from torch import autocast 14 | from contextlib import nullcontext 15 | import time 16 | from pytorch_lightning import seed_everything 17 | 18 | from ldm.util import instantiate_from_config 19 | from ldm.models.diffusion.ddim import DDIMSampler 20 | from ldm.models.diffusion.plms import PLMSSampler 21 | 22 | 23 | def chunk(it, size): 24 | it = iter(it) 25 | return iter(lambda: tuple(islice(it, size)), ()) 26 | 27 | 28 | def load_model_from_config(config, ckpt, verbose=False): 29 | print(f"Loading model from {ckpt}") 30 | pl_sd = torch.load(ckpt, map_location="cpu") 31 | if "global_step" in pl_sd: 32 | print(f"Global Step: {pl_sd['global_step']}") 33 | sd = pl_sd["state_dict"] 34 | model = instantiate_from_config(config.model) 35 | m, u = model.load_state_dict(sd, strict=False) 36 | if len(m) > 0 and verbose: 37 | print("missing keys:") 38 | print(m) 39 | if len(u) > 0 and verbose: 40 | print("unexpected keys:") 41 | print(u) 42 | 43 | model.cuda() 44 | model.eval() 45 | return model 46 | 47 | 48 | def load_img(path): 49 | image = Image.open(path).convert("RGB") 50 | w, h = image.size 51 | print(f"loaded input image of size ({w}, {h}) from {path}") 52 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 53 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 54 | image = np.array(image).astype(np.float32) / 255.0 55 | image = image[None].transpose(0, 3, 1, 2) 56 | image = torch.from_numpy(image) 57 | return 2.*image - 1. 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser() 62 | 63 | parser.add_argument( 64 | "--prompt", 65 | type=str, 66 | nargs="?", 67 | default="a painting of a virus monster playing guitar", 68 | help="the prompt to render" 69 | ) 70 | 71 | parser.add_argument( 72 | "--init-img", 73 | type=str, 74 | nargs="?", 75 | help="path to the input image" 76 | ) 77 | 78 | parser.add_argument( 79 | "--outdir", 80 | type=str, 81 | nargs="?", 82 | help="dir to write results to", 83 | default="outputs/img2img-samples" 84 | ) 85 | 86 | parser.add_argument( 87 | "--skip_grid", 88 | action='store_true', 89 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", 90 | ) 91 | 92 | parser.add_argument( 93 | "--skip_save", 94 | action='store_true', 95 | help="do not save indiviual samples. For speed measurements.", 96 | ) 97 | 98 | parser.add_argument( 99 | "--ddim_steps", 100 | type=int, 101 | default=50, 102 | help="number of ddim sampling steps", 103 | ) 104 | 105 | parser.add_argument( 106 | "--plms", 107 | action='store_true', 108 | help="use plms sampling", 109 | ) 110 | parser.add_argument( 111 | "--fixed_code", 112 | action='store_true', 113 | help="if enabled, uses the same starting code across all samples ", 114 | ) 115 | 116 | parser.add_argument( 117 | "--ddim_eta", 118 | type=float, 119 | default=0.0, 120 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 121 | ) 122 | parser.add_argument( 123 | "--n_iter", 124 | type=int, 125 | default=1, 126 | help="sample this often", 127 | ) 128 | parser.add_argument( 129 | "--C", 130 | type=int, 131 | default=4, 132 | help="latent channels", 133 | ) 134 | parser.add_argument( 135 | "--f", 136 | type=int, 137 | default=8, 138 | help="downsampling factor, most often 8 or 16", 139 | ) 140 | parser.add_argument( 141 | "--n_samples", 142 | type=int, 143 | default=2, 144 | help="how many samples to produce for each given prompt. A.k.a batch size", 145 | ) 146 | parser.add_argument( 147 | "--n_rows", 148 | type=int, 149 | default=0, 150 | help="rows in the grid (default: n_samples)", 151 | ) 152 | parser.add_argument( 153 | "--scale", 154 | type=float, 155 | default=5.0, 156 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 157 | ) 158 | 159 | parser.add_argument( 160 | "--strength", 161 | type=float, 162 | default=0.75, 163 | help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", 164 | ) 165 | parser.add_argument( 166 | "--from-file", 167 | type=str, 168 | help="if specified, load prompts from this file", 169 | ) 170 | parser.add_argument( 171 | "--config", 172 | type=str, 173 | default="configs/stable-diffusion/v1-inference.yaml", 174 | help="path to config which constructs model", 175 | ) 176 | parser.add_argument( 177 | "--ckpt", 178 | type=str, 179 | default="models/ldm/stable-diffusion-v1/model.ckpt", 180 | help="path to checkpoint of model", 181 | ) 182 | parser.add_argument( 183 | "--seed", 184 | type=int, 185 | default=42, 186 | help="the seed (for reproducible sampling)", 187 | ) 188 | parser.add_argument( 189 | "--precision", 190 | type=str, 191 | help="evaluate at this precision", 192 | choices=["full", "autocast"], 193 | default="autocast" 194 | ) 195 | 196 | opt = parser.parse_args() 197 | seed_everything(opt.seed) 198 | 199 | config = OmegaConf.load(f"{opt.config}") 200 | model = load_model_from_config(config, f"{opt.ckpt}") 201 | 202 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 203 | model = model.to(device) 204 | 205 | if opt.plms: 206 | raise NotImplementedError("PLMS sampler not (yet) supported") 207 | sampler = PLMSSampler(model) 208 | else: 209 | sampler = DDIMSampler(model) 210 | 211 | os.makedirs(opt.outdir, exist_ok=True) 212 | outpath = opt.outdir 213 | 214 | batch_size = opt.n_samples 215 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 216 | if not opt.from_file: 217 | prompt = opt.prompt 218 | assert prompt is not None 219 | data = [batch_size * [prompt]] 220 | 221 | else: 222 | print(f"reading prompts from {opt.from_file}") 223 | with open(opt.from_file, "r") as f: 224 | data = f.read().splitlines() 225 | data = list(chunk(data, batch_size)) 226 | 227 | sample_path = os.path.join(outpath, "samples") 228 | os.makedirs(sample_path, exist_ok=True) 229 | base_count = len(os.listdir(sample_path)) 230 | grid_count = len(os.listdir(outpath)) - 1 231 | 232 | assert os.path.isfile(opt.init_img) 233 | init_image = load_img(opt.init_img).to(device) 234 | init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) 235 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space 236 | 237 | sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) 238 | 239 | assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' 240 | t_enc = int(opt.strength * opt.ddim_steps) 241 | print(f"target t_enc is {t_enc} steps") 242 | 243 | precision_scope = autocast if opt.precision == "autocast" else nullcontext 244 | with torch.no_grad(): 245 | with precision_scope("cuda"): 246 | with model.ema_scope(): 247 | tic = time.time() 248 | all_samples = list() 249 | for n in trange(opt.n_iter, desc="Sampling"): 250 | for prompts in tqdm(data, desc="data"): 251 | uc = None 252 | if opt.scale != 1.0: 253 | uc = model.get_learned_conditioning(batch_size * [""]) 254 | if isinstance(prompts, tuple): 255 | prompts = list(prompts) 256 | c = model.get_learned_conditioning(prompts) 257 | 258 | # encode (scaled latent) 259 | z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) 260 | # decode it 261 | samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, 262 | unconditional_conditioning=uc,) 263 | 264 | x_samples = model.decode_first_stage(samples) 265 | x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) 266 | 267 | if not opt.skip_save: 268 | for x_sample in x_samples: 269 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 270 | Image.fromarray(x_sample.astype(np.uint8)).save( 271 | os.path.join(sample_path, f"{base_count:05}.png")) 272 | base_count += 1 273 | all_samples.append(x_samples) 274 | 275 | if not opt.skip_grid: 276 | # additionally, save as grid 277 | grid = torch.stack(all_samples, 0) 278 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 279 | grid = make_grid(grid, nrow=n_rows) 280 | 281 | # to image 282 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 283 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) 284 | grid_count += 1 285 | 286 | toc = time.time() 287 | 288 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n" 289 | f" \nEnjoy.") 290 | 291 | 292 | if __name__ == "__main__": 293 | main() 294 | -------------------------------------------------------------------------------- /scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /scripts/train_searcher.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scann 4 | import argparse 5 | import glob 6 | from multiprocessing import cpu_count 7 | from tqdm import tqdm 8 | 9 | from ldm.util import parallel_data_prefetch 10 | 11 | 12 | def search_bruteforce(searcher): 13 | return searcher.score_brute_force().build() 14 | 15 | 16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, 17 | partioning_trainsize, num_leaves, num_leaves_to_search): 18 | return searcher.tree(num_leaves=num_leaves, 19 | num_leaves_to_search=num_leaves_to_search, 20 | training_sample_size=partioning_trainsize). \ 21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() 22 | 23 | 24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): 25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( 26 | reorder_k).build() 27 | 28 | def load_datapool(dpath): 29 | 30 | 31 | def load_single_file(saved_embeddings): 32 | compressed = np.load(saved_embeddings) 33 | database = {key: compressed[key] for key in compressed.files} 34 | return database 35 | 36 | def load_multi_files(data_archive): 37 | database = {key: [] for key in data_archive[0].files} 38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): 39 | for key in d.files: 40 | database[key].append(d[key]) 41 | 42 | return database 43 | 44 | print(f'Load saved patch embedding from "{dpath}"') 45 | file_content = glob.glob(os.path.join(dpath, '*.npz')) 46 | 47 | if len(file_content) == 1: 48 | data_pool = load_single_file(file_content[0]) 49 | elif len(file_content) > 1: 50 | data = [np.load(f) for f in file_content] 51 | prefetched_data = parallel_data_prefetch(load_multi_files, data, 52 | n_proc=min(len(data), cpu_count()), target_data_type='dict') 53 | 54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} 55 | else: 56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') 57 | 58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') 59 | return data_pool 60 | 61 | 62 | def train_searcher(opt, 63 | metric='dot_product', 64 | partioning_trainsize=None, 65 | reorder_k=None, 66 | # todo tune 67 | aiq_thld=0.2, 68 | dims_per_block=2, 69 | num_leaves=None, 70 | num_leaves_to_search=None,): 71 | 72 | data_pool = load_datapool(opt.database) 73 | k = opt.knn 74 | 75 | if not reorder_k: 76 | reorder_k = 2 * k 77 | 78 | # normalize 79 | # embeddings = 80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) 81 | pool_size = data_pool['embedding'].shape[0] 82 | 83 | print(*(['#'] * 100)) 84 | print('Initializing scaNN searcher with the following values:') 85 | print(f'k: {k}') 86 | print(f'metric: {metric}') 87 | print(f'reorder_k: {reorder_k}') 88 | print(f'anisotropic_quantization_threshold: {aiq_thld}') 89 | print(f'dims_per_block: {dims_per_block}') 90 | print(*(['#'] * 100)) 91 | print('Start training searcher....') 92 | print(f'N samples in pool is {pool_size}') 93 | 94 | # this reflects the recommended design choices proposed at 95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md 96 | if pool_size < 2e4: 97 | print('Using brute force search.') 98 | searcher = search_bruteforce(searcher) 99 | elif 2e4 <= pool_size and pool_size < 1e5: 100 | print('Using asymmetric hashing search and reordering.') 101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 102 | else: 103 | print('Using using partioning, asymmetric hashing search and reordering.') 104 | 105 | if not partioning_trainsize: 106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10 107 | if not num_leaves: 108 | num_leaves = int(np.sqrt(pool_size)) 109 | 110 | if not num_leaves_to_search: 111 | num_leaves_to_search = max(num_leaves // 20, 1) 112 | 113 | print('Partitioning params:') 114 | print(f'num_leaves: {num_leaves}') 115 | print(f'num_leaves_to_search: {num_leaves_to_search}') 116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, 118 | partioning_trainsize, num_leaves, num_leaves_to_search) 119 | 120 | print('Finish training searcher') 121 | searcher_savedir = opt.target_path 122 | os.makedirs(searcher_savedir, exist_ok=True) 123 | searcher.serialize(searcher_savedir) 124 | print(f'Saved trained searcher under "{searcher_savedir}"') 125 | 126 | if __name__ == '__main__': 127 | sys.path.append(os.getcwd()) 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--database', 130 | '-d', 131 | default='data/rdm/retrieval_databases/openimages', 132 | type=str, 133 | help='path to folder containing the clip feature of the database') 134 | parser.add_argument('--target_path', 135 | '-t', 136 | default='data/rdm/searchers/openimages', 137 | type=str, 138 | help='path to the target folder where the searcher shall be stored.') 139 | parser.add_argument('--knn', 140 | '-k', 141 | default=20, 142 | type=int, 143 | help='number of nearest neighbors, for which the searcher shall be optimized') 144 | 145 | opt, _ = parser.parse_known_args() 146 | 147 | train_searcher(opt,) -------------------------------------------------------------------------------- /scripts/txt2img.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | from tqdm import tqdm, trange 7 | from itertools import islice 8 | from einops import rearrange 9 | from torchvision.utils import make_grid 10 | import time 11 | from pytorch_lightning import seed_everything 12 | from torch import autocast 13 | from contextlib import contextmanager, nullcontext 14 | 15 | from ldm.util import instantiate_from_config 16 | from ldm.models.diffusion.ddim import DDIMSampler 17 | from ldm.models.diffusion.plms import PLMSSampler 18 | 19 | 20 | def chunk(it, size): 21 | it = iter(it) 22 | return iter(lambda: tuple(islice(it, size)), ()) 23 | 24 | 25 | def load_model_from_config(config, ckpt, verbose=False): 26 | print(f"Loading model from {ckpt}") 27 | pl_sd = torch.load(ckpt, map_location="cpu") 28 | if "global_step" in pl_sd: 29 | print(f"Global Step: {pl_sd['global_step']}") 30 | sd = pl_sd["state_dict"] 31 | model = instantiate_from_config(config.model) 32 | m, u = model.load_state_dict(sd, strict=False) 33 | if len(m) > 0 and verbose: 34 | print("missing keys:") 35 | print(m) 36 | if len(u) > 0 and verbose: 37 | print("unexpected keys:") 38 | print(u) 39 | 40 | model.cuda() 41 | model.eval() 42 | return model 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument( 49 | "--prompt", 50 | type=str, 51 | nargs="?", 52 | default="a painting of a virus monster playing guitar", 53 | help="the prompt to render" 54 | ) 55 | parser.add_argument( 56 | "--outdir", 57 | type=str, 58 | nargs="?", 59 | help="dir to write results to", 60 | default="outputs/txt2img-samples" 61 | ) 62 | parser.add_argument( 63 | "--skip_grid", 64 | action='store_true', 65 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", 66 | ) 67 | parser.add_argument( 68 | "--skip_save", 69 | action='store_true', 70 | help="do not save individual samples. For speed measurements.", 71 | ) 72 | parser.add_argument( 73 | "--ddim_steps", 74 | type=int, 75 | default=50, 76 | help="number of ddim sampling steps", 77 | ) 78 | parser.add_argument( 79 | "--plms", 80 | action='store_true', 81 | help="use plms sampling", 82 | ) 83 | parser.add_argument( 84 | "--laion400m", 85 | action='store_true', 86 | help="uses the LAION400M model", 87 | ) 88 | parser.add_argument( 89 | "--fixed_code", 90 | action='store_true', 91 | help="if enabled, uses the same starting code across samples ", 92 | ) 93 | parser.add_argument( 94 | "--ddim_eta", 95 | type=float, 96 | default=0.0, 97 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 98 | ) 99 | parser.add_argument( 100 | "--n_iter", 101 | type=int, 102 | default=2, 103 | help="sample this often", 104 | ) 105 | parser.add_argument( 106 | "--H", 107 | type=int, 108 | default=512, 109 | help="image height, in pixel space", 110 | ) 111 | parser.add_argument( 112 | "--W", 113 | type=int, 114 | default=512, 115 | help="image width, in pixel space", 116 | ) 117 | parser.add_argument( 118 | "--C", 119 | type=int, 120 | default=4, 121 | help="latent channels", 122 | ) 123 | parser.add_argument( 124 | "--f", 125 | type=int, 126 | default=8, 127 | help="downsampling factor", 128 | ) 129 | parser.add_argument( 130 | "--n_samples", 131 | type=int, 132 | default=3, 133 | help="how many samples to produce for each given prompt. A.k.a. batch size", 134 | ) 135 | parser.add_argument( 136 | "--n_rows", 137 | type=int, 138 | default=0, 139 | help="rows in the grid (default: n_samples)", 140 | ) 141 | parser.add_argument( 142 | "--scale", 143 | type=float, 144 | default=7.5, 145 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 146 | ) 147 | parser.add_argument( 148 | "--from-file", 149 | type=str, 150 | help="if specified, load prompts from this file", 151 | ) 152 | parser.add_argument( 153 | "--config", 154 | type=str, 155 | default="configs/stable-diffusion/v1-inference.yaml", 156 | help="path to config which constructs model", 157 | ) 158 | parser.add_argument( 159 | "--ckpt", 160 | type=str, 161 | default="models/ldm/stable-diffusion-v1/model.ckpt", 162 | help="path to checkpoint of model", 163 | ) 164 | parser.add_argument( 165 | "--seed", 166 | type=int, 167 | default=42, 168 | help="the seed (for reproducible sampling)", 169 | ) 170 | parser.add_argument( 171 | "--precision", 172 | type=str, 173 | help="evaluate at this precision", 174 | choices=["full", "autocast"], 175 | default="autocast" 176 | ) 177 | opt = parser.parse_args() 178 | 179 | if opt.laion400m: 180 | print("Falling back to LAION 400M model...") 181 | opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" 182 | opt.ckpt = "models/ldm/text2img-large/model.ckpt" 183 | opt.outdir = "outputs/txt2img-samples-laion400m" 184 | 185 | seed_everything(opt.seed) 186 | 187 | config = OmegaConf.load(f"{opt.config}") 188 | model = load_model_from_config(config, f"{opt.ckpt}") 189 | 190 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 191 | model = model.to(device) 192 | 193 | if opt.plms: 194 | sampler = PLMSSampler(model) 195 | else: 196 | sampler = DDIMSampler(model) 197 | 198 | os.makedirs(opt.outdir, exist_ok=True) 199 | outpath = opt.outdir 200 | 201 | batch_size = opt.n_samples 202 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 203 | if not opt.from_file: 204 | prompt = opt.prompt 205 | assert prompt is not None 206 | data = [batch_size * [prompt]] 207 | 208 | else: 209 | print(f"reading prompts from {opt.from_file}") 210 | with open(opt.from_file, "r") as f: 211 | data = f.read().splitlines() 212 | data = list(chunk(data, batch_size)) 213 | 214 | sample_path = os.path.join(outpath, "samples") 215 | os.makedirs(sample_path, exist_ok=True) 216 | base_count = len(os.listdir(sample_path)) 217 | grid_count = len(os.listdir(outpath)) - 1 218 | 219 | start_code = None 220 | if opt.fixed_code: 221 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) 222 | 223 | precision_scope = autocast if opt.precision=="autocast" else nullcontext 224 | with torch.no_grad(): 225 | with precision_scope("cuda"): 226 | with model.ema_scope(): 227 | tic = time.time() 228 | all_samples = list() 229 | for n in trange(opt.n_iter, desc="Sampling"): 230 | for prompts in tqdm(data, desc="data"): 231 | uc = None 232 | if opt.scale != 1.0: 233 | uc = model.get_learned_conditioning(batch_size * [""]) 234 | if isinstance(prompts, tuple): 235 | prompts = list(prompts) 236 | c = model.get_learned_conditioning(prompts) 237 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f] 238 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 239 | conditioning=c, 240 | batch_size=opt.n_samples, 241 | shape=shape, 242 | verbose=False, 243 | unconditional_guidance_scale=opt.scale, 244 | unconditional_conditioning=uc, 245 | eta=opt.ddim_eta, 246 | x_T=start_code) 247 | 248 | x_samples_ddim = model.decode_first_stage(samples_ddim) 249 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 250 | 251 | if not opt.skip_save: 252 | for x_sample in x_samples_ddim: 253 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 254 | Image.fromarray(x_sample.astype(np.uint8)).save( 255 | os.path.join(sample_path, f"{base_count:05}.png")) 256 | base_count += 1 257 | 258 | if not opt.skip_grid: 259 | all_samples.append(x_samples_ddim) 260 | 261 | if not opt.skip_grid: 262 | # additionally, save as grid 263 | grid = torch.stack(all_samples, 0) 264 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 265 | grid = make_grid(grid, nrow=n_rows) 266 | 267 | # to image 268 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 269 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) 270 | grid_count += 1 271 | 272 | toc = time.time() 273 | 274 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n" 275 | f" \nEnjoy.") 276 | 277 | 278 | if __name__ == "__main__": 279 | main() 280 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------