├── LICENSE
├── README.md
├── configs
├── autoencoder
│ ├── autoencoder_kl_16x16x16.yaml
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ └── autoencoder_kl_8x8x64.yaml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_bedrooms-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ ├── txt2img-1p4B-eval.yaml
│ ├── txt2img-1p4B-eval_with_tokens.yaml
│ ├── txt2img-1p4B-finetune.yaml
│ └── txt2img-1p4B-finetune_style.yaml
└── stable-diffusion
│ ├── v1-finetune.yaml
│ ├── v1-finetune_unfrozen.yaml
│ └── v1-inference.yaml
├── environment.yaml
├── evaluation
├── __pycache__
│ ├── clip_eval.cpython-36.pyc
│ └── clip_eval.cpython-38.pyc
└── clip_eval.py
├── img
├── samples.jpg
├── style.jpg
└── teaser.jpg
├── ldm
├── __pycache__
│ ├── util.cpython-36.pyc
│ └── util.cpython-38.pyc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── base.cpython-36.pyc
│ │ ├── base.cpython-38.pyc
│ │ ├── personalized.cpython-36.pyc
│ │ ├── personalized.cpython-38.pyc
│ │ ├── personalized_compose.cpython-38.pyc
│ │ ├── personalized_detailed_text.cpython-36.pyc
│ │ ├── personalized_style.cpython-36.pyc
│ │ └── personalized_style.cpython-38.pyc
│ ├── base.py
│ ├── imagenet.py
│ ├── lsun.py
│ ├── personalized.py
│ └── personalized_style.py
├── lr_scheduler.py
├── models
│ ├── __pycache__
│ │ ├── autoencoder.cpython-36.pyc
│ │ └── autoencoder.cpython-38.pyc
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ddim.cpython-36.pyc
│ │ ├── ddim.cpython-38.pyc
│ │ ├── ddim_inversion.cpython-38.pyc
│ │ ├── ddpm.cpython-36.pyc
│ │ ├── ddpm.cpython-38.pyc
│ │ ├── ddpm_pti.cpython-38.pyc
│ │ ├── plms.cpython-36.pyc
│ │ └── plms.cpython-38.pyc
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ └── plms.py
├── modules
│ ├── __pycache__
│ │ ├── attention.cpython-36.pyc
│ │ ├── attention.cpython-38.pyc
│ │ ├── ema.cpython-36.pyc
│ │ ├── ema.cpython-38.pyc
│ │ ├── embedding_manager.cpython-36.pyc
│ │ ├── embedding_manager.cpython-38.pyc
│ │ ├── x_transformer.cpython-36.pyc
│ │ └── x_transformer.cpython-38.pyc
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── model.cpython-36.pyc
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── openaimodel.cpython-36.pyc
│ │ │ ├── openaimodel.cpython-38.pyc
│ │ │ ├── util.cpython-36.pyc
│ │ │ └── util.cpython-38.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── distributions.cpython-36.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── embedding_manager.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── modules.cpython-36.pyc
│ │ │ └── modules.cpython-38.pyc
│ │ ├── modules.py
│ │ └── modules_bak.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── main.py
├── merge_embeddings.py
├── models
├── first_stage_models
│ ├── kl-f16
│ │ └── config.yaml
│ ├── kl-f32
│ │ └── config.yaml
│ ├── kl-f4
│ │ └── config.yaml
│ ├── kl-f8
│ │ └── config.yaml
│ ├── vq-f16
│ │ └── config.yaml
│ ├── vq-f4-noattn
│ │ └── config.yaml
│ ├── vq-f4
│ │ └── config.yaml
│ ├── vq-f8-n256
│ │ └── config.yaml
│ └── vq-f8
│ │ └── config.yaml
└── ldm
│ ├── bsr_sr
│ └── config.yaml
│ ├── celeba256
│ └── config.yaml
│ ├── cin256
│ └── config.yaml
│ ├── ffhq256
│ └── config.yaml
│ ├── inpainting_big
│ └── config.yaml
│ ├── layout2img-openimages256
│ └── config.yaml
│ ├── lsun_beds256
│ └── config.yaml
│ ├── lsun_churches256
│ └── config.yaml
│ ├── semantic_synthesis256
│ └── config.yaml
│ ├── semantic_synthesis512
│ └── config.yaml
│ └── text2img256
│ └── config.yaml
├── scripts
├── download_first_stages.sh
├── download_models.sh
├── evaluate_model.py
├── inpaint.py
├── latent_imagenet_diffusion.ipynb
├── sample_diffusion.py
├── stable_txt2img.py
└── txt2img.py
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion
2 |
3 | [](https://arxiv.org/abs/2208.01618)
4 |
5 | [[Project Website](https://textual-inversion.github.io/)]
6 |
7 | > **An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion**
8 | > Rinon Gal1,2, Yuval Alaluf1, Yuval Atzmon2, Or Patashnik1, Amit H. Bermano1, Gal Chechik2, Daniel Cohen-Or1
9 | > 1Tel Aviv University, 2NVIDIA
10 |
11 | >**Abstract**:
12 | > Text-to-image models offer unprecedented freedom to guide creation through natural language.
13 | Yet, it is unclear how such freedom can be exercised to generate images of specific unique concepts, modify their appearance, or compose them in new roles and novel scenes.
14 | In other words, we ask: how can we use language-guided models to turn our cat into a painting, or imagine a new product based on our favorite toy?
15 | Here we present a simple approach that allows such creative freedom.
16 | Using only 3-5 images of a user-provided concept, like an object or a style, we learn to represent it through new "words" in the embedding space of a frozen text-to-image model.
17 | These "words" can be composed into natural language sentences, guiding personalized creation in an intuitive way.
18 | Notably, we find evidence that a single word embedding is sufficient for capturing unique and varied concepts.
19 | We compare our approach to a wide range of baselines, and demonstrate that it can more faithfully portray the concepts across a range of applications and tasks.
20 |
21 | ## Description
22 | This repo contains the official code, data and sample inversions for our Textual Inversion paper.
23 |
24 | ## Updates
25 | **29/08/2022** Merge embeddings now supports SD embeddings. Added SD pivotal tuning code (WIP), fixed training duration, checkpoint save iterations.
26 | **21/08/2022** Code released!
27 |
28 | ## TODO:
29 | - [x] Release code!
30 | - [x] Optimize gradient storing / checkpointing. Memory requirements, training times reduced by ~55%
31 | - [x] Release data sets
32 | - [ ] Release pre-trained embeddings
33 | - [ ] Add Stable Diffusion support
34 |
35 | ## Setup
36 |
37 | Our code builds on, and shares requirements with [Latent Diffusion Models (LDM)](https://github.com/CompVis/latent-diffusion). To set up their environment, please run:
38 |
39 | ```
40 | conda env create -f environment.yaml
41 | conda activate ldm
42 | ```
43 |
44 | You will also need the official LDM text-to-image checkpoint, available through the [LDM project page](https://github.com/CompVis/latent-diffusion).
45 |
46 | Currently, the model can be downloaded by running:
47 |
48 | ```
49 | mkdir -p models/ldm/text2img-large/
50 | wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
51 | ```
52 |
53 | ## Usage
54 |
55 | ### Inversion
56 |
57 | To invert an image set, run:
58 |
59 | ```
60 | python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml
61 | -t
62 | --actual_resume /path/to/pretrained/model.ckpt
63 | -n
64 | --gpus 0,
65 | --data_root /path/to/directory/with/images
66 | --init_word
67 | ```
68 |
69 | where the initialization word should be a single-token rough description of the object (e.g., 'toy', 'painting', 'sculpture'). If the input is comprised of more than a single token, you will be prompted to replace it.
70 |
71 | Please note that `init_word` is *not* the placeholder string that will later represent the concept. It is only used as a beggining point for the optimization scheme.
72 |
73 | In the paper, we use 5k training iterations. However, some concepts (particularly styles) can converge much faster.
74 |
75 | To run on multiple GPUs, provide a comma-delimited list of GPU indices to the --gpus argument (e.g., ``--gpus 0,3,7,8``)
76 |
77 | Embeddings and output images will be saved in the log directory.
78 |
79 | See `configs/latent-diffusion/txt2img-1p4B-finetune.yaml` for more options, such as: changing the placeholder string which denotes the concept (defaults to "*"), changing the maximal number of training iterations, changing how often checkpoints are saved and more.
80 |
81 | **Important** All training set images should be upright. If you are using phone captured images, check the inputs_gs*.jpg files in the output image directory and make sure they are oriented correctly. Many phones capture images with a 90 degree rotation and denote this in the image metadata. Windows parses these correctly, but PIL does not. Hence you will need to correct them manually (e.g. by pasting them into paint and re-saving) or wait until we add metadata parsing.
82 |
83 | ### Generation
84 |
85 | To generate new images of the learned concept, run:
86 | ```
87 | python scripts/txt2img.py --ddim_eta 0.0
88 | --n_samples 8
89 | --n_iter 2
90 | --scale 10.0
91 | --ddim_steps 50
92 | --embedding_path /path/to/logs/trained_model/checkpoints/embeddings_gs-5049.pt
93 | --ckpt_path /path/to/pretrained/model.ckpt
94 | --prompt "a photo of *"
95 | ```
96 |
97 | where * is the placeholder string used during inversion.
98 |
99 | ### Merging Checkpoints
100 |
101 | LDM embedding checkpoints can be merged into a single file by running:
102 |
103 | ```
104 | python merge_embeddings.py
105 | --manager_ckpts /path/to/first/embedding.pt /path/to/second/embedding.pt [...]
106 | --output_path /path/to/output/embedding.pt
107 | ```
108 |
109 | For SD embeddings, simply add the flag: `-sd` or `--stable_diffusion`.
110 |
111 | If the checkpoints contain conflicting placeholder strings, you will be prompted to select new placeholders. The merged checkpoint can later be used to prompt multiple concepts at once ("A photo of * in the style of @").
112 |
113 | ### Pretrained Models / Data
114 |
115 | Datasets which appear in the paper are being uploaded [here](https://drive.google.com/drive/folders/1d2UXkX0GWM-4qUwThjNhFIPP7S6WUbQJ). Some sets are unavailable due to image ownership. We will upload more as we recieve permissions to do so.
116 |
117 | Pretained models coming soon.
118 |
119 | ## Stable Diffusion
120 |
121 | Stable Diffusion support is a work in progress and will be completed soon™.
122 |
123 | ## Tips and Tricks
124 | - Adding "a photo of" to the prompt usually results in better target consistency.
125 | - Results can be seed sensititve. If you're unsatisfied with the model, try re-inverting with a new seed (by adding `--seed <#>` to the prompt).
126 |
127 |
128 | ## Citation
129 |
130 | If you make use of our work, please cite our paper:
131 |
132 | ```
133 | @misc{gal2022textual,
134 | doi = {10.48550/ARXIV.2208.01618},
135 | url = {https://arxiv.org/abs/2208.01618},
136 | author = {Gal, Rinon and Alaluf, Yuval and Atzmon, Yuval and Patashnik, Or and Bermano, Amit H. and Chechik, Gal and Cohen-Or, Daniel},
137 | title = {An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion},
138 | publisher = {arXiv},
139 | year = {2022},
140 | primaryClass={cs.CV}
141 | }
142 | ```
143 |
144 | ## Results
145 | Here are some sample results. Please visit our [project page](https://textual-inversion.github.io/) or read our paper for more!
146 |
147 | 
148 |
149 | 
150 |
151 | 
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | personalization_config:
21 | target: ldm.modules.embedding_manager.EmbeddingManager
22 | params:
23 | placeholder_strings: ["*"]
24 | initializer_words: []
25 |
26 | unet_config:
27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
28 | params:
29 | image_size: 32
30 | in_channels: 4
31 | out_channels: 4
32 | model_channels: 320
33 | attention_resolutions:
34 | - 4
35 | - 2
36 | - 1
37 | num_res_blocks: 2
38 | channel_mult:
39 | - 1
40 | - 2
41 | - 4
42 | - 4
43 | num_heads: 8
44 | use_spatial_transformer: true
45 | transformer_depth: 1
46 | context_dim: 1280
47 | use_checkpoint: true
48 | legacy: False
49 |
50 | first_stage_config:
51 | target: ldm.models.autoencoder.AutoencoderKL
52 | params:
53 | embed_dim: 4
54 | monitor: val/rec_loss
55 | ddconfig:
56 | double_z: true
57 | z_channels: 4
58 | resolution: 256
59 | in_channels: 3
60 | out_ch: 3
61 | ch: 128
62 | ch_mult:
63 | - 1
64 | - 2
65 | - 4
66 | - 4
67 | num_res_blocks: 2
68 | attn_resolutions: []
69 | dropout: 0.0
70 | lossconfig:
71 | target: torch.nn.Identity
72 |
73 | cond_stage_config:
74 | target: ldm.modules.encoders.modules.BERTEmbedder
75 | params:
76 | n_embed: 1280
77 | n_layer: 32
78 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-finetune.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-3
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | embedding_reg_weight: 0.0
20 |
21 | personalization_config:
22 | target: ldm.modules.embedding_manager.EmbeddingManager
23 | params:
24 | placeholder_strings: ["*"]
25 | initializer_words: ["sculpture"]
26 | per_image_tokens: false
27 | num_vectors_per_token: 1
28 | progressive_words: False
29 |
30 | unet_config:
31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32 | params:
33 | image_size: 32
34 | in_channels: 4
35 | out_channels: 4
36 | model_channels: 320
37 | attention_resolutions:
38 | - 4
39 | - 2
40 | - 1
41 | num_res_blocks: 2
42 | channel_mult:
43 | - 1
44 | - 2
45 | - 4
46 | - 4
47 | num_heads: 8
48 | use_spatial_transformer: true
49 | transformer_depth: 1
50 | context_dim: 1280
51 | use_checkpoint: true
52 | legacy: False
53 |
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config:
78 | target: ldm.modules.encoders.modules.BERTEmbedder
79 | params:
80 | n_embed: 1280
81 | n_layer: 32
82 |
83 |
84 | data:
85 | target: main.DataModuleFromConfig
86 | params:
87 | batch_size: 4
88 | num_workers: 2
89 | wrap: false
90 | train:
91 | target: ldm.data.personalized.PersonalizedBase
92 | params:
93 | size: 256
94 | set: train
95 | per_image_tokens: false
96 | repeats: 100
97 | validation:
98 | target: ldm.data.personalized.PersonalizedBase
99 | params:
100 | size: 256
101 | set: val
102 | per_image_tokens: false
103 | repeats: 10
104 |
105 | lightning:
106 | modelcheckpoint:
107 | params:
108 | every_n_train_steps: 500
109 | callbacks:
110 | image_logger:
111 | target: main.ImageLogger
112 | params:
113 | batch_frequency: 500
114 | max_images: 8
115 | increase_log_steps: False
116 |
117 | trainer:
118 | benchmark: True
119 | max_steps: 6100
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-3
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | embedding_reg_weight: 0.0
20 |
21 | personalization_config:
22 | target: ldm.modules.embedding_manager.EmbeddingManager
23 | params:
24 | placeholder_strings: ["*"]
25 | initializer_words: ["painting"]
26 | per_image_tokens: false
27 | num_vectors_per_token: 1
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions:
37 | - 4
38 | - 2
39 | - 1
40 | num_res_blocks: 2
41 | channel_mult:
42 | - 1
43 | - 2
44 | - 4
45 | - 4
46 | num_heads: 8
47 | use_spatial_transformer: true
48 | transformer_depth: 1
49 | context_dim: 1280
50 | use_checkpoint: true
51 | legacy: False
52 |
53 | first_stage_config:
54 | target: ldm.models.autoencoder.AutoencoderKL
55 | params:
56 | embed_dim: 4
57 | monitor: val/rec_loss
58 | ddconfig:
59 | double_z: true
60 | z_channels: 4
61 | resolution: 256
62 | in_channels: 3
63 | out_ch: 3
64 | ch: 128
65 | ch_mult:
66 | - 1
67 | - 2
68 | - 4
69 | - 4
70 | num_res_blocks: 2
71 | attn_resolutions: []
72 | dropout: 0.0
73 | lossconfig:
74 | target: torch.nn.Identity
75 |
76 | cond_stage_config:
77 | target: ldm.modules.encoders.modules.BERTEmbedder
78 | params:
79 | n_embed: 1280
80 | n_layer: 32
81 |
82 |
83 | data:
84 | target: main.DataModuleFromConfig
85 | params:
86 | batch_size: 4
87 | num_workers: 4
88 | wrap: false
89 | train:
90 | target: ldm.data.personalized_style.PersonalizedBase
91 | params:
92 | size: 256
93 | set: train
94 | per_image_tokens: false
95 | repeats: 100
96 | validation:
97 | target: ldm.data.personalized_style.PersonalizedBase
98 | params:
99 | size: 256
100 | set: val
101 | per_image_tokens: false
102 | repeats: 10
103 |
104 | lightning:
105 | modelcheckpoint:
106 | params:
107 | every_n_train_steps: 500
108 | callbacks:
109 | image_logger:
110 | target: main.ImageLogger
111 | params:
112 | batch_frequency: 500
113 | max_images: 8
114 | increase_log_steps: False
115 |
116 | trainer:
117 | benchmark: True
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-finetune.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-03
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: true # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | embedding_reg_weight: 0.0
20 | unfreeze_model: False
21 | model_lr: 0.0
22 |
23 | personalization_config:
24 | target: ldm.modules.embedding_manager.EmbeddingManager
25 | params:
26 | placeholder_strings: ["*"]
27 | initializer_words: ["sculpture"]
28 | per_image_tokens: false
29 | num_vectors_per_token: 1
30 | progressive_words: False
31 |
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32 # unused
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 320
39 | attention_resolutions: [ 4, 2, 1 ]
40 | num_res_blocks: 2
41 | channel_mult: [ 1, 2, 4, 4 ]
42 | num_heads: 8
43 | use_spatial_transformer: True
44 | transformer_depth: 1
45 | context_dim: 768
46 | use_checkpoint: True
47 | legacy: False
48 |
49 | first_stage_config:
50 | target: ldm.models.autoencoder.AutoencoderKL
51 | params:
52 | embed_dim: 4
53 | monitor: val/rec_loss
54 | ddconfig:
55 | double_z: true
56 | z_channels: 4
57 | resolution: 512
58 | in_channels: 3
59 | out_ch: 3
60 | ch: 128
61 | ch_mult:
62 | - 1
63 | - 2
64 | - 4
65 | - 4
66 | num_res_blocks: 2
67 | attn_resolutions: []
68 | dropout: 0.0
69 | lossconfig:
70 | target: torch.nn.Identity
71 |
72 | cond_stage_config:
73 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
74 |
75 | data:
76 | target: main.DataModuleFromConfig
77 | params:
78 | batch_size: 2
79 | num_workers: 2
80 | wrap: false
81 | train:
82 | target: ldm.data.personalized.PersonalizedBase
83 | params:
84 | size: 512
85 | set: train
86 | per_image_tokens: false
87 | repeats: 100
88 | validation:
89 | target: ldm.data.personalized.PersonalizedBase
90 | params:
91 | size: 512
92 | set: val
93 | per_image_tokens: false
94 | repeats: 10
95 |
96 | lightning:
97 | modelcheckpoint:
98 | params:
99 | every_n_train_steps: 500
100 | callbacks:
101 | image_logger:
102 | target: main.ImageLogger
103 | params:
104 | batch_frequency: 500
105 | max_images: 8
106 | increase_log_steps: False
107 |
108 | trainer:
109 | benchmark: True
110 | max_steps: 6100
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-finetune_unfrozen.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: true # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 | embedding_reg_weight: 0.0
20 | unfreeze_model: True
21 | model_lr: 1.0e-7
22 |
23 | personalization_config:
24 | target: ldm.modules.embedding_manager.EmbeddingManager
25 | params:
26 | placeholder_strings: ["*"]
27 | initializer_words: ["sculpture"]
28 | per_image_tokens: false
29 | num_vectors_per_token: 1
30 | progressive_words: False
31 |
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32 # unused
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 320
39 | attention_resolutions: [ 4, 2, 1 ]
40 | num_res_blocks: 2
41 | channel_mult: [ 1, 2, 4, 4 ]
42 | num_heads: 8
43 | use_spatial_transformer: True
44 | transformer_depth: 1
45 | context_dim: 768
46 | use_checkpoint: True
47 | legacy: False
48 |
49 | first_stage_config:
50 | target: ldm.models.autoencoder.AutoencoderKL
51 | params:
52 | embed_dim: 4
53 | monitor: val/rec_loss
54 | ddconfig:
55 | double_z: true
56 | z_channels: 4
57 | resolution: 512
58 | in_channels: 3
59 | out_ch: 3
60 | ch: 128
61 | ch_mult:
62 | - 1
63 | - 2
64 | - 4
65 | - 4
66 | num_res_blocks: 2
67 | attn_resolutions: []
68 | dropout: 0.0
69 | lossconfig:
70 | target: torch.nn.Identity
71 |
72 | cond_stage_config:
73 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
74 |
75 | data:
76 | target: main.DataModuleFromConfig
77 | params:
78 | batch_size: 1
79 | num_workers: 2
80 | wrap: false
81 | train:
82 | target: ldm.data.personalized.PersonalizedBase
83 | params:
84 | size: 512
85 | set: train
86 | per_image_tokens: false
87 | repeats: 100
88 | validation:
89 | target: ldm.data.personalized.PersonalizedBase
90 | params:
91 | size: 512
92 | set: val
93 | per_image_tokens: false
94 | repeats: 10
95 |
96 | lightning:
97 | modelcheckpoint:
98 | params:
99 | every_n_train_steps: 500
100 | callbacks:
101 | image_logger:
102 | target: main.ImageLogger
103 | params:
104 | batch_frequency: 500
105 | max_images: 8
106 | increase_log_steps: False
107 |
108 | trainer:
109 | benchmark: True
110 | max_steps: 4200
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | personalization_config:
21 | target: ldm.modules.embedding_manager.EmbeddingManager
22 | params:
23 | placeholder_strings: ["*"]
24 | initializer_words: ["sculpture"]
25 | per_image_tokens: false
26 | num_vectors_per_token: 1
27 | progressive_words: False
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.10
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.10.2
10 | - torchvision=0.11.3
11 | - numpy=1.22.3
12 | - pip:
13 | - albumentations==1.1.0
14 | - opencv-python==4.2.0.34
15 | - pudb==2019.2
16 | - imageio==2.14.1
17 | - imageio-ffmpeg==0.4.7
18 | - pytorch-lightning==1.5.9
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - setuptools==59.5.0
23 | - pillow==9.0.1
24 | - einops==0.4.1
25 | - torch-fidelity==0.3.0
26 | - transformers==4.18.0
27 | - torchmetrics==0.6.0
28 | - kornia==0.6
29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31 | - -e .
32 |
--------------------------------------------------------------------------------
/evaluation/__pycache__/clip_eval.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/evaluation/__pycache__/clip_eval.cpython-36.pyc
--------------------------------------------------------------------------------
/evaluation/__pycache__/clip_eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/evaluation/__pycache__/clip_eval.cpython-38.pyc
--------------------------------------------------------------------------------
/evaluation/clip_eval.py:
--------------------------------------------------------------------------------
1 | import clip
2 | import torch
3 | from torchvision import transforms
4 |
5 | from ldm.models.diffusion.ddim import DDIMSampler
6 |
7 | class CLIPEvaluator(object):
8 | def __init__(self, device, clip_model='ViT-B/32') -> None:
9 | self.device = device
10 | self.model, clip_preprocess = clip.load(clip_model, device=self.device)
11 |
12 | self.clip_preprocess = clip_preprocess
13 |
14 | self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1].
15 | clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
16 | clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
17 |
18 | def tokenize(self, strings: list):
19 | return clip.tokenize(strings).to(self.device)
20 |
21 | @torch.no_grad()
22 | def encode_text(self, tokens: list) -> torch.Tensor:
23 | return self.model.encode_text(tokens)
24 |
25 | @torch.no_grad()
26 | def encode_images(self, images: torch.Tensor) -> torch.Tensor:
27 | images = self.preprocess(images).to(self.device)
28 | return self.model.encode_image(images)
29 |
30 | def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
31 |
32 | tokens = clip.tokenize(text).to(self.device)
33 |
34 | text_features = self.encode_text(tokens).detach()
35 |
36 | if norm:
37 | text_features /= text_features.norm(dim=-1, keepdim=True)
38 |
39 | return text_features
40 |
41 | def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
42 | image_features = self.encode_images(img)
43 |
44 | if norm:
45 | image_features /= image_features.clone().norm(dim=-1, keepdim=True)
46 |
47 | return image_features
48 |
49 | def img_to_img_similarity(self, src_images, generated_images):
50 | src_img_features = self.get_image_features(src_images)
51 | gen_img_features = self.get_image_features(generated_images)
52 |
53 | return (src_img_features @ gen_img_features.T).mean()
54 |
55 | def txt_to_img_similarity(self, text, generated_images):
56 | text_features = self.get_text_features(text)
57 | gen_img_features = self.get_image_features(generated_images)
58 |
59 | return (text_features @ gen_img_features.T).mean()
60 |
61 |
62 | class LDMCLIPEvaluator(CLIPEvaluator):
63 | def __init__(self, device, clip_model='ViT-B/32') -> None:
64 | super().__init__(device, clip_model)
65 |
66 | def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50):
67 |
68 | sampler = DDIMSampler(ldm_model)
69 |
70 | samples_per_batch = 8
71 | n_batches = n_samples // samples_per_batch
72 |
73 | # generate samples
74 | all_samples=list()
75 | with torch.no_grad():
76 | with ldm_model.ema_scope():
77 | uc = ldm_model.get_learned_conditioning(samples_per_batch * [""])
78 |
79 | for batch in range(n_batches):
80 | c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text])
81 | shape = [4, 256//8, 256//8]
82 | samples_ddim, _ = sampler.sample(S=n_steps,
83 | conditioning=c,
84 | batch_size=samples_per_batch,
85 | shape=shape,
86 | verbose=False,
87 | unconditional_guidance_scale=5.0,
88 | unconditional_conditioning=uc,
89 | eta=0.0)
90 |
91 | x_samples_ddim = ldm_model.decode_first_stage(samples_ddim)
92 | x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
93 |
94 | all_samples.append(x_samples_ddim)
95 |
96 | all_samples = torch.cat(all_samples, axis=0)
97 |
98 | sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples)
99 | sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples)
100 |
101 | return sim_samples_to_img, sim_samples_to_text
102 |
103 |
104 | class ImageDirEvaluator(CLIPEvaluator):
105 | def __init__(self, device, clip_model='ViT-B/32') -> None:
106 | super().__init__(device, clip_model)
107 |
108 | def evaluate(self, gen_samples, src_images, target_text):
109 |
110 | sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples)
111 | sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples)
112 |
113 | return sim_samples_to_img, sim_samples_to_text
--------------------------------------------------------------------------------
/img/samples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/img/samples.jpg
--------------------------------------------------------------------------------
/img/style.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/img/style.jpg
--------------------------------------------------------------------------------
/img/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/img/teaser.jpg
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/base.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/base.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/personalized.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/personalized.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/personalized_compose.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_compose.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/personalized_style.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_style.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/data/__pycache__/personalized_style.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/data/__pycache__/personalized_style.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/data/personalized.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 | import random
9 |
10 | imagenet_templates_smallest = [
11 | 'a photo of a {}',
12 | ]
13 |
14 | imagenet_templates_small = [
15 | 'a photo of a {}',
16 | 'a rendering of a {}',
17 | 'a cropped photo of the {}',
18 | 'the photo of a {}',
19 | 'a photo of a clean {}',
20 | 'a photo of a dirty {}',
21 | 'a dark photo of the {}',
22 | 'a photo of my {}',
23 | 'a photo of the cool {}',
24 | 'a close-up photo of a {}',
25 | 'a bright photo of the {}',
26 | 'a cropped photo of a {}',
27 | 'a photo of the {}',
28 | 'a good photo of the {}',
29 | 'a photo of one {}',
30 | 'a close-up photo of the {}',
31 | 'a rendition of the {}',
32 | 'a photo of the clean {}',
33 | 'a rendition of a {}',
34 | 'a photo of a nice {}',
35 | 'a good photo of a {}',
36 | 'a photo of the nice {}',
37 | 'a photo of the small {}',
38 | 'a photo of the weird {}',
39 | 'a photo of the large {}',
40 | 'a photo of a cool {}',
41 | 'a photo of a small {}',
42 | 'an illustration of a {}',
43 | 'a rendering of a {}',
44 | 'a cropped photo of the {}',
45 | 'the photo of a {}',
46 | 'an illustration of a clean {}',
47 | 'an illustration of a dirty {}',
48 | 'a dark photo of the {}',
49 | 'an illustration of my {}',
50 | 'an illustration of the cool {}',
51 | 'a close-up photo of a {}',
52 | 'a bright photo of the {}',
53 | 'a cropped photo of a {}',
54 | 'an illustration of the {}',
55 | 'a good photo of the {}',
56 | 'an illustration of one {}',
57 | 'a close-up photo of the {}',
58 | 'a rendition of the {}',
59 | 'an illustration of the clean {}',
60 | 'a rendition of a {}',
61 | 'an illustration of a nice {}',
62 | 'a good photo of a {}',
63 | 'an illustration of the nice {}',
64 | 'an illustration of the small {}',
65 | 'an illustration of the weird {}',
66 | 'an illustration of the large {}',
67 | 'an illustration of a cool {}',
68 | 'an illustration of a small {}',
69 | 'a depiction of a {}',
70 | 'a rendering of a {}',
71 | 'a cropped photo of the {}',
72 | 'the photo of a {}',
73 | 'a depiction of a clean {}',
74 | 'a depiction of a dirty {}',
75 | 'a dark photo of the {}',
76 | 'a depiction of my {}',
77 | 'a depiction of the cool {}',
78 | 'a close-up photo of a {}',
79 | 'a bright photo of the {}',
80 | 'a cropped photo of a {}',
81 | 'a depiction of the {}',
82 | 'a good photo of the {}',
83 | 'a depiction of one {}',
84 | 'a close-up photo of the {}',
85 | 'a rendition of the {}',
86 | 'a depiction of the clean {}',
87 | 'a rendition of a {}',
88 | 'a depiction of a nice {}',
89 | 'a good photo of a {}',
90 | 'a depiction of the nice {}',
91 | 'a depiction of the small {}',
92 | 'a depiction of the weird {}',
93 | 'a depiction of the large {}',
94 | 'a depiction of a cool {}',
95 | 'a depiction of a small {}',
96 | ]
97 |
98 | imagenet_dual_templates_small = [
99 | 'a photo of a {} with {}',
100 | 'a rendering of a {} with {}',
101 | 'a cropped photo of the {} with {}',
102 | 'the photo of a {} with {}',
103 | 'a photo of a clean {} with {}',
104 | 'a photo of a dirty {} with {}',
105 | 'a dark photo of the {} with {}',
106 | 'a photo of my {} with {}',
107 | 'a photo of the cool {} with {}',
108 | 'a close-up photo of a {} with {}',
109 | 'a bright photo of the {} with {}',
110 | 'a cropped photo of a {} with {}',
111 | 'a photo of the {} with {}',
112 | 'a good photo of the {} with {}',
113 | 'a photo of one {} with {}',
114 | 'a close-up photo of the {} with {}',
115 | 'a rendition of the {} with {}',
116 | 'a photo of the clean {} with {}',
117 | 'a rendition of a {} with {}',
118 | 'a photo of a nice {} with {}',
119 | 'a good photo of a {} with {}',
120 | 'a photo of the nice {} with {}',
121 | 'a photo of the small {} with {}',
122 | 'a photo of the weird {} with {}',
123 | 'a photo of the large {} with {}',
124 | 'a photo of a cool {} with {}',
125 | 'a photo of a small {} with {}',
126 | ]
127 |
128 | per_img_token_list = [
129 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
130 | ]
131 |
132 | class PersonalizedBase(Dataset):
133 | def __init__(self,
134 | data_root,
135 | size=None,
136 | repeats=100,
137 | interpolation="bicubic",
138 | flip_p=0.5,
139 | set="train",
140 | placeholder_token="*",
141 | per_image_tokens=False,
142 | center_crop=False,
143 | mixing_prob=0.25,
144 | coarse_class_text=None,
145 | ):
146 |
147 | self.data_root = data_root
148 |
149 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
150 |
151 | # self._length = len(self.image_paths)
152 | self.num_images = len(self.image_paths)
153 | self._length = self.num_images
154 |
155 | self.placeholder_token = placeholder_token
156 |
157 | self.per_image_tokens = per_image_tokens
158 | self.center_crop = center_crop
159 | self.mixing_prob = mixing_prob
160 |
161 | self.coarse_class_text = coarse_class_text
162 |
163 | if per_image_tokens:
164 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
165 |
166 | if set == "train":
167 | self._length = self.num_images * repeats
168 |
169 | self.size = size
170 | self.interpolation = {"linear": PIL.Image.LINEAR,
171 | "bilinear": PIL.Image.BILINEAR,
172 | "bicubic": PIL.Image.BICUBIC,
173 | "lanczos": PIL.Image.LANCZOS,
174 | }[interpolation]
175 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
176 |
177 | def __len__(self):
178 | return self._length
179 |
180 | def __getitem__(self, i):
181 | example = {}
182 | image = Image.open(self.image_paths[i % self.num_images])
183 |
184 | if not image.mode == "RGB":
185 | image = image.convert("RGB")
186 |
187 | placeholder_string = self.placeholder_token
188 | if self.coarse_class_text:
189 | placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
190 |
191 | if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
192 | text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images])
193 | else:
194 | text = random.choice(imagenet_templates_small).format(placeholder_string)
195 |
196 | example["caption"] = text
197 |
198 | # default to score-sde preprocessing
199 | img = np.array(image).astype(np.uint8)
200 |
201 | if self.center_crop:
202 | crop = min(img.shape[0], img.shape[1])
203 | h, w, = img.shape[0], img.shape[1]
204 | img = img[(h - crop) // 2:(h + crop) // 2,
205 | (w - crop) // 2:(w + crop) // 2]
206 |
207 | image = Image.fromarray(img)
208 | if self.size is not None:
209 | image = image.resize((self.size, self.size), resample=self.interpolation)
210 |
211 | image = self.flip(image)
212 | image = np.array(image).astype(np.uint8)
213 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
214 | return example
--------------------------------------------------------------------------------
/ldm/data/personalized_style.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 | import random
9 |
10 | imagenet_templates_small = [
11 | 'a painting in the style of {}',
12 | 'a rendering in the style of {}',
13 | 'a cropped painting in the style of {}',
14 | 'the painting in the style of {}',
15 | 'a clean painting in the style of {}',
16 | 'a dirty painting in the style of {}',
17 | 'a dark painting in the style of {}',
18 | 'a picture in the style of {}',
19 | 'a cool painting in the style of {}',
20 | 'a close-up painting in the style of {}',
21 | 'a bright painting in the style of {}',
22 | 'a cropped painting in the style of {}',
23 | 'a good painting in the style of {}',
24 | 'a close-up painting in the style of {}',
25 | 'a rendition in the style of {}',
26 | 'a nice painting in the style of {}',
27 | 'a small painting in the style of {}',
28 | 'a weird painting in the style of {}',
29 | 'a large painting in the style of {}',
30 | ]
31 |
32 | imagenet_dual_templates_small = [
33 | 'a painting in the style of {} with {}',
34 | 'a rendering in the style of {} with {}',
35 | 'a cropped painting in the style of {} with {}',
36 | 'the painting in the style of {} with {}',
37 | 'a clean painting in the style of {} with {}',
38 | 'a dirty painting in the style of {} with {}',
39 | 'a dark painting in the style of {} with {}',
40 | 'a cool painting in the style of {} with {}',
41 | 'a close-up painting in the style of {} with {}',
42 | 'a bright painting in the style of {} with {}',
43 | 'a cropped painting in the style of {} with {}',
44 | 'a good painting in the style of {} with {}',
45 | 'a painting of one {} in the style of {}',
46 | 'a nice painting in the style of {} with {}',
47 | 'a small painting in the style of {} with {}',
48 | 'a weird painting in the style of {} with {}',
49 | 'a large painting in the style of {} with {}',
50 | ]
51 |
52 | per_img_token_list = [
53 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
54 | ]
55 |
56 | class PersonalizedBase(Dataset):
57 | def __init__(self,
58 | data_root,
59 | size=None,
60 | repeats=100,
61 | interpolation="bicubic",
62 | flip_p=0.5,
63 | set="train",
64 | placeholder_token="*",
65 | per_image_tokens=False,
66 | center_crop=False,
67 | ):
68 |
69 | self.data_root = data_root
70 |
71 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
72 |
73 | # self._length = len(self.image_paths)
74 | self.num_images = len(self.image_paths)
75 | self._length = self.num_images
76 |
77 | self.placeholder_token = placeholder_token
78 |
79 | self.per_image_tokens = per_image_tokens
80 | self.center_crop = center_crop
81 |
82 | if per_image_tokens:
83 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
84 |
85 | if set == "train":
86 | self._length = self.num_images * repeats
87 |
88 | self.size = size
89 | self.interpolation = {"linear": PIL.Image.LINEAR,
90 | "bilinear": PIL.Image.BILINEAR,
91 | "bicubic": PIL.Image.BICUBIC,
92 | "lanczos": PIL.Image.LANCZOS,
93 | }[interpolation]
94 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
95 |
96 | def __len__(self):
97 | return self._length
98 |
99 | def __getitem__(self, i):
100 | example = {}
101 | image = Image.open(self.image_paths[i % self.num_images])
102 |
103 | if not image.mode == "RGB":
104 | image = image.convert("RGB")
105 |
106 | if self.per_image_tokens and np.random.uniform() < 0.25:
107 | text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
108 | else:
109 | text = random.choice(imagenet_templates_small).format(self.placeholder_token)
110 |
111 | example["caption"] = text
112 |
113 | # default to score-sde preprocessing
114 | img = np.array(image).astype(np.uint8)
115 |
116 | if self.center_crop:
117 | crop = min(img.shape[0], img.shape[1])
118 | h, w, = img.shape[0], img.shape[1]
119 | img = img[(h - crop) // 2:(h + crop) // 2,
120 | (w - crop) // 2:(w + crop) // 2]
121 |
122 | image = Image.fromarray(img)
123 | if self.size is not None:
124 | image = image.resize((self.size, self.size), resample=self.interpolation)
125 |
126 | image = self.flip(image)
127 | image = np.array(image).astype(np.uint8)
128 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
129 | return example
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/__pycache__/autoencoder.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddim_inversion.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/ddpm_pti.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/plms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/plms.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/attention.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/ema.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/embedding_manager.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/embedding_manager.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/embedding_manager.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/embedding_manager.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/x_transformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/x_transformer.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/x_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/__pycache__/x_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if False: # disabled checkpointing to allow requires_grad = False for main model
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/distributions.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/embedding_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ldm.data.personalized import per_img_token_list
5 | from transformers import CLIPTokenizer
6 | from functools import partial
7 |
8 | DEFAULT_PLACEHOLDER_TOKEN = ["*"]
9 |
10 | PROGRESSIVE_SCALE = 2000
11 |
12 | def get_clip_token_for_string(tokenizer, string):
13 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
14 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
15 | tokens = batch_encoding["input_ids"]
16 | assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
17 |
18 | return tokens[0, 1]
19 |
20 | def get_bert_token_for_string(tokenizer, string):
21 | token = tokenizer(string)
22 | assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
23 |
24 | token = token[0, 1]
25 |
26 | return token
27 |
28 | def get_embedding_for_clip_token(embedder, token):
29 | return embedder(token.unsqueeze(0))[0, 0]
30 |
31 |
32 | class EmbeddingManager(nn.Module):
33 | def __init__(
34 | self,
35 | embedder,
36 | placeholder_strings=None,
37 | initializer_words=None,
38 | per_image_tokens=False,
39 | num_vectors_per_token=1,
40 | progressive_words=False,
41 | **kwargs
42 | ):
43 | super().__init__()
44 |
45 | self.string_to_token_dict = {}
46 |
47 | self.string_to_param_dict = nn.ParameterDict()
48 |
49 | self.initial_embeddings = nn.ParameterDict() # These should not be optimized
50 |
51 | self.progressive_words = progressive_words
52 | self.progressive_counter = 0
53 |
54 | self.max_vectors_per_token = num_vectors_per_token
55 |
56 | if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
57 | self.is_clip = True
58 | get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
59 | get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
60 | token_dim = 768
61 | else: # using LDM's BERT encoder
62 | self.is_clip = False
63 | get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
64 | get_embedding_for_tkn = embedder.transformer.token_emb
65 | token_dim = 1280
66 |
67 | if per_image_tokens:
68 | placeholder_strings.extend(per_img_token_list)
69 |
70 | for idx, placeholder_string in enumerate(placeholder_strings):
71 |
72 | token = get_token_for_string(placeholder_string)
73 |
74 | if initializer_words and idx < len(initializer_words):
75 | init_word_token = get_token_for_string(initializer_words[idx])
76 |
77 | with torch.no_grad():
78 | init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
79 |
80 | token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
81 | self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
82 | else:
83 | token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
84 |
85 | self.string_to_token_dict[placeholder_string] = token
86 | self.string_to_param_dict[placeholder_string] = token_params
87 |
88 | def forward(
89 | self,
90 | tokenized_text,
91 | embedded_text,
92 | ):
93 | b, n, device = *tokenized_text.shape, tokenized_text.device
94 |
95 | for placeholder_string, placeholder_token in self.string_to_token_dict.items():
96 |
97 | placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
98 |
99 | if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
100 | placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
101 | embedded_text[placeholder_idx] = placeholder_embedding
102 | else: # otherwise, need to insert and keep track of changing indices
103 | if self.progressive_words:
104 | self.progressive_counter += 1
105 | max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
106 | else:
107 | max_step_tokens = self.max_vectors_per_token
108 |
109 | num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
110 |
111 | placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
112 |
113 | if placeholder_rows.nelement() == 0:
114 | continue
115 |
116 | sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
117 | sorted_rows = placeholder_rows[sort_idx]
118 |
119 | for idx in range(len(sorted_rows)):
120 | row = sorted_rows[idx]
121 | col = sorted_cols[idx]
122 |
123 | new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
124 | new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
125 |
126 | embedded_text[row] = new_embed_row
127 | tokenized_text[row] = new_token_row
128 |
129 | return embedded_text
130 |
131 | def save(self, ckpt_path):
132 | torch.save({"string_to_token": self.string_to_token_dict,
133 | "string_to_param": self.string_to_param_dict}, ckpt_path)
134 |
135 | def load(self, ckpt_path):
136 | ckpt = torch.load(ckpt_path, map_location='cpu')
137 |
138 | self.string_to_token_dict = ckpt["string_to_token"]
139 | self.string_to_param_dict = ckpt["string_to_param"]
140 |
141 | def get_embedding_norms_squared(self):
142 | all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
143 | param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
144 |
145 | return param_norm_squared
146 |
147 | def embedding_parameters(self):
148 | return self.string_to_param_dict.parameters()
149 |
150 | def embedding_to_coarse_loss(self):
151 |
152 | loss = 0.
153 | num_embeddings = len(self.initial_embeddings)
154 |
155 | for key in self.initial_embeddings:
156 | optimized = self.string_to_param_dict[key]
157 | coarse = self.initial_embeddings[key].clone().to(optimized.device)
158 |
159 | loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
160 |
161 | return loss
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/modules.cpython-36.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rinongal/textual_inversion/424192de1518a358f1b648e0e781fdbe3f40c210/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.load_default()
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config, **kwargs):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/merge_embeddings.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
2 | from ldm.modules.embedding_manager import EmbeddingManager
3 |
4 | import argparse, os
5 | from functools import partial
6 |
7 | import torch
8 |
9 | def get_placeholder_loop(placeholder_string, embedder, is_sd):
10 |
11 | new_placeholder = None
12 |
13 | while True:
14 | if new_placeholder is None:
15 | new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
16 | else:
17 | new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
18 |
19 | token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
20 |
21 | if token is not None:
22 | return new_placeholder, token
23 |
24 | def get_clip_token_for_string(tokenizer, string):
25 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
26 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
27 | tokens = batch_encoding["input_ids"]
28 |
29 | if torch.count_nonzero(tokens - 49407) == 2:
30 | return tokens[0, 1]
31 |
32 | return None
33 |
34 | def get_bert_token_for_string(tokenizer, string):
35 | token = tokenizer(string)
36 | if torch.count_nonzero(token) == 3:
37 | return token[0, 1]
38 |
39 | return None
40 |
41 |
42 | if __name__ == "__main__":
43 |
44 | parser = argparse.ArgumentParser()
45 |
46 | parser.add_argument(
47 | "--manager_ckpts",
48 | type=str,
49 | nargs="+",
50 | required=True,
51 | help="Paths to a set of embedding managers to be merged."
52 | )
53 |
54 | parser.add_argument(
55 | "--output_path",
56 | type=str,
57 | required=True,
58 | help="Output path for the merged manager",
59 | )
60 |
61 | parser.add_argument(
62 | "-sd", "--stable_diffusion",
63 | action="store_true",
64 | help="Flag to denote that we are merging stable diffusion embeddings"
65 | )
66 |
67 | args = parser.parse_args()
68 |
69 | if args.stable_diffusion:
70 | embedder = FrozenCLIPEmbedder().cuda()
71 | else:
72 | embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
73 |
74 | EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
75 |
76 | string_to_token_dict = {}
77 | string_to_param_dict = torch.nn.ParameterDict()
78 |
79 | placeholder_to_src = {}
80 |
81 | for manager_ckpt in args.manager_ckpts:
82 | print(f"Parsing {manager_ckpt}...")
83 |
84 | manager = EmbeddingManager()
85 | manager.load(manager_ckpt)
86 |
87 | for placeholder_string in manager.string_to_token_dict:
88 | if not placeholder_string in string_to_token_dict:
89 | string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
90 | string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
91 |
92 | placeholder_to_src[placeholder_string] = manager_ckpt
93 | else:
94 | new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion)
95 | string_to_token_dict[new_placeholder] = new_token
96 | string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
97 |
98 | placeholder_to_src[new_placeholder] = manager_ckpt
99 |
100 | print("Saving combined manager...")
101 | merged_manager = EmbeddingManager()
102 | merged_manager.string_to_param_dict = string_to_param_dict
103 | merged_manager.string_to_token_dict = string_to_token_dict
104 | merged_manager.save(args.output_path)
105 |
106 | print("Managers merged. Final list of placeholders: ")
107 | print(placeholder_to_src)
108 |
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/models/ldm/ffhq256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 42
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/models/ldm/lsun_beds256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.lsun.LSUNBedroomsTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.lsun.LSUNBedroomsValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11 |
12 |
13 |
14 | cd models/first_stage_models/kl-f4
15 | unzip -o model.zip
16 |
17 | cd ../kl-f8
18 | unzip -o model.zip
19 |
20 | cd ../kl-f16
21 | unzip -o model.zip
22 |
23 | cd ../kl-f32
24 | unzip -o model.zip
25 |
26 | cd ../vq-f4
27 | unzip -o model.zip
28 |
29 | cd ../vq-f4-noattn
30 | unzip -o model.zip
31 |
32 | cd ../vq-f8
33 | unzip -o model.zip
34 |
35 | cd ../vq-f8-n256
36 | unzip -o model.zip
37 |
38 | cd ../vq-f16
39 | unzip -o model.zip
40 |
41 | cd ../..
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13 |
14 |
15 |
16 | cd models/ldm/celeba256
17 | unzip -o celeba-256.zip
18 |
19 | cd ../ffhq256
20 | unzip -o ffhq-256.zip
21 |
22 | cd ../lsun_churches256
23 | unzip -o lsun_churches-256.zip
24 |
25 | cd ../lsun_beds256
26 | unzip -o lsun_beds-256.zip
27 |
28 | cd ../text2img256
29 | unzip -o model.zip
30 |
31 | cd ../cin256
32 | unzip -o model.zip
33 |
34 | cd ../semantic_synthesis512
35 | unzip -o model.zip
36 |
37 | cd ../semantic_synthesis256
38 | unzip -o model.zip
39 |
40 | cd ../bsr_sr
41 | unzip -o model.zip
42 |
43 | cd ../layout2img-openimages256
44 | unzip -o model.zip
45 |
46 | cd ../inpainting_big
47 | unzip -o model.zip
48 |
49 | cd ../..
50 |
--------------------------------------------------------------------------------
/scripts/evaluate_model.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 |
3 | sys.path.append(os.path.join(sys.path[0], '..'))
4 |
5 | import torch
6 | import numpy as np
7 | from omegaconf import OmegaConf
8 | from PIL import Image
9 | from tqdm import tqdm, trange
10 | from einops import rearrange
11 | from torchvision.utils import make_grid
12 |
13 | from ldm.util import instantiate_from_config
14 | from ldm.models.diffusion.ddim import DDIMSampler
15 | from ldm.models.diffusion.plms import PLMSSampler
16 | from ldm.data.personalized import PersonalizedBase
17 | from evaluation.clip_eval import LDMCLIPEvaluator
18 |
19 | def load_model_from_config(config, ckpt, verbose=False):
20 | print(f"Loading model from {ckpt}")
21 | pl_sd = torch.load(ckpt, map_location="cpu")
22 | sd = pl_sd["state_dict"]
23 | model = instantiate_from_config(config.model)
24 | m, u = model.load_state_dict(sd, strict=False)
25 | if len(m) > 0 and verbose:
26 | print("missing keys:")
27 | print(m)
28 | if len(u) > 0 and verbose:
29 | print("unexpected keys:")
30 | print(u)
31 |
32 | model.cuda()
33 | model.eval()
34 | return model
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser()
39 |
40 | parser.add_argument(
41 | "--prompt",
42 | type=str,
43 | nargs="?",
44 | default="a painting of a virus monster playing guitar",
45 | help="the prompt to render"
46 | )
47 |
48 | parser.add_argument(
49 | "--ckpt_path",
50 | type=str,
51 | default="/data/pretrained_models/ldm/text2img-large/model.ckpt",
52 | help="Path to pretrained ldm text2img model")
53 |
54 | parser.add_argument(
55 | "--embedding_path",
56 | type=str,
57 | help="Path to a pre-trained embedding manager checkpoint")
58 |
59 | parser.add_argument(
60 | "--data_dir",
61 | type=str,
62 | help="Path to directory with images used to train the embedding vectors"
63 | )
64 |
65 | opt = parser.parse_args()
66 |
67 |
68 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
69 | model = load_model_from_config(config, opt.ckpt_path) # TODO: check path
70 | model.embedding_manager.load(opt.embedding_path)
71 |
72 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
73 | model = model.to(device)
74 |
75 | evaluator = LDMCLIPEvaluator(device)
76 |
77 | prompt = opt.prompt
78 |
79 | data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0)
80 |
81 | images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)]
82 | images = torch.stack(images, axis=0)
83 |
84 | sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt)
85 |
86 | output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-"))
87 |
88 | print("Image similarity: ", sim_img)
89 | print("Text similarity: ", sim_text)
--------------------------------------------------------------------------------
/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | opt = parser.parse_args()
54 |
55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
56 | images = [x.replace("_mask.png", ".png") for x in masks]
57 | print(f"Found {len(masks)} inputs.")
58 |
59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
60 | model = instantiate_from_config(config.model)
61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
62 | strict=False)
63 |
64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
65 | model = model.to(device)
66 | sampler = DDIMSampler(model)
67 |
68 | os.makedirs(opt.outdir, exist_ok=True)
69 | with torch.no_grad():
70 | with model.ema_scope():
71 | for image, mask in tqdm(zip(images, masks)):
72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
73 | batch = make_batch(image, mask, device=device)
74 |
75 | # encode masked image and concat downsampled mask
76 | c = model.cond_stage_model.encode(batch["masked_image"])
77 | cc = torch.nn.functional.interpolate(batch["mask"],
78 | size=c.shape[-2:])
79 | c = torch.cat((c, cc), dim=1)
80 |
81 | shape = (c.shape[1]-1,)+c.shape[2:]
82 | samples_ddim, _ = sampler.sample(S=opt.steps,
83 | conditioning=c,
84 | batch_size=c.shape[0],
85 | shape=shape,
86 | verbose=False)
87 | x_samples_ddim = model.decode_first_stage(samples_ddim)
88 |
89 | image = torch.clamp((batch["image"]+1.0)/2.0,
90 | min=0.0, max=1.0)
91 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
92 | min=0.0, max=1.0)
93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
94 | min=0.0, max=1.0)
95 |
96 | inpainted = (1-mask)*image+mask*predicted_image
97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
99 |
--------------------------------------------------------------------------------
/scripts/stable_txt2img.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | import torch
3 | import numpy as np
4 | from omegaconf import OmegaConf
5 | from PIL import Image
6 | from tqdm import tqdm, trange
7 | from itertools import islice
8 | from einops import rearrange
9 | from torchvision.utils import make_grid
10 | import time
11 | from pytorch_lightning import seed_everything
12 | from torch import autocast
13 | from contextlib import contextmanager, nullcontext
14 |
15 | from ldm.util import instantiate_from_config
16 | from ldm.models.diffusion.ddim import DDIMSampler
17 | from ldm.models.diffusion.plms import PLMSSampler
18 |
19 |
20 | def chunk(it, size):
21 | it = iter(it)
22 | return iter(lambda: tuple(islice(it, size)), ())
23 |
24 |
25 | def load_model_from_config(config, ckpt, verbose=False):
26 | print(f"Loading model from {ckpt}")
27 | pl_sd = torch.load(ckpt, map_location="cpu")
28 | if "global_step" in pl_sd:
29 | print(f"Global Step: {pl_sd['global_step']}")
30 | sd = pl_sd["state_dict"]
31 | model = instantiate_from_config(config.model)
32 | m, u = model.load_state_dict(sd, strict=False)
33 | if len(m) > 0 and verbose:
34 | print("missing keys:")
35 | print(m)
36 | if len(u) > 0 and verbose:
37 | print("unexpected keys:")
38 | print(u)
39 |
40 | model.cuda()
41 | model.eval()
42 | return model
43 |
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 |
48 | parser.add_argument(
49 | "--prompt",
50 | type=str,
51 | nargs="?",
52 | default="a painting of a virus monster playing guitar",
53 | help="the prompt to render"
54 | )
55 | parser.add_argument(
56 | "--outdir",
57 | type=str,
58 | nargs="?",
59 | help="dir to write results to",
60 | default="outputs/txt2img-samples"
61 | )
62 | parser.add_argument(
63 | "--skip_grid",
64 | action='store_true',
65 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
66 | )
67 | parser.add_argument(
68 | "--skip_save",
69 | action='store_true',
70 | help="do not save individual samples. For speed measurements.",
71 | )
72 | parser.add_argument(
73 | "--ddim_steps",
74 | type=int,
75 | default=50,
76 | help="number of ddim sampling steps",
77 | )
78 | parser.add_argument(
79 | "--plms",
80 | action='store_true',
81 | help="use plms sampling",
82 | )
83 | parser.add_argument(
84 | "--laion400m",
85 | action='store_true',
86 | help="uses the LAION400M model",
87 | )
88 | parser.add_argument(
89 | "--fixed_code",
90 | action='store_true',
91 | help="if enabled, uses the same starting code across samples ",
92 | )
93 | parser.add_argument(
94 | "--ddim_eta",
95 | type=float,
96 | default=0.0,
97 | help="ddim eta (eta=0.0 corresponds to deterministic sampling",
98 | )
99 | parser.add_argument(
100 | "--n_iter",
101 | type=int,
102 | default=2,
103 | help="sample this often",
104 | )
105 | parser.add_argument(
106 | "--H",
107 | type=int,
108 | default=512,
109 | help="image height, in pixel space",
110 | )
111 | parser.add_argument(
112 | "--W",
113 | type=int,
114 | default=512,
115 | help="image width, in pixel space",
116 | )
117 | parser.add_argument(
118 | "--C",
119 | type=int,
120 | default=4,
121 | help="latent channels",
122 | )
123 | parser.add_argument(
124 | "--f",
125 | type=int,
126 | default=8,
127 | help="downsampling factor",
128 | )
129 | parser.add_argument(
130 | "--n_samples",
131 | type=int,
132 | default=3,
133 | help="how many samples to produce for each given prompt. A.k.a. batch size",
134 | )
135 | parser.add_argument(
136 | "--n_rows",
137 | type=int,
138 | default=0,
139 | help="rows in the grid (default: n_samples)",
140 | )
141 | parser.add_argument(
142 | "--scale",
143 | type=float,
144 | default=7.5,
145 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
146 | )
147 | parser.add_argument(
148 | "--from-file",
149 | type=str,
150 | help="if specified, load prompts from this file",
151 | )
152 | parser.add_argument(
153 | "--config",
154 | type=str,
155 | default="configs/stable-diffusion/v1-inference.yaml",
156 | help="path to config which constructs model",
157 | )
158 | parser.add_argument(
159 | "--ckpt",
160 | type=str,
161 | default="models/ldm/stable-diffusion-v1/model.ckpt",
162 | help="path to checkpoint of model",
163 | )
164 | parser.add_argument(
165 | "--seed",
166 | type=int,
167 | default=42,
168 | help="the seed (for reproducible sampling)",
169 | )
170 | parser.add_argument(
171 | "--precision",
172 | type=str,
173 | help="evaluate at this precision",
174 | choices=["full", "autocast"],
175 | default="autocast"
176 | )
177 |
178 |
179 | parser.add_argument(
180 | "--embedding_path",
181 | type=str,
182 | help="Path to a pre-trained embedding manager checkpoint")
183 |
184 | opt = parser.parse_args()
185 |
186 | if opt.laion400m:
187 | print("Falling back to LAION 400M model...")
188 | opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
189 | opt.ckpt = "models/ldm/text2img-large/model.ckpt"
190 | opt.outdir = "outputs/txt2img-samples-laion400m"
191 |
192 | seed_everything(opt.seed)
193 |
194 | config = OmegaConf.load(f"{opt.config}")
195 | model = load_model_from_config(config, f"{opt.ckpt}")
196 | model.embedding_manager.load(opt.embedding_path)
197 |
198 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
199 | model = model.to(device)
200 |
201 | if opt.plms:
202 | sampler = PLMSSampler(model)
203 | else:
204 | sampler = DDIMSampler(model)
205 |
206 | os.makedirs(opt.outdir, exist_ok=True)
207 | outpath = opt.outdir
208 |
209 | batch_size = opt.n_samples
210 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
211 | if not opt.from_file:
212 | prompt = opt.prompt
213 | assert prompt is not None
214 | data = [batch_size * [prompt]]
215 |
216 | else:
217 | print(f"reading prompts from {opt.from_file}")
218 | with open(opt.from_file, "r") as f:
219 | data = f.read().splitlines()
220 | data = list(chunk(data, batch_size))
221 |
222 | sample_path = os.path.join(outpath, "samples")
223 | os.makedirs(sample_path, exist_ok=True)
224 | base_count = len(os.listdir(sample_path))
225 | grid_count = len(os.listdir(outpath)) - 1
226 |
227 | start_code = None
228 | if opt.fixed_code:
229 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
230 |
231 | precision_scope = autocast if opt.precision=="autocast" else nullcontext
232 | with torch.no_grad():
233 | with precision_scope("cuda"):
234 | with model.ema_scope():
235 | tic = time.time()
236 | all_samples = list()
237 | for n in trange(opt.n_iter, desc="Sampling"):
238 | for prompts in tqdm(data, desc="data"):
239 | uc = None
240 | if opt.scale != 1.0:
241 | uc = model.get_learned_conditioning(batch_size * [""])
242 | if isinstance(prompts, tuple):
243 | prompts = list(prompts)
244 | c = model.get_learned_conditioning(prompts)
245 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
246 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
247 | conditioning=c,
248 | batch_size=opt.n_samples,
249 | shape=shape,
250 | verbose=False,
251 | unconditional_guidance_scale=opt.scale,
252 | unconditional_conditioning=uc,
253 | eta=opt.ddim_eta,
254 | x_T=start_code)
255 |
256 | x_samples_ddim = model.decode_first_stage(samples_ddim)
257 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
258 |
259 | if not opt.skip_save:
260 | for x_sample in x_samples_ddim:
261 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
262 | Image.fromarray(x_sample.astype(np.uint8)).save(
263 | os.path.join(sample_path, f"{base_count:05}.jpg"))
264 | base_count += 1
265 |
266 | if not opt.skip_grid:
267 | all_samples.append(x_samples_ddim)
268 |
269 | if not opt.skip_grid:
270 | # additionally, save as grid
271 | grid = torch.stack(all_samples, 0)
272 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
273 | grid = make_grid(grid, nrow=n_rows)
274 |
275 | # to image
276 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
277 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}-{grid_count:04}.jpg'))
278 | grid_count += 1
279 |
280 | toc = time.time()
281 |
282 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
283 | f" \nEnjoy.")
284 |
285 |
286 | if __name__ == "__main__":
287 | main()
288 |
--------------------------------------------------------------------------------
/scripts/txt2img.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | import torch
3 | import numpy as np
4 | from omegaconf import OmegaConf
5 | from PIL import Image
6 | from tqdm import tqdm, trange
7 | from einops import rearrange
8 | from torchvision.utils import make_grid
9 |
10 | from ldm.util import instantiate_from_config
11 | from ldm.models.diffusion.ddim import DDIMSampler
12 | from ldm.models.diffusion.plms import PLMSSampler
13 |
14 | def load_model_from_config(config, ckpt, verbose=False):
15 | print(f"Loading model from {ckpt}")
16 | pl_sd = torch.load(ckpt, map_location="cpu")
17 | sd = pl_sd["state_dict"]
18 | model = instantiate_from_config(config.model)
19 | m, u = model.load_state_dict(sd, strict=False)
20 | if len(m) > 0 and verbose:
21 | print("missing keys:")
22 | print(m)
23 | if len(u) > 0 and verbose:
24 | print("unexpected keys:")
25 | print(u)
26 |
27 | model.cuda()
28 | model.eval()
29 | return model
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 |
35 | parser.add_argument(
36 | "--prompt",
37 | type=str,
38 | nargs="?",
39 | default="a painting of a virus monster playing guitar",
40 | help="the prompt to render"
41 | )
42 |
43 | parser.add_argument(
44 | "--outdir",
45 | type=str,
46 | nargs="?",
47 | help="dir to write results to",
48 | default="outputs/txt2img-samples"
49 | )
50 | parser.add_argument(
51 | "--ddim_steps",
52 | type=int,
53 | default=200,
54 | help="number of ddim sampling steps",
55 | )
56 |
57 | parser.add_argument(
58 | "--plms",
59 | action='store_true',
60 | help="use plms sampling",
61 | )
62 |
63 | parser.add_argument(
64 | "--ddim_eta",
65 | type=float,
66 | default=0.0,
67 | help="ddim eta (eta=0.0 corresponds to deterministic sampling",
68 | )
69 | parser.add_argument(
70 | "--n_iter",
71 | type=int,
72 | default=1,
73 | help="sample this often",
74 | )
75 |
76 | parser.add_argument(
77 | "--H",
78 | type=int,
79 | default=256,
80 | help="image height, in pixel space",
81 | )
82 |
83 | parser.add_argument(
84 | "--W",
85 | type=int,
86 | default=256,
87 | help="image width, in pixel space",
88 | )
89 |
90 | parser.add_argument(
91 | "--n_samples",
92 | type=int,
93 | default=4,
94 | help="how many samples to produce for the given prompt",
95 | )
96 |
97 | parser.add_argument(
98 | "--scale",
99 | type=float,
100 | default=5.0,
101 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
102 | )
103 |
104 | parser.add_argument(
105 | "--ckpt_path",
106 | type=str,
107 | default="/data/pretrained_models/ldm/text2img-large/model.ckpt",
108 | help="Path to pretrained ldm text2img model")
109 |
110 | parser.add_argument(
111 | "--embedding_path",
112 | type=str,
113 | help="Path to a pre-trained embedding manager checkpoint")
114 |
115 | opt = parser.parse_args()
116 |
117 |
118 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
119 | model = load_model_from_config(config, opt.ckpt_path) # TODO: check path
120 | model.embedding_manager.load(opt.embedding_path)
121 |
122 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
123 | model = model.to(device)
124 |
125 | if opt.plms:
126 | sampler = PLMSSampler(model)
127 | else:
128 | sampler = DDIMSampler(model)
129 |
130 | os.makedirs(opt.outdir, exist_ok=True)
131 | outpath = opt.outdir
132 |
133 | prompt = opt.prompt
134 |
135 |
136 | sample_path = os.path.join(outpath, "samples")
137 | os.makedirs(sample_path, exist_ok=True)
138 | base_count = len(os.listdir(sample_path))
139 |
140 | all_samples=list()
141 | with torch.no_grad():
142 | with model.ema_scope():
143 | uc = None
144 | if opt.scale != 1.0:
145 | uc = model.get_learned_conditioning(opt.n_samples * [""])
146 | for n in trange(opt.n_iter, desc="Sampling"):
147 | c = model.get_learned_conditioning(opt.n_samples * [prompt])
148 | shape = [4, opt.H//8, opt.W//8]
149 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
150 | conditioning=c,
151 | batch_size=opt.n_samples,
152 | shape=shape,
153 | verbose=False,
154 | unconditional_guidance_scale=opt.scale,
155 | unconditional_conditioning=uc,
156 | eta=opt.ddim_eta)
157 |
158 | x_samples_ddim = model.decode_first_stage(samples_ddim)
159 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
160 |
161 | for x_sample in x_samples_ddim:
162 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
163 | Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.jpg"))
164 | base_count += 1
165 | all_samples.append(x_samples_ddim)
166 |
167 |
168 | # additionally, save as grid
169 | grid = torch.stack(all_samples, 0)
170 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
171 | grid = make_grid(grid, nrow=opt.n_samples)
172 |
173 | # to image
174 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
175 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.jpg'))
176 |
177 | print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.")
178 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------