├── structured_stable_diffusion
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── lsun.py
│ └── imagenet.py
├── models
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── ddim.cpython-38.pyc
│ │ │ ├── ddpm.cpython-38.pyc
│ │ │ ├── plms.cpython-38.pyc
│ │ │ └── __init__.cpython-38.pyc
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ └── plms.py
│ ├── __pycache__
│ │ └── autoencoder.cpython-38.pyc
│ └── autoencoder.py
├── modules
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── modules.cpython-38.pyc
│ │ └── modules.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── util.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── openaimodel.cpython-38.pyc
│ │ └── util.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ ├── __pycache__
│ │ ├── ema.cpython-38.pyc
│ │ ├── attention.cpython-38.pyc
│ │ └── x_transformer.cpython-38.pyc
│ ├── image_degradation
│ │ ├── utils
│ │ │ └── test.png
│ │ └── __init__.py
│ ├── ema.py
│ └── attention.py
├── __pycache__
│ └── util.cpython-38.pyc
├── lr_scheduler.py
└── util.py
├── .gitignore
├── assets
├── ablation.png
├── attention.jpg
├── a red apple and a green bird.png
├── a dog standing on a surfboard on the beach.png
├── a spacious kitchen has white walls,red countertops,and a large stove.jpg
└── Merry Christmas! a red teddy bear in a christmas hat sitting next to a glass, a can and a plastic cup full of liquid, on a table in a living room.png
├── setup.py
├── environment.yaml
├── models
└── first_stage_models
│ ├── kl-f4
│ └── config.yaml
│ ├── kl-f8
│ └── config.yaml
│ ├── kl-f16
│ └── config.yaml
│ ├── kl-f32
│ └── config.yaml
│ ├── vq-f4
│ └── config.yaml
│ ├── vq-f4-noattn
│ └── config.yaml
│ ├── vq-f8-n256
│ └── config.yaml
│ ├── vq-f16
│ └── config.yaml
│ └── vq-f8
│ └── config.yaml
├── configs
├── autoencoder
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ ├── autoencoder_kl_8x8x64.yaml
│ └── autoencoder_kl_16x16x16.yaml
└── stable-diffusion
│ └── v1-inference.yaml
├── GLIP_eval
└── eval.py
├── README.md
├── LICENSE
└── scripts
└── txt2img_demo.py
/structured_stable_diffusion/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs
2 | *.egg-info
3 | src
4 | test.ipynb
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/assets/ablation.png
--------------------------------------------------------------------------------
/assets/attention.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/assets/attention.jpg
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from structured_stable_diffusion.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/assets/a red apple and a green bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/assets/a red apple and a green bird.png
--------------------------------------------------------------------------------
/assets/a dog standing on a surfboard on the beach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/assets/a dog standing on a surfboard on the beach.png
--------------------------------------------------------------------------------
/structured_stable_diffusion/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/__pycache__/plms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/models/diffusion/__pycache__/plms.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/__pycache__/x_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/__pycache__/x_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/assets/a spacious kitchen has white walls,red countertops,and a large stove.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/assets/a spacious kitchen has white walls,red countertops,and a large stove.jpg
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/structured_stable_diffusion/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='structure-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from structured_stable_diffusion.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from structured_stable_diffusion.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/assets/Merry Christmas! a red teddy bear in a christmas hat sitting next to a glass, a can and a plastic cup full of liquid, on a table in a living room.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixi-feng/Structured-Diffusion-Guidance/HEAD/assets/Merry Christmas! a red teddy bear in a christmas hat sitting next to a glass, a can and a plastic cup full of liquid, on a table in a living room.png
--------------------------------------------------------------------------------
/structured_stable_diffusion/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: structure_diffusion
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - diffusers
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - invisible-watermark
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - transformers==4.19.2
27 | - torchmetrics==0.6.0
28 | - kornia==0.6
29 | - stanza==1.4.2
30 | - nltk==3.7
31 | - scenegraphparser==0.1.0
32 | - tqdm==4.64.1
33 | - matplotlib==3.6.2
34 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
35 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
36 | - -e .
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: structured_stable_diffusion.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: structured_stable_diffusion.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: structured_stable_diffusion.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: structured_stable_diffusion.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: structured_stable_diffusion.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: structured_stable_diffusion.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: structured_stable_diffusion.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: structured_stable_diffusion.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: structured_stable_diffusion.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: structured_stable_diffusion.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: structured_stable_diffusion.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: structured_stable_diffusion.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: structured_stable_diffusion.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: structured_stable_diffusion.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: structured_stable_diffusion.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: structured_stable_diffusion.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: structured_stable_diffusion.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: structured_stable_diffusion.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: structured_stable_diffusion.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 | save_map: False # set True to visualize attention maps
46 |
47 | first_stage_config:
48 | target: structured_stable_diffusion.models.autoencoder.AutoencoderKL
49 | params:
50 | embed_dim: 4
51 | monitor: val/rec_loss
52 | ddconfig:
53 | double_z: true
54 | z_channels: 4
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 2
62 | - 4
63 | - 4
64 | num_res_blocks: 2
65 | attn_resolutions: []
66 | dropout: 0.0
67 | lossconfig:
68 | target: torch.nn.Identity
69 |
70 | cond_stage_config:
71 | target: structured_stable_diffusion.modules.encoders.modules.FrozenCLIPEmbedder
72 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/GLIP_eval/eval.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import requests
4 | from io import BytesIO
5 | from PIL import Image
6 | import numpy as np
7 | from maskrcnn_benchmark.config import cfg
8 | from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo
9 |
10 | import os
11 | from collections import defaultdict
12 | import json
13 | from tqdm import tqdm
14 | import sys
15 |
16 |
17 | def blockPrint():
18 | sys.stdout = open(os.devnull, 'w')
19 |
20 |
21 | # Restore
22 | def enablePrint():
23 | sys.stdout = sys.__stdout__
24 |
25 |
26 | def load(dir):
27 | """
28 | Given an url of an image, downloads the image and
29 | returns a PIL image
30 | """
31 | pil_image = Image.open(dir).convert("RGB")
32 | # convert to BGR format
33 | image = np.array(pil_image)[:, :, [2, 1, 0]]
34 | return image
35 |
36 | def imshow(img, caption):
37 | plt.imshow(img[:, :, [2, 1, 0]])
38 | plt.axis("off")
39 | plt.figtext(0.5, 0.09, caption, wrap=True, horizontalalignment='center', fontsize=20)
40 |
41 |
42 |
43 | if __name__ == '__main__':
44 | from argparse import ArgumentParser
45 | parser = ArgumentParser()
46 | parser.add_argument("--image_dir", type=str, required=True)
47 | parser.add_argument("--output_dir", type=str, required=True)
48 | parser.add_argument("--thresh", type=float, default=0.5)
49 | parser.add_argument("--dataset", type=str)
50 | args = parser.parse_args()
51 |
52 | image_dir = args.image_dir
53 | output_dir = args.output_dir
54 |
55 | os.makedirs(output_dir, exist_ok=True)
56 | image_names = sorted(os.listdir(image_dir))
57 |
58 | config_file = "configs/pretrain/glip_Swin_L.yaml"
59 | weight_file = "MODEL/glip_large_model.pth" # NOTE
60 |
61 | # update the config options with the config file
62 | # manual override some options
63 | cfg.local_rank = 0
64 | cfg.num_gpus = 1
65 | cfg.merge_from_file(config_file)
66 | cfg.merge_from_list(["MODEL.WEIGHT", weight_file])
67 | cfg.merge_from_list(["MODEL.DEVICE", "cuda"])
68 |
69 | glip_demo = GLIPDemo(
70 | cfg,
71 | min_image_size=800,
72 | confidence_threshold=0.7,
73 | show_mask_heatmaps=False
74 | )
75 |
76 | plus = 1 if glip_demo.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD" else 0
77 |
78 | grounding_results = {}
79 | blockPrint()
80 | for file in tqdm(image_names):
81 | image = load(os.path.join(image_dir, file))
82 | caption = os.path.splitext(file.split("-")[-1])[0] # NOTE
83 | result, top_predictions = glip_demo.run_on_web_image(image, caption, args.thresh)
84 | fig = plt.figure(figsize=(5,5))
85 | plt.imshow(result[:, :, [2, 1, 0]])
86 | plt.axis("off")
87 | plt.tight_layout()
88 | plt.savefig(f"{args.output_dir}/{file}")
89 | plt.close()
90 |
91 | scores = top_predictions.get_field("scores")
92 | labels = top_predictions.get_field("labels")
93 | bbox = top_predictions.bbox
94 | entities = glip_demo.entities
95 |
96 | new_labels = []
97 | for i in labels:
98 | if i <= len(entities):
99 | new_labels.append(entities[i-plus])
100 | else:
101 | new_labels.append("object")
102 |
103 | bbox_by_entities = defaultdict(list)
104 | for l, score, coord in zip(new_labels, scores, bbox):
105 | bbox_by_entities[l].append((score.item(), coord.tolist()))
106 | grounding_results[file] = bbox_by_entities
107 |
108 | with open(f"{output_dir}/glip_results.json", "w") as file:
109 | json.dump(grounding_results, file, indent=4, separators=(",",":"), sort_keys=True)
--------------------------------------------------------------------------------
/structured_stable_diffusion/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Structured Diffusion Guidance (ICLR 2023)
2 | ## We propose a method to fuse language structures into diffusion guidance for compositionality text-to-image generation.
3 |
4 | ### [Project Page](https://weixi-feng.github.io/structure-diffusion-guidance/) | [arxiv](https://arxiv.org/abs/2212.05032) | [OpenReview](https://openreview.net/forum?id=PUIqjT4rzq7)
5 |
6 |
7 | This is the official codebase for **Training-Free Structured Diffusion Guidance for Compositional Text-to-Image Synthesis**.
8 |
9 | [Training-Free Structured Diffusion Guidance for Compositional Text-to-Image Synthesis](https://weixi-feng.github.io/structure-diffusion-guidance/)
10 |
11 | [Weixi Feng](https://weixi-feng.github.io/) 1,
12 | [Xuehai He](https://scholar.google.com/citations?user=kDzxOzUAAAAJ&) 2,
13 | [Tsu-Jui Fu](https://tsujuifu.github.io/)1,
14 | [Varun Jampani](https://varunjampani.github.io/)3,
15 | [Arjun Akula](https://www.arjunakula.com/)3,
16 | [Pradyumna Narayana](https://scholar.google.com/citations?user=BV2dbjEAAAAJ&)3,
17 | [Sugato Basu](https://sites.google.com/site/sugatobasu/)3,
18 | [Xin Eric Wang](https://eric-xw.github.io/)2,
19 | [William Yang Wang](https://sites.cs.ucsb.edu/~william/) 1
20 |
21 | 1UCSB, 2UCSC, 3Google
22 |
23 |
24 | ## Update:
25 | Apr. 4th: updated links, uploaded benchmarks and GLIP eval scripts, updated bibtex.
26 |
27 | ## Setup
28 |
29 | Clone this repository and then create a conda environment with:
30 | ```
31 | conda env create -f environment.yaml
32 | conda activate structure_diffusion
33 | ```
34 | If you already have a [stable diffusion](https://github.com/CompVis/stable-diffusion/) environment, you can run the following commands:
35 | ```
36 | pip install stanza nltk scenegraphparser tqdm matplotlib
37 | pip install -e .
38 | ```
39 |
40 | ## Inference
41 | This repository supports stable diffusion 1.4 for now. Please refer to the official [stable-diffusion](https://github.com/CompVis/stable-diffusion/#weights) repository to download the pre-trained model and put it under ```models/ldm/stable-diffusion-v1/```.
42 | Our method is training-free and can be applied to the trained stable diffusion checkpoint directly.
43 |
44 | To generate an image, run
45 | ```
46 | python scripts/txt2img_demo.py --prompt "A red teddy bear in a christmas hat sitting next to a glass" --plms --parser_type constituency
47 | ```
48 |
49 | By default, the guidance scale is set to 7.5 and output image size is 512x512. We only support PLMS sampling and batch size equals to 1 for now.
50 | Apart from the default arguments from [Stable Diffusion](https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/scripts/txt2img.py), we add ```--parser_type``` and ```--conjunction```.
51 |
52 | ```
53 | usage: txt2img_demo.py [-h] [--prompt [PROMPT]] ...
54 | [--parser_type {constituency,scene_graph}] [--conjunction] [--save_attn_maps]
55 |
56 | optional arguments:
57 | ...
58 | --parser_type {constituency,scene_graph}
59 | --conjunction If True, the input prompt is a conjunction of two concepts like "A and B"
60 | --save_attn_maps If True, the attention maps will be saved as a .pth file with the name same as the image
61 | ```
62 |
63 | Without specifying the ```conjunction``` argument, the model applies one ```key``` and multiple ```values``` for each cross-attention layer.
64 | For concept conjunction prompts, you can run:
65 | ```
66 | python scripts/txt2img_demo.py --prompt "A red car and a white sheep" --plms --parser_type constituency --conjunction
67 | ```
68 |
69 | Overall, compositional prompts remains a challenge for Stable Diffusion v1.4. It may still take several attempts to get a correct image with our method.
70 | The improvement is system-level instead of sample-level, and we are still looking for good evaluation metrics for compositional T2I synthesis.
71 | We observe less missing objects in [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion), and we are implementing our method on top of it as well.
72 | Please feel free to reach out for a discussion.
73 |
74 | ## Benchmarks
75 | CC-500.txt: Concept Conjunction of two objects with different colors (line1-446).
76 | ABC-6K.txt: ~6K attribute binding prompts collected and created from COCO captions.
77 |
78 | ## GLIP Eval
79 | For our GLIP eval, please first clone and setup your environment according to the official [GLIP](https://github.com/microsoft/GLIP) repo and download the model checkpoint(s). Then refer to our ```GLIP_eval/eval.py``` and you may need to modify line 59&82. We assumed that each image file name contains the text prompt.
80 |
81 | ## Comments
82 | Our codebase builds heavily on [Stable Diffusion](https://github.com/CompVis/stable-diffusion). Thanks for open-sourcing!
83 |
84 |
85 | ## Citing our Paper
86 |
87 | If you find our code or paper useful for your research, please consider citing
88 | ```
89 | @inproceedings{feng2023trainingfree,
90 | title={Training-Free Structured Diffusion Guidance for Compositional Text-to-Image Synthesis},
91 | author={Weixi Feng and Xuehai He and Tsu-Jui Fu and Varun Jampani and Arjun Reddy Akula and Pradyumna Narayana and Sugato Basu and Xin Eric Wang and William Yang Wang},
92 | booktitle={The Eleventh International Conference on Learning Representations },
93 | year={2023},
94 | url={https://openreview.net/forum?id=PUIqjT4rzq7}
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Weixi Feng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
24 | Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
25 |
26 | 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
27 | Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
28 | You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
29 | You must cause any modified files to carry prominent notices stating that You changed the files;
30 | You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
31 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
32 | 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
33 | 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
34 |
35 | Section IV: OTHER PROVISIONS
36 |
37 | 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
38 | 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
39 | 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
40 | 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
41 | 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
42 | 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
43 |
44 | END OF TERMS AND CONDITIONS
45 |
46 |
47 |
48 |
49 | Attachment A
50 |
51 | Use Restrictions
52 |
53 | You agree not to use the Model or Derivatives of the Model:
54 | - In any way that violates any applicable national, federal, state, local or international law or regulation;
55 | - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
56 | - To generate or disseminate verifiably false information and/or content with the purpose of harming others;
57 | - To generate or disseminate personal identifiable information that can be used to harm an individual;
58 | - To defame, disparage or otherwise harass others;
59 | - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
60 | - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
61 | - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
62 | - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
63 | - To provide medical advice and medical results interpretation;
64 | - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
65 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from structured_stable_diffusion.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 | class FrozenCLIPEmbedder(AbstractEncoder):
138 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140 | super().__init__()
141 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
142 | self.transformer = CLIPTextModel.from_pretrained(version)
143 | self.device = device
144 | self.max_length = max_length
145 | self.freeze()
146 |
147 | def freeze(self):
148 | self.transformer = self.transformer.eval()
149 | for param in self.parameters():
150 | param.requires_grad = False
151 |
152 | def forward(self, text):
153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155 | tokens = batch_encoding["input_ids"].to(self.device)
156 | outputs = self.transformer(input_ids=tokens)
157 |
158 | z = outputs.last_hidden_state
159 | return z
160 |
161 | def encode(self, text):
162 | return self(text)
163 |
164 |
165 | class FrozenCLIPTextEmbedder(nn.Module):
166 | """
167 | Uses the CLIP transformer encoder for text.
168 | """
169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
170 | super().__init__()
171 | self.model, _ = clip.load(version, jit=False, device="cpu")
172 | self.device = device
173 | self.max_length = max_length
174 | self.n_repeat = n_repeat
175 | self.normalize = normalize
176 |
177 | def freeze(self):
178 | self.model = self.model.eval()
179 | for param in self.parameters():
180 | param.requires_grad = False
181 |
182 | def forward(self, text):
183 | tokens = clip.tokenize(text).to(self.device)
184 | z = self.model.encode_text(tokens)
185 | if self.normalize:
186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187 | return z
188 |
189 | def encode(self, text):
190 | z = self(text)
191 | if z.ndim==2:
192 | z = z[:, None, :]
193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194 | return z
195 |
196 |
197 | class FrozenClipImageEmbedder(nn.Module):
198 | """
199 | Uses the CLIP image encoder.
200 | """
201 | def __init__(
202 | self,
203 | model,
204 | jit=False,
205 | device='cuda' if torch.cuda.is_available() else 'cpu',
206 | antialias=False,
207 | ):
208 | super().__init__()
209 | self.model, _ = clip.load(name=model, device=device, jit=jit)
210 |
211 | self.antialias = antialias
212 |
213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
215 |
216 | def preprocess(self, x):
217 | # normalize to [0,1]
218 | x = kornia.geometry.resize(x, (224, 224),
219 | interpolation='bicubic',align_corners=True,
220 | antialias=self.antialias)
221 | x = (x + 1.) / 2.
222 | # renormalize according to clip
223 | x = kornia.enhance.normalize(x, self.mean, self.std)
224 | return x
225 |
226 | def forward(self, x):
227 | # x is assumed to be in range [-1,1]
228 | return self.model.encode_image(self.preprocess(x))
229 |
230 |
231 | if __name__ == "__main__":
232 | from structured_stable_diffusion.util import count_params
233 | model = FrozenCLIPEmbedder()
234 | count_params(model, verbose=True)
235 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from structured_stable_diffusion.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from structured_stable_diffusion.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from structured_stable_diffusion.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from structured_stable_diffusion.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9 | extract_into_tensor
10 |
11 |
12 | class DDIMSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28 | alphas_cumprod = self.model.alphas_cumprod
29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31 |
32 | self.register_buffer('betas', to_torch(self.model.betas))
33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35 |
36 | # calculations for diffusion q(x_t | x_{t-1}) and others
37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42 |
43 | # ddim sampling parameters
44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45 | ddim_timesteps=self.ddim_timesteps,
46 | eta=ddim_eta,verbose=verbose)
47 | self.register_buffer('ddim_sigmas', ddim_sigmas)
48 | self.register_buffer('ddim_alphas', ddim_alphas)
49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55 |
56 | @torch.no_grad()
57 | def sample(self,
58 | S,
59 | batch_size,
60 | shape,
61 | conditioning=None,
62 | callback=None,
63 | normals_sequence=None,
64 | img_callback=None,
65 | quantize_x0=False,
66 | eta=0.,
67 | mask=None,
68 | x0=None,
69 | temperature=1.,
70 | noise_dropout=0.,
71 | score_corrector=None,
72 | corrector_kwargs=None,
73 | verbose=True,
74 | x_T=None,
75 | log_every_t=100,
76 | unconditional_guidance_scale=1.,
77 | unconditional_conditioning=None,
78 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79 | **kwargs
80 | ):
81 | if conditioning is not None:
82 | if isinstance(conditioning, dict):
83 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84 | if cbs != batch_size:
85 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86 | else:
87 | if conditioning.shape[0] != batch_size:
88 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89 |
90 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91 | # sampling
92 | C, H, W = shape
93 | size = (batch_size, C, H, W)
94 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
95 |
96 | samples, intermediates = self.ddim_sampling(conditioning, size,
97 | callback=callback,
98 | img_callback=img_callback,
99 | quantize_denoised=quantize_x0,
100 | mask=mask, x0=x0,
101 | ddim_use_original_steps=False,
102 | noise_dropout=noise_dropout,
103 | temperature=temperature,
104 | score_corrector=score_corrector,
105 | corrector_kwargs=corrector_kwargs,
106 | x_T=x_T,
107 | log_every_t=log_every_t,
108 | unconditional_guidance_scale=unconditional_guidance_scale,
109 | unconditional_conditioning=unconditional_conditioning,
110 | )
111 | return samples, intermediates
112 |
113 | @torch.no_grad()
114 | def ddim_sampling(self, cond, shape,
115 | x_T=None, ddim_use_original_steps=False,
116 | callback=None, timesteps=None, quantize_denoised=False,
117 | mask=None, x0=None, img_callback=None, log_every_t=100,
118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
120 | device = self.model.betas.device
121 | b = shape[0]
122 | if x_T is None:
123 | img = torch.randn(shape, device=device)
124 | else:
125 | img = x_T
126 |
127 | if timesteps is None:
128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
129 | elif timesteps is not None and not ddim_use_original_steps:
130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
131 | timesteps = self.ddim_timesteps[:subset_end]
132 |
133 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
134 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
135 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
136 | print(f"Running DDIM Sampling with {total_steps} timesteps")
137 |
138 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
139 |
140 | for i, step in enumerate(iterator):
141 | index = total_steps - i - 1
142 | ts = torch.full((b,), step, device=device, dtype=torch.long)
143 |
144 | if mask is not None:
145 | assert x0 is not None
146 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
147 | img = img_orig * mask + (1. - mask) * img
148 |
149 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
150 | quantize_denoised=quantize_denoised, temperature=temperature,
151 | noise_dropout=noise_dropout, score_corrector=score_corrector,
152 | corrector_kwargs=corrector_kwargs,
153 | unconditional_guidance_scale=unconditional_guidance_scale,
154 | unconditional_conditioning=unconditional_conditioning)
155 | img, pred_x0 = outs
156 | if callback: callback(i)
157 | if img_callback: img_callback(pred_x0, i)
158 |
159 | if index % log_every_t == 0 or index == total_steps - 1:
160 | intermediates['x_inter'].append(img)
161 | intermediates['pred_x0'].append(pred_x0)
162 |
163 | return img, intermediates
164 |
165 | @torch.no_grad()
166 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
167 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
168 | unconditional_guidance_scale=1., unconditional_conditioning=None):
169 | b, *_, device = *x.shape, x.device
170 |
171 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
172 | e_t = self.model.apply_model(x, t, c)
173 | else:
174 | x_in = torch.cat([x] * 2)
175 | t_in = torch.cat([t] * 2)
176 | c_in = torch.cat([unconditional_conditioning, c])
177 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
178 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
179 |
180 | if score_corrector is not None:
181 | assert self.model.parameterization == "eps"
182 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
183 |
184 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
185 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
186 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
187 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
188 | # select parameters corresponding to the currently considered timestep
189 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
190 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
191 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
192 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
193 |
194 | # current prediction for x_0
195 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
196 | if quantize_denoised:
197 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
198 | # direction pointing to x_t
199 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
200 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
201 | if noise_dropout > 0.:
202 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
203 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
204 | return x_prev, pred_x0
205 |
206 | @torch.no_grad()
207 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
208 | # fast, but does not allow for exact reconstruction
209 | # t serves as an index to gather the correct alphas
210 | if use_original_steps:
211 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
212 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
213 | else:
214 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
215 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
216 |
217 | if noise is None:
218 | noise = torch.randn_like(x0)
219 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
220 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
221 |
222 | @torch.no_grad()
223 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
224 | use_original_steps=False):
225 |
226 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
227 | timesteps = timesteps[:t_start]
228 |
229 | time_range = np.flip(timesteps)
230 | total_steps = timesteps.shape[0]
231 | print(f"Running DDIM Sampling with {total_steps} timesteps")
232 |
233 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
234 | x_dec = x_latent
235 | for i, step in enumerate(iterator):
236 | index = total_steps - i - 1
237 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
238 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239 | unconditional_guidance_scale=unconditional_guidance_scale,
240 | unconditional_conditioning=unconditional_conditioning)
241 | return x_dec
--------------------------------------------------------------------------------
/structured_stable_diffusion/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import pdb
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, einsum
7 | from einops import rearrange, repeat
8 |
9 | from structured_stable_diffusion.modules.diffusionmodules.util import checkpoint
10 |
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 |
16 | def uniq(arr):
17 | return{el: True for el in arr}.keys()
18 |
19 |
20 | def default(val, d):
21 | if exists(val):
22 | return val
23 | return d() if isfunction(d) else d
24 |
25 |
26 | def max_neg_value(t):
27 | return -torch.finfo(t.dtype).max
28 |
29 |
30 | def init_(tensor):
31 | dim = tensor.shape[-1]
32 | std = 1 / math.sqrt(dim)
33 | tensor.uniform_(-std, std)
34 | return tensor
35 |
36 |
37 | # feedforward
38 | class GEGLU(nn.Module):
39 | def __init__(self, dim_in, dim_out):
40 | super().__init__()
41 | self.proj = nn.Linear(dim_in, dim_out * 2)
42 |
43 | def forward(self, x):
44 | x, gate = self.proj(x).chunk(2, dim=-1)
45 | return x * F.gelu(gate)
46 |
47 |
48 | class FeedForward(nn.Module):
49 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
50 | super().__init__()
51 | inner_dim = int(dim * mult)
52 | dim_out = default(dim_out, dim)
53 | project_in = nn.Sequential(
54 | nn.Linear(dim, inner_dim),
55 | nn.GELU()
56 | ) if not glu else GEGLU(dim, inner_dim)
57 |
58 | self.net = nn.Sequential(
59 | project_in,
60 | nn.Dropout(dropout),
61 | nn.Linear(inner_dim, dim_out)
62 | )
63 |
64 | def forward(self, x):
65 | return self.net(x)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def Normalize(in_channels):
78 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
79 |
80 |
81 | class LinearAttention(nn.Module):
82 | def __init__(self, dim, heads=4, dim_head=32):
83 | super().__init__()
84 | self.heads = heads
85 | hidden_dim = dim_head * heads
86 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
87 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88 |
89 | def forward(self, x):
90 | b, c, h, w = x.shape
91 | qkv = self.to_qkv(x)
92 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
93 | k = k.softmax(dim=-1)
94 | context = torch.einsum('bhdn,bhen->bhde', k, v)
95 | out = torch.einsum('bhde,bhdn->bhen', context, q)
96 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
97 | return self.to_out(out)
98 |
99 |
100 | class SpatialSelfAttention(nn.Module):
101 | def __init__(self, in_channels):
102 | super().__init__()
103 | self.in_channels = in_channels
104 |
105 | self.norm = Normalize(in_channels)
106 | self.q = torch.nn.Conv2d(in_channels,
107 | in_channels,
108 | kernel_size=1,
109 | stride=1,
110 | padding=0)
111 | self.k = torch.nn.Conv2d(in_channels,
112 | in_channels,
113 | kernel_size=1,
114 | stride=1,
115 | padding=0)
116 | self.v = torch.nn.Conv2d(in_channels,
117 | in_channels,
118 | kernel_size=1,
119 | stride=1,
120 | padding=0)
121 | self.proj_out = torch.nn.Conv2d(in_channels,
122 | in_channels,
123 | kernel_size=1,
124 | stride=1,
125 | padding=0)
126 |
127 | def forward(self, x):
128 | h_ = x
129 | h_ = self.norm(h_)
130 | q = self.q(h_)
131 | k = self.k(h_)
132 | v = self.v(h_)
133 |
134 | # compute attention
135 | b,c,h,w = q.shape
136 | q = rearrange(q, 'b c h w -> b (h w) c')
137 | k = rearrange(k, 'b c h w -> b c (h w)')
138 | w_ = torch.einsum('bij,bjk->bik', q, k)
139 |
140 | w_ = w_ * (int(c)**(-0.5))
141 | w_ = torch.nn.functional.softmax(w_, dim=2)
142 |
143 | # attend to values
144 | v = rearrange(v, 'b c h w -> b c (h w)')
145 | w_ = rearrange(w_, 'b i j -> b j i')
146 | h_ = torch.einsum('bij,bjk->bik', v, w_)
147 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
148 | h_ = self.proj_out(h_)
149 |
150 | return x+h_
151 |
152 |
153 | class CrossAttention(nn.Module):
154 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., struct_attn=False, save_map=False):
155 | super().__init__()
156 | inner_dim = dim_head * heads
157 | context_dim = default(context_dim, query_dim)
158 |
159 | self.scale = dim_head ** -0.5
160 | self.heads = heads
161 |
162 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165 |
166 | self.to_out = nn.Sequential(
167 | nn.Linear(inner_dim, query_dim),
168 | nn.Dropout(dropout)
169 | )
170 |
171 | self.struct_attn = struct_attn
172 | self.save_map = save_map
173 |
174 | def forward(self, x, context=None, mask=None):
175 | q = self.to_q(x)
176 |
177 | if isinstance(context, list):
178 | if self.struct_attn:
179 | out = self.struct_qkv(q, context, mask)
180 | else:
181 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context
182 | out = self.normal_qkv(q, context, mask)
183 | else:
184 | context = default(context, x)
185 | out = self.normal_qkv(q, context, mask)
186 |
187 | return self.to_out(out)
188 |
189 | def struct_qkv(self, q, context, mask):
190 | """
191 | context: list of [uc, list of conditional context]
192 | """
193 | uc_context = context[0]
194 | context_k, context_v = context[1]['k'], context[1]['v']
195 |
196 | if isinstance(context_k, list) and isinstance(context_v, list):
197 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask)
198 | else:
199 | raise NotImplementedError
200 | return out
201 |
202 | def multi_qkv(self, q, uc_context, context_k, context_v, mask):
203 | h = self.heads
204 |
205 | assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0)
206 | true_bs = uc_context.size(0) * h
207 |
208 | k_uc, v_uc = self.get_kv(uc_context)
209 | k_c = [self.to_k(c_k) for c_k in context_k]
210 | v_c = [self.to_v(c_v) for c_v in context_v]
211 |
212 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
213 |
214 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h)
215 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h)
216 |
217 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point
218 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c]
219 |
220 | # get composition
221 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale
222 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c]
223 |
224 | attn_uc = sim_uc.softmax(dim=-1)
225 | attn_c = [sim.softmax(dim=-1) for sim in sim_c]
226 |
227 | if self.save_map and sim_uc.size(1) != sim_uc.size(2):
228 | self.save_attn_maps(attn_c)
229 |
230 | # get uc output
231 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc)
232 |
233 | # get c output
234 | n_keys, n_values = len(k_c), len(v_c)
235 | if n_keys == n_values:
236 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c)
237 | else:
238 | assert n_keys == 1 or n_values == 1
239 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn in attn_c for v in v_c]) / (n_keys * n_values)
240 |
241 | out = torch.cat([out_uc, out_c], dim=0)
242 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
243 |
244 | return out
245 |
246 | def normal_qkv(self, q, context, mask):
247 | h = self.heads
248 | k = self.to_k(context)
249 | v = self.to_v(context)
250 |
251 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
252 |
253 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
254 |
255 | if exists(mask):
256 | mask = rearrange(mask, 'b ... -> b (...)')
257 | max_neg_value = -torch.finfo(sim.dtype).max
258 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
259 | sim.masked_fill_(~mask, max_neg_value)
260 |
261 | # attention, what we cannot get enough of
262 | attn = sim.softmax(dim=-1)
263 |
264 | if self.save_map and sim.size(1) != sim.size(2):
265 | self.save_attn_maps(attn.chunk(2)[1])
266 |
267 | out = einsum('b i j, b j d -> b i d', attn, v)
268 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
269 |
270 | return out
271 |
272 | def get_kv(self, context):
273 | return self.to_k(context), self.to_v(context)
274 |
275 | def save_attn_maps(self, attn):
276 | h = self.heads
277 | if isinstance(attn, list):
278 | height = width = int(math.sqrt(attn[0].size(1)))
279 | self.attn_maps = [rearrange(m.detach(), '(b x) (h w) l -> b x h w l', x=h, h=height, w=width)[...,:20].cpu() for m in attn]
280 | else:
281 | height = width = int(math.sqrt(attn.size(1)))
282 | self.attn_maps = rearrange(attn.detach(), '(b x) (h w) l -> b x h w l', x=h, h=height, w=width)[...,:20].cpu()
283 |
284 |
285 | class BasicTransformerBlock(nn.Module):
286 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, struct_attn=False, save_map=False):
287 | super().__init__()
288 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
289 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
290 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
291 | heads=n_heads, dim_head=d_head, dropout=dropout,
292 | struct_attn=struct_attn, save_map=save_map) # is self-attn if context is none
293 | self.norm1 = nn.LayerNorm(dim)
294 | self.norm2 = nn.LayerNorm(dim)
295 | self.norm3 = nn.LayerNorm(dim)
296 | self.checkpoint = checkpoint
297 |
298 | def forward(self, x, context=None):
299 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
300 |
301 | def _forward(self, x, context=None):
302 | x = self.attn1(self.norm1(x)) + x
303 | x = self.attn2(self.norm2(x), context=context) + x
304 | x = self.ff(self.norm3(x)) + x
305 | return x
306 |
307 |
308 | class SpatialTransformer(nn.Module):
309 | """
310 | Transformer block for image-like data.
311 | First, project the input (aka embedding)
312 | and reshape to b, t, d.
313 | Then apply standard transformer action.
314 | Finally, reshape to image
315 | """
316 | def __init__(self, in_channels, n_heads, d_head,
317 | depth=1, dropout=0., context_dim=None, struct_attn=False, save_map=False):
318 | super().__init__()
319 | self.in_channels = in_channels
320 | inner_dim = n_heads * d_head
321 | self.norm = Normalize(in_channels)
322 |
323 | self.proj_in = nn.Conv2d(in_channels,
324 | inner_dim,
325 | kernel_size=1,
326 | stride=1,
327 | padding=0)
328 |
329 | self.transformer_blocks = nn.ModuleList(
330 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, struct_attn=struct_attn, save_map=save_map)
331 | for d in range(depth)]
332 | )
333 |
334 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
335 | in_channels,
336 | kernel_size=1,
337 | stride=1,
338 | padding=0))
339 |
340 | def forward(self, x, context=None):
341 | # note: if no context is given, cross-attention defaults to self-attention
342 | b, c, h, w = x.shape
343 | x_in = x
344 | x = self.norm(x)
345 | x = self.proj_in(x)
346 | x = rearrange(x, 'b c h w -> b (h w) c')
347 | for block in self.transformer_blocks:
348 | x = block(x, context=context)
349 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
350 | x = self.proj_out(x)
351 | return x + x_in
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | from collections import defaultdict
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | from functools import partial
8 |
9 | from structured_stable_diffusion.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
10 |
11 |
12 | class PLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | skip=False,
82 | quiet=False,
83 | **kwargs
84 | ):
85 | if conditioning is not None:
86 | if isinstance(conditioning, dict):
87 | cbs = len(conditioning[list(conditioning.keys())[0]])
88 | if cbs != batch_size:
89 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
90 | else:
91 | if len(conditioning) != batch_size:
92 | print(f"Warning: Got {len(conditioning)} conditionings but batch-size is {batch_size}")
93 |
94 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
95 | # sampling
96 | C, H, W = shape
97 | size = (batch_size, C, H, W)
98 | # print(f'Data shape for PLMS sampling is {size}')
99 |
100 | samples, intermediates = self.plms_sampling(conditioning, size,
101 | callback=callback,
102 | img_callback=img_callback,
103 | quantize_denoised=quantize_x0,
104 | mask=mask, x0=x0,
105 | ddim_use_original_steps=False,
106 | noise_dropout=noise_dropout,
107 | temperature=temperature,
108 | score_corrector=score_corrector,
109 | corrector_kwargs=corrector_kwargs,
110 | x_T=x_T,
111 | log_every_t=log_every_t,
112 | unconditional_guidance_scale=unconditional_guidance_scale,
113 | unconditional_conditioning=unconditional_conditioning,
114 | skip=skip,
115 | quiet=quiet,
116 | **kwargs
117 | )
118 | return samples, intermediates
119 |
120 | @torch.no_grad()
121 | def plms_sampling(self, cond, shape,
122 | x_T=None, ddim_use_original_steps=False,
123 | callback=None, timesteps=None, quantize_denoised=False,
124 | mask=None, x0=None, img_callback=None, log_every_t=100,
125 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
126 | unconditional_guidance_scale=1., unconditional_conditioning=None, skip=False, quiet=False, **kwargs):
127 | device = self.model.betas.device
128 | b = shape[0]
129 | if x_T is None:
130 | img = torch.randn(shape, device=device)
131 | else:
132 | img = x_T
133 |
134 | if timesteps is None:
135 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
136 | elif timesteps is not None and not ddim_use_original_steps:
137 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
138 | timesteps = self.ddim_timesteps[:subset_end]
139 |
140 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
141 | if skip:
142 | return img, intermediates
143 |
144 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
145 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146 | print(f"Running PLMS Sampling with {total_steps} timesteps")
147 |
148 | if not quiet:
149 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
150 | else:
151 | iterator = time_range
152 | old_eps = []
153 | self.attn_maps = defaultdict(list)
154 |
155 | for i, step in enumerate(iterator):
156 | index = total_steps - i - 1
157 | ts = torch.full((b,), step, device=device, dtype=torch.long)
158 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
159 |
160 | if mask is not None:
161 | assert x0 is not None
162 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
163 | img = img_orig * mask + (1. - mask) * img
164 |
165 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
166 | quantize_denoised=quantize_denoised, temperature=temperature,
167 | noise_dropout=noise_dropout, score_corrector=score_corrector,
168 | corrector_kwargs=corrector_kwargs,
169 | unconditional_guidance_scale=unconditional_guidance_scale,
170 | unconditional_conditioning=unconditional_conditioning,
171 | old_eps=old_eps, t_next=ts_next)
172 | img, pred_x0, e_t = outs
173 | old_eps.append(e_t)
174 | if len(old_eps) >= 4:
175 | old_eps.pop(0)
176 | if callback: callback(i)
177 | if img_callback: img_callback(pred_x0, i)
178 |
179 | if index % log_every_t == 0 or index == total_steps - 1:
180 | intermediates['x_inter'].append(img)
181 | intermediates['pred_x0'].append(pred_x0)
182 |
183 | if kwargs.get('save_attn_maps', False) and i % 5 == 0:
184 | for name, module in self.model.model.diffusion_model.named_modules():
185 | module_name = type(module).__name__
186 | if module_name == 'CrossAttention' and 'attn2' in name:
187 | self.attn_maps[name].append(module.attn_maps)
188 |
189 | return img, intermediates
190 |
191 | @torch.no_grad()
192 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
193 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
194 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
195 | b, *_, device = *x.shape, x.device
196 |
197 | def get_model_output(x, t):
198 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
199 | e_t = self.model.apply_model(x, t, c)
200 | else:
201 | x_in = torch.cat([x] * 2)
202 | t_in = torch.cat([t] * 2)
203 | if isinstance(c, (list, dict)):
204 | c_in = [unconditional_conditioning, c]
205 | else:
206 | c_in = torch.cat([unconditional_conditioning, c])
207 |
208 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
209 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
210 |
211 | if score_corrector is not None:
212 | assert self.model.parameterization == "eps"
213 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
214 |
215 | return e_t
216 |
217 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
218 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
219 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
220 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
221 |
222 | def get_x_prev_and_pred_x0(e_t, index):
223 | # select parameters corresponding to the currently considered timestep
224 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
225 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
226 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
227 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
228 |
229 | # current prediction for x_0
230 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
231 | if quantize_denoised:
232 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
233 | # direction pointing to x_t
234 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
235 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
236 | if noise_dropout > 0.:
237 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
238 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
239 | return x_prev, pred_x0
240 |
241 | e_t = get_model_output(x, t)
242 | if len(old_eps) == 0:
243 | # Pseudo Improved Euler (2nd order)
244 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
245 | e_t_next = get_model_output(x_prev, t_next)
246 | e_t_prime = (e_t + e_t_next) / 2
247 | elif len(old_eps) == 1:
248 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
249 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
250 | elif len(old_eps) == 2:
251 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
252 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
253 | elif len(old_eps) >= 3:
254 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
255 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
256 |
257 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
258 |
259 | return x_prev, pred_x0, e_t
260 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os, yaml, pickle, shutil, tarfile, glob
2 | import cv2
3 | import albumentations
4 | import PIL
5 | import numpy as np
6 | import torchvision.transforms.functional as TF
7 | from omegaconf import OmegaConf
8 | from functools import partial
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from torch.utils.data import Dataset, Subset
12 |
13 | import taming.data.utils as tdu
14 | from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15 | from taming.data.imagenet import ImagePaths
16 |
17 | from structured_stable_diffusion.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18 |
19 |
20 | def synset2idx(path_to_yaml="data/index_synset.yaml"):
21 | with open(path_to_yaml) as f:
22 | di2s = yaml.load(f)
23 | return dict((v,k) for k,v in di2s.items())
24 |
25 |
26 | class ImageNetBase(Dataset):
27 | def __init__(self, config=None):
28 | self.config = config or OmegaConf.create()
29 | if not type(self.config)==dict:
30 | self.config = OmegaConf.to_container(self.config)
31 | self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32 | self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33 | self._prepare()
34 | self._prepare_synset_to_human()
35 | self._prepare_idx_to_synset()
36 | self._prepare_human_to_integer_label()
37 | self._load()
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, i):
43 | return self.data[i]
44 |
45 | def _prepare(self):
46 | raise NotImplementedError()
47 |
48 | def _filter_relpaths(self, relpaths):
49 | ignore = set([
50 | "n06596364_9591.JPEG",
51 | ])
52 | relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53 | if "sub_indices" in self.config:
54 | indices = str_to_indices(self.config["sub_indices"])
55 | synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56 | self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57 | files = []
58 | for rpath in relpaths:
59 | syn = rpath.split("/")[0]
60 | if syn in synsets:
61 | files.append(rpath)
62 | return files
63 | else:
64 | return relpaths
65 |
66 | def _prepare_synset_to_human(self):
67 | SIZE = 2655750
68 | URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69 | self.human_dict = os.path.join(self.root, "synset_human.txt")
70 | if (not os.path.exists(self.human_dict) or
71 | not os.path.getsize(self.human_dict)==SIZE):
72 | download(URL, self.human_dict)
73 |
74 | def _prepare_idx_to_synset(self):
75 | URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76 | self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77 | if (not os.path.exists(self.idx2syn)):
78 | download(URL, self.idx2syn)
79 |
80 | def _prepare_human_to_integer_label(self):
81 | URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82 | self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83 | if (not os.path.exists(self.human2integer)):
84 | download(URL, self.human2integer)
85 | with open(self.human2integer, "r") as f:
86 | lines = f.read().splitlines()
87 | assert len(lines) == 1000
88 | self.human2integer_dict = dict()
89 | for line in lines:
90 | value, key = line.split(":")
91 | self.human2integer_dict[key] = int(value)
92 |
93 | def _load(self):
94 | with open(self.txt_filelist, "r") as f:
95 | self.relpaths = f.read().splitlines()
96 | l1 = len(self.relpaths)
97 | self.relpaths = self._filter_relpaths(self.relpaths)
98 | print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99 |
100 | self.synsets = [p.split("/")[0] for p in self.relpaths]
101 | self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102 |
103 | unique_synsets = np.unique(self.synsets)
104 | class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105 | if not self.keep_orig_class_label:
106 | self.class_labels = [class_dict[s] for s in self.synsets]
107 | else:
108 | self.class_labels = [self.synset2idx[s] for s in self.synsets]
109 |
110 | with open(self.human_dict, "r") as f:
111 | human_dict = f.read().splitlines()
112 | human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113 |
114 | self.human_labels = [human_dict[s] for s in self.synsets]
115 |
116 | labels = {
117 | "relpath": np.array(self.relpaths),
118 | "synsets": np.array(self.synsets),
119 | "class_label": np.array(self.class_labels),
120 | "human_label": np.array(self.human_labels),
121 | }
122 |
123 | if self.process_images:
124 | self.size = retrieve(self.config, "size", default=256)
125 | self.data = ImagePaths(self.abspaths,
126 | labels=labels,
127 | size=self.size,
128 | random_crop=self.random_crop,
129 | )
130 | else:
131 | self.data = self.abspaths
132 |
133 |
134 | class ImageNetTrain(ImageNetBase):
135 | NAME = "ILSVRC2012_train"
136 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137 | AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138 | FILES = [
139 | "ILSVRC2012_img_train.tar",
140 | ]
141 | SIZES = [
142 | 147897477120,
143 | ]
144 |
145 | def __init__(self, process_images=True, data_root=None, **kwargs):
146 | self.process_images = process_images
147 | self.data_root = data_root
148 | super().__init__(**kwargs)
149 |
150 | def _prepare(self):
151 | if self.data_root:
152 | self.root = os.path.join(self.data_root, self.NAME)
153 | else:
154 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156 |
157 | self.datadir = os.path.join(self.root, "data")
158 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
159 | self.expected_length = 1281167
160 | self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161 | default=True)
162 | if not tdu.is_prepared(self.root):
163 | # prep
164 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
165 |
166 | datadir = self.datadir
167 | if not os.path.exists(datadir):
168 | path = os.path.join(self.root, self.FILES[0])
169 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170 | import academictorrents as at
171 | atpath = at.get(self.AT_HASH, datastore=self.root)
172 | assert atpath == path
173 |
174 | print("Extracting {} to {}".format(path, datadir))
175 | os.makedirs(datadir, exist_ok=True)
176 | with tarfile.open(path, "r:") as tar:
177 | tar.extractall(path=datadir)
178 |
179 | print("Extracting sub-tars.")
180 | subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181 | for subpath in tqdm(subpaths):
182 | subdir = subpath[:-len(".tar")]
183 | os.makedirs(subdir, exist_ok=True)
184 | with tarfile.open(subpath, "r:") as tar:
185 | tar.extractall(path=subdir)
186 |
187 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189 | filelist = sorted(filelist)
190 | filelist = "\n".join(filelist)+"\n"
191 | with open(self.txt_filelist, "w") as f:
192 | f.write(filelist)
193 |
194 | tdu.mark_prepared(self.root)
195 |
196 |
197 | class ImageNetValidation(ImageNetBase):
198 | NAME = "ILSVRC2012_validation"
199 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200 | AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201 | VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202 | FILES = [
203 | "ILSVRC2012_img_val.tar",
204 | "validation_synset.txt",
205 | ]
206 | SIZES = [
207 | 6744924160,
208 | 1950000,
209 | ]
210 |
211 | def __init__(self, process_images=True, data_root=None, **kwargs):
212 | self.data_root = data_root
213 | self.process_images = process_images
214 | super().__init__(**kwargs)
215 |
216 | def _prepare(self):
217 | if self.data_root:
218 | self.root = os.path.join(self.data_root, self.NAME)
219 | else:
220 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222 | self.datadir = os.path.join(self.root, "data")
223 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
224 | self.expected_length = 50000
225 | self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226 | default=False)
227 | if not tdu.is_prepared(self.root):
228 | # prep
229 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
230 |
231 | datadir = self.datadir
232 | if not os.path.exists(datadir):
233 | path = os.path.join(self.root, self.FILES[0])
234 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235 | import academictorrents as at
236 | atpath = at.get(self.AT_HASH, datastore=self.root)
237 | assert atpath == path
238 |
239 | print("Extracting {} to {}".format(path, datadir))
240 | os.makedirs(datadir, exist_ok=True)
241 | with tarfile.open(path, "r:") as tar:
242 | tar.extractall(path=datadir)
243 |
244 | vspath = os.path.join(self.root, self.FILES[1])
245 | if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246 | download(self.VS_URL, vspath)
247 |
248 | with open(vspath, "r") as f:
249 | synset_dict = f.read().splitlines()
250 | synset_dict = dict(line.split() for line in synset_dict)
251 |
252 | print("Reorganizing into synset folders")
253 | synsets = np.unique(list(synset_dict.values()))
254 | for s in synsets:
255 | os.makedirs(os.path.join(datadir, s), exist_ok=True)
256 | for k, v in synset_dict.items():
257 | src = os.path.join(datadir, k)
258 | dst = os.path.join(datadir, v)
259 | shutil.move(src, dst)
260 |
261 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263 | filelist = sorted(filelist)
264 | filelist = "\n".join(filelist)+"\n"
265 | with open(self.txt_filelist, "w") as f:
266 | f.write(filelist)
267 |
268 | tdu.mark_prepared(self.root)
269 |
270 |
271 |
272 | class ImageNetSR(Dataset):
273 | def __init__(self, size=None,
274 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275 | random_crop=True):
276 | """
277 | Imagenet Superresolution Dataloader
278 | Performs following ops in order:
279 | 1. crops a crop of size s from image either as random or center crop
280 | 2. resizes crop to size with cv2.area_interpolation
281 | 3. degrades resized crop with degradation_fn
282 |
283 | :param size: resizing to size after cropping
284 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285 | :param downscale_f: Low Resolution Downsample factor
286 | :param min_crop_f: determines crop size s,
287 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288 | :param max_crop_f: ""
289 | :param data_root:
290 | :param random_crop:
291 | """
292 | self.base = self.get_base()
293 | assert size
294 | assert (size / downscale_f).is_integer()
295 | self.size = size
296 | self.LR_size = int(size / downscale_f)
297 | self.min_crop_f = min_crop_f
298 | self.max_crop_f = max_crop_f
299 | assert(max_crop_f <= 1.)
300 | self.center_crop = not random_crop
301 |
302 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303 |
304 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305 |
306 | if degradation == "bsrgan":
307 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308 |
309 | elif degradation == "bsrgan_light":
310 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311 |
312 | else:
313 | interpolation_fn = {
314 | "cv_nearest": cv2.INTER_NEAREST,
315 | "cv_bilinear": cv2.INTER_LINEAR,
316 | "cv_bicubic": cv2.INTER_CUBIC,
317 | "cv_area": cv2.INTER_AREA,
318 | "cv_lanczos": cv2.INTER_LANCZOS4,
319 | "pil_nearest": PIL.Image.NEAREST,
320 | "pil_bilinear": PIL.Image.BILINEAR,
321 | "pil_bicubic": PIL.Image.BICUBIC,
322 | "pil_box": PIL.Image.BOX,
323 | "pil_hamming": PIL.Image.HAMMING,
324 | "pil_lanczos": PIL.Image.LANCZOS,
325 | }[degradation]
326 |
327 | self.pil_interpolation = degradation.startswith("pil_")
328 |
329 | if self.pil_interpolation:
330 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331 |
332 | else:
333 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334 | interpolation=interpolation_fn)
335 |
336 | def __len__(self):
337 | return len(self.base)
338 |
339 | def __getitem__(self, i):
340 | example = self.base[i]
341 | image = Image.open(example["file_path_"])
342 |
343 | if not image.mode == "RGB":
344 | image = image.convert("RGB")
345 |
346 | image = np.array(image).astype(np.uint8)
347 |
348 | min_side_len = min(image.shape[:2])
349 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350 | crop_side_len = int(crop_side_len)
351 |
352 | if self.center_crop:
353 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354 |
355 | else:
356 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357 |
358 | image = self.cropper(image=image)["image"]
359 | image = self.image_rescaler(image=image)["image"]
360 |
361 | if self.pil_interpolation:
362 | image_pil = PIL.Image.fromarray(image)
363 | LR_image = self.degradation_process(image_pil)
364 | LR_image = np.array(LR_image).astype(np.uint8)
365 |
366 | else:
367 | LR_image = self.degradation_process(image=image)["image"]
368 |
369 | example["image"] = (image/127.5 - 1.0).astype(np.float32)
370 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371 |
372 | return example
373 |
374 |
375 | class ImageNetSRTrain(ImageNetSR):
376 | def __init__(self, **kwargs):
377 | super().__init__(**kwargs)
378 |
379 | def get_base(self):
380 | with open("data/imagenet_train_hr_indices.p", "rb") as f:
381 | indices = pickle.load(f)
382 | dset = ImageNetTrain(process_images=False,)
383 | return Subset(dset, indices)
384 |
385 |
386 | class ImageNetSRValidation(ImageNetSR):
387 | def __init__(self, **kwargs):
388 | super().__init__(**kwargs)
389 |
390 | def get_base(self):
391 | with open("data/imagenet_val_hr_indices.p", "rb") as f:
392 | indices = pickle.load(f)
393 | dset = ImageNetValidation(process_images=False,)
394 | return Subset(dset, indices)
395 |
--------------------------------------------------------------------------------
/structured_stable_diffusion/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7 |
8 | from structured_stable_diffusion.modules.diffusionmodules.model import Encoder, Decoder
9 | from structured_stable_diffusion.modules.distributions.distributions import DiagonalGaussianDistribution
10 |
11 | from structured_stable_diffusion.util import instantiate_from_config
12 |
13 |
14 | class VQModel(pl.LightningModule):
15 | def __init__(self,
16 | ddconfig,
17 | lossconfig,
18 | n_embed,
19 | embed_dim,
20 | ckpt_path=None,
21 | ignore_keys=[],
22 | image_key="image",
23 | colorize_nlabels=None,
24 | monitor=None,
25 | batch_resize_range=None,
26 | scheduler_config=None,
27 | lr_g_factor=1.0,
28 | remap=None,
29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
30 | use_ema=False
31 | ):
32 | super().__init__()
33 | self.embed_dim = embed_dim
34 | self.n_embed = n_embed
35 | self.image_key = image_key
36 | self.encoder = Encoder(**ddconfig)
37 | self.decoder = Decoder(**ddconfig)
38 | self.loss = instantiate_from_config(lossconfig)
39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40 | remap=remap,
41 | sane_index_shape=sane_index_shape)
42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44 | if colorize_nlabels is not None:
45 | assert type(colorize_nlabels)==int
46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47 | if monitor is not None:
48 | self.monitor = monitor
49 | self.batch_resize_range = batch_resize_range
50 | if self.batch_resize_range is not None:
51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52 |
53 | self.use_ema = use_ema
54 | if self.use_ema:
55 | self.model_ema = LitEma(self)
56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57 |
58 | if ckpt_path is not None:
59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60 | self.scheduler_config = scheduler_config
61 | self.lr_g_factor = lr_g_factor
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def init_from_ckpt(self, path, ignore_keys=list()):
79 | sd = torch.load(path, map_location="cpu")["state_dict"]
80 | keys = list(sd.keys())
81 | for k in keys:
82 | for ik in ignore_keys:
83 | if k.startswith(ik):
84 | print("Deleting key {} from state_dict.".format(k))
85 | del sd[k]
86 | missing, unexpected = self.load_state_dict(sd, strict=False)
87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88 | if len(missing) > 0:
89 | print(f"Missing Keys: {missing}")
90 | print(f"Unexpected Keys: {unexpected}")
91 |
92 | def on_train_batch_end(self, *args, **kwargs):
93 | if self.use_ema:
94 | self.model_ema(self)
95 |
96 | def encode(self, x):
97 | h = self.encoder(x)
98 | h = self.quant_conv(h)
99 | quant, emb_loss, info = self.quantize(h)
100 | return quant, emb_loss, info
101 |
102 | def encode_to_prequant(self, x):
103 | h = self.encoder(x)
104 | h = self.quant_conv(h)
105 | return h
106 |
107 | def decode(self, quant):
108 | quant = self.post_quant_conv(quant)
109 | dec = self.decoder(quant)
110 | return dec
111 |
112 | def decode_code(self, code_b):
113 | quant_b = self.quantize.embed_code(code_b)
114 | dec = self.decode(quant_b)
115 | return dec
116 |
117 | def forward(self, input, return_pred_indices=False):
118 | quant, diff, (_,_,ind) = self.encode(input)
119 | dec = self.decode(quant)
120 | if return_pred_indices:
121 | return dec, diff, ind
122 | return dec, diff
123 |
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129 | if self.batch_resize_range is not None:
130 | lower_size = self.batch_resize_range[0]
131 | upper_size = self.batch_resize_range[1]
132 | if self.global_step <= 4:
133 | # do the first few batches with max size to avoid later oom
134 | new_resize = upper_size
135 | else:
136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137 | if new_resize != x.shape[2]:
138 | x = F.interpolate(x, size=new_resize, mode="bicubic")
139 | x = x.detach()
140 | return x
141 |
142 | def training_step(self, batch, batch_idx, optimizer_idx):
143 | # https://github.com/pytorch/pytorch/issues/37142
144 | # try not to fool the heuristics
145 | x = self.get_input(batch, self.image_key)
146 | xrec, qloss, ind = self(x, return_pred_indices=True)
147 |
148 | if optimizer_idx == 0:
149 | # autoencode
150 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151 | last_layer=self.get_last_layer(), split="train",
152 | predicted_indices=ind)
153 |
154 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155 | return aeloss
156 |
157 | if optimizer_idx == 1:
158 | # discriminator
159 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160 | last_layer=self.get_last_layer(), split="train")
161 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162 | return discloss
163 |
164 | def validation_step(self, batch, batch_idx):
165 | log_dict = self._validation_step(batch, batch_idx)
166 | with self.ema_scope():
167 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168 | return log_dict
169 |
170 | def _validation_step(self, batch, batch_idx, suffix=""):
171 | x = self.get_input(batch, self.image_key)
172 | xrec, qloss, ind = self(x, return_pred_indices=True)
173 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174 | self.global_step,
175 | last_layer=self.get_last_layer(),
176 | split="val"+suffix,
177 | predicted_indices=ind
178 | )
179 |
180 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181 | self.global_step,
182 | last_layer=self.get_last_layer(),
183 | split="val"+suffix,
184 | predicted_indices=ind
185 | )
186 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187 | self.log(f"val{suffix}/rec_loss", rec_loss,
188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189 | self.log(f"val{suffix}/aeloss", aeloss,
190 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191 | if version.parse(pl.__version__) >= version.parse('1.4.0'):
192 | del log_dict_ae[f"val{suffix}/rec_loss"]
193 | self.log_dict(log_dict_ae)
194 | self.log_dict(log_dict_disc)
195 | return self.log_dict
196 |
197 | def configure_optimizers(self):
198 | lr_d = self.learning_rate
199 | lr_g = self.lr_g_factor*self.learning_rate
200 | print("lr_d", lr_d)
201 | print("lr_g", lr_g)
202 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203 | list(self.decoder.parameters())+
204 | list(self.quantize.parameters())+
205 | list(self.quant_conv.parameters())+
206 | list(self.post_quant_conv.parameters()),
207 | lr=lr_g, betas=(0.5, 0.9))
208 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209 | lr=lr_d, betas=(0.5, 0.9))
210 |
211 | if self.scheduler_config is not None:
212 | scheduler = instantiate_from_config(self.scheduler_config)
213 |
214 | print("Setting up LambdaLR scheduler...")
215 | scheduler = [
216 | {
217 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218 | 'interval': 'step',
219 | 'frequency': 1
220 | },
221 | {
222 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223 | 'interval': 'step',
224 | 'frequency': 1
225 | },
226 | ]
227 | return [opt_ae, opt_disc], scheduler
228 | return [opt_ae, opt_disc], []
229 |
230 | def get_last_layer(self):
231 | return self.decoder.conv_out.weight
232 |
233 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234 | log = dict()
235 | x = self.get_input(batch, self.image_key)
236 | x = x.to(self.device)
237 | if only_inputs:
238 | log["inputs"] = x
239 | return log
240 | xrec, _ = self(x)
241 | if x.shape[1] > 3:
242 | # colorize with random projection
243 | assert xrec.shape[1] > 3
244 | x = self.to_rgb(x)
245 | xrec = self.to_rgb(xrec)
246 | log["inputs"] = x
247 | log["reconstructions"] = xrec
248 | if plot_ema:
249 | with self.ema_scope():
250 | xrec_ema, _ = self(x)
251 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252 | log["reconstructions_ema"] = xrec_ema
253 | return log
254 |
255 | def to_rgb(self, x):
256 | assert self.image_key == "segmentation"
257 | if not hasattr(self, "colorize"):
258 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259 | x = F.conv2d(x, weight=self.colorize)
260 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261 | return x
262 |
263 |
264 | class VQModelInterface(VQModel):
265 | def __init__(self, embed_dim, *args, **kwargs):
266 | super().__init__(embed_dim=embed_dim, *args, **kwargs)
267 | self.embed_dim = embed_dim
268 |
269 | def encode(self, x):
270 | h = self.encoder(x)
271 | h = self.quant_conv(h)
272 | return h
273 |
274 | def decode(self, h, force_not_quantize=False):
275 | # also go through quantization layer
276 | if not force_not_quantize:
277 | quant, emb_loss, info = self.quantize(h)
278 | else:
279 | quant = h
280 | quant = self.post_quant_conv(quant)
281 | dec = self.decoder(quant)
282 | return dec
283 |
284 |
285 | class AutoencoderKL(pl.LightningModule):
286 | def __init__(self,
287 | ddconfig,
288 | lossconfig,
289 | embed_dim,
290 | ckpt_path=None,
291 | ignore_keys=[],
292 | image_key="image",
293 | colorize_nlabels=None,
294 | monitor=None,
295 | ):
296 | super().__init__()
297 | self.image_key = image_key
298 | self.encoder = Encoder(**ddconfig)
299 | self.decoder = Decoder(**ddconfig)
300 | self.loss = instantiate_from_config(lossconfig)
301 | assert ddconfig["double_z"]
302 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304 | self.embed_dim = embed_dim
305 | if colorize_nlabels is not None:
306 | assert type(colorize_nlabels)==int
307 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308 | if monitor is not None:
309 | self.monitor = monitor
310 | if ckpt_path is not None:
311 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312 |
313 | def init_from_ckpt(self, path, ignore_keys=list()):
314 | sd = torch.load(path, map_location="cpu")["state_dict"]
315 | keys = list(sd.keys())
316 | for k in keys:
317 | for ik in ignore_keys:
318 | if k.startswith(ik):
319 | print("Deleting key {} from state_dict.".format(k))
320 | del sd[k]
321 | self.load_state_dict(sd, strict=False)
322 | print(f"Restored from {path}")
323 |
324 | def encode(self, x):
325 | h = self.encoder(x)
326 | moments = self.quant_conv(h)
327 | posterior = DiagonalGaussianDistribution(moments)
328 | return posterior
329 |
330 | def decode(self, z):
331 | z = self.post_quant_conv(z)
332 | dec = self.decoder(z)
333 | return dec
334 |
335 | def forward(self, input, sample_posterior=True):
336 | posterior = self.encode(input)
337 | if sample_posterior:
338 | z = posterior.sample()
339 | else:
340 | z = posterior.mode()
341 | dec = self.decode(z)
342 | return dec, posterior
343 |
344 | def get_input(self, batch, k):
345 | x = batch[k]
346 | if len(x.shape) == 3:
347 | x = x[..., None]
348 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349 | return x
350 |
351 | def training_step(self, batch, batch_idx, optimizer_idx):
352 | inputs = self.get_input(batch, self.image_key)
353 | reconstructions, posterior = self(inputs)
354 |
355 | if optimizer_idx == 0:
356 | # train encoder+decoder+logvar
357 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358 | last_layer=self.get_last_layer(), split="train")
359 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361 | return aeloss
362 |
363 | if optimizer_idx == 1:
364 | # train the discriminator
365 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366 | last_layer=self.get_last_layer(), split="train")
367 |
368 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370 | return discloss
371 |
372 | def validation_step(self, batch, batch_idx):
373 | inputs = self.get_input(batch, self.image_key)
374 | reconstructions, posterior = self(inputs)
375 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376 | last_layer=self.get_last_layer(), split="val")
377 |
378 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379 | last_layer=self.get_last_layer(), split="val")
380 |
381 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382 | self.log_dict(log_dict_ae)
383 | self.log_dict(log_dict_disc)
384 | return self.log_dict
385 |
386 | def configure_optimizers(self):
387 | lr = self.learning_rate
388 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389 | list(self.decoder.parameters())+
390 | list(self.quant_conv.parameters())+
391 | list(self.post_quant_conv.parameters()),
392 | lr=lr, betas=(0.5, 0.9))
393 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394 | lr=lr, betas=(0.5, 0.9))
395 | return [opt_ae, opt_disc], []
396 |
397 | def get_last_layer(self):
398 | return self.decoder.conv_out.weight
399 |
400 | @torch.no_grad()
401 | def log_images(self, batch, only_inputs=False, **kwargs):
402 | log = dict()
403 | x = self.get_input(batch, self.image_key)
404 | x = x.to(self.device)
405 | if not only_inputs:
406 | xrec, posterior = self(x)
407 | if x.shape[1] > 3:
408 | # colorize with random projection
409 | assert xrec.shape[1] > 3
410 | x = self.to_rgb(x)
411 | xrec = self.to_rgb(xrec)
412 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413 | log["reconstructions"] = xrec
414 | log["inputs"] = x
415 | return log
416 |
417 | def to_rgb(self, x):
418 | assert self.image_key == "segmentation"
419 | if not hasattr(self, "colorize"):
420 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421 | x = F.conv2d(x, weight=self.colorize)
422 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423 | return x
424 |
425 |
426 | class IdentityFirstStage(torch.nn.Module):
427 | def __init__(self, *args, vq_interface=False, **kwargs):
428 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429 | super().__init__()
430 |
431 | def encode(self, x, *args, **kwargs):
432 | return x
433 |
434 | def decode(self, x, *args, **kwargs):
435 | return x
436 |
437 | def quantize(self, x, *args, **kwargs):
438 | if self.vq_interface:
439 | return x, None, [None, None, None]
440 | return x
441 |
442 | def forward(self, x, *args, **kwargs):
443 | return x
444 |
--------------------------------------------------------------------------------
/scripts/txt2img_demo.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from collections import defaultdict
3 | from ossaudiodev import SNDCTL_SEQ_CTRLRATE
4 | from ast import parse
5 | import cv2
6 | import torch
7 | import numpy as np
8 | from omegaconf import OmegaConf
9 | from PIL import Image
10 | from tqdm import tqdm, trange
11 | from itertools import islice
12 | from einops import rearrange
13 | from torchvision.utils import make_grid
14 | import time
15 | from pytorch_lightning import seed_everything
16 | from torch import autocast
17 | from contextlib import contextmanager, nullcontext
18 |
19 | from structured_stable_diffusion.util import instantiate_from_config
20 | from structured_stable_diffusion.models.diffusion.ddim import DDIMSampler
21 | from structured_stable_diffusion.models.diffusion.plms import PLMSSampler
22 |
23 | import sng_parser
24 | import stanza
25 | from nltk.tree import Tree
26 | nlp = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency')
27 | import pdb
28 | import json
29 |
30 |
31 | def preprocess_prompts(prompts):
32 | if isinstance(prompts, (list, tuple)):
33 | return [p.lower().strip().strip(".").strip() for p in prompts]
34 | elif isinstance(prompts, str):
35 | return prompts.lower().strip().strip(".").strip()
36 | else:
37 | raise NotImplementedError
38 |
39 |
40 | def get_all_nps(tree, full_sent, tokens=None, highest_only=False, lowest_only=False):
41 | start = 0
42 | end = len(tree.leaves())
43 |
44 | idx_map = get_token_alignment_map(tree, tokens)
45 |
46 | def get_sub_nps(tree, left, right):
47 | if isinstance(tree, str) or len(tree.leaves()) == 1:
48 | return []
49 | sub_nps = []
50 | n_leaves = len(tree.leaves())
51 | n_subtree_leaves = [len(t.leaves()) for t in tree]
52 | offset = np.cumsum([0] + n_subtree_leaves)[:len(n_subtree_leaves)]
53 | assert right - left == n_leaves
54 | if tree.label() == 'NP' and n_leaves > 1:
55 | sub_nps.append([" ".join(tree.leaves()), (int(min(idx_map[left])), int(min(idx_map[right])))])
56 | if highest_only and sub_nps[-1][0] != full_sent: return sub_nps
57 | for i, subtree in enumerate(tree):
58 | sub_nps += get_sub_nps(subtree, left=left+offset[i], right=left+offset[i]+n_subtree_leaves[i])
59 | return sub_nps
60 |
61 | all_nps = get_sub_nps(tree, left=start, right=end)
62 | lowest_nps = []
63 | for i in range(len(all_nps)):
64 | span = all_nps[i][1]
65 | lowest = True
66 | for j in range(len(all_nps)):
67 | if i == j: continue
68 | span2 = all_nps[j][1]
69 | if span2[0] >= span[0] and span2[1] <= span[1]:
70 | lowest = False
71 | break
72 | if lowest:
73 | lowest_nps.append(all_nps[i])
74 |
75 | if lowest_only:
76 | all_nps = lowest_nps
77 |
78 | if len(all_nps) == 0:
79 | all_nps = []
80 | spans = []
81 | else:
82 | all_nps, spans = map(list, zip(*all_nps))
83 | if full_sent not in all_nps:
84 | all_nps = [full_sent] + all_nps
85 | spans = [(min(idx_map[start]), min(idx_map[end]))] + spans
86 |
87 | return all_nps, spans, lowest_nps
88 |
89 |
90 | def get_token_alignment_map(tree, tokens):
91 | if tokens is None:
92 | return {i:[i] for i in range(len(tree.leaves())+1)}
93 |
94 | def get_token(token):
95 | return token[:-4] if token.endswith("") else token
96 |
97 | idx_map = {}
98 | j = 0
99 | max_offset = np.abs(len(tokens) - len(tree.leaves()))
100 | mytree_prev_leaf = ""
101 | for i, w in enumerate(tree.leaves()):
102 | token = get_token(tokens[j])
103 | idx_map[i] = [j]
104 | if token == mytree_prev_leaf+w:
105 | mytree_prev_leaf = ""
106 | j += 1
107 | else:
108 | if len(token) < len(w):
109 | prev = ""
110 | while prev + token != w:
111 | prev += token
112 | j += 1
113 | token = get_token(tokens[j])
114 | idx_map[i].append(j)
115 | # assert j - i <= max_offset
116 | else:
117 | mytree_prev_leaf += w
118 | j -= 1
119 | j += 1
120 | idx_map[i+1] = [j]
121 | return idx_map
122 |
123 |
124 | def get_all_spans_from_scene_graph(caption):
125 | caption = caption.strip()
126 | graph = sng_parser.parse(caption)
127 | nps = []
128 | spans = []
129 | words = caption.split()
130 | for e in graph['entities']:
131 | start, end = e['span_bounds']
132 | if e['span'] == caption: continue
133 | if end-start == 1: continue
134 | nps.append(e['span'])
135 | spans.append(e['span_bounds'])
136 | for r in graph['relations']:
137 | start1, end1 = graph['entities'][r['subject']]['span_bounds']
138 | start2, end2 = graph['entities'][r['object']]['span_bounds']
139 | start = min(start1, start2)
140 | end = max(end1, end2)
141 | if " ".join(words[start:end]) == caption: continue
142 | nps.append(" ".join(words[start:end]))
143 | spans.append((start, end))
144 |
145 | return [caption] + nps, [(0, len(words))] + spans, None
146 |
147 |
148 | def single_align(main_seq, seqs, spans, dim=1):
149 | main_seq = main_seq.transpose(0, dim)
150 | for seq, span in zip(seqs, spans):
151 | seq = seq.transpose(0, dim)
152 | start, end = span[0]+1, span[1]+1
153 | seg_length = end - start
154 | main_seq[start:end] = seq[1:1+seg_length]
155 |
156 | return main_seq.transpose(0, dim)
157 |
158 |
159 | def multi_align(main_seq, seq, span, dim=1):
160 | seq = seq.transpose(0, dim)
161 | main_seq = main_seq.transpose(0, dim)
162 | start, end = span[0]+1, span[1]+1
163 | seg_length = end - start
164 |
165 | main_seq[start:end] = seq[1:1+seg_length]
166 |
167 | return main_seq.transpose(0, dim)
168 |
169 |
170 | def align_sequence(main_seq, seqs, spans, dim=1, single=False):
171 | aligned_seqs = []
172 | if single:
173 | return [single_align(main_seq, seqs, spans, dim=dim)]
174 | else:
175 | for seq, span in zip(seqs, spans):
176 | aligned_seqs.append(multi_align(main_seq.clone(), seq, span, dim=dim))
177 | return aligned_seqs
178 |
179 |
180 | def chunk(it, size):
181 | it = iter(it)
182 | return iter(lambda: tuple(islice(it, size)), ())
183 |
184 |
185 | def load_model_from_config(config, ckpt, verbose=False):
186 | print(f"Loading model from {ckpt}")
187 | pl_sd = torch.load(ckpt, map_location="cpu")
188 | if "global_step" in pl_sd:
189 | print(f"Global Step: {pl_sd['global_step']}")
190 | sd = pl_sd["state_dict"]
191 | model = instantiate_from_config(config.model)
192 | m, u = model.load_state_dict(sd, strict=False)
193 | if len(m) > 0 and verbose:
194 | print("missing keys:")
195 | print(m)
196 | if len(u) > 0 and verbose:
197 | print("unexpected keys:")
198 | print(u)
199 |
200 | model.cuda()
201 | model.eval()
202 | return model
203 |
204 |
205 | def load_replacement(x):
206 | try:
207 | hwc = x.shape
208 | y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
209 | y = (np.array(y)/255.0).astype(x.dtype)
210 | assert y.shape == x.shape
211 | return y
212 | except Exception:
213 | return x
214 |
215 |
216 | def main():
217 | parser = argparse.ArgumentParser()
218 |
219 | parser.add_argument(
220 | "--prompt",
221 | type=str,
222 | nargs="?",
223 | default="a painting of a virus monster playing guitar",
224 | help="the prompt to render"
225 | )
226 | parser.add_argument(
227 | "--outdir",
228 | type=str,
229 | nargs="?",
230 | help="dir to write results to",
231 | default="outputs/txt2img-samples"
232 | )
233 | parser.add_argument(
234 | "--skip_grid",
235 | action='store_true',
236 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
237 | )
238 | parser.add_argument(
239 | "--skip_save",
240 | action='store_true',
241 | help="do not save individual samples. For speed measurements.",
242 | )
243 | parser.add_argument(
244 | "--ddim_steps",
245 | type=int,
246 | default=50,
247 | help="number of ddim sampling steps",
248 | )
249 | parser.add_argument(
250 | "--plms",
251 | action='store_true',
252 | help="use plms sampling",
253 | )
254 | parser.add_argument(
255 | "--laion400m",
256 | action='store_true',
257 | help="uses the LAION400M model",
258 | )
259 | parser.add_argument(
260 | "--fixed_code",
261 | action='store_true',
262 | help="if enabled, uses the same starting code across samples ",
263 | )
264 | parser.add_argument(
265 | "--ddim_eta",
266 | type=float,
267 | default=0.0,
268 | help="ddim eta (eta=0.0 corresponds to deterministic sampling",
269 | )
270 | parser.add_argument(
271 | "--n_iter",
272 | type=int,
273 | default=2,
274 | help="sample this often",
275 | )
276 | parser.add_argument(
277 | "--H",
278 | type=int,
279 | default=512,
280 | help="image height, in pixel space",
281 | )
282 | parser.add_argument(
283 | "--W",
284 | type=int,
285 | default=512,
286 | help="image width, in pixel space",
287 | )
288 | parser.add_argument(
289 | "--C",
290 | type=int,
291 | default=4,
292 | help="latent channels",
293 | )
294 | parser.add_argument(
295 | "--f",
296 | type=int,
297 | default=8,
298 | help="downsampling factor",
299 | )
300 | parser.add_argument(
301 | "--n_samples",
302 | type=int,
303 | default=1,
304 | help="how many samples to produce for each given prompt. A.k.a. batch size",
305 | )
306 | parser.add_argument(
307 | "--n_rows",
308 | type=int,
309 | default=0,
310 | help="rows in the grid (default: n_samples)",
311 | )
312 | parser.add_argument(
313 | "--scale",
314 | type=float,
315 | default=7.5,
316 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
317 | )
318 | parser.add_argument(
319 | "--from-file",
320 | type=str,
321 | help="if specified, load prompts from this file",
322 | )
323 | parser.add_argument(
324 | "--config",
325 | type=str,
326 | default="configs/stable-diffusion/v1-inference.yaml",
327 | help="path to config which constructs model",
328 | )
329 | parser.add_argument(
330 | "--ckpt",
331 | type=str,
332 | default="models/ldm/stable-diffusion-v1/model.ckpt",
333 | help="path to checkpoint of model",
334 | )
335 | parser.add_argument(
336 | "--seed",
337 | type=int,
338 | default=42,
339 | help="the seed (for reproducible sampling)",
340 | )
341 | parser.add_argument(
342 | "--precision",
343 | type=str,
344 | help="evaluate at this precision",
345 | choices=["full", "autocast"],
346 | default="autocast"
347 | )
348 | parser.add_argument(
349 | "--parser_type",
350 | type=str,
351 | choices=['constituency', 'scene_graph'],
352 | default='constituency'
353 | )
354 | parser.add_argument(
355 | "--conjunction",
356 | action='store_true',
357 | help='If True, the input prompt is a conjunction of two concepts like "A and B"'
358 | )
359 | parser.add_argument(
360 | "--save_attn_maps",
361 | action='store_true',
362 | help='If True, the attention maps will be saved as a .pth file with the name same as the image'
363 | )
364 |
365 | opt = parser.parse_args()
366 |
367 | if opt.laion400m:
368 | print("Falling back to LAION 400M model...")
369 | opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
370 | opt.ckpt = "models/ldm/text2img-large/model.ckpt"
371 | opt.outdir = "outputs/txt2img-samples-laion400m"
372 |
373 | seed_everything(opt.seed)
374 |
375 | config = OmegaConf.load(f"{opt.config}")
376 | model = load_model_from_config(config, f"{opt.ckpt}")
377 |
378 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
379 | model = model.to(device)
380 |
381 | if opt.plms:
382 | sampler = PLMSSampler(model)
383 | else:
384 | sampler = DDIMSampler(model)
385 |
386 | os.makedirs(opt.outdir, exist_ok=True)
387 | outpath = opt.outdir
388 |
389 |
390 | batch_size = opt.n_samples
391 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
392 | if not opt.from_file:
393 | prompt = opt.prompt
394 | assert prompt is not None
395 | data = [batch_size * [prompt]]
396 | opt.from_file = ""
397 | else:
398 | print(f"reading prompts from {opt.from_file}")
399 | with open(opt.from_file, "r") as f:
400 | data = f.read().splitlines()
401 | try:
402 | opt.end_idx = len(data) if opt.end_idx == -1 else opt.end_idx
403 | data = data[:opt.end_idx]
404 | data, filenames = zip(*[d.strip("\n").split("\t") for d in data])
405 | data = list(chunk(data, batch_size))
406 | except:
407 | data = [batch_size * [d] for d in data]
408 |
409 | sample_path = os.path.join(outpath, "samples")
410 | os.makedirs(sample_path, exist_ok=True)
411 | base_count = len(os.listdir(sample_path))
412 | grid_count = len(os.listdir(outpath)) - 1
413 |
414 | start_code = None
415 | if opt.fixed_code:
416 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
417 |
418 | precision_scope = autocast if opt.precision=="autocast" else nullcontext
419 | with torch.no_grad():
420 | with precision_scope("cuda"):
421 | with model.ema_scope():
422 |
423 | all_samples = list()
424 | for n in trange(opt.n_iter, desc="Sampling"):
425 | for bid, prompts in enumerate(tqdm(data, desc="data")):
426 | prompts = preprocess_prompts(prompts)
427 |
428 | uc = None
429 | if opt.scale != 1.0:
430 | uc = model.get_learned_conditioning(batch_size * [""])
431 |
432 | c = model.get_learned_conditioning(prompts)
433 |
434 | if opt.parser_type == 'constituency':
435 | doc = nlp(prompts[0])
436 | mytree = Tree.fromstring(str(doc.sentences[0].constituency))
437 | tokens = model.cond_stage_model.tokenizer.tokenize(prompts[0])
438 | nps, spans, noun_chunk = get_all_nps(mytree, prompts[0], tokens)
439 | elif opt.parser_type == 'scene_graph':
440 | nps, spans, noun_chunk = get_all_spans_from_scene_graph(prompts[0].split("\t")[0])
441 | else:
442 | raise NotImplementedError
443 |
444 | nps = [[np]*len(prompts) for np in nps]
445 |
446 | if opt.conjunction:
447 | c = [model.get_learned_conditioning(np) for np in nps]
448 | k_c = [c[0]] + align_sequence(c[0], c[1:], spans[1:])
449 | v_c = align_sequence(c[0], c[1:], spans[1:], single=True)
450 | c = {'k': k_c, 'v': v_c}
451 | else:
452 | c = [model.get_learned_conditioning(np) for np in nps]
453 | k_c = c[:1]
454 | v_c = [c[0]] + align_sequence(c[0], c[1:], spans[1:])
455 | c = {'k': k_c, 'v': v_c}
456 |
457 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
458 | samples_ddim, intermediates = sampler.sample(S=opt.ddim_steps,
459 | conditioning=c,
460 | batch_size=opt.n_samples,
461 | shape=shape,
462 | verbose=False,
463 | unconditional_guidance_scale=opt.scale,
464 | unconditional_conditioning=uc,
465 | eta=opt.ddim_eta,
466 | x_T=start_code,
467 | save_attn_maps=opt.save_attn_maps)
468 |
469 | x_samples_ddim = model.decode_first_stage(samples_ddim)
470 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
471 | x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
472 |
473 | x_checked_image = x_samples_ddim
474 |
475 | x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
476 |
477 | if not opt.skip_save:
478 | for sid, x_sample in enumerate(x_checked_image_torch):
479 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
480 | img = Image.fromarray(x_sample.astype(np.uint8))
481 | try:
482 | count = bid * opt.n_samples + sid
483 | safe_filename = f"{n}-{count}-" + (filenames[count][:-4])[:150] + ".jpg"
484 | except:
485 | safe_filename = f"{base_count:05}-{n}-{prompts[0]}"[:100] + ".jpg"
486 | img.save(os.path.join(sample_path, f"{safe_filename}"))
487 |
488 | if opt.save_attn_maps:
489 | torch.save(sampler.attn_maps, os.path.join(sample_path, f'{safe_filename}.pt'))
490 |
491 | base_count += 1
492 |
493 | if not opt.skip_grid:
494 | all_samples.append(x_checked_image_torch)
495 |
496 | if not opt.skip_grid:
497 | # additionally, save as grid
498 | grid = torch.stack(all_samples, 0)
499 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
500 | grid = make_grid(grid, nrow=n_rows)
501 |
502 | # to image
503 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
504 | img = Image.fromarray(grid.astype(np.uint8))
505 | img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
506 | grid_count += 1
507 |
508 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
509 | f" \nEnjoy.")
510 |
511 |
512 | if __name__ == "__main__":
513 | main()
514 |
--------------------------------------------------------------------------------