├── LICENSE ├── README.md ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ ├── txt2img-1p4B-eval.yaml │ ├── txt2img-1p4B-eval_with_tokens.yaml │ ├── txt2img-1p4B-finetune.yaml │ └── txt2img-1p4B-finetune_style.yaml └── stable-diffusion │ ├── v1-finetune.yaml │ ├── v1-finetune_unfrozen.yaml │ └── v1-inference.yaml ├── environment.yaml ├── evaluation ├── __pycache__ │ ├── clip_eval.cpython-36.pyc │ └── clip_eval.cpython-38.pyc └── clip_eval.py ├── img ├── samples.jpg ├── style.jpg └── teaser.jpg ├── ldm ├── __pycache__ │ ├── util.cpython-36.pyc │ └── util.cpython-38.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base.cpython-36.pyc │ │ ├── base.cpython-38.pyc │ │ ├── personalized.cpython-36.pyc │ │ ├── personalized.cpython-38.pyc │ │ ├── personalized_compose.cpython-38.pyc │ │ ├── personalized_detailed_text.cpython-36.pyc │ │ ├── personalized_style.cpython-36.pyc │ │ └── personalized_style.cpython-38.pyc │ ├── base.py │ ├── imagenet.py │ ├── lsun.py │ ├── personalized.py │ └── personalized_style.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ ├── autoencoder.cpython-36.pyc │ │ └── autoencoder.cpython-38.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ddim.cpython-36.pyc │ │ ├── ddim.cpython-38.pyc │ │ ├── ddim_inversion.cpython-38.pyc │ │ ├── ddpm.cpython-36.pyc │ │ ├── ddpm.cpython-38.pyc │ │ ├── ddpm_pti.cpython-38.pyc │ │ ├── plms.cpython-36.pyc │ │ └── plms.cpython-38.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-36.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── ema.cpython-36.pyc │ │ ├── ema.cpython-38.pyc │ │ ├── embedding_manager.cpython-36.pyc │ │ ├── embedding_manager.cpython-38.pyc │ │ ├── x_transformer.cpython-36.pyc │ │ └── x_transformer.cpython-38.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-36.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-36.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ ├── util.cpython-36.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── distributions.cpython-36.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── embedding_manager.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── modules.cpython-36.pyc │ │ │ └── modules.cpython-38.pyc │ │ ├── modules.py │ │ └── modules_bak.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── merge_embeddings.py ├── models ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml └── ldm │ ├── bsr_sr │ └── config.yaml │ ├── celeba256 │ └── config.yaml │ ├── cin256 │ └── config.yaml │ ├── ffhq256 │ └── config.yaml │ ├── inpainting_big │ └── config.yaml │ ├── layout2img-openimages256 │ └── config.yaml │ ├── lsun_beds256 │ └── config.yaml │ ├── lsun_churches256 │ └── config.yaml │ ├── semantic_synthesis256 │ └── config.yaml │ ├── semantic_synthesis512 │ └── config.yaml │ └── text2img256 │ └── config.yaml ├── scripts ├── download_first_stages.sh ├── download_models.sh ├── evaluate_model.py ├── inpaint.py ├── latent_imagenet_diffusion.ipynb ├── sample_diffusion.py ├── stable_txt2img.py └── txt2img.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2208.01618-b31b1b.svg)](https://arxiv.org/abs/2208.01618) 4 | 5 | [[Project Website](https://textual-inversion.github.io/)] 6 | 7 | > **An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion**
8 | > Rinon Gal1,2, Yuval Alaluf1, Yuval Atzmon2, Or Patashnik1, Amit H. Bermano1, Gal Chechik2, Daniel Cohen-Or1
9 | > 1Tel Aviv University, 2NVIDIA 10 | 11 | >**Abstract**:
12 | > Text-to-image models offer unprecedented freedom to guide creation through natural language. 13 | Yet, it is unclear how such freedom can be exercised to generate images of specific unique concepts, modify their appearance, or compose them in new roles and novel scenes. 14 | In other words, we ask: how can we use language-guided models to turn our cat into a painting, or imagine a new product based on our favorite toy? 15 | Here we present a simple approach that allows such creative freedom. 16 | Using only 3-5 images of a user-provided concept, like an object or a style, we learn to represent it through new "words" in the embedding space of a frozen text-to-image model. 17 | These "words" can be composed into natural language sentences, guiding personalized creation in an intuitive way. 18 | Notably, we find evidence that a single word embedding is sufficient for capturing unique and varied concepts. 19 | We compare our approach to a wide range of baselines, and demonstrate that it can more faithfully portray the concepts across a range of applications and tasks. 20 | 21 | ## Description 22 | This repo contains the official code, data and sample inversions for our Textual Inversion paper. 23 | 24 | ## Updates 25 | **29/08/2022** Merge embeddings now supports SD embeddings. Added SD pivotal tuning code (WIP), fixed training duration, checkpoint save iterations. 26 | **21/08/2022** Code released! 27 | 28 | ## TODO: 29 | - [x] Release code! 30 | - [x] Optimize gradient storing / checkpointing. Memory requirements, training times reduced by ~55% 31 | - [x] Release data sets 32 | - [ ] Release pre-trained embeddings 33 | - [ ] Add Stable Diffusion support 34 | 35 | ## Setup 36 | 37 | Our code builds on, and shares requirements with [Latent Diffusion Models (LDM)](https://github.com/CompVis/latent-diffusion). To set up their environment, please run: 38 | 39 | ``` 40 | conda env create -f environment.yaml 41 | conda activate ldm 42 | ``` 43 | 44 | You will also need the official LDM text-to-image checkpoint, available through the [LDM project page](https://github.com/CompVis/latent-diffusion). 45 | 46 | Currently, the model can be downloaded by running: 47 | 48 | ``` 49 | mkdir -p models/ldm/text2img-large/ 50 | wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt 51 | ``` 52 | 53 | ## Usage 54 | 55 | ### Inversion 56 | 57 | To invert an image set, run: 58 | 59 | ``` 60 | python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml 61 | -t 62 | --actual_resume /path/to/pretrained/model.ckpt 63 | -n 64 | --gpus 0, 65 | --data_root /path/to/directory/with/images 66 | --init_word 67 | ``` 68 | 69 | where the initialization word should be a single-token rough description of the object (e.g., 'toy', 'painting', 'sculpture'). If the input is comprised of more than a single token, you will be prompted to replace it. 70 | 71 | Please note that `init_word` is *not* the placeholder string that will later represent the concept. It is only used as a beggining point for the optimization scheme. 72 | 73 | In the paper, we use 5k training iterations. However, some concepts (particularly styles) can converge much faster. 74 | 75 | To run on multiple GPUs, provide a comma-delimited list of GPU indices to the --gpus argument (e.g., ``--gpus 0,3,7,8``) 76 | 77 | Embeddings and output images will be saved in the log directory. 78 | 79 | See `configs/latent-diffusion/txt2img-1p4B-finetune.yaml` for more options, such as: changing the placeholder string which denotes the concept (defaults to "*"), changing the maximal number of training iterations, changing how often checkpoints are saved and more. 80 | 81 | **Important** All training set images should be upright. If you are using phone captured images, check the inputs_gs*.jpg files in the output image directory and make sure they are oriented correctly. Many phones capture images with a 90 degree rotation and denote this in the image metadata. Windows parses these correctly, but PIL does not. Hence you will need to correct them manually (e.g. by pasting them into paint and re-saving) or wait until we add metadata parsing. 82 | 83 | ### Generation 84 | 85 | To generate new images of the learned concept, run: 86 | ``` 87 | python scripts/txt2img.py --ddim_eta 0.0 88 | --n_samples 8 89 | --n_iter 2 90 | --scale 10.0 91 | --ddim_steps 50 92 | --embedding_path /path/to/logs/trained_model/checkpoints/embeddings_gs-5049.pt 93 | --ckpt_path /path/to/pretrained/model.ckpt 94 | --prompt "a photo of *" 95 | ``` 96 | 97 | where * is the placeholder string used during inversion. 98 | 99 | ### Merging Checkpoints 100 | 101 | LDM embedding checkpoints can be merged into a single file by running: 102 | 103 | ``` 104 | python merge_embeddings.py 105 | --manager_ckpts /path/to/first/embedding.pt /path/to/second/embedding.pt [...] 106 | --output_path /path/to/output/embedding.pt 107 | ``` 108 | 109 | For SD embeddings, simply add the flag: `-sd` or `--stable_diffusion`. 110 | 111 | If the checkpoints contain conflicting placeholder strings, you will be prompted to select new placeholders. The merged checkpoint can later be used to prompt multiple concepts at once ("A photo of * in the style of @"). 112 | 113 | ### Pretrained Models / Data 114 | 115 | Datasets which appear in the paper are being uploaded [here](https://drive.google.com/drive/folders/1d2UXkX0GWM-4qUwThjNhFIPP7S6WUbQJ). Some sets are unavailable due to image ownership. We will upload more as we recieve permissions to do so. 116 | 117 | Pretained models coming soon. 118 | 119 | ## Stable Diffusion 120 | 121 | Stable Diffusion support is a work in progress and will be completed soon™. 122 | 123 | ## Tips and Tricks 124 | - Adding "a photo of" to the prompt usually results in better target consistency. 125 | - Results can be seed sensititve. If you're unsatisfied with the model, try re-inverting with a new seed (by adding `--seed <#>` to the prompt). 126 | 127 | 128 | ## Citation 129 | 130 | If you make use of our work, please cite our paper: 131 | 132 | ``` 133 | @misc{gal2022textual, 134 | doi = {10.48550/ARXIV.2208.01618}, 135 | url = {https://arxiv.org/abs/2208.01618}, 136 | author = {Gal, Rinon and Alaluf, Yuval and Atzmon, Yuval and Patashnik, Or and Bermano, Amit H. and Chechik, Gal and Cohen-Or, Daniel}, 137 | title = {An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion}, 138 | publisher = {arXiv}, 139 | year = {2022}, 140 | primaryClass={cs.CV} 141 | } 142 | ``` 143 | 144 | ## Results 145 | Here are some sample results. Please visit our [project page](https://textual-inversion.github.io/) or read our paper for more! 146 | 147 | ![](img/teaser.jpg) 148 | 149 | ![](img/samples.jpg) 150 | 151 | ![](img/style.jpg) -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | personalization_config: 21 | target: ldm.modules.embedding_manager.EmbeddingManager 22 | params: 23 | placeholder_strings: ["*"] 24 | initializer_words: [] 25 | 26 | unet_config: 27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 28 | params: 29 | image_size: 32 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 320 33 | attention_resolutions: 34 | - 4 35 | - 2 36 | - 1 37 | num_res_blocks: 2 38 | channel_mult: 39 | - 1 40 | - 2 41 | - 4 42 | - 4 43 | num_heads: 8 44 | use_spatial_transformer: true 45 | transformer_depth: 1 46 | context_dim: 1280 47 | use_checkpoint: true 48 | legacy: False 49 | 50 | first_stage_config: 51 | target: ldm.models.autoencoder.AutoencoderKL 52 | params: 53 | embed_dim: 4 54 | monitor: val/rec_loss 55 | ddconfig: 56 | double_z: true 57 | z_channels: 4 58 | resolution: 256 59 | in_channels: 3 60 | out_ch: 3 61 | ch: 128 62 | ch_mult: 63 | - 1 64 | - 2 65 | - 4 66 | - 4 67 | num_res_blocks: 2 68 | attn_resolutions: [] 69 | dropout: 0.0 70 | lossconfig: 71 | target: torch.nn.Identity 72 | 73 | cond_stage_config: 74 | target: ldm.modules.encoders.modules.BERTEmbedder 75 | params: 76 | n_embed: 1280 77 | n_layer: 32 78 | -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-3 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | 21 | personalization_config: 22 | target: ldm.modules.embedding_manager.EmbeddingManager 23 | params: 24 | placeholder_strings: ["*"] 25 | initializer_words: ["sculpture"] 26 | per_image_tokens: false 27 | num_vectors_per_token: 1 28 | progressive_words: False 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | image_size: 32 34 | in_channels: 4 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: 38 | - 4 39 | - 2 40 | - 1 41 | num_res_blocks: 2 42 | channel_mult: 43 | - 1 44 | - 2 45 | - 4 46 | - 4 47 | num_heads: 8 48 | use_spatial_transformer: true 49 | transformer_depth: 1 50 | context_dim: 1280 51 | use_checkpoint: true 52 | legacy: False 53 | 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: 78 | target: ldm.modules.encoders.modules.BERTEmbedder 79 | params: 80 | n_embed: 1280 81 | n_layer: 32 82 | 83 | 84 | data: 85 | target: main.DataModuleFromConfig 86 | params: 87 | batch_size: 4 88 | num_workers: 2 89 | wrap: false 90 | train: 91 | target: ldm.data.personalized.PersonalizedBase 92 | params: 93 | size: 256 94 | set: train 95 | per_image_tokens: false 96 | repeats: 100 97 | validation: 98 | target: ldm.data.personalized.PersonalizedBase 99 | params: 100 | size: 256 101 | set: val 102 | per_image_tokens: false 103 | repeats: 10 104 | 105 | lightning: 106 | modelcheckpoint: 107 | params: 108 | every_n_train_steps: 500 109 | callbacks: 110 | image_logger: 111 | target: main.ImageLogger 112 | params: 113 | batch_frequency: 500 114 | max_images: 8 115 | increase_log_steps: False 116 | 117 | trainer: 118 | benchmark: True 119 | max_steps: 6100 -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-3 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | 21 | personalization_config: 22 | target: ldm.modules.embedding_manager.EmbeddingManager 23 | params: 24 | placeholder_strings: ["*"] 25 | initializer_words: ["painting"] 26 | per_image_tokens: false 27 | num_vectors_per_token: 1 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: 37 | - 4 38 | - 2 39 | - 1 40 | num_res_blocks: 2 41 | channel_mult: 42 | - 1 43 | - 2 44 | - 4 45 | - 4 46 | num_heads: 8 47 | use_spatial_transformer: true 48 | transformer_depth: 1 49 | context_dim: 1280 50 | use_checkpoint: true 51 | legacy: False 52 | 53 | first_stage_config: 54 | target: ldm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 256 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: ldm.modules.encoders.modules.BERTEmbedder 78 | params: 79 | n_embed: 1280 80 | n_layer: 32 81 | 82 | 83 | data: 84 | target: main.DataModuleFromConfig 85 | params: 86 | batch_size: 4 87 | num_workers: 4 88 | wrap: false 89 | train: 90 | target: ldm.data.personalized_style.PersonalizedBase 91 | params: 92 | size: 256 93 | set: train 94 | per_image_tokens: false 95 | repeats: 100 96 | validation: 97 | target: ldm.data.personalized_style.PersonalizedBase 98 | params: 99 | size: 256 100 | set: val 101 | per_image_tokens: false 102 | repeats: 10 103 | 104 | lightning: 105 | modelcheckpoint: 106 | params: 107 | every_n_train_steps: 500 108 | callbacks: 109 | image_logger: 110 | target: main.ImageLogger 111 | params: 112 | batch_frequency: 500 113 | max_images: 8 114 | increase_log_steps: False 115 | 116 | trainer: 117 | benchmark: True -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-03 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: true # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | unfreeze_model: False 21 | model_lr: 0.0 22 | 23 | personalization_config: 24 | target: ldm.modules.embedding_manager.EmbeddingManager 25 | params: 26 | placeholder_strings: ["*"] 27 | initializer_words: ["sculpture"] 28 | per_image_tokens: false 29 | num_vectors_per_token: 1 30 | progressive_words: False 31 | 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 # unused 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 320 39 | attention_resolutions: [ 4, 2, 1 ] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 4, 4 ] 42 | num_heads: 8 43 | use_spatial_transformer: True 44 | transformer_depth: 1 45 | context_dim: 768 46 | use_checkpoint: True 47 | legacy: False 48 | 49 | first_stage_config: 50 | target: ldm.models.autoencoder.AutoencoderKL 51 | params: 52 | embed_dim: 4 53 | monitor: val/rec_loss 54 | ddconfig: 55 | double_z: true 56 | z_channels: 4 57 | resolution: 512 58 | in_channels: 3 59 | out_ch: 3 60 | ch: 128 61 | ch_mult: 62 | - 1 63 | - 2 64 | - 4 65 | - 4 66 | num_res_blocks: 2 67 | attn_resolutions: [] 68 | dropout: 0.0 69 | lossconfig: 70 | target: torch.nn.Identity 71 | 72 | cond_stage_config: 73 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 74 | 75 | data: 76 | target: main.DataModuleFromConfig 77 | params: 78 | batch_size: 2 79 | num_workers: 2 80 | wrap: false 81 | train: 82 | target: ldm.data.personalized.PersonalizedBase 83 | params: 84 | size: 512 85 | set: train 86 | per_image_tokens: false 87 | repeats: 100 88 | validation: 89 | target: ldm.data.personalized.PersonalizedBase 90 | params: 91 | size: 512 92 | set: val 93 | per_image_tokens: false 94 | repeats: 10 95 | 96 | lightning: 97 | modelcheckpoint: 98 | params: 99 | every_n_train_steps: 500 100 | callbacks: 101 | image_logger: 102 | target: main.ImageLogger 103 | params: 104 | batch_frequency: 500 105 | max_images: 8 106 | increase_log_steps: False 107 | 108 | trainer: 109 | benchmark: True 110 | max_steps: 6100 -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-finetune_unfrozen.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: true # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | unfreeze_model: True 21 | model_lr: 1.0e-7 22 | 23 | personalization_config: 24 | target: ldm.modules.embedding_manager.EmbeddingManager 25 | params: 26 | placeholder_strings: ["*"] 27 | initializer_words: ["sculpture"] 28 | per_image_tokens: false 29 | num_vectors_per_token: 1 30 | progressive_words: False 31 | 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 # unused 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 320 39 | attention_resolutions: [ 4, 2, 1 ] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 4, 4 ] 42 | num_heads: 8 43 | use_spatial_transformer: True 44 | transformer_depth: 1 45 | context_dim: 768 46 | use_checkpoint: True 47 | legacy: False 48 | 49 | first_stage_config: 50 | target: ldm.models.autoencoder.AutoencoderKL 51 | params: 52 | embed_dim: 4 53 | monitor: val/rec_loss 54 | ddconfig: 55 | double_z: true 56 | z_channels: 4 57 | resolution: 512 58 | in_channels: 3 59 | out_ch: 3 60 | ch: 128 61 | ch_mult: 62 | - 1 63 | - 2 64 | - 4 65 | - 4 66 | num_res_blocks: 2 67 | attn_resolutions: [] 68 | dropout: 0.0 69 | lossconfig: 70 | target: torch.nn.Identity 71 | 72 | cond_stage_config: 73 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 74 | 75 | data: 76 | target: main.DataModuleFromConfig 77 | params: 78 | batch_size: 1 79 | num_workers: 2 80 | wrap: false 81 | train: 82 | target: ldm.data.personalized.PersonalizedBase 83 | params: 84 | size: 512 85 | set: train 86 | per_image_tokens: false 87 | repeats: 100 88 | validation: 89 | target: ldm.data.personalized.PersonalizedBase 90 | params: 91 | size: 512 92 | set: val 93 | per_image_tokens: false 94 | repeats: 10 95 | 96 | lightning: 97 | modelcheckpoint: 98 | params: 99 | every_n_train_steps: 500 100 | callbacks: 101 | image_logger: 102 | target: main.ImageLogger 103 | params: 104 | batch_frequency: 500 105 | max_images: 8 106 | increase_log_steps: False 107 | 108 | trainer: 109 | benchmark: True 110 | max_steps: 4200 -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | personalization_config: 21 | target: ldm.modules.embedding_manager.EmbeddingManager 22 | params: 23 | placeholder_strings: ["*"] 24 | initializer_words: ["sculpture"] 25 | per_image_tokens: false 26 | num_vectors_per_token: 1 27 | progressive_words: False 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.10 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.10.2 10 | - torchvision=0.11.3 11 | - numpy=1.22.3 12 | - pip: 13 | - albumentations==1.1.0 14 | - opencv-python==4.2.0.34 15 | - pudb==2019.2 16 | - imageio==2.14.1 17 | - imageio-ffmpeg==0.4.7 18 | - pytorch-lightning==1.5.9 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - setuptools==59.5.0 23 | - pillow==9.0.1 24 | - einops==0.4.1 25 | - torch-fidelity==0.3.0 26 | - transformers==4.18.0 27 | - torchmetrics==0.6.0 28 | - kornia==0.6 29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | - -e . 32 | -------------------------------------------------------------------------------- /evaluation/__pycache__/clip_eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/evaluation/__pycache__/clip_eval.cpython-36.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/clip_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/evaluation/__pycache__/clip_eval.cpython-38.pyc -------------------------------------------------------------------------------- /evaluation/clip_eval.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | from torchvision import transforms 4 | 5 | from ldm.models.diffusion.ddim import DDIMSampler 6 | 7 | class CLIPEvaluator(object): 8 | def __init__(self, device, clip_model='ViT-B/32') -> None: 9 | self.device = device 10 | self.model, clip_preprocess = clip.load(clip_model, device=self.device) 11 | 12 | self.clip_preprocess = clip_preprocess 13 | 14 | self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1]. 15 | clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions 16 | clip_preprocess.transforms[4:]) # + skip convert PIL to tensor 17 | 18 | def tokenize(self, strings: list): 19 | return clip.tokenize(strings).to(self.device) 20 | 21 | @torch.no_grad() 22 | def encode_text(self, tokens: list) -> torch.Tensor: 23 | return self.model.encode_text(tokens) 24 | 25 | @torch.no_grad() 26 | def encode_images(self, images: torch.Tensor) -> torch.Tensor: 27 | images = self.preprocess(images).to(self.device) 28 | return self.model.encode_image(images) 29 | 30 | def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor: 31 | 32 | tokens = clip.tokenize(text).to(self.device) 33 | 34 | text_features = self.encode_text(tokens).detach() 35 | 36 | if norm: 37 | text_features /= text_features.norm(dim=-1, keepdim=True) 38 | 39 | return text_features 40 | 41 | def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: 42 | image_features = self.encode_images(img) 43 | 44 | if norm: 45 | image_features /= image_features.clone().norm(dim=-1, keepdim=True) 46 | 47 | return image_features 48 | 49 | def img_to_img_similarity(self, src_images, generated_images): 50 | src_img_features = self.get_image_features(src_images) 51 | gen_img_features = self.get_image_features(generated_images) 52 | 53 | return (src_img_features @ gen_img_features.T).mean() 54 | 55 | def txt_to_img_similarity(self, text, generated_images): 56 | text_features = self.get_text_features(text) 57 | gen_img_features = self.get_image_features(generated_images) 58 | 59 | return (text_features @ gen_img_features.T).mean() 60 | 61 | 62 | class LDMCLIPEvaluator(CLIPEvaluator): 63 | def __init__(self, device, clip_model='ViT-B/32') -> None: 64 | super().__init__(device, clip_model) 65 | 66 | def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50): 67 | 68 | sampler = DDIMSampler(ldm_model) 69 | 70 | samples_per_batch = 8 71 | n_batches = n_samples // samples_per_batch 72 | 73 | # generate samples 74 | all_samples=list() 75 | with torch.no_grad(): 76 | with ldm_model.ema_scope(): 77 | uc = ldm_model.get_learned_conditioning(samples_per_batch * [""]) 78 | 79 | for batch in range(n_batches): 80 | c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text]) 81 | shape = [4, 256//8, 256//8] 82 | samples_ddim, _ = sampler.sample(S=n_steps, 83 | conditioning=c, 84 | batch_size=samples_per_batch, 85 | shape=shape, 86 | verbose=False, 87 | unconditional_guidance_scale=5.0, 88 | unconditional_conditioning=uc, 89 | eta=0.0) 90 | 91 | x_samples_ddim = ldm_model.decode_first_stage(samples_ddim) 92 | x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0) 93 | 94 | all_samples.append(x_samples_ddim) 95 | 96 | all_samples = torch.cat(all_samples, axis=0) 97 | 98 | sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples) 99 | sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples) 100 | 101 | return sim_samples_to_img, sim_samples_to_text 102 | 103 | 104 | class ImageDirEvaluator(CLIPEvaluator): 105 | def __init__(self, device, clip_model='ViT-B/32') -> None: 106 | super().__init__(device, clip_model) 107 | 108 | def evaluate(self, gen_samples, src_images, target_text): 109 | 110 | sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples) 111 | sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples) 112 | 113 | return sim_samples_to_img, sim_samples_to_text -------------------------------------------------------------------------------- /img/samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/img/samples.jpg -------------------------------------------------------------------------------- /img/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/img/style.jpg -------------------------------------------------------------------------------- /img/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/img/teaser.jpg -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/personalized.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/personalized.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/personalized_compose.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_compose.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/personalized_style.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_style.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/personalized_style.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_style.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/data/personalized.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | import random 9 | 10 | imagenet_templates_smallest = [ 11 | 'a photo of a {}', 12 | ] 13 | 14 | imagenet_templates_small = [ 15 | 'a photo of a {}', 16 | 'a rendering of a {}', 17 | 'a cropped photo of the {}', 18 | 'the photo of a {}', 19 | 'a photo of a clean {}', 20 | 'a photo of a dirty {}', 21 | 'a dark photo of the {}', 22 | 'a photo of my {}', 23 | 'a photo of the cool {}', 24 | 'a close-up photo of a {}', 25 | 'a bright photo of the {}', 26 | 'a cropped photo of a {}', 27 | 'a photo of the {}', 28 | 'a good photo of the {}', 29 | 'a photo of one {}', 30 | 'a close-up photo of the {}', 31 | 'a rendition of the {}', 32 | 'a photo of the clean {}', 33 | 'a rendition of a {}', 34 | 'a photo of a nice {}', 35 | 'a good photo of a {}', 36 | 'a photo of the nice {}', 37 | 'a photo of the small {}', 38 | 'a photo of the weird {}', 39 | 'a photo of the large {}', 40 | 'a photo of a cool {}', 41 | 'a photo of a small {}', 42 | 'an illustration of a {}', 43 | 'a rendering of a {}', 44 | 'a cropped photo of the {}', 45 | 'the photo of a {}', 46 | 'an illustration of a clean {}', 47 | 'an illustration of a dirty {}', 48 | 'a dark photo of the {}', 49 | 'an illustration of my {}', 50 | 'an illustration of the cool {}', 51 | 'a close-up photo of a {}', 52 | 'a bright photo of the {}', 53 | 'a cropped photo of a {}', 54 | 'an illustration of the {}', 55 | 'a good photo of the {}', 56 | 'an illustration of one {}', 57 | 'a close-up photo of the {}', 58 | 'a rendition of the {}', 59 | 'an illustration of the clean {}', 60 | 'a rendition of a {}', 61 | 'an illustration of a nice {}', 62 | 'a good photo of a {}', 63 | 'an illustration of the nice {}', 64 | 'an illustration of the small {}', 65 | 'an illustration of the weird {}', 66 | 'an illustration of the large {}', 67 | 'an illustration of a cool {}', 68 | 'an illustration of a small {}', 69 | 'a depiction of a {}', 70 | 'a rendering of a {}', 71 | 'a cropped photo of the {}', 72 | 'the photo of a {}', 73 | 'a depiction of a clean {}', 74 | 'a depiction of a dirty {}', 75 | 'a dark photo of the {}', 76 | 'a depiction of my {}', 77 | 'a depiction of the cool {}', 78 | 'a close-up photo of a {}', 79 | 'a bright photo of the {}', 80 | 'a cropped photo of a {}', 81 | 'a depiction of the {}', 82 | 'a good photo of the {}', 83 | 'a depiction of one {}', 84 | 'a close-up photo of the {}', 85 | 'a rendition of the {}', 86 | 'a depiction of the clean {}', 87 | 'a rendition of a {}', 88 | 'a depiction of a nice {}', 89 | 'a good photo of a {}', 90 | 'a depiction of the nice {}', 91 | 'a depiction of the small {}', 92 | 'a depiction of the weird {}', 93 | 'a depiction of the large {}', 94 | 'a depiction of a cool {}', 95 | 'a depiction of a small {}', 96 | ] 97 | 98 | imagenet_dual_templates_small = [ 99 | 'a photo of a {} with {}', 100 | 'a rendering of a {} with {}', 101 | 'a cropped photo of the {} with {}', 102 | 'the photo of a {} with {}', 103 | 'a photo of a clean {} with {}', 104 | 'a photo of a dirty {} with {}', 105 | 'a dark photo of the {} with {}', 106 | 'a photo of my {} with {}', 107 | 'a photo of the cool {} with {}', 108 | 'a close-up photo of a {} with {}', 109 | 'a bright photo of the {} with {}', 110 | 'a cropped photo of a {} with {}', 111 | 'a photo of the {} with {}', 112 | 'a good photo of the {} with {}', 113 | 'a photo of one {} with {}', 114 | 'a close-up photo of the {} with {}', 115 | 'a rendition of the {} with {}', 116 | 'a photo of the clean {} with {}', 117 | 'a rendition of a {} with {}', 118 | 'a photo of a nice {} with {}', 119 | 'a good photo of a {} with {}', 120 | 'a photo of the nice {} with {}', 121 | 'a photo of the small {} with {}', 122 | 'a photo of the weird {} with {}', 123 | 'a photo of the large {} with {}', 124 | 'a photo of a cool {} with {}', 125 | 'a photo of a small {} with {}', 126 | ] 127 | 128 | per_img_token_list = [ 129 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 130 | ] 131 | 132 | class PersonalizedBase(Dataset): 133 | def __init__(self, 134 | data_root, 135 | size=None, 136 | repeats=100, 137 | interpolation="bicubic", 138 | flip_p=0.5, 139 | set="train", 140 | placeholder_token="*", 141 | per_image_tokens=False, 142 | center_crop=False, 143 | mixing_prob=0.25, 144 | coarse_class_text=None, 145 | ): 146 | 147 | self.data_root = data_root 148 | 149 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 150 | 151 | # self._length = len(self.image_paths) 152 | self.num_images = len(self.image_paths) 153 | self._length = self.num_images 154 | 155 | self.placeholder_token = placeholder_token 156 | 157 | self.per_image_tokens = per_image_tokens 158 | self.center_crop = center_crop 159 | self.mixing_prob = mixing_prob 160 | 161 | self.coarse_class_text = coarse_class_text 162 | 163 | if per_image_tokens: 164 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." 165 | 166 | if set == "train": 167 | self._length = self.num_images * repeats 168 | 169 | self.size = size 170 | self.interpolation = {"linear": PIL.Image.LINEAR, 171 | "bilinear": PIL.Image.BILINEAR, 172 | "bicubic": PIL.Image.BICUBIC, 173 | "lanczos": PIL.Image.LANCZOS, 174 | }[interpolation] 175 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 176 | 177 | def __len__(self): 178 | return self._length 179 | 180 | def __getitem__(self, i): 181 | example = {} 182 | image = Image.open(self.image_paths[i % self.num_images]) 183 | 184 | if not image.mode == "RGB": 185 | image = image.convert("RGB") 186 | 187 | placeholder_string = self.placeholder_token 188 | if self.coarse_class_text: 189 | placeholder_string = f"{self.coarse_class_text} {placeholder_string}" 190 | 191 | if self.per_image_tokens and np.random.uniform() < self.mixing_prob: 192 | text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) 193 | else: 194 | text = random.choice(imagenet_templates_small).format(placeholder_string) 195 | 196 | example["caption"] = text 197 | 198 | # default to score-sde preprocessing 199 | img = np.array(image).astype(np.uint8) 200 | 201 | if self.center_crop: 202 | crop = min(img.shape[0], img.shape[1]) 203 | h, w, = img.shape[0], img.shape[1] 204 | img = img[(h - crop) // 2:(h + crop) // 2, 205 | (w - crop) // 2:(w + crop) // 2] 206 | 207 | image = Image.fromarray(img) 208 | if self.size is not None: 209 | image = image.resize((self.size, self.size), resample=self.interpolation) 210 | 211 | image = self.flip(image) 212 | image = np.array(image).astype(np.uint8) 213 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 214 | return example -------------------------------------------------------------------------------- /ldm/data/personalized_style.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | import random 9 | 10 | imagenet_templates_small = [ 11 | 'a painting in the style of {}', 12 | 'a rendering in the style of {}', 13 | 'a cropped painting in the style of {}', 14 | 'the painting in the style of {}', 15 | 'a clean painting in the style of {}', 16 | 'a dirty painting in the style of {}', 17 | 'a dark painting in the style of {}', 18 | 'a picture in the style of {}', 19 | 'a cool painting in the style of {}', 20 | 'a close-up painting in the style of {}', 21 | 'a bright painting in the style of {}', 22 | 'a cropped painting in the style of {}', 23 | 'a good painting in the style of {}', 24 | 'a close-up painting in the style of {}', 25 | 'a rendition in the style of {}', 26 | 'a nice painting in the style of {}', 27 | 'a small painting in the style of {}', 28 | 'a weird painting in the style of {}', 29 | 'a large painting in the style of {}', 30 | ] 31 | 32 | imagenet_dual_templates_small = [ 33 | 'a painting in the style of {} with {}', 34 | 'a rendering in the style of {} with {}', 35 | 'a cropped painting in the style of {} with {}', 36 | 'the painting in the style of {} with {}', 37 | 'a clean painting in the style of {} with {}', 38 | 'a dirty painting in the style of {} with {}', 39 | 'a dark painting in the style of {} with {}', 40 | 'a cool painting in the style of {} with {}', 41 | 'a close-up painting in the style of {} with {}', 42 | 'a bright painting in the style of {} with {}', 43 | 'a cropped painting in the style of {} with {}', 44 | 'a good painting in the style of {} with {}', 45 | 'a painting of one {} in the style of {}', 46 | 'a nice painting in the style of {} with {}', 47 | 'a small painting in the style of {} with {}', 48 | 'a weird painting in the style of {} with {}', 49 | 'a large painting in the style of {} with {}', 50 | ] 51 | 52 | per_img_token_list = [ 53 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 54 | ] 55 | 56 | class PersonalizedBase(Dataset): 57 | def __init__(self, 58 | data_root, 59 | size=None, 60 | repeats=100, 61 | interpolation="bicubic", 62 | flip_p=0.5, 63 | set="train", 64 | placeholder_token="*", 65 | per_image_tokens=False, 66 | center_crop=False, 67 | ): 68 | 69 | self.data_root = data_root 70 | 71 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 72 | 73 | # self._length = len(self.image_paths) 74 | self.num_images = len(self.image_paths) 75 | self._length = self.num_images 76 | 77 | self.placeholder_token = placeholder_token 78 | 79 | self.per_image_tokens = per_image_tokens 80 | self.center_crop = center_crop 81 | 82 | if per_image_tokens: 83 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." 84 | 85 | if set == "train": 86 | self._length = self.num_images * repeats 87 | 88 | self.size = size 89 | self.interpolation = {"linear": PIL.Image.LINEAR, 90 | "bilinear": PIL.Image.BILINEAR, 91 | "bicubic": PIL.Image.BICUBIC, 92 | "lanczos": PIL.Image.LANCZOS, 93 | }[interpolation] 94 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 95 | 96 | def __len__(self): 97 | return self._length 98 | 99 | def __getitem__(self, i): 100 | example = {} 101 | image = Image.open(self.image_paths[i % self.num_images]) 102 | 103 | if not image.mode == "RGB": 104 | image = image.convert("RGB") 105 | 106 | if self.per_image_tokens and np.random.uniform() < 0.25: 107 | text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) 108 | else: 109 | text = random.choice(imagenet_templates_small).format(self.placeholder_token) 110 | 111 | example["caption"] = text 112 | 113 | # default to score-sde preprocessing 114 | img = np.array(image).astype(np.uint8) 115 | 116 | if self.center_crop: 117 | crop = min(img.shape[0], img.shape[1]) 118 | h, w, = img.shape[0], img.shape[1] 119 | img = img[(h - crop) // 2:(h + crop) // 2, 120 | (w - crop) // 2:(w + crop) // 2] 121 | 122 | image = Image.fromarray(img) 123 | if self.size is not None: 124 | image = image.resize((self.size, self.size), resample=self.interpolation) 125 | 126 | image = self.flip(image) 127 | image = np.array(image).astype(np.uint8) 128 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 129 | return example -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/__pycache__/autoencoder.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/plms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/plms.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/plms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/ema.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/embedding_manager.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/embedding_manager.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/embedding_manager.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/embedding_manager.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/x_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/x_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if False: # disabled checkpointing to allow requires_grad = False for main model 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/embedding_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ldm.data.personalized import per_img_token_list 5 | from transformers import CLIPTokenizer 6 | from functools import partial 7 | 8 | DEFAULT_PLACEHOLDER_TOKEN = ["*"] 9 | 10 | PROGRESSIVE_SCALE = 2000 11 | 12 | def get_clip_token_for_string(tokenizer, string): 13 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 14 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 15 | tokens = batch_encoding["input_ids"] 16 | assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" 17 | 18 | return tokens[0, 1] 19 | 20 | def get_bert_token_for_string(tokenizer, string): 21 | token = tokenizer(string) 22 | assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" 23 | 24 | token = token[0, 1] 25 | 26 | return token 27 | 28 | def get_embedding_for_clip_token(embedder, token): 29 | return embedder(token.unsqueeze(0))[0, 0] 30 | 31 | 32 | class EmbeddingManager(nn.Module): 33 | def __init__( 34 | self, 35 | embedder, 36 | placeholder_strings=None, 37 | initializer_words=None, 38 | per_image_tokens=False, 39 | num_vectors_per_token=1, 40 | progressive_words=False, 41 | **kwargs 42 | ): 43 | super().__init__() 44 | 45 | self.string_to_token_dict = {} 46 | 47 | self.string_to_param_dict = nn.ParameterDict() 48 | 49 | self.initial_embeddings = nn.ParameterDict() # These should not be optimized 50 | 51 | self.progressive_words = progressive_words 52 | self.progressive_counter = 0 53 | 54 | self.max_vectors_per_token = num_vectors_per_token 55 | 56 | if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder 57 | self.is_clip = True 58 | get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) 59 | get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) 60 | token_dim = 768 61 | else: # using LDM's BERT encoder 62 | self.is_clip = False 63 | get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) 64 | get_embedding_for_tkn = embedder.transformer.token_emb 65 | token_dim = 1280 66 | 67 | if per_image_tokens: 68 | placeholder_strings.extend(per_img_token_list) 69 | 70 | for idx, placeholder_string in enumerate(placeholder_strings): 71 | 72 | token = get_token_for_string(placeholder_string) 73 | 74 | if initializer_words and idx < len(initializer_words): 75 | init_word_token = get_token_for_string(initializer_words[idx]) 76 | 77 | with torch.no_grad(): 78 | init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) 79 | 80 | token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) 81 | self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) 82 | else: 83 | token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) 84 | 85 | self.string_to_token_dict[placeholder_string] = token 86 | self.string_to_param_dict[placeholder_string] = token_params 87 | 88 | def forward( 89 | self, 90 | tokenized_text, 91 | embedded_text, 92 | ): 93 | b, n, device = *tokenized_text.shape, tokenized_text.device 94 | 95 | for placeholder_string, placeholder_token in self.string_to_token_dict.items(): 96 | 97 | placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) 98 | 99 | if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement 100 | placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) 101 | embedded_text[placeholder_idx] = placeholder_embedding 102 | else: # otherwise, need to insert and keep track of changing indices 103 | if self.progressive_words: 104 | self.progressive_counter += 1 105 | max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE 106 | else: 107 | max_step_tokens = self.max_vectors_per_token 108 | 109 | num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) 110 | 111 | placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) 112 | 113 | if placeholder_rows.nelement() == 0: 114 | continue 115 | 116 | sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) 117 | sorted_rows = placeholder_rows[sort_idx] 118 | 119 | for idx in range(len(sorted_rows)): 120 | row = sorted_rows[idx] 121 | col = sorted_cols[idx] 122 | 123 | new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] 124 | new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] 125 | 126 | embedded_text[row] = new_embed_row 127 | tokenized_text[row] = new_token_row 128 | 129 | return embedded_text 130 | 131 | def save(self, ckpt_path): 132 | torch.save({"string_to_token": self.string_to_token_dict, 133 | "string_to_param": self.string_to_param_dict}, ckpt_path) 134 | 135 | def load(self, ckpt_path): 136 | ckpt = torch.load(ckpt_path, map_location='cpu') 137 | 138 | self.string_to_token_dict = ckpt["string_to_token"] 139 | self.string_to_param_dict = ckpt["string_to_param"] 140 | 141 | def get_embedding_norms_squared(self): 142 | all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim 143 | param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders 144 | 145 | return param_norm_squared 146 | 147 | def embedding_parameters(self): 148 | return self.string_to_param_dict.parameters() 149 | 150 | def embedding_to_coarse_loss(self): 151 | 152 | loss = 0. 153 | num_embeddings = len(self.initial_embeddings) 154 | 155 | for key in self.initial_embeddings: 156 | optimized = self.string_to_param_dict[key] 157 | coarse = self.initial_embeddings[key].clone().to(optimized.device) 158 | 159 | loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings 160 | 161 | return loss -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.load_default() 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config, **kwargs): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /merge_embeddings.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder 2 | from ldm.modules.embedding_manager import EmbeddingManager 3 | 4 | import argparse, os 5 | from functools import partial 6 | 7 | import torch 8 | 9 | def get_placeholder_loop(placeholder_string, embedder, is_sd): 10 | 11 | new_placeholder = None 12 | 13 | while True: 14 | if new_placeholder is None: 15 | new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") 16 | else: 17 | new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") 18 | 19 | token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder) 20 | 21 | if token is not None: 22 | return new_placeholder, token 23 | 24 | def get_clip_token_for_string(tokenizer, string): 25 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 26 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 27 | tokens = batch_encoding["input_ids"] 28 | 29 | if torch.count_nonzero(tokens - 49407) == 2: 30 | return tokens[0, 1] 31 | 32 | return None 33 | 34 | def get_bert_token_for_string(tokenizer, string): 35 | token = tokenizer(string) 36 | if torch.count_nonzero(token) == 3: 37 | return token[0, 1] 38 | 39 | return None 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | parser = argparse.ArgumentParser() 45 | 46 | parser.add_argument( 47 | "--manager_ckpts", 48 | type=str, 49 | nargs="+", 50 | required=True, 51 | help="Paths to a set of embedding managers to be merged." 52 | ) 53 | 54 | parser.add_argument( 55 | "--output_path", 56 | type=str, 57 | required=True, 58 | help="Output path for the merged manager", 59 | ) 60 | 61 | parser.add_argument( 62 | "-sd", "--stable_diffusion", 63 | action="store_true", 64 | help="Flag to denote that we are merging stable diffusion embeddings" 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | if args.stable_diffusion: 70 | embedder = FrozenCLIPEmbedder().cuda() 71 | else: 72 | embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda() 73 | 74 | EmbeddingManager = partial(EmbeddingManager, embedder, ["*"]) 75 | 76 | string_to_token_dict = {} 77 | string_to_param_dict = torch.nn.ParameterDict() 78 | 79 | placeholder_to_src = {} 80 | 81 | for manager_ckpt in args.manager_ckpts: 82 | print(f"Parsing {manager_ckpt}...") 83 | 84 | manager = EmbeddingManager() 85 | manager.load(manager_ckpt) 86 | 87 | for placeholder_string in manager.string_to_token_dict: 88 | if not placeholder_string in string_to_token_dict: 89 | string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] 90 | string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] 91 | 92 | placeholder_to_src[placeholder_string] = manager_ckpt 93 | else: 94 | new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion) 95 | string_to_token_dict[new_placeholder] = new_token 96 | string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string] 97 | 98 | placeholder_to_src[new_placeholder] = manager_ckpt 99 | 100 | print("Saving combined manager...") 101 | merged_manager = EmbeddingManager() 102 | merged_manager.string_to_param_dict = string_to_param_dict 103 | merged_manager.string_to_token_dict = string_to_token_dict 104 | merged_manager.save(args.output_path) 105 | 106 | print("Managers merged. Final list of placeholders: ") 107 | print(placeholder_to_src) 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /scripts/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | 3 | sys.path.append(os.path.join(sys.path[0], '..')) 4 | 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from tqdm import tqdm, trange 10 | from einops import rearrange 11 | from torchvision.utils import make_grid 12 | 13 | from ldm.util import instantiate_from_config 14 | from ldm.models.diffusion.ddim import DDIMSampler 15 | from ldm.models.diffusion.plms import PLMSSampler 16 | from ldm.data.personalized import PersonalizedBase 17 | from evaluation.clip_eval import LDMCLIPEvaluator 18 | 19 | def load_model_from_config(config, ckpt, verbose=False): 20 | print(f"Loading model from {ckpt}") 21 | pl_sd = torch.load(ckpt, map_location="cpu") 22 | sd = pl_sd["state_dict"] 23 | model = instantiate_from_config(config.model) 24 | m, u = model.load_state_dict(sd, strict=False) 25 | if len(m) > 0 and verbose: 26 | print("missing keys:") 27 | print(m) 28 | if len(u) > 0 and verbose: 29 | print("unexpected keys:") 30 | print(u) 31 | 32 | model.cuda() 33 | model.eval() 34 | return model 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | 40 | parser.add_argument( 41 | "--prompt", 42 | type=str, 43 | nargs="?", 44 | default="a painting of a virus monster playing guitar", 45 | help="the prompt to render" 46 | ) 47 | 48 | parser.add_argument( 49 | "--ckpt_path", 50 | type=str, 51 | default="/data/pretrained_models/ldm/text2img-large/model.ckpt", 52 | help="Path to pretrained ldm text2img model") 53 | 54 | parser.add_argument( 55 | "--embedding_path", 56 | type=str, 57 | help="Path to a pre-trained embedding manager checkpoint") 58 | 59 | parser.add_argument( 60 | "--data_dir", 61 | type=str, 62 | help="Path to directory with images used to train the embedding vectors" 63 | ) 64 | 65 | opt = parser.parse_args() 66 | 67 | 68 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic 69 | model = load_model_from_config(config, opt.ckpt_path) # TODO: check path 70 | model.embedding_manager.load(opt.embedding_path) 71 | 72 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 73 | model = model.to(device) 74 | 75 | evaluator = LDMCLIPEvaluator(device) 76 | 77 | prompt = opt.prompt 78 | 79 | data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0) 80 | 81 | images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)] 82 | images = torch.stack(images, axis=0) 83 | 84 | sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt) 85 | 86 | output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-")) 87 | 88 | print("Image similarity: ", sim_img) 89 | print("Text similarity: ", sim_text) -------------------------------------------------------------------------------- /scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /scripts/stable_txt2img.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | from tqdm import tqdm, trange 7 | from itertools import islice 8 | from einops import rearrange 9 | from torchvision.utils import make_grid 10 | import time 11 | from pytorch_lightning import seed_everything 12 | from torch import autocast 13 | from contextlib import contextmanager, nullcontext 14 | 15 | from ldm.util import instantiate_from_config 16 | from ldm.models.diffusion.ddim import DDIMSampler 17 | from ldm.models.diffusion.plms import PLMSSampler 18 | 19 | 20 | def chunk(it, size): 21 | it = iter(it) 22 | return iter(lambda: tuple(islice(it, size)), ()) 23 | 24 | 25 | def load_model_from_config(config, ckpt, verbose=False): 26 | print(f"Loading model from {ckpt}") 27 | pl_sd = torch.load(ckpt, map_location="cpu") 28 | if "global_step" in pl_sd: 29 | print(f"Global Step: {pl_sd['global_step']}") 30 | sd = pl_sd["state_dict"] 31 | model = instantiate_from_config(config.model) 32 | m, u = model.load_state_dict(sd, strict=False) 33 | if len(m) > 0 and verbose: 34 | print("missing keys:") 35 | print(m) 36 | if len(u) > 0 and verbose: 37 | print("unexpected keys:") 38 | print(u) 39 | 40 | model.cuda() 41 | model.eval() 42 | return model 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument( 49 | "--prompt", 50 | type=str, 51 | nargs="?", 52 | default="a painting of a virus monster playing guitar", 53 | help="the prompt to render" 54 | ) 55 | parser.add_argument( 56 | "--outdir", 57 | type=str, 58 | nargs="?", 59 | help="dir to write results to", 60 | default="outputs/txt2img-samples" 61 | ) 62 | parser.add_argument( 63 | "--skip_grid", 64 | action='store_true', 65 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", 66 | ) 67 | parser.add_argument( 68 | "--skip_save", 69 | action='store_true', 70 | help="do not save individual samples. For speed measurements.", 71 | ) 72 | parser.add_argument( 73 | "--ddim_steps", 74 | type=int, 75 | default=50, 76 | help="number of ddim sampling steps", 77 | ) 78 | parser.add_argument( 79 | "--plms", 80 | action='store_true', 81 | help="use plms sampling", 82 | ) 83 | parser.add_argument( 84 | "--laion400m", 85 | action='store_true', 86 | help="uses the LAION400M model", 87 | ) 88 | parser.add_argument( 89 | "--fixed_code", 90 | action='store_true', 91 | help="if enabled, uses the same starting code across samples ", 92 | ) 93 | parser.add_argument( 94 | "--ddim_eta", 95 | type=float, 96 | default=0.0, 97 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 98 | ) 99 | parser.add_argument( 100 | "--n_iter", 101 | type=int, 102 | default=2, 103 | help="sample this often", 104 | ) 105 | parser.add_argument( 106 | "--H", 107 | type=int, 108 | default=512, 109 | help="image height, in pixel space", 110 | ) 111 | parser.add_argument( 112 | "--W", 113 | type=int, 114 | default=512, 115 | help="image width, in pixel space", 116 | ) 117 | parser.add_argument( 118 | "--C", 119 | type=int, 120 | default=4, 121 | help="latent channels", 122 | ) 123 | parser.add_argument( 124 | "--f", 125 | type=int, 126 | default=8, 127 | help="downsampling factor", 128 | ) 129 | parser.add_argument( 130 | "--n_samples", 131 | type=int, 132 | default=3, 133 | help="how many samples to produce for each given prompt. A.k.a. batch size", 134 | ) 135 | parser.add_argument( 136 | "--n_rows", 137 | type=int, 138 | default=0, 139 | help="rows in the grid (default: n_samples)", 140 | ) 141 | parser.add_argument( 142 | "--scale", 143 | type=float, 144 | default=7.5, 145 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 146 | ) 147 | parser.add_argument( 148 | "--from-file", 149 | type=str, 150 | help="if specified, load prompts from this file", 151 | ) 152 | parser.add_argument( 153 | "--config", 154 | type=str, 155 | default="configs/stable-diffusion/v1-inference.yaml", 156 | help="path to config which constructs model", 157 | ) 158 | parser.add_argument( 159 | "--ckpt", 160 | type=str, 161 | default="models/ldm/stable-diffusion-v1/model.ckpt", 162 | help="path to checkpoint of model", 163 | ) 164 | parser.add_argument( 165 | "--seed", 166 | type=int, 167 | default=42, 168 | help="the seed (for reproducible sampling)", 169 | ) 170 | parser.add_argument( 171 | "--precision", 172 | type=str, 173 | help="evaluate at this precision", 174 | choices=["full", "autocast"], 175 | default="autocast" 176 | ) 177 | 178 | 179 | parser.add_argument( 180 | "--embedding_path", 181 | type=str, 182 | help="Path to a pre-trained embedding manager checkpoint") 183 | 184 | opt = parser.parse_args() 185 | 186 | if opt.laion400m: 187 | print("Falling back to LAION 400M model...") 188 | opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" 189 | opt.ckpt = "models/ldm/text2img-large/model.ckpt" 190 | opt.outdir = "outputs/txt2img-samples-laion400m" 191 | 192 | seed_everything(opt.seed) 193 | 194 | config = OmegaConf.load(f"{opt.config}") 195 | model = load_model_from_config(config, f"{opt.ckpt}") 196 | model.embedding_manager.load(opt.embedding_path) 197 | 198 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 199 | model = model.to(device) 200 | 201 | if opt.plms: 202 | sampler = PLMSSampler(model) 203 | else: 204 | sampler = DDIMSampler(model) 205 | 206 | os.makedirs(opt.outdir, exist_ok=True) 207 | outpath = opt.outdir 208 | 209 | batch_size = opt.n_samples 210 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 211 | if not opt.from_file: 212 | prompt = opt.prompt 213 | assert prompt is not None 214 | data = [batch_size * [prompt]] 215 | 216 | else: 217 | print(f"reading prompts from {opt.from_file}") 218 | with open(opt.from_file, "r") as f: 219 | data = f.read().splitlines() 220 | data = list(chunk(data, batch_size)) 221 | 222 | sample_path = os.path.join(outpath, "samples") 223 | os.makedirs(sample_path, exist_ok=True) 224 | base_count = len(os.listdir(sample_path)) 225 | grid_count = len(os.listdir(outpath)) - 1 226 | 227 | start_code = None 228 | if opt.fixed_code: 229 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) 230 | 231 | precision_scope = autocast if opt.precision=="autocast" else nullcontext 232 | with torch.no_grad(): 233 | with precision_scope("cuda"): 234 | with model.ema_scope(): 235 | tic = time.time() 236 | all_samples = list() 237 | for n in trange(opt.n_iter, desc="Sampling"): 238 | for prompts in tqdm(data, desc="data"): 239 | uc = None 240 | if opt.scale != 1.0: 241 | uc = model.get_learned_conditioning(batch_size * [""]) 242 | if isinstance(prompts, tuple): 243 | prompts = list(prompts) 244 | c = model.get_learned_conditioning(prompts) 245 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f] 246 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 247 | conditioning=c, 248 | batch_size=opt.n_samples, 249 | shape=shape, 250 | verbose=False, 251 | unconditional_guidance_scale=opt.scale, 252 | unconditional_conditioning=uc, 253 | eta=opt.ddim_eta, 254 | x_T=start_code) 255 | 256 | x_samples_ddim = model.decode_first_stage(samples_ddim) 257 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 258 | 259 | if not opt.skip_save: 260 | for x_sample in x_samples_ddim: 261 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 262 | Image.fromarray(x_sample.astype(np.uint8)).save( 263 | os.path.join(sample_path, f"{base_count:05}.jpg")) 264 | base_count += 1 265 | 266 | if not opt.skip_grid: 267 | all_samples.append(x_samples_ddim) 268 | 269 | if not opt.skip_grid: 270 | # additionally, save as grid 271 | grid = torch.stack(all_samples, 0) 272 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 273 | grid = make_grid(grid, nrow=n_rows) 274 | 275 | # to image 276 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 277 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}-{grid_count:04}.jpg')) 278 | grid_count += 1 279 | 280 | toc = time.time() 281 | 282 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n" 283 | f" \nEnjoy.") 284 | 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /scripts/txt2img.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | from tqdm import tqdm, trange 7 | from einops import rearrange 8 | from torchvision.utils import make_grid 9 | 10 | from ldm.util import instantiate_from_config 11 | from ldm.models.diffusion.ddim import DDIMSampler 12 | from ldm.models.diffusion.plms import PLMSSampler 13 | 14 | def load_model_from_config(config, ckpt, verbose=False): 15 | print(f"Loading model from {ckpt}") 16 | pl_sd = torch.load(ckpt, map_location="cpu") 17 | sd = pl_sd["state_dict"] 18 | model = instantiate_from_config(config.model) 19 | m, u = model.load_state_dict(sd, strict=False) 20 | if len(m) > 0 and verbose: 21 | print("missing keys:") 22 | print(m) 23 | if len(u) > 0 and verbose: 24 | print("unexpected keys:") 25 | print(u) 26 | 27 | model.cuda() 28 | model.eval() 29 | return model 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument( 36 | "--prompt", 37 | type=str, 38 | nargs="?", 39 | default="a painting of a virus monster playing guitar", 40 | help="the prompt to render" 41 | ) 42 | 43 | parser.add_argument( 44 | "--outdir", 45 | type=str, 46 | nargs="?", 47 | help="dir to write results to", 48 | default="outputs/txt2img-samples" 49 | ) 50 | parser.add_argument( 51 | "--ddim_steps", 52 | type=int, 53 | default=200, 54 | help="number of ddim sampling steps", 55 | ) 56 | 57 | parser.add_argument( 58 | "--plms", 59 | action='store_true', 60 | help="use plms sampling", 61 | ) 62 | 63 | parser.add_argument( 64 | "--ddim_eta", 65 | type=float, 66 | default=0.0, 67 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 68 | ) 69 | parser.add_argument( 70 | "--n_iter", 71 | type=int, 72 | default=1, 73 | help="sample this often", 74 | ) 75 | 76 | parser.add_argument( 77 | "--H", 78 | type=int, 79 | default=256, 80 | help="image height, in pixel space", 81 | ) 82 | 83 | parser.add_argument( 84 | "--W", 85 | type=int, 86 | default=256, 87 | help="image width, in pixel space", 88 | ) 89 | 90 | parser.add_argument( 91 | "--n_samples", 92 | type=int, 93 | default=4, 94 | help="how many samples to produce for the given prompt", 95 | ) 96 | 97 | parser.add_argument( 98 | "--scale", 99 | type=float, 100 | default=5.0, 101 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 102 | ) 103 | 104 | parser.add_argument( 105 | "--ckpt_path", 106 | type=str, 107 | default="/data/pretrained_models/ldm/text2img-large/model.ckpt", 108 | help="Path to pretrained ldm text2img model") 109 | 110 | parser.add_argument( 111 | "--embedding_path", 112 | type=str, 113 | help="Path to a pre-trained embedding manager checkpoint") 114 | 115 | opt = parser.parse_args() 116 | 117 | 118 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic 119 | model = load_model_from_config(config, opt.ckpt_path) # TODO: check path 120 | model.embedding_manager.load(opt.embedding_path) 121 | 122 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 123 | model = model.to(device) 124 | 125 | if opt.plms: 126 | sampler = PLMSSampler(model) 127 | else: 128 | sampler = DDIMSampler(model) 129 | 130 | os.makedirs(opt.outdir, exist_ok=True) 131 | outpath = opt.outdir 132 | 133 | prompt = opt.prompt 134 | 135 | 136 | sample_path = os.path.join(outpath, "samples") 137 | os.makedirs(sample_path, exist_ok=True) 138 | base_count = len(os.listdir(sample_path)) 139 | 140 | all_samples=list() 141 | with torch.no_grad(): 142 | with model.ema_scope(): 143 | uc = None 144 | if opt.scale != 1.0: 145 | uc = model.get_learned_conditioning(opt.n_samples * [""]) 146 | for n in trange(opt.n_iter, desc="Sampling"): 147 | c = model.get_learned_conditioning(opt.n_samples * [prompt]) 148 | shape = [4, opt.H//8, opt.W//8] 149 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 150 | conditioning=c, 151 | batch_size=opt.n_samples, 152 | shape=shape, 153 | verbose=False, 154 | unconditional_guidance_scale=opt.scale, 155 | unconditional_conditioning=uc, 156 | eta=opt.ddim_eta) 157 | 158 | x_samples_ddim = model.decode_first_stage(samples_ddim) 159 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 160 | 161 | for x_sample in x_samples_ddim: 162 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 163 | Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.jpg")) 164 | base_count += 1 165 | all_samples.append(x_samples_ddim) 166 | 167 | 168 | # additionally, save as grid 169 | grid = torch.stack(all_samples, 0) 170 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 171 | grid = make_grid(grid, nrow=opt.n_samples) 172 | 173 | # to image 174 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 175 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.jpg')) 176 | 177 | print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") 178 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------