├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── lsun.py
│ └── celebahq.py
├── models
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ └── compose_modules.py
├── modules
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ └── util.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ ├── image_degradation
│ │ ├── utils
│ │ │ └── test.png
│ │ └── __init__.py
│ ├── ema.py
│ └── attention.py
├── lr_scheduler.py
└── util.py
├── assets
├── fig_teaser.jpg
└── fig_framework.jpg
├── test_data
├── 256_masks
│ └── 29980.png
├── 512_masks
│ ├── 27007.png
│ └── 29980.png
└── test_mask_edit
│ ├── 256_input_image
│ └── 27044.jpg
│ └── 256_edited_masks
│ └── 27044_0_remove_smile_and_rings.png
├── freeu
├── assets
│ ├── mask2face_27007.jpeg
│ ├── mask2face_29980.jpeg
│ ├── text2face_female.jpeg
│ └── text2face_male.jpeg
├── README.md
├── text2image_freeu.py
└── mask2image_freeu.py
├── .gitignore
├── setup.py
├── environment.yaml
├── LICENSE
├── configs
├── 512_vae.yaml
├── 256_vae.yaml
├── 512_text.yaml
├── 512_codiff_mask_text.yaml
├── 256_text.yaml
├── 512_mask.yaml
├── 256_codiff_mask_text.yaml
└── 256_mask.yaml
├── text2image.py
├── mask2image.py
├── generate.py
└── editing
├── imagic_edit_text.py
├── collaborative_edit.py
└── imagic_edit_mask.py
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/assets/fig_teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/assets/fig_teaser.jpg
--------------------------------------------------------------------------------
/assets/fig_framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/assets/fig_framework.jpg
--------------------------------------------------------------------------------
/test_data/256_masks/29980.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/256_masks/29980.png
--------------------------------------------------------------------------------
/test_data/512_masks/27007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/512_masks/27007.png
--------------------------------------------------------------------------------
/test_data/512_masks/29980.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/512_masks/29980.png
--------------------------------------------------------------------------------
/freeu/assets/mask2face_27007.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/mask2face_27007.jpeg
--------------------------------------------------------------------------------
/freeu/assets/mask2face_29980.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/mask2face_29980.jpeg
--------------------------------------------------------------------------------
/freeu/assets/text2face_female.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/text2face_female.jpeg
--------------------------------------------------------------------------------
/freeu/assets/text2face_male.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/text2face_male.jpeg
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/test_data/test_mask_edit/256_input_image/27044.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/test_mask_edit/256_input_image/27044.jpg
--------------------------------------------------------------------------------
/test_data/test_mask_edit/256_edited_masks/27044_0_remove_smile_and_rings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/test_mask_edit/256_edited_masks/27044_0_remove_smile_and_rings.png
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # postscripts
2 | *.ckpt
3 | *.zip
4 | *.pyc
5 | *.pt
6 | *.pth
7 |
8 | # directories
9 | pretrained
10 | dataset
11 | outputs
12 | src
13 | trash
14 | stashed_for_git
15 | private_scripts
16 | ldm/data/trash
17 |
18 | latent_diffusion.egg-info
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: codiff
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.4.2
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - torch-fidelity==0.3.0
24 | - transformers==4.3.1
25 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
26 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
27 | - -e .
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2023 S-Lab
4 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
8 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
9 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
--------------------------------------------------------------------------------
/configs/512_vae.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 512
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4]
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 2
30 | wrap: True
31 | train:
32 | target: ldm.data.celebahq.CelebAConditionalDataset
33 | params:
34 | phase: train
35 | im_preprocessor_config:
36 | target: ldm.data.celebahq.DalleTransformerPreprocessor
37 | params:
38 | size: 512
39 | phase: train
40 | test_dataset_size: 3000
41 | conditions: []
42 | validation:
43 | target: ldm.data.celebahq.CelebAConditionalDataset
44 | params:
45 | phase: test
46 | im_preprocessor_config:
47 | target: ldm.data.celebahq.DalleTransformerPreprocessor
48 | params:
49 | size: 512
50 | phase: val
51 | test_dataset_size: 3000
52 | conditions: []
53 |
54 | lightning:
55 | callbacks:
56 | image_logger:
57 | target: main.ImageLogger
58 | params:
59 | batch_frequency: 1000
60 | max_images: 8
61 | increase_log_steps: True
62 |
63 | trainer:
64 | benchmark: True
65 | accumulate_grad_batches: 2
66 |
--------------------------------------------------------------------------------
/configs/256_vae.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ]
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 8
31 | wrap: True
32 | train:
33 | target: ldm.data.celebahq.CelebAConditionalDataset
34 | params:
35 | phase: train
36 | im_preprocessor_config:
37 | target: ldm.data.celebahq.DalleTransformerPreprocessor
38 | params:
39 | size: 256
40 | phase: train
41 | test_dataset_size: 3000
42 | conditions: []
43 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
44 | validation:
45 | target: ldm.data.celebahq.CelebAConditionalDataset
46 | params:
47 | phase: test
48 | im_preprocessor_config:
49 | target: ldm.data.celebahq.DalleTransformerPreprocessor
50 | params:
51 | size: 256
52 | phase: val
53 | test_dataset_size: 3000
54 | conditions: []
55 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
56 |
57 | lightning:
58 | callbacks:
59 | image_logger:
60 | target: main.ImageLogger
61 | params:
62 | batch_frequency: 1000
63 | max_images: 8
64 | increase_log_steps: True
65 |
66 | trainer:
67 | benchmark: True
68 | accumulate_grad_batches: 2
69 |
--------------------------------------------------------------------------------
/freeu/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Face Generation + FreeU
3 | We now integrated [FreeU](https://chenyangsi.top/FreeU/) [](https://huggingface.co/spaces/ChenyangSi/FreeU) into the LDMs to further boost synthesis quality.
4 | We adapted FreeU's [code](https://github.com/ChenyangSi/FreeU) for our face generation diffusion models.
5 | For more details about FreeU, please refer to the [paper](https://arxiv.org/abs/2309.11497).
6 |
7 | ## Mask-to-Face Generation:
8 |
9 |
10 |
11 | 1. without FreeU
12 | ```bash
13 | python freeu/mask2image_freeu.py \
14 | --mask_path "test_data/512_masks/27007.png"
15 | ```
16 | 2. with FreeU
17 | ```bash
18 | python freeu/mask2image_freeu.py \
19 | --mask_path "test_data/512_masks/27007.png" \
20 | --enable_freeu \
21 | --b1 1.1 \
22 | --b2 1.2 \
23 | --s1 1 \
24 | --s2 1
25 | ```
26 | ## Text-to-Face Generation:
27 |
28 |
29 |
30 |
31 | 1. without FreeU
32 | ```bash
33 | python freeu/text2image_freeu.py \
34 | --input_text "This man has beard of medium length. He is in his thirties."
35 | ```
36 | 2. with FreeU
37 | ```bash
38 | python freeu/text2image_freeu.py \
39 | --input_text "This man has beard of medium length. He is in his thirties." \
40 | --enable_freeu \
41 | --b1 1.1 \
42 | --b2 1.2 \
43 | --s1 1 \
44 | --s2 1
45 | ```
46 | another example:
47 | 1. without FreeU
48 | ```bash
49 | python freeu/text2image_freeu.py \
50 | --input_text "This woman is in her forties."
51 | ```
52 | 2. with FreeU
53 | ```bash
54 | python freeu/text2image_freeu.py \
55 | --input_text "This woman is in her forties." \
56 | --enable_freeu \
57 | --b1 1.1 \
58 | --b2 1.2 \
59 | --s1 1 \
60 | --s2 1
61 | ```
62 |
--------------------------------------------------------------------------------
/configs/512_text.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.141
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 64
24 | in_channels: 3
25 | out_channels: 3
26 | model_channels: 192
27 | attention_resolutions:
28 | - 8
29 | - 4
30 | - 2
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 3
36 | - 5
37 | num_heads: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 640
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 3
48 | monitor: val/rec_loss
49 | ckpt_path: pretrained/512_vae.ckpt
50 | ddconfig:
51 | double_z: true
52 | z_channels: 3
53 | resolution: 512
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [1, 2, 4, 4]
58 | num_res_blocks: 2
59 | attn_resolutions: [ ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: torch.nn.Identity
63 |
64 | cond_stage_config:
65 | target: ldm.modules.encoders.modules.BERTEmbedder
66 | params:
67 | n_embed: 640
68 | n_layer: 32
69 |
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 8
74 | wrap: True
75 | train:
76 | target: ldm.data.celebahq.CelebAConditionalDataset
77 | params:
78 | phase: train
79 | im_preprocessor_config:
80 | target: ldm.data.celebahq.DalleTransformerPreprocessor
81 | params:
82 | size: 512
83 | phase: train
84 | test_dataset_size: 3000
85 | conditions:
86 | - 'text'
87 | validation:
88 | target: ldm.data.celebahq.CelebAConditionalDataset
89 | params:
90 | phase: test
91 | im_preprocessor_config:
92 | target: ldm.data.celebahq.DalleTransformerPreprocessor
93 | params:
94 | size: 512
95 | phase: val
96 | test_dataset_size: 3000
97 | conditions:
98 | - 'text'
99 |
100 |
101 | lightning:
102 | callbacks:
103 | image_logger:
104 | target: main.ImageLogger
105 | params:
106 | batch_frequency: 1000
107 | max_images: 8
108 | increase_log_steps: False
109 |
110 | trainer:
111 | benchmark: True
--------------------------------------------------------------------------------
/configs/512_codiff_mask_text.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm_compose.LatentDiffusionCompose
4 | params:
5 |
6 | linear_start: 0.0015
7 | linear_end: 0.0195
8 | num_timesteps_cond: 1
9 | log_every_t: 200
10 | timesteps: 1000
11 | first_stage_key: image
12 | cond_stage_key: conditions
13 | image_size: 64
14 | channels: 3
15 | cond_stage_trainable: True # so that don't need to change LatentDiffusion def forward's code
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple
18 | scale_factor: 0.141
19 | use_ema: False
20 |
21 | seg_mask_ldm_config_path: configs/512_mask.yaml
22 | seg_mask_ldm_ckpt_path: pretrained/512_mask.ckpt
23 | text_ldm_config_path: configs/512_text.yaml
24 | text_ldm_ckpt_path: pretrained/512_text.ckpt
25 |
26 | compose_unet_config:
27 | target: ldm.models.diffusion.compose_modules.ComposeUNet
28 | params:
29 | confidence_conditioning_key: crossattn
30 | confidence_input: x_t
31 | confidence_map_predictor_config:
32 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
33 | params:
34 | image_size: 64
35 | in_channels: 3
36 | out_channels: 1
37 | model_channels: 32
38 | attention_resolutions:
39 | - 8
40 | - 4
41 | - 2
42 | num_res_blocks: 2
43 | channel_mult:
44 | - 1
45 | - 2
46 | - 3
47 | - 5
48 | num_heads: 32
49 | use_spatial_transformer: true
50 | transformer_depth: 1
51 | context_dim: 640
52 | use_checkpoint: true
53 | legacy: False
54 |
55 |
56 | compose_cond_stage_config:
57 | target: ldm.models.diffusion.compose_modules.ComposeCondStageModel
58 | params: {}
59 |
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 8
64 | wrap: True
65 | train:
66 | target: ldm.data.celebahq.CelebAConditionalDataset
67 | params:
68 | phase: train
69 | im_preprocessor_config:
70 | target: ldm.data.celebahq.DalleTransformerPreprocessor
71 | params:
72 | size: 512
73 | phase: train
74 | test_dataset_size: 3000
75 | conditions:
76 | - 'text'
77 | - 'seg_mask'
78 | validation:
79 | target: ldm.data.celebahq.CelebAConditionalDataset
80 | params:
81 | phase: test
82 | im_preprocessor_config:
83 | target: ldm.data.celebahq.DalleTransformerPreprocessor
84 | params:
85 | size: 512
86 | phase: val
87 | test_dataset_size: 3000
88 | conditions:
89 | - 'text'
90 | - 'seg_mask'
91 |
92 |
93 | lightning:
94 | callbacks:
95 | image_logger:
96 | target: main.ImageLogger
97 | params:
98 | batch_frequency: 1000
99 | max_images: 8
100 | increase_log_steps: False
101 |
102 | trainer:
103 | benchmark: True
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/configs/256_text.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.058
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 64
24 | in_channels: 3
25 | out_channels: 3
26 | model_channels: 192
27 | attention_resolutions:
28 | - 8
29 | - 4
30 | - 2
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 3
36 | - 5
37 | num_heads: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 640
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 3
48 | monitor: val/rec_loss
49 | ckpt_path: pretrained/256_vae.ckpt
50 | ddconfig:
51 | double_z: true
52 | z_channels: 3
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [1, 2, 4]
58 | num_res_blocks: 2
59 | attn_resolutions: [ ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: torch.nn.Identity
63 |
64 | cond_stage_config:
65 | target: ldm.modules.encoders.modules.BERTEmbedder
66 | params:
67 | n_embed: 640
68 | n_layer: 32
69 |
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 8
74 | wrap: True
75 | train:
76 | target: ldm.data.celebahq.CelebAConditionalDataset
77 | params:
78 | phase: train
79 | im_preprocessor_config:
80 | target: ldm.data.celebahq.DalleTransformerPreprocessor
81 | params:
82 | size: 256
83 | phase: train
84 | test_dataset_size: 3000
85 | conditions:
86 | - 'text'
87 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
88 | validation:
89 | target: ldm.data.celebahq.CelebAConditionalDataset
90 | params:
91 | phase: test
92 | im_preprocessor_config:
93 | target: ldm.data.celebahq.DalleTransformerPreprocessor
94 | params:
95 | size: 256
96 | phase: val
97 | test_dataset_size: 3000
98 | conditions:
99 | - 'text'
100 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
101 |
102 |
103 | lightning:
104 | callbacks:
105 | image_logger:
106 | target: main.ImageLogger
107 | params:
108 | batch_frequency: 1000
109 | max_images: 8
110 | increase_log_steps: False
111 |
112 | trainer:
113 | benchmark: True
--------------------------------------------------------------------------------
/configs/512_mask.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: seg_mask
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple
17 | scale_factor: 0.141
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 64
24 | in_channels: 3
25 | out_channels: 3
26 | model_channels: 192
27 | attention_resolutions:
28 | - 8
29 | - 4
30 | - 2
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 3
36 | - 5
37 | num_heads: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 640
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 3
48 | monitor: val/rec_loss
49 | ckpt_path: pretrained/512_vae.ckpt
50 | ddconfig:
51 | double_z: true
52 | z_channels: 3
53 | resolution: 512
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [1, 2, 4, 4]
58 | num_res_blocks: 2
59 | attn_resolutions: [ ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: torch.nn.Identity
63 |
64 | cond_stage_config:
65 | target: ldm.modules.encoders.modules.SegMaskEncoder
66 | params:
67 | seg_mask_encoder_config:
68 | target: ldm.modules.encoders.modules.PassSegMaskEncoder
69 | params: {}
70 | mask_embed_dim: 1024
71 | context_dim: 640
72 |
73 | data:
74 | target: main.DataModuleFromConfig
75 | params:
76 | batch_size: 8
77 | wrap: True
78 | train:
79 | target: ldm.data.celebahq.CelebAConditionalDataset
80 | params:
81 | phase: train
82 | im_preprocessor_config:
83 | target: ldm.data.celebahq.DalleTransformerPreprocessor
84 | params:
85 | size: 512
86 | phase: train
87 | test_dataset_size: 3000
88 | conditions:
89 | - 'seg_mask'
90 | validation:
91 | target: ldm.data.celebahq.CelebAConditionalDataset
92 | params:
93 | phase: test
94 | im_preprocessor_config:
95 | target: ldm.data.celebahq.DalleTransformerPreprocessor
96 | params:
97 | size: 512
98 | phase: val
99 | test_dataset_size: 3000
100 | conditions:
101 | - 'seg_mask'
102 |
103 |
104 | lightning:
105 | callbacks:
106 | image_logger:
107 | target: main.ImageLogger
108 | params:
109 | batch_frequency: 1000
110 | max_images: 8
111 | increase_log_steps: False
112 |
113 | trainer:
114 | benchmark: True
--------------------------------------------------------------------------------
/configs/256_codiff_mask_text.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm_compose.LatentDiffusionCompose
4 | params:
5 |
6 | linear_start: 0.0015
7 | linear_end: 0.0195
8 | num_timesteps_cond: 1
9 | log_every_t: 200
10 | timesteps: 1000
11 | first_stage_key: image
12 | cond_stage_key: conditions
13 | image_size: 64
14 | channels: 3
15 | cond_stage_trainable: True # so that don't need to change LatentDiffusion def forward's code
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple
18 | scale_factor: 0.058
19 | use_ema: False
20 |
21 | seg_mask_ldm_config_path: configs/256_mask.yaml
22 | seg_mask_ldm_ckpt_path: pretrained/256_mask.ckpt
23 | text_ldm_config_path: configs/256_text.yaml
24 | text_ldm_ckpt_path: pretrained/256_text.ckpt
25 |
26 | compose_unet_config:
27 | target: ldm.models.diffusion.compose_modules.ComposeUNet
28 | params:
29 | confidence_conditioning_key: crossattn
30 | confidence_input: x_t
31 | confidence_map_predictor_config:
32 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
33 | params:
34 | image_size: 64
35 | in_channels: 3
36 | out_channels: 1
37 | model_channels: 32
38 | attention_resolutions:
39 | - 8
40 | - 4
41 | - 2
42 | num_res_blocks: 2
43 | channel_mult:
44 | - 1
45 | - 2
46 | - 3
47 | - 5
48 | num_heads: 32
49 | use_spatial_transformer: true
50 | transformer_depth: 1
51 | context_dim: 640
52 | use_checkpoint: true
53 | legacy: False
54 |
55 |
56 | compose_cond_stage_config:
57 | target: ldm.models.diffusion.compose_modules.ComposeCondStageModel
58 | params: {}
59 |
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 8
64 | wrap: True
65 | train:
66 | target: ldm.data.celebahq.CelebAConditionalDataset
67 | params:
68 | phase: train
69 | im_preprocessor_config:
70 | target: ldm.data.celebahq.DalleTransformerPreprocessor
71 | params:
72 | size: 256
73 | phase: train
74 | test_dataset_size: 3000
75 | conditions:
76 | - 'text'
77 | - 'seg_mask'
78 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
79 | validation:
80 | target: ldm.data.celebahq.CelebAConditionalDataset
81 | params:
82 | phase: test
83 | im_preprocessor_config:
84 | target: ldm.data.celebahq.DalleTransformerPreprocessor
85 | params:
86 | size: 256
87 | phase: val
88 | test_dataset_size: 3000
89 | conditions:
90 | - 'text'
91 | - 'seg_mask'
92 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
93 |
94 | lightning:
95 | callbacks:
96 | image_logger:
97 | target: main.ImageLogger
98 | params:
99 | batch_frequency: 1000
100 | max_images: 8
101 | increase_log_steps: False
102 |
103 | trainer:
104 | benchmark: True
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/configs/256_mask.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: seg_mask
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.058
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 64
24 | in_channels: 3
25 | out_channels: 3
26 | model_channels: 192
27 | attention_resolutions:
28 | - 8
29 | - 4
30 | - 2
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 3
36 | - 5
37 | num_heads: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 640
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 3
48 | monitor: val/rec_loss
49 | ckpt_path: pretrained/256_vae.ckpt
50 | ddconfig:
51 | double_z: true
52 | z_channels: 3
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [1, 2, 4]
58 | num_res_blocks: 2
59 | attn_resolutions: [ ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: torch.nn.Identity
63 |
64 | cond_stage_config:
65 | target: ldm.modules.encoders.modules.SegMaskEncoder
66 | params:
67 | seg_mask_encoder_config:
68 | target: ldm.modules.encoders.modules.PassSegMaskEncoder
69 | params: {}
70 | mask_embed_dim: 1024
71 | context_dim: 640
72 |
73 |
74 | data:
75 | target: main.DataModuleFromConfig
76 | params:
77 | batch_size: 8
78 | wrap: True
79 | train:
80 | target: ldm.data.celebahq.CelebAConditionalDataset
81 | params:
82 | phase: train
83 | im_preprocessor_config:
84 | target: ldm.data.celebahq.DalleTransformerPreprocessor
85 | params:
86 | size: 256
87 | phase: train
88 | test_dataset_size: 3000
89 | conditions:
90 | - 'seg_mask'
91 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
92 | validation:
93 | target: ldm.data.celebahq.CelebAConditionalDataset
94 | params:
95 | phase: test
96 | im_preprocessor_config:
97 | target: ldm.data.celebahq.DalleTransformerPreprocessor
98 | params:
99 | size: 256
100 | phase: val
101 | test_dataset_size: 3000
102 | conditions:
103 | - 'seg_mask'
104 | image_folder: 'datasets/image_256_downsampled_from_hq_1024'
105 |
106 | lightning:
107 | callbacks:
108 | image_logger:
109 | target: main.ImageLogger
110 | params:
111 | batch_frequency: 1000
112 | max_images: 8
113 | increase_log_steps: False
114 |
115 | trainer:
116 | benchmark: True
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss # = rec_loss + p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar # = rec_loss + p_loss
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 |
84 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss # # = rec_loss + p_loss + 1e-06 * kl_loss = l1_loss + p_loss + 1e-06 * kl_loss
85 |
86 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
87 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
88 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
89 | "{}/d_weight".format(split): d_weight.detach(),
90 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
91 | "{}/g_loss".format(split): g_loss.detach().mean(),
92 | }
93 | return loss, log
94 |
95 | if optimizer_idx == 1:
96 | # second pass for discriminator update
97 | if cond is None:
98 | logits_real = self.discriminator(inputs.contiguous().detach())
99 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
100 | else:
101 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
102 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
103 |
104 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
105 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
106 |
107 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
108 | "{}/logits_real".format(split): logits_real.detach().mean(),
109 | "{}/logits_fake".format(split): logits_fake.detach().mean()
110 | }
111 | return d_loss, log
112 |
113 |
--------------------------------------------------------------------------------
/text2image.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from omegaconf import OmegaConf
9 | from PIL import Image
10 | from torchvision.utils import make_grid
11 |
12 | from ldm.models.diffusion.ddim import DDIMSampler
13 | from ldm.util import instantiate_from_config
14 |
15 | """
16 | Inference script for text-to-face generation at 512x512 resolution
17 | """
18 |
19 |
20 | def parse_args():
21 |
22 | parser = argparse.ArgumentParser(description="")
23 |
24 | # conditions
25 | parser.add_argument(
26 | "--input_text",
27 | type=str,
28 | default="This man has beard of medium length. He is in his thirties.",
29 | help="text condition")
30 |
31 | # paths
32 | parser.add_argument(
33 | "--config_path",
34 | type=str,
35 | default="configs/512_text.yaml",
36 | help="path to model config")
37 | parser.add_argument(
38 | "--ckpt_path",
39 | type=str,
40 | default="pretrained/512_text.ckpt",
41 | help="path to model checkpoint")
42 | parser.add_argument(
43 | "--save_folder",
44 | type=str,
45 | default="outputs/512_text2image",
46 | help="folder to save synthesis outputs")
47 |
48 | # batch size and ddim steps
49 | parser.add_argument(
50 | "--batch_size",
51 | type=int,
52 | default=4,
53 | help="number of images to generate")
54 | parser.add_argument(
55 | "--ddim_steps",
56 | type=int,
57 | default="50",
58 | help=
59 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)"
60 | )
61 |
62 | # whether save intermediate outputs
63 | parser.add_argument(
64 | "--save_z",
65 | type=bool,
66 | default=False,
67 | help=
68 | "whether visualize the VAE latent codes and save them in the output folder",
69 | )
70 | parser.add_argument(
71 | "--return_influence_function",
72 | type=bool,
73 | default=False,
74 | help=
75 | "whether visualize the Influence Functions and save them in the output folder",
76 | )
77 | parser.add_argument(
78 | "--display_x_inter",
79 | type=bool,
80 | default=False,
81 | help=
82 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder",
83 | )
84 |
85 | args = parser.parse_args()
86 | return args
87 |
88 |
89 | def main():
90 |
91 | args = parse_args()
92 |
93 | # ========== set up model ==========
94 | print(f'Set up model')
95 | config = OmegaConf.load(args.config_path)
96 | model_config = config['model']
97 | model = instantiate_from_config(model_config)
98 | model.init_from_ckpt(args.ckpt_path)
99 | model = model.cuda()
100 | model.eval()
101 |
102 | # ========== set output directory ==========
103 | os.makedirs(args.save_folder, exist_ok=True)
104 | # save a copy of this python script being used
105 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__))
106 |
107 | print(
108 | f'================================================================================'
109 | )
110 | print(f'text: {args.input_text}')
111 |
112 | # prepare directories
113 | save_sub_folder = os.path.join(args.save_folder, str(args.input_text))
114 | os.makedirs(save_sub_folder, exist_ok=True)
115 |
116 | # ========== inference ==========
117 | with torch.no_grad():
118 |
119 | # encode condition
120 | condition = []
121 | for i in range(args.batch_size):
122 | condition.append(args.input_text.lower())
123 |
124 | with model.ema_scope("Plotting"):
125 |
126 | # encode condition
127 | condition = model.get_learned_conditioning(
128 | condition) # [1, 77, 640]
129 | print(f'condition.shape={condition.shape}') # [B, 77, 640]
130 |
131 | # DDIM sampling
132 | ddim_sampler = DDIMSampler(model)
133 | z_0_batch, intermediates = ddim_sampler.sample(
134 | S=args.ddim_steps,
135 | batch_size=args.batch_size,
136 | shape=(3, 64, 64),
137 | conditioning=condition,
138 | verbose=False,
139 | eta=1.0,
140 | log_every_t=1)
141 |
142 | # decode VAE latent z_0 to image x_0
143 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256]
144 |
145 | # ========== save outputs ==========
146 | for idx in range(args.batch_size):
147 |
148 | # ========== save synthesized image x_0 ==========
149 | save_x_0_path = os.path.join(save_sub_folder,
150 | f'{str(idx).zfill(6)}_x_0.png')
151 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256]
152 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy()
153 | x_0 = (x_0 + 1.0) * 127.5
154 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255
155 | x_0 = x_0.astype(np.uint8)
156 | x_0 = Image.fromarray(x_0[0])
157 | x_0.save(save_x_0_path)
158 |
159 | # save intermediate x_t and pred_x_0
160 | if args.display_x_inter:
161 | for cond_name in ['x_inter', 'pred_x0']:
162 | save_conf_path = os.path.join(
163 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png')
164 | conf = intermediates[f'{cond_name}']
165 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64
166 | conf = conf[:, idx, :, :, :] # 50x3x64x64
167 | print('decoding x_inter ......')
168 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256]
169 | conf = make_grid(
170 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10]
171 | conf = conf.permute(1, 2,
172 | 0).to('cpu').numpy() # cxhxh -> hxhxc
173 | conf = (conf + 1.0) * 127.5
174 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
175 | conf = conf.astype(np.uint8)
176 | conf = Image.fromarray(conf)
177 | conf.save(save_conf_path)
178 |
179 | # save latent z_0
180 | if args.save_z:
181 | save_z_0_path = os.path.join(save_sub_folder,
182 | f'{str(idx).zfill(6)}_z_0.png')
183 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64]
184 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy()
185 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization
186 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255
187 | z_0 = z_0.astype(np.uint8)
188 | z_0 = Image.fromarray(z_0[0])
189 | z_0.save(save_z_0_path)
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 | def instantiate_from_config_vq_diffusion(config):
88 | """the VQ-Diffusion version"""
89 | if config is None:
90 | return None
91 | if not "target" in config:
92 | raise KeyError("Expected key `target` to instantiate.")
93 | module, cls = config["target"].rsplit(".", 1)
94 | print(f'instantiate_from_config --- module: {module}, cls: {cls}') # ziqi added
95 | cls = getattr(importlib.import_module(module, package=None), cls)
96 | return cls(**config.get("params", dict()))
97 |
98 |
99 | def get_obj_from_str(string, reload=False):
100 | module, cls = string.rsplit(".", 1)
101 | if reload:
102 | module_imp = importlib.import_module(module)
103 | importlib.reload(module_imp)
104 | return getattr(importlib.import_module(module, package=None), cls)
105 |
106 |
107 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
108 | # create dummy dataset instance
109 |
110 | # run prefetching
111 | if idx_to_fn:
112 | res = func(data, worker_id=idx)
113 | else:
114 | res = func(data)
115 | Q.put([idx, res])
116 | Q.put("Done")
117 |
118 |
119 | def parallel_data_prefetch(
120 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
121 | ):
122 | # if target_data_type not in ["ndarray", "list"]:
123 | # raise ValueError(
124 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
125 | # )
126 | if isinstance(data, np.ndarray) and target_data_type == "list":
127 | raise ValueError("list expected but function got ndarray.")
128 | elif isinstance(data, abc.Iterable):
129 | if isinstance(data, dict):
130 | print(
131 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
132 | )
133 | data = list(data.values())
134 | if target_data_type == "ndarray":
135 | data = np.asarray(data)
136 | else:
137 | data = list(data)
138 | else:
139 | raise TypeError(
140 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
141 | )
142 |
143 | if cpu_intensive:
144 | Q = mp.Queue(1000)
145 | proc = mp.Process
146 | else:
147 | Q = Queue(1000)
148 | proc = Thread
149 | # spawn processes
150 | if target_data_type == "ndarray":
151 | arguments = [
152 | [func, Q, part, i, use_worker_id]
153 | for i, part in enumerate(np.array_split(data, n_proc))
154 | ]
155 | else:
156 | step = (
157 | int(len(data) / n_proc + 1)
158 | if len(data) % n_proc != 0
159 | else int(len(data) / n_proc)
160 | )
161 | arguments = [
162 | [func, Q, part, i, use_worker_id]
163 | for i, part in enumerate(
164 | [data[i: i + step] for i in range(0, len(data), step)]
165 | )
166 | ]
167 | processes = []
168 | for i in range(n_proc):
169 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
170 | processes += [p]
171 |
172 | # start processes
173 | print(f"Start prefetching...")
174 | import time
175 |
176 | start = time.time()
177 | gather_res = [[] for _ in range(n_proc)]
178 | try:
179 | for p in processes:
180 | p.start()
181 |
182 | k = 0
183 | while k < n_proc:
184 | # get result
185 | res = Q.get()
186 | if res == "Done":
187 | k += 1
188 | else:
189 | gather_res[res[0]] = res[1]
190 |
191 | except Exception as e:
192 | print("Exception: ", e)
193 | for p in processes:
194 | p.terminate()
195 |
196 | raise e
197 | finally:
198 | for p in processes:
199 | p.join()
200 | print(f"Prefetching complete. [{time.time() - start} sec.]")
201 |
202 | if target_data_type == 'ndarray':
203 | if not isinstance(gather_res[0], np.ndarray):
204 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
205 |
206 | # order outputs
207 | return np.concatenate(gather_res, axis=0)
208 | elif target_data_type == 'list':
209 | out = []
210 | for r in gather_res:
211 | out.extend(r)
212 | return out
213 | else:
214 | return gather_res
215 |
--------------------------------------------------------------------------------
/ldm/data/celebahq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 | import json
7 | import random
8 | from ldm.util import instantiate_from_config_vq_diffusion
9 | import albumentations
10 | from torchvision import transforms as trans
11 |
12 | def load_img(filepath):
13 | img = Image.open(filepath).convert('RGB')
14 | return img
15 |
16 |
17 | class DalleTransformerPreprocessor(object):
18 | def __init__(self,
19 | size=256,
20 | phase='train',
21 | additional_targets=None):
22 |
23 | self.size = size
24 | self.phase = phase
25 | # ddc: following dalle to use randomcrop
26 | self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)],
27 | additional_targets=additional_targets)
28 | self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)],
29 | additional_targets=additional_targets)
30 |
31 |
32 | def __call__(self, image, **kargs):
33 | """
34 | image: PIL.Image
35 | """
36 | if isinstance(image, np.ndarray):
37 | image = Image.fromarray(image.astype(np.uint8))
38 |
39 | w, h = image.size
40 | s_min = min(h, w)
41 |
42 | if self.phase == 'train':
43 | off_h = int(random.uniform(3*(h-s_min)//8, max(3*(h-s_min)//8+1, 5*(h-s_min)//8)))
44 | off_w = int(random.uniform(3*(w-s_min)//8, max(3*(w-s_min)//8+1, 5*(w-s_min)//8)))
45 |
46 | image = image.crop((off_w, off_h, off_w + s_min, off_h + s_min))
47 |
48 | # resize image
49 | t_max = min(s_min, round(9/8*self.size))
50 | t_max = max(t_max, self.size)
51 | t = int(random.uniform(self.size, t_max+1))
52 | image = image.resize((t, t))
53 | image = np.array(image).astype(np.uint8)
54 | image = self.train_preprocessor(image=image)
55 | else:
56 | if w < h:
57 | w_ = self.size
58 | h_ = int(h * w_/w)
59 | else:
60 | h_ = self.size
61 | w_ = int(w * h_/h)
62 | image = image.resize((w_, h_))
63 | image = np.array(image).astype(np.uint8)
64 | image = self.val_preprocessor(image=image)
65 | return image
66 |
67 |
68 |
69 | class CelebAConditionalDataset(Dataset):
70 |
71 | """
72 | This Dataset can be used for:
73 | - image-only: setting 'conditions' = []
74 | - image and multi-modal 'conditions': setting conditions as the list of modalities you need
75 |
76 | To toggle between 256 and 512 image resolution, simply change the 'image_folder'
77 | """
78 |
79 | def __init__(self,
80 | phase = 'train',
81 | im_preprocessor_config=None,
82 | test_dataset_size=3000,
83 | conditions = ['seg_mask', 'text', 'sketch'],
84 | image_folder = 'datasets/image/image_512_downsampled_from_hq_1024',
85 | text_file = 'datasets/text/captions_hq_beard_and_age_2022-08-19.json',
86 | mask_folder = 'datasets/mask/CelebAMask-HQ-mask-color-palette_32_nearest_downsampled_from_hq_512_one_hot_2d_tensor',
87 | sketch_folder = 'datasets/sketch/sketch_1x1024_tensor',
88 | ):
89 |
90 | self.transform = instantiate_from_config_vq_diffusion(im_preprocessor_config)
91 | self.conditions = conditions
92 | print(f'conditions = {conditions}')
93 |
94 | self.image_folder = image_folder
95 | print(f'self.image_folder = {self.image_folder}')
96 |
97 | # conditions directory
98 | self.text_file = text_file
99 | print(f'self.text_file = {self.text_file}')
100 | print(f'start loading text')
101 | with open(self.text_file, 'r') as f:
102 | self.text_file_content = json.load(f)
103 | print(f'end loading text')
104 | if 'seg_mask' in self.conditions:
105 | self.mask_folder = mask_folder
106 | print(f'self.mask_folder = {self.mask_folder}')
107 | if 'sketch' in self.conditions:
108 | self.sketch_folder = sketch_folder
109 | print(f'self.sketch_folder = {self.sketch_folder}')
110 |
111 | # list of valid image names & train test split
112 | self.image_name_list = list(self.text_file_content.keys())
113 |
114 | # train test split
115 | if phase == 'train':
116 | self.image_name_list = self.image_name_list[:-test_dataset_size]
117 | elif phase == 'test':
118 | self.image_name_list = self.image_name_list[-test_dataset_size:]
119 | else:
120 | raise NotImplementedError
121 | self.num = len(self.image_name_list)
122 |
123 | # verbose
124 | print(f'phase = {phase}')
125 | print(f'number of samples = {self.num}')
126 | print(f'self.image_name_list[:10] = {self.image_name_list[:10]}')
127 | print(f'self.image_name_list[-10:] = {self.image_name_list[-10:]}\n')
128 |
129 |
130 | def __len__(self):
131 | return self.num
132 |
133 | def __getitem__(self, index):
134 |
135 | # ---------- (1) get image ----------
136 | image_name = self.image_name_list[index]
137 | image_path = os.path.join(self.image_folder, image_name)
138 | image = load_img(image_path)
139 | image = np.array(image).astype(np.uint8)
140 | image = self.transform(image = image)['image']
141 | image = image.astype(np.float32)/127.5 - 1.0
142 |
143 | # record into data entry
144 | if len(self.conditions) == 1:
145 | data = {
146 | 'image': image,
147 | }
148 | else:
149 | data = {
150 | 'image': image,
151 | 'conditions': {}
152 | }
153 |
154 | # ---------- (2) get text ----------
155 | if 'text' in self.conditions:
156 | text = self.text_file_content[image_name]["Beard_and_Age"].lower()
157 | # record into data entry
158 | if len(self.conditions) == 1:
159 | data['caption'] = text
160 | else:
161 | data['conditions']['text'] = text
162 |
163 | # ---------- (3) get mask ----------
164 | if 'seg_mask' in self.conditions:
165 | mask_idx = image_name.split('.')[0]
166 | mask_name = f'{mask_idx}.pt'
167 | mask_path = os.path.join(self.mask_folder, mask_name)
168 | mask_one_hot_tensor = torch.load(mask_path)
169 |
170 | # record into data entry
171 | if len(self.conditions) == 1:
172 | data['seg_mask'] = mask_one_hot_tensor
173 | else:
174 | data['conditions']['seg_mask'] = mask_one_hot_tensor
175 |
176 |
177 | # ---------- (4) get sketch ----------
178 | if 'sketch' in self.conditions:
179 | sketch_idx = image_name.split('.')[0]
180 | sketch_name = f'{sketch_idx}.pt'
181 | sketch_path = os.path.join(self.sketch_folder, sketch_name)
182 | sketch_one_hot_tensor = torch.load(sketch_path)
183 |
184 | # record into data entry
185 | if len(self.conditions) == 1:
186 | data['sketch'] = sketch_one_hot_tensor
187 | else:
188 | data['conditions']['sketch'] = sketch_one_hot_tensor
189 |
190 |
191 | return data
192 |
--------------------------------------------------------------------------------
/mask2image.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision
9 | from omegaconf import OmegaConf
10 | from PIL import Image
11 | from torchvision.utils import make_grid
12 |
13 | from ldm.models.diffusion.ddim import DDIMSampler
14 | from ldm.util import instantiate_from_config
15 |
16 |
17 | def parse_args():
18 |
19 | parser = argparse.ArgumentParser(description="")
20 |
21 | # conditions
22 | parser.add_argument(
23 | "--mask_path",
24 | type=str,
25 | default="test_data/512_masks/27007.png",
26 | help="path to the segmentation mask")
27 |
28 | # paths
29 | parser.add_argument(
30 | "--config_path",
31 | type=str,
32 | default="configs/512_mask.yaml",
33 | help="path to model config")
34 | parser.add_argument(
35 | "--ckpt_path",
36 | type=str,
37 | default="pretrained/512_mask.ckpt",
38 | help="path to model checkpoint")
39 | parser.add_argument(
40 | "--save_folder",
41 | type=str,
42 | default="outputs/512_mask2image",
43 | help="folder to save synthesis outputs")
44 |
45 | # batch size and ddim steps
46 | parser.add_argument(
47 | "--batch_size",
48 | type=int,
49 | default=4,
50 | help="number of images to generate")
51 | parser.add_argument(
52 | "--ddim_steps",
53 | type=int,
54 | default="50",
55 | help=
56 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)"
57 | )
58 |
59 | # whether save intermediate outputs
60 | parser.add_argument(
61 | "--save_z",
62 | type=bool,
63 | default=False,
64 | help=
65 | "whether visualize the VAE latent codes and save them in the output folder",
66 | )
67 | parser.add_argument(
68 | "--return_influence_function",
69 | type=bool,
70 | default=False,
71 | help=
72 | "whether visualize the Influence Functions and save them in the output folder",
73 | )
74 | parser.add_argument(
75 | "--display_x_inter",
76 | type=bool,
77 | default=False,
78 | help=
79 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder",
80 | )
81 | parser.add_argument(
82 | "--save_mixed",
83 | type=bool,
84 | default=False,
85 | help=
86 | "whether overlay the segmentation mask on the synthesized image to visualize mask consistency",
87 | )
88 |
89 | args = parser.parse_args()
90 | return args
91 |
92 |
93 | def main():
94 |
95 | args = parse_args()
96 |
97 | # ========== set up model ==========
98 | print(f'Set up model')
99 | config = OmegaConf.load(args.config_path)
100 | model_config = config['model']
101 | model = instantiate_from_config(model_config)
102 | model.init_from_ckpt(args.ckpt_path)
103 | model = model.cuda()
104 | model.eval()
105 |
106 | # ========== set output directory ==========
107 | os.makedirs(args.save_folder, exist_ok=True)
108 | # save a copy of this python script being used
109 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__))
110 |
111 | # ========== prepare seg mask for model ==========
112 | with open(args.mask_path, 'rb') as f:
113 | img = Image.open(f)
114 | resized_img = img.resize((32, 32), Image.NEAREST) # resize
115 | flattened_img = list(resized_img.getdata())
116 | flattened_img_tensor = torch.tensor(flattened_img) # flatten
117 | flattened_img_tensor_one_hot = F.one_hot(
118 | flattened_img_tensor, num_classes=19) # one hot
119 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose(
120 | 0, 1)
121 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze(
122 | flattened_img_tensor_one_hot_transpose,
123 | 0).cuda() # add batch dimension
124 |
125 | # ========== prepare mask for visualization ==========
126 | mask = Image.open(args.mask_path)
127 | mask = mask.convert('RGB')
128 | mask = np.array(mask).astype(np.uint8) # three channel integer
129 | input_mask = mask
130 |
131 | print(
132 | f'================================================================================'
133 | )
134 | print(f'mask_path: {args.mask_path}')
135 |
136 | # prepare directories
137 | mask_name = args.mask_path.split('/')[-1]
138 | save_sub_folder = os.path.join(args.save_folder, mask_name)
139 | os.makedirs(save_sub_folder, exist_ok=True)
140 |
141 | # save seg_mask
142 | save_path_mask = os.path.join(save_sub_folder, mask_name)
143 | mask_ = Image.fromarray(input_mask)
144 | mask_.save(save_path_mask)
145 |
146 | # ========== inference ==========
147 | with torch.no_grad():
148 |
149 | condition = flattened_img_tensor_one_hot_transpose
150 |
151 | with model.ema_scope("Plotting"):
152 |
153 | # encode condition
154 | condition = model.get_learned_conditioning(
155 | condition) # [1, 96, 640]
156 | condition = condition.repeat(args.batch_size, 1, 1) # [B, 96, 640]
157 |
158 | ddim_sampler = DDIMSampler(model)
159 | z_0_batch, intermediates = ddim_sampler.sample(
160 | S=args.ddim_steps,
161 | batch_size=args.batch_size,
162 | shape=(3, 64, 64),
163 | conditioning=condition,
164 | verbose=False,
165 | eta=1.0,
166 | log_every_t=1)
167 |
168 | # decode latent z_0 to image x_0
169 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256]
170 |
171 | for idx in range(args.batch_size):
172 |
173 | # ========== save synthesized image x_0 ==========
174 | save_x_0_path = os.path.join(save_sub_folder,
175 | f'{str(idx).zfill(6)}_x_0.png')
176 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256]
177 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy()
178 | x_0 = (x_0 + 1.0) * 127.5
179 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255
180 | x_0 = x_0.astype(np.uint8)
181 | x_0 = Image.fromarray(x_0[0])
182 | x_0.save(save_x_0_path)
183 |
184 | # save intermediate x_t and pred_x_0
185 | if args.display_x_inter:
186 | for cond_name in ['x_inter', 'pred_x0']:
187 | save_conf_path = os.path.join(
188 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png')
189 | conf = intermediates[f'{cond_name}']
190 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64
191 | conf = conf[:, idx, :, :, :] # 50x3x64x64
192 | print('decoding x_inter ......')
193 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256]
194 | conf = make_grid(
195 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10]
196 | conf = conf.permute(1, 2,
197 | 0).to('cpu').numpy() # cxhxh -> hxhxc
198 | conf = (conf + 1.0) * 127.5
199 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
200 | conf = conf.astype(np.uint8)
201 | conf = Image.fromarray(conf)
202 | conf.save(save_conf_path)
203 |
204 | # save latent z_0
205 | if args.save_z:
206 | save_z_0_path = os.path.join(save_sub_folder,
207 | f'{str(idx).zfill(6)}_z_0.png')
208 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64]
209 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy()
210 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization
211 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255
212 | z_0 = z_0.astype(np.uint8)
213 | z_0 = Image.fromarray(z_0[0])
214 | z_0.save(save_z_0_path)
215 |
216 | # overlay the segmentation mask on the synthesized image to visualize mask consistency
217 | save_mixed_path = os.path.join(save_sub_folder,
218 | f'{str(idx).zfill(6)}_mixed.png')
219 | Image.blend(x_0, mask_, 0.3).save(save_mixed_path)
220 |
221 |
222 | if __name__ == "__main__":
223 | main()
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/freeu/text2image_freeu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from omegaconf import OmegaConf, open_dict
9 | from PIL import Image
10 | from torchvision.utils import make_grid
11 |
12 | from ldm.models.diffusion.ddim import DDIMSampler
13 | from ldm.util import instantiate_from_config
14 |
15 | """
16 | Inference script for text-to-face generation at 512x512 resolution
17 | """
18 |
19 |
20 | def parse_args():
21 |
22 | parser = argparse.ArgumentParser(description="")
23 |
24 | # conditions
25 | parser.add_argument(
26 | "--input_text",
27 | type=str,
28 | default="This man has beard of medium length. He is in his thirties.",
29 | help="text condition")
30 |
31 | # paths
32 | parser.add_argument(
33 | "--config_path",
34 | type=str,
35 | default="configs/512_text.yaml",
36 | help="path to model config")
37 | parser.add_argument(
38 | "--ckpt_path",
39 | type=str,
40 | default="pretrained/512_text.ckpt",
41 | help="path to model checkpoint")
42 | parser.add_argument(
43 | "--save_folder",
44 | type=str,
45 | default="outputs/512_text2image",
46 | help="folder to save synthesis outputs")
47 |
48 | # batch size and ddim steps
49 | parser.add_argument(
50 | "--batch_size",
51 | type=int,
52 | default=4,
53 | help="number of images to generate")
54 | parser.add_argument(
55 | "--ddim_steps",
56 | type=int,
57 | default="50",
58 | help=
59 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)"
60 | )
61 |
62 | # whether save intermediate outputs
63 | parser.add_argument(
64 | "--save_z",
65 | type=bool,
66 | default=False,
67 | help=
68 | "whether visualize the VAE latent codes and save them in the output folder",
69 | )
70 | parser.add_argument(
71 | "--return_influence_function",
72 | type=bool,
73 | default=False,
74 | help=
75 | "whether visualize the Influence Functions and save them in the output folder",
76 | )
77 | parser.add_argument(
78 | "--display_x_inter",
79 | type=bool,
80 | default=False,
81 | help=
82 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder",
83 | )
84 |
85 | # FreeU Config
86 | parser.add_argument(
87 | "--seed",
88 | type=int,
89 | default=2,
90 | help=
91 | "fix random seed to compare with and without FreeU",
92 | )
93 | parser.add_argument(
94 | "--enable_freeu",
95 | action='store_true',
96 | help=
97 | "whether enable FreeU",
98 | )
99 | parser.add_argument(
100 | "--b1",
101 | type=float,
102 | default=1.1,
103 | help=
104 | "parameter of FreeU",
105 | )
106 | parser.add_argument(
107 | "--b2",
108 | type=float,
109 | default=1.2,
110 | help=
111 | "parameter of FreeU",
112 | )
113 | parser.add_argument(
114 | "--s1",
115 | type=float,
116 | default=1,
117 | help=
118 | "parameter of FreeU",
119 | )
120 | parser.add_argument(
121 | "--s2",
122 | type=float,
123 | default=1,
124 | help=
125 | "parameter of FreeU",
126 | )
127 |
128 | args = parser.parse_args()
129 | return args
130 |
131 |
132 | def main():
133 |
134 | args = parse_args()
135 | torch.manual_seed(args.seed)
136 |
137 | # ========== set up model ==========
138 | print(f'Set up model')
139 | config = OmegaConf.load(args.config_path)
140 | model_config = config['model']
141 | # ---------- FreeU code starts ----------
142 | print('Setting up FreeU')
143 | OmegaConf.set_struct(model_config, True)
144 | with open_dict(model_config):
145 | model_config.params.unet_config.params.enable_freeu = args.enable_freeu
146 | model_config.params.unet_config.params.b1 = args.b1
147 | model_config.params.unet_config.params.b2 = args.b2
148 | model_config.params.unet_config.params.s1 = args.s1
149 | model_config.params.unet_config.params.s2 = args.s2
150 | if args.enable_freeu:
151 | args.save_folder = f'{args.save_folder}_with_freeu'
152 | else:
153 | args.save_folder = f'{args.save_folder}_without_freeu'
154 | print(f'args.enable_freeu = {args.enable_freeu}')
155 | print(f'args.save_folder = {args.save_folder}')
156 | # ---------- FreeU code ends ----------
157 | model = instantiate_from_config(model_config)
158 | model.init_from_ckpt(args.ckpt_path)
159 | model = model.cuda()
160 | model.eval()
161 |
162 | # ========== set output directory ==========
163 | os.makedirs(args.save_folder, exist_ok=True)
164 | # save a copy of this python script being used
165 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__))
166 |
167 | print(
168 | f'================================================================================'
169 | )
170 | print(f'text: {args.input_text}')
171 |
172 | # prepare directories
173 | save_sub_folder = os.path.join(args.save_folder, str(args.input_text))
174 | os.makedirs(save_sub_folder, exist_ok=True)
175 |
176 | # ========== inference ==========
177 | with torch.no_grad():
178 |
179 | # encode condition
180 | condition = []
181 | for i in range(args.batch_size):
182 | condition.append(args.input_text.lower())
183 |
184 | with model.ema_scope("Plotting"):
185 |
186 | # encode condition
187 | condition = model.get_learned_conditioning(
188 | condition) # [1, 77, 640]
189 | print(f'condition.shape={condition.shape}') # [B, 77, 640]
190 |
191 | # DDIM sampling
192 | ddim_sampler = DDIMSampler(model)
193 | z_0_batch, intermediates = ddim_sampler.sample(
194 | S=args.ddim_steps,
195 | batch_size=args.batch_size,
196 | shape=(3, 64, 64),
197 | conditioning=condition,
198 | verbose=False,
199 | eta=1.0,
200 | log_every_t=1)
201 |
202 | # decode VAE latent z_0 to image x_0
203 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256]
204 |
205 | # ========== save outputs ==========
206 | for idx in range(args.batch_size):
207 |
208 | # ========== save synthesized image x_0 ==========
209 | save_x_0_path = os.path.join(save_sub_folder,
210 | f'{str(idx).zfill(6)}_x_0.png')
211 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256]
212 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy()
213 | x_0 = (x_0 + 1.0) * 127.5
214 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255
215 | x_0 = x_0.astype(np.uint8)
216 | x_0 = Image.fromarray(x_0[0])
217 | x_0.save(save_x_0_path)
218 |
219 | # save intermediate x_t and pred_x_0
220 | if args.display_x_inter:
221 | for cond_name in ['x_inter', 'pred_x0']:
222 | save_conf_path = os.path.join(
223 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png')
224 | conf = intermediates[f'{cond_name}']
225 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64
226 | conf = conf[:, idx, :, :, :] # 50x3x64x64
227 | print('decoding x_inter ......')
228 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256]
229 | conf = make_grid(
230 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10]
231 | conf = conf.permute(1, 2,
232 | 0).to('cpu').numpy() # cxhxh -> hxhxc
233 | conf = (conf + 1.0) * 127.5
234 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
235 | conf = conf.astype(np.uint8)
236 | conf = Image.fromarray(conf)
237 | conf.save(save_conf_path)
238 |
239 | # save latent z_0
240 | if args.save_z:
241 | save_z_0_path = os.path.join(save_sub_folder,
242 | f'{str(idx).zfill(6)}_z_0.png')
243 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64]
244 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy()
245 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization
246 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255
247 | z_0 = z_0.astype(np.uint8)
248 | z_0 = Image.fromarray(z_0[0])
249 | z_0.save(save_z_0_path)
250 |
251 |
252 | if __name__ == "__main__":
253 | main()
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/freeu/mask2image_freeu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision
9 | from omegaconf import OmegaConf, open_dict
10 | from PIL import Image
11 | from torchvision.utils import make_grid
12 |
13 | from ldm.models.diffusion.ddim import DDIMSampler
14 | from ldm.util import instantiate_from_config
15 |
16 |
17 | def parse_args():
18 |
19 | parser = argparse.ArgumentParser(description="")
20 |
21 | # conditions
22 | parser.add_argument(
23 | "--mask_path",
24 | type=str,
25 | default="test_data/512_masks/27007.png",
26 | help="path to the segmentation mask")
27 |
28 | # paths
29 | parser.add_argument(
30 | "--config_path",
31 | type=str,
32 | default="configs/512_mask.yaml",
33 | help="path to model config")
34 | parser.add_argument(
35 | "--ckpt_path",
36 | type=str,
37 | default="pretrained/512_mask.ckpt",
38 | help="path to model checkpoint")
39 | parser.add_argument(
40 | "--save_folder",
41 | type=str,
42 | default="outputs/512_mask2image",
43 | help="folder to save synthesis outputs")
44 |
45 | # batch size and ddim steps
46 | parser.add_argument(
47 | "--batch_size",
48 | type=int,
49 | default=4,
50 | help="number of images to generate")
51 | parser.add_argument(
52 | "--ddim_steps",
53 | type=int,
54 | default="50",
55 | help=
56 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)"
57 | )
58 |
59 | # whether save intermediate outputs
60 | parser.add_argument(
61 | "--save_z",
62 | type=bool,
63 | default=False,
64 | help=
65 | "whether visualize the VAE latent codes and save them in the output folder",
66 | )
67 | parser.add_argument(
68 | "--return_influence_function",
69 | type=bool,
70 | default=False,
71 | help=
72 | "whether visualize the Influence Functions and save them in the output folder",
73 | )
74 | parser.add_argument(
75 | "--display_x_inter",
76 | type=bool,
77 | default=False,
78 | help=
79 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder",
80 | )
81 | parser.add_argument(
82 | "--save_mixed",
83 | type=bool,
84 | default=False,
85 | help=
86 | "whether overlay the segmentation mask on the synthesized image to visualize mask consistency",
87 | )
88 |
89 | # FreeU Config
90 | parser.add_argument(
91 | "--seed",
92 | type=int,
93 | default=2,
94 | help=
95 | "fix random seed to compare with and without FreeU",
96 | )
97 | parser.add_argument(
98 | "--enable_freeu",
99 | action='store_true',
100 | help=
101 | "whether enable FreeU",
102 | )
103 | parser.add_argument(
104 | "--b1",
105 | type=float,
106 | default=1.1,
107 | help=
108 | "parameter of FreeU",
109 | )
110 | parser.add_argument(
111 | "--b2",
112 | type=float,
113 | default=1.2,
114 | help=
115 | "parameter of FreeU",
116 | )
117 | parser.add_argument(
118 | "--s1",
119 | type=float,
120 | default=1,
121 | help=
122 | "parameter of FreeU",
123 | )
124 | parser.add_argument(
125 | "--s2",
126 | type=float,
127 | default=1,
128 | help=
129 | "parameter of FreeU",
130 | )
131 |
132 | args = parser.parse_args()
133 | return args
134 |
135 |
136 | def main():
137 |
138 | args = parse_args()
139 | torch.manual_seed(args.seed)
140 |
141 | # ========== set up model ==========
142 | print(f'Set up model')
143 | config = OmegaConf.load(args.config_path)
144 | model_config = config['model']
145 | # ---------- FreeU code starts ----------
146 | print('Setting up FreeU')
147 | OmegaConf.set_struct(model_config, True)
148 | with open_dict(model_config):
149 | model_config.params.unet_config.params.enable_freeu = args.enable_freeu
150 | model_config.params.unet_config.params.b1 = args.b1
151 | model_config.params.unet_config.params.b2 = args.b2
152 | model_config.params.unet_config.params.s1 = args.s1
153 | model_config.params.unet_config.params.s2 = args.s2
154 | if args.enable_freeu:
155 | args.save_folder = f'{args.save_folder}_with_freeu'
156 | else:
157 | args.save_folder = f'{args.save_folder}_without_freeu'
158 | print(f'args.enable_freeu = {args.enable_freeu}')
159 | print(f'args.save_folder = {args.save_folder}')
160 | # ---------- FreeU code ends ----------
161 | model = instantiate_from_config(model_config)
162 | model.init_from_ckpt(args.ckpt_path)
163 | model = model.cuda()
164 | model.eval()
165 |
166 | # ========== set output directory ==========
167 | os.makedirs(args.save_folder, exist_ok=True)
168 | # save a copy of this python script being used
169 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__))
170 |
171 | # ========== prepare seg mask for model ==========
172 | with open(args.mask_path, 'rb') as f:
173 | img = Image.open(f)
174 | resized_img = img.resize((32, 32), Image.NEAREST) # resize
175 | flattened_img = list(resized_img.getdata())
176 | flattened_img_tensor = torch.tensor(flattened_img) # flatten
177 | flattened_img_tensor_one_hot = F.one_hot(
178 | flattened_img_tensor, num_classes=19) # one hot
179 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose(
180 | 0, 1)
181 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze(
182 | flattened_img_tensor_one_hot_transpose,
183 | 0).cuda() # add batch dimension
184 |
185 | # ========== prepare mask for visualization ==========
186 | mask = Image.open(args.mask_path)
187 | mask = mask.convert('RGB')
188 | mask = np.array(mask).astype(np.uint8) # three channel integer
189 | input_mask = mask
190 |
191 | print(
192 | f'================================================================================'
193 | )
194 | print(f'mask_path: {args.mask_path}')
195 |
196 | # prepare directories
197 | mask_name = args.mask_path.split('/')[-1]
198 | save_sub_folder = os.path.join(args.save_folder, mask_name)
199 | os.makedirs(save_sub_folder, exist_ok=True)
200 |
201 | # save seg_mask
202 | save_path_mask = os.path.join(save_sub_folder, mask_name)
203 | mask_ = Image.fromarray(input_mask)
204 | mask_.save(save_path_mask)
205 |
206 | # ========== inference ==========
207 | with torch.no_grad():
208 |
209 | condition = flattened_img_tensor_one_hot_transpose
210 |
211 | with model.ema_scope("Plotting"):
212 |
213 | # encode condition
214 | condition = model.get_learned_conditioning(
215 | condition) # [1, 96, 640]
216 | condition = condition.repeat(args.batch_size, 1, 1) # [B, 96, 640]
217 |
218 | ddim_sampler = DDIMSampler(model)
219 | z_0_batch, intermediates = ddim_sampler.sample(
220 | S=args.ddim_steps,
221 | batch_size=args.batch_size,
222 | shape=(3, 64, 64),
223 | conditioning=condition,
224 | verbose=False,
225 | eta=1.0,
226 | log_every_t=1)
227 |
228 | # decode latent z_0 to image x_0
229 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256]
230 |
231 | for idx in range(args.batch_size):
232 |
233 | # ========== save synthesized image x_0 ==========
234 | save_x_0_path = os.path.join(save_sub_folder,
235 | f'{str(idx).zfill(6)}_x_0.png')
236 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256]
237 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy()
238 | x_0 = (x_0 + 1.0) * 127.5
239 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255
240 | x_0 = x_0.astype(np.uint8)
241 | x_0 = Image.fromarray(x_0[0])
242 | x_0.save(save_x_0_path)
243 |
244 | # save intermediate x_t and pred_x_0
245 | if args.display_x_inter:
246 | for cond_name in ['x_inter', 'pred_x0']:
247 | save_conf_path = os.path.join(
248 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png')
249 | conf = intermediates[f'{cond_name}']
250 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64
251 | conf = conf[:, idx, :, :, :] # 50x3x64x64
252 | print('decoding x_inter ......')
253 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256]
254 | conf = make_grid(
255 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10]
256 | conf = conf.permute(1, 2,
257 | 0).to('cpu').numpy() # cxhxh -> hxhxc
258 | conf = (conf + 1.0) * 127.5
259 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
260 | conf = conf.astype(np.uint8)
261 | conf = Image.fromarray(conf)
262 | conf.save(save_conf_path)
263 |
264 | # save latent z_0
265 | if args.save_z:
266 | save_z_0_path = os.path.join(save_sub_folder,
267 | f'{str(idx).zfill(6)}_z_0.png')
268 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64]
269 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy()
270 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization
271 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255
272 | z_0 = z_0.astype(np.uint8)
273 | z_0 = Image.fromarray(z_0[0])
274 | z_0.save(save_z_0_path)
275 |
276 | # overlay the segmentation mask on the synthesized image to visualize mask consistency
277 | save_mixed_path = os.path.join(save_sub_folder,
278 | f'{str(idx).zfill(6)}_mixed.png')
279 | Image.blend(x_0, mask_, 0.3).save(save_mixed_path)
280 |
281 |
282 | if __name__ == "__main__":
283 | main()
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | import kornia
7 |
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 | from ldm.util import instantiate_from_config_vq_diffusion
11 |
12 |
13 | class AbstractEncoder(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def encode(self, *args, **kwargs):
18 | raise NotImplementedError
19 |
20 |
21 |
22 | class ClassEmbedder(nn.Module):
23 | def __init__(self, embed_dim, n_classes=1000, key='class'):
24 | super().__init__()
25 | self.key = key
26 | self.embedding = nn.Embedding(n_classes, embed_dim)
27 |
28 | def forward(self, batch, key=None):
29 | if key is None:
30 | key = self.key
31 | # this is for use in crossattn
32 | c = batch[key][:, None]
33 | c = self.embedding(c)
34 | return c
35 |
36 |
37 | class TransformerEmbedder(AbstractEncoder):
38 | """Some transformer encoder layers"""
39 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
40 | super().__init__()
41 | self.device = device
42 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
43 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
44 |
45 | def forward(self, tokens):
46 | tokens = tokens.to(self.device) # meh
47 | z = self.transformer(tokens, return_embeddings=True)
48 | return z
49 |
50 | def encode(self, x):
51 | return self(x)
52 |
53 |
54 | class BERTTokenizer(AbstractEncoder):
55 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
56 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
57 | super().__init__()
58 | from transformers import BertTokenizerFast # TODO: add to reuquirements
59 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
60 | self.device = device
61 | self.vq_interface = vq_interface
62 | self.max_length = max_length
63 |
64 | def forward(self, text):
65 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
66 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
67 | tokens = batch_encoding["input_ids"].to(self.device)
68 | return tokens
69 |
70 | @torch.no_grad()
71 | def encode(self, text):
72 | tokens = self(text)
73 | if not self.vq_interface:
74 | return tokens
75 | return None, None, [None, None, tokens]
76 |
77 | def decode(self, text):
78 | return text
79 |
80 |
81 | class BERTEmbedder(AbstractEncoder):
82 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
83 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
84 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
85 | super().__init__()
86 | self.use_tknz_fn = use_tokenizer
87 | if self.use_tknz_fn:
88 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
89 | self.device = device
90 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
91 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
92 | emb_dropout=embedding_dropout)
93 |
94 | def forward(self, text):
95 | if self.use_tknz_fn:
96 | tokens = self.tknz_fn(text)#.to(self.device)
97 | else:
98 | tokens = text
99 | z = self.transformer(tokens, return_embeddings=True)
100 | return z
101 |
102 | def encode(self, text):
103 | # output of length 77
104 | return self(text) # [batch_size, 77, BERTEmbedder.n_embed] # exp0023: [B, 77, 640]
105 |
106 |
107 | class SpatialRescaler(nn.Module):
108 | def __init__(self,
109 | n_stages=1,
110 | method='bilinear',
111 | multiplier=0.5,
112 | in_channels=3,
113 | out_channels=None,
114 | bias=False):
115 | super().__init__()
116 | self.n_stages = n_stages
117 | assert self.n_stages >= 0
118 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
119 | self.multiplier = multiplier
120 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
121 | self.remap_output = out_channels is not None
122 | if self.remap_output:
123 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
124 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
125 |
126 | def forward(self,x):
127 | for stage in range(self.n_stages):
128 | x = self.interpolator(x, scale_factor=self.multiplier)
129 |
130 |
131 | if self.remap_output:
132 | x = self.channel_mapper(x)
133 | return x
134 |
135 | def encode(self, x):
136 | return self(x)
137 |
138 |
139 | class FrozenCLIPTextEmbedder(nn.Module):
140 | """
141 | Uses the CLIP transformer encoder for text.
142 | """
143 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
144 | super().__init__()
145 | self.model, _ = clip.load(version, jit=False, device="cpu")
146 | self.device = device
147 | self.max_length = max_length
148 | self.n_repeat = n_repeat
149 | self.normalize = normalize
150 |
151 | def freeze(self):
152 | self.model = self.model.eval()
153 | for param in self.parameters():
154 | param.requires_grad = False
155 |
156 | def forward(self, text):
157 | tokens = clip.tokenize(text).to(self.device)
158 | z = self.model.encode_text(tokens)
159 | if self.normalize:
160 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
161 | return z
162 |
163 | def encode(self, text):
164 | z = self(text)
165 | if z.ndim==2:
166 | z = z[:, None, :]
167 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
168 | return z
169 |
170 |
171 | class FrozenClipImageEmbedder(nn.Module):
172 | """
173 | Uses the CLIP image encoder.
174 | """
175 | def __init__(
176 | self,
177 | model,
178 | jit=False,
179 | device='cuda' if torch.cuda.is_available() else 'cpu',
180 | antialias=False,
181 | ):
182 | super().__init__()
183 | self.model, _ = clip.load(name=model, device=device, jit=jit)
184 |
185 | self.antialias = antialias
186 |
187 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
188 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
189 |
190 | def preprocess(self, x):
191 | # normalize to [0,1]
192 | x = kornia.geometry.resize(x, (224, 224),
193 | interpolation='bicubic',align_corners=True,
194 | antialias=self.antialias)
195 | x = (x + 1.) / 2.
196 | # renormalize according to clip
197 | x = kornia.enhance.normalize(x, self.mean, self.std)
198 | return x
199 |
200 | def forward(self, x):
201 | # x is assumed to be in range [-1,1]
202 | return self.model.encode_image(self.preprocess(x))
203 |
204 |
205 | class PassSegMaskEncoder(nn.Module):
206 |
207 |
208 | def __init__(self):
209 | super().__init__()
210 |
211 |
212 | def forward(self, x):
213 | """convert :torch.cuda.LongTensor to torch.cuda.FloatTensor"""
214 | return x.to(torch.float)
215 |
216 | def encode(self, input):
217 | return self(input)
218 |
219 |
220 | class SegMaskAndTextEncoder(nn.Module):
221 |
222 | def __init__(self, seg_mask_encoder_config, text_encoder_config, mask_embed_dim=1024, text_embed_dim=640):
223 | super().__init__()
224 |
225 | self.seg_mask_encoder = instantiate_from_config_vq_diffusion(seg_mask_encoder_config)
226 | self.text_encoder = instantiate_from_config_vq_diffusion(text_encoder_config)
227 |
228 | self.mask_embed_dim = mask_embed_dim
229 | self.text_embed_dim = text_embed_dim
230 |
231 | self.linear = nn.Linear(
232 | in_features=self.mask_embed_dim, # 1024
233 | out_features=self.text_embed_dim, # 640
234 | bias=True)
235 |
236 |
237 | def forward(self, input):
238 |
239 | seg_mask = self.seg_mask_encoder(input['seg_mask']) # [B, 19, 1024]
240 | seg_mask = self.linear(seg_mask) # [B, 19, 640]
241 |
242 | text = self.text_encoder(input['text']) # [B, 77, 640]
243 |
244 | seg_mask_and_text = torch.cat([seg_mask, text], 1) # Bx(19+77)x640
245 |
246 | return seg_mask_and_text
247 |
248 | def encode(self, input):
249 | return self(input)
250 |
251 |
252 |
253 | class SegMaskEncoder(nn.Module):
254 |
255 | def __init__(self, seg_mask_encoder_config, mask_embed_dim=1024, context_dim=640):
256 | super().__init__()
257 |
258 | self.seg_mask_encoder = instantiate_from_config_vq_diffusion(seg_mask_encoder_config)
259 |
260 | self.mask_embed_dim = mask_embed_dim
261 | self.context_dim = context_dim
262 |
263 | self.linear = nn.Linear(
264 | in_features=self.mask_embed_dim, # 1024
265 | out_features=self.context_dim, # 640
266 | bias=True)
267 |
268 |
269 | def forward(self, input):
270 |
271 |
272 | if len(input.shape) == 4 and input.shape[1] == 1:
273 | # Bx1x19x1024 -> Bx19x1024
274 | # the '1' dimension comes from inner_function() in class DDPM
275 | input = input.view(input.shape[0], input.shape[2], input.shape[3])
276 | assert len(input.shape) == 3 # [B, 19, 1024]
277 | seg_mask = self.seg_mask_encoder(input) # [B, 19, 1024]
278 |
279 | # match mask dimension to the desired condition dimension
280 | seg_mask = self.linear(seg_mask) # [B, 19, 640]
281 |
282 | return seg_mask # [B, 19, 640]
283 |
284 | def encode(self, input):
285 | return self(input)
286 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | from omegaconf import OmegaConf
11 | from PIL import Image
12 | from torchvision.utils import make_grid
13 |
14 | from ldm.models.diffusion.ddim_confidence import DDIMConfidenceSampler
15 | from ldm.util import instantiate_from_config
16 | """
17 | Inference script for multi-modal-driven face generation at 512x512 resolution
18 | """
19 |
20 |
21 | def parse_args():
22 |
23 | parser = argparse.ArgumentParser(description="")
24 |
25 | # multi-modal conditions
26 | parser.add_argument(
27 | "--mask_path",
28 | type=str,
29 | default="test_data/512_masks/27007.png",
30 | help="path to the segmentation mask")
31 | parser.add_argument(
32 | "--input_text",
33 | type=str,
34 | default="This man has beard of medium length. He is in his thirties.",
35 | help="text condition")
36 |
37 | # paths
38 | parser.add_argument(
39 | "--config_path",
40 | type=str,
41 | default="configs/512_codiff_mask_text.yaml",
42 | help="path to model config")
43 | parser.add_argument(
44 | "--ckpt_path",
45 | type=str,
46 | default="pretrained/512_codiff_mask_text.ckpt",
47 | help="path to model checkpoint")
48 | parser.add_argument(
49 | "--save_folder",
50 | type=str,
51 | default="outputs/inference_512_codiff_mask_text",
52 | help="folder to save synthesis outputs")
53 |
54 | # batch size and ddim steps
55 | parser.add_argument(
56 | "--batch_size",
57 | type=int,
58 | default=4,
59 | help="number of images to generate")
60 | parser.add_argument(
61 | "--ddim_steps",
62 | type=int,
63 | default="50",
64 | help=
65 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)"
66 | )
67 |
68 | # whether save intermediate outputs
69 | parser.add_argument(
70 | "--save_z",
71 | type=bool,
72 | default=False,
73 | help=
74 | "whether visualize the VAE latent codes and save them in the output folder",
75 | )
76 | parser.add_argument(
77 | "--return_influence_function",
78 | type=bool,
79 | default=False,
80 | help=
81 | "whether visualize the Influence Functions and save them in the output folder",
82 | )
83 | parser.add_argument(
84 | "--display_x_inter",
85 | type=bool,
86 | default=False,
87 | help=
88 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder",
89 | )
90 | parser.add_argument(
91 | "--save_mixed",
92 | type=bool,
93 | default=False,
94 | help=
95 | "whether overlay the segmentation mask on the synthesized image to visualize mask consistency",
96 | )
97 |
98 | args = parser.parse_args()
99 | return args
100 |
101 |
102 | def main():
103 |
104 | args = parse_args()
105 |
106 | # ========== set up model ==========
107 | print(f'Set up model')
108 | config = OmegaConf.load(args.config_path)
109 | model_config = config['model']
110 | model = instantiate_from_config(model_config)
111 | model.init_from_ckpt(args.ckpt_path)
112 | model = model.cuda()
113 | model.eval()
114 |
115 | # ========== set output directory ==========
116 | os.makedirs(args.save_folder, exist_ok=True)
117 | # save a copy of this python script being used
118 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__))
119 |
120 | # ========== prepare seg mask for model ==========
121 | with open(args.mask_path, 'rb') as f:
122 | img = Image.open(f)
123 | resized_img = img.resize((32, 32), Image.NEAREST) # resize
124 | flattened_img = list(resized_img.getdata())
125 | flattened_img_tensor = torch.tensor(flattened_img) # flatten
126 | flattened_img_tensor_one_hot = F.one_hot(
127 | flattened_img_tensor, num_classes=19) # one hot
128 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose(
129 | 0, 1)
130 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze(
131 | flattened_img_tensor_one_hot_transpose,
132 | 0).cuda() # add batch dimension
133 |
134 | # ========== prepare mask for visualization ==========
135 | mask = Image.open(args.mask_path)
136 | mask = mask.convert('RGB')
137 | mask = np.array(mask).astype(np.uint8) # three channel integer
138 | input_mask = mask
139 |
140 | print(
141 | f'================================================================================'
142 | )
143 | print(f'mask_path: {args.mask_path} | text: {args.input_text}')
144 |
145 | # prepare directories
146 | mask_name = args.mask_path.split('/')[-1]
147 | save_sub_folder = os.path.join(args.save_folder, mask_name,
148 | str(args.input_text))
149 | os.makedirs(save_sub_folder, exist_ok=True)
150 |
151 | # save seg_mask
152 | save_path_mask = os.path.join(save_sub_folder, mask_name)
153 | mask_ = Image.fromarray(input_mask)
154 | mask_.save(save_path_mask)
155 |
156 | # ========== inference ==========
157 | with torch.no_grad():
158 |
159 | condition = {
160 | 'seg_mask': flattened_img_tensor_one_hot_transpose,
161 | 'text': [args.input_text.lower()]
162 | }
163 |
164 | with model.ema_scope("Plotting"):
165 |
166 | # encode condition
167 | condition = model.get_learned_conditioning(condition)
168 | if isinstance(condition, dict):
169 | for key, value in condition.items():
170 | condition[key] = condition[key].repeat(
171 | args.batch_size, 1, 1)
172 | else:
173 | condition = condition.repeat(args.batch_size, 1, 1)
174 |
175 | # define DDIM sampler with dynamic diffusers
176 | ddim_sampler = DDIMConfidenceSampler(
177 | model=model,
178 | return_confidence_map=args.return_influence_function)
179 |
180 | # DDIM sampling
181 | z_0_batch, intermediates = ddim_sampler.sample(
182 | S=args.ddim_steps,
183 | batch_size=args.batch_size,
184 | shape=(3, 64, 64),
185 | conditioning=condition,
186 | verbose=False,
187 | eta=1.0,
188 | log_every_t=1)
189 |
190 | # decode VAE latent z_0 to image x_0
191 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256]
192 |
193 | # ========== save outputs ==========
194 | for idx in range(args.batch_size):
195 |
196 | # save synthesized image x_0
197 | save_x_0_path = os.path.join(save_sub_folder,
198 | f'{str(idx).zfill(6)}_x_0.png')
199 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256]
200 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy()
201 | x_0 = (x_0 + 1.0) * 127.5
202 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255
203 | x_0 = x_0.astype(np.uint8)
204 | x_0 = Image.fromarray(x_0[0])
205 | x_0.save(save_x_0_path)
206 |
207 | # save intermediate x_t and pred_x_0
208 | if args.display_x_inter:
209 | for cond_name in ['x_inter', 'pred_x0']:
210 | save_conf_path = os.path.join(
211 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png')
212 | conf = intermediates[f'{cond_name}']
213 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64
214 | conf = conf[:, idx, :, :, :] # 50x3x64x64
215 | print('decoding x_inter ......')
216 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256]
217 | conf = make_grid(
218 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10]
219 | conf = conf.permute(1, 2,
220 | 0).to('cpu').numpy() # cxhxh -> hxhxc
221 | conf = (conf + 1.0) * 127.5
222 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
223 | conf = conf.astype(np.uint8)
224 | conf = Image.fromarray(conf)
225 | conf.save(save_conf_path)
226 |
227 | # save influence functions
228 | if args.return_influence_function:
229 | for cond_name in ['seg_mask', 'text']:
230 | save_conf_path = os.path.join(
231 | save_sub_folder,
232 | f'{str(idx).zfill(6)}_{cond_name}_influence_function.png')
233 | conf = intermediates[f'{cond_name}_confidence_map']
234 | conf = torch.stack(conf, dim=0) # 50x8x1x64x64
235 | conf = conf[:, idx, :, :, :] # 50x1x64x64
236 | conf = torch.cat(
237 | [conf, conf, conf],
238 | dim=1) # manually create 3 channels # [50, 3, 64, 64]
239 | conf = make_grid(
240 | conf, nrow=10) # 10 images per row # [3, 332, 662]
241 | conf = conf.permute(1, 2,
242 | 0).to('cpu').numpy() # cxhxh -> hxhxc
243 | conf = conf * 255 # manually tuned denormalization: [0,1] -> [0,255]
244 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
245 | conf = conf.astype(np.uint8)
246 | conf = Image.fromarray(conf)
247 | conf.save(save_conf_path)
248 |
249 | # save latent z_0
250 | if args.save_z:
251 | save_z_0_path = os.path.join(save_sub_folder,
252 | f'{str(idx).zfill(6)}_z_0.png')
253 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64]
254 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy()
255 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization
256 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255
257 | z_0 = z_0.astype(np.uint8)
258 | z_0 = Image.fromarray(z_0[0])
259 | z_0.save(save_z_0_path)
260 |
261 | # overlay the segmentation mask on the synthesized image to visualize mask consistency
262 | if args.save_mixed:
263 | save_mixed_path = os.path.join(save_sub_folder,
264 | f'{str(idx).zfill(6)}_mixed.png')
265 | Image.blend(x_0, mask_, 0.3).save(save_mixed_path)
266 |
267 |
268 | if __name__ == "__main__":
269 | main()
--------------------------------------------------------------------------------
/editing/imagic_edit_text.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 | from pathlib import Path
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | from omegaconf import OmegaConf
10 | from PIL import Image
11 | from torchvision import transforms
12 | from tqdm import tqdm
13 |
14 | from einops import rearrange
15 | from ldm.models.diffusion.ddim import DDIMSampler
16 | from ldm.util import instantiate_from_config
17 | """
18 | Reference code:
19 | https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
20 | """
21 |
22 | parser = argparse.ArgumentParser()
23 |
24 | # directories
25 | parser.add_argument('--config', type=str, default='configs/256_text.yaml')
26 | parser.add_argument('--ckpt', type=str, default='pretrained/256_text.ckpt')
27 | parser.add_argument('--save_folder', type=str, default='outputs/text_edit')
28 | parser.add_argument(
29 | '--input_image_path',
30 | type=str,
31 | default='test_data/test_mask_edit/256_input_image/27044.jpg')
32 | parser.add_argument(
33 | '--text_prompt',
34 | type=str,
35 | default='He is a teen. The face is covered with short pointed beard.')
36 |
37 | # hyperparameters
38 | parser.add_argument('--seed', type=int, default=0)
39 | parser.add_argument('--stage1_lr', type=float, default=0.001)
40 | parser.add_argument('--stage1_num_iter', type=int, default=500)
41 | parser.add_argument('--stage2_lr', type=float, default=1e-6)
42 | parser.add_argument('--stage2_num_iter', type=int, default=1000)
43 | parser.add_argument(
44 | '--alpha_list',
45 | type=str,
46 | default='-1, -0.5, 0, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5')
47 | parser.add_argument('--set_random_seed', type=bool, default=False)
48 | parser.add_argument('--save_checkpoint', type=bool, default=True)
49 |
50 | args = parser.parse_args()
51 |
52 |
53 | def load_model_from_config(config, ckpt, device="cpu", verbose=False):
54 | """Loads a model from config and a ckpt
55 | if config is a path will use omegaconf to load
56 | """
57 | if isinstance(config, (str, Path)):
58 | config = OmegaConf.load(config)
59 |
60 | pl_sd = torch.load(ckpt, map_location="cpu")
61 | global_step = pl_sd["global_step"]
62 | sd = pl_sd["state_dict"]
63 | model = instantiate_from_config(config.model)
64 | m, u = model.load_state_dict(sd, strict=False)
65 | model.to(device)
66 | model.eval()
67 | model.cond_stage_model.device = device
68 | return model
69 |
70 |
71 | @torch.no_grad()
72 | def sample_model(model,
73 | sampler,
74 | c,
75 | h,
76 | w,
77 | ddim_steps,
78 | scale,
79 | ddim_eta,
80 | start_code=None,
81 | n_samples=1):
82 | """Sample the model"""
83 | uc = None
84 | if scale != 1.0:
85 | uc = model.get_learned_conditioning(n_samples * [""])
86 |
87 | # print(f'model.model.parameters(): {model.model.parameters()}')
88 | # for name, param in model.model.named_parameters():
89 | # if param.requires_grad:
90 | # print (name, param.data)
91 | # break
92 |
93 | # print(f'unconditional_guidance_scale: {scale}') # 1.0
94 | # print(f'unconditional_conditioning: {uc}') # None
95 | with model.ema_scope("Plotting"):
96 |
97 | shape = [3, 64, 64] # [4, h // 8, w // 8]
98 | samples_ddim, _ = sampler.sample(
99 | S=ddim_steps,
100 | conditioning=c,
101 | batch_size=n_samples,
102 | shape=shape,
103 | verbose=False,
104 | start_code=start_code,
105 | unconditional_guidance_scale=scale,
106 | unconditional_conditioning=uc,
107 | eta=ddim_eta,
108 | )
109 | return samples_ddim
110 |
111 |
112 | def load_img(path, target_size=256):
113 | """Load an image, resize and output -1..1"""
114 | image = Image.open(path).convert("RGB")
115 |
116 | tform = transforms.Compose([
117 | # transforms.Resize(target_size),
118 | # transforms.CenterCrop(target_size),
119 | transforms.ToTensor(),
120 | ])
121 | image = tform(image)
122 | return 2. * image - 1.
123 |
124 |
125 | def decode_to_im(samples, n_samples=1, nrow=1):
126 | """Decode a latent and return PIL image"""
127 | samples = model.decode_first_stage(samples)
128 | ims = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
129 | x_sample = 255. * rearrange(
130 | ims.cpu().numpy(),
131 | '(n1 n2) c h w -> (n1 h) (n2 w) c',
132 | n1=n_samples // nrow,
133 | n2=nrow)
134 | return Image.fromarray(x_sample.astype(np.uint8))
135 |
136 |
137 | if __name__ == '__main__':
138 |
139 | args.alpha_list = [float(i) for i in args.alpha_list.split(',')]
140 |
141 | device = "cuda" # "cuda:0"
142 |
143 | # Generation parameters
144 | scale = 1.0
145 | h = 256
146 | w = 256
147 | ddim_steps = 50
148 | ddim_eta = 1.0
149 |
150 | # initialize model
151 | global_model = load_model_from_config(args.config, args.ckpt, device)
152 |
153 | input_image = args.input_image_path
154 | image_name = input_image.split('/')[-1]
155 |
156 | prompt = args.text_prompt
157 |
158 | torch.manual_seed(args.seed)
159 |
160 | model = copy.deepcopy(global_model)
161 | sampler = DDIMSampler(model)
162 |
163 | # prepare directories
164 | save_dir = os.path.join(args.save_folder, image_name, str(prompt))
165 | os.makedirs(save_dir, exist_ok=True)
166 | print(
167 | f'================================================================================'
168 | )
169 | print(f'input_image: {input_image} | text: {prompt}')
170 |
171 | # read input image
172 | init_image = load_img(input_image).to(device).unsqueeze(
173 | 0) # [1, 3, 256, 256]
174 | gaussian_distribution = model.encode_first_stage(init_image)
175 | init_latent = model.get_first_stage_encoding(
176 | gaussian_distribution) # [1, 3, 64, 64]
177 | img = decode_to_im(init_latent)
178 | img.save(os.path.join(save_dir, 'input_image_reconstructed.png'))
179 |
180 | # obtain text embedding
181 | emb_tgt = model.get_learned_conditioning([prompt])
182 | emb_ = emb_tgt.clone()
183 | torch.save(emb_, os.path.join(save_dir, 'emb_tgt.pt'))
184 | emb = torch.load(os.path.join(save_dir, 'emb_tgt.pt')) # [1, 77, 640]
185 |
186 | # Sample the model with a fixed code to see what it looks like
187 | quick_sample = lambda x, s, code: decode_to_im(
188 | sample_model(
189 | model, sampler, x, h, w, ddim_steps, s, ddim_eta, start_code=code))
190 | # start_code = torch.randn_like(init_latent)
191 | start_code = torch.randn((1, 3, 64, 64), device=device)
192 | torch.save(start_code, os.path.join(save_dir,
193 | 'start_code.pt')) # [1, 3, 64, 64]
194 | torch.manual_seed(args.seed)
195 | img = quick_sample(emb_tgt, scale, start_code)
196 | img.save(os.path.join(save_dir, 'A_start_tgtText_origDM.png'))
197 |
198 | # ======================= (A) Text Embedding Optimization ===================================
199 | print('########### Step 1 - Optimise the embedding ###########')
200 | emb.requires_grad = True
201 | opt = torch.optim.Adam([emb], lr=args.stage1_lr)
202 | criteria = torch.nn.MSELoss()
203 | history = []
204 |
205 | pbar = tqdm(range(args.stage1_num_iter))
206 | for i in pbar:
207 | opt.zero_grad()
208 |
209 | if args.set_random_seed:
210 | torch.seed()
211 | noise = torch.randn_like(init_latent)
212 | t_enc = torch.randint(1000, (1, ), device=device)
213 | z = model.q_sample(init_latent, t_enc, noise=noise)
214 |
215 | pred_noise = model.apply_model(z, t_enc, emb)
216 |
217 | loss = criteria(pred_noise, noise)
218 | loss.backward()
219 | pbar.set_postfix({"loss": loss.item()})
220 | history.append(loss.item())
221 | opt.step()
222 |
223 | plt.plot(history)
224 | plt.show()
225 | torch.save(emb, os.path.join(save_dir, 'emb_opt.pt'))
226 | emb_opt = torch.load(os.path.join(save_dir, 'emb_opt.pt')) # [1, 77, 640]
227 |
228 | torch.manual_seed(args.seed)
229 | img = quick_sample(emb_opt, scale, start_code)
230 | img.save(os.path.join(save_dir, 'A_end_optText_origDM.png'))
231 |
232 | # Interpolate the embedding
233 | for idx, alpha in enumerate(args.alpha_list):
234 | print(f'alpha={alpha}')
235 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt
236 | torch.manual_seed(args.seed)
237 | img = quick_sample(new_emb, scale, start_code)
238 | img.save(
239 | os.path.join(
240 | save_dir,
241 | f'0A_interText_origDM_{idx}_alpha={round(alpha,3)}.png'))
242 |
243 | # ======================= (B) Model Fine-Tuning ===================================
244 | print('########### Step 2 - Fine tune the model ###########')
245 | emb_opt.requires_grad = False
246 | model.train()
247 |
248 | opt = torch.optim.Adam(model.model.parameters(), lr=args.stage2_lr)
249 | criteria = torch.nn.MSELoss()
250 | history = []
251 |
252 | pbar = tqdm(range(args.stage2_num_iter))
253 | for i in pbar:
254 | opt.zero_grad()
255 |
256 | if args.set_random_seed:
257 | torch.seed()
258 | noise = torch.randn_like(init_latent)
259 | t_enc = torch.randint(model.num_timesteps, (1, ), device=device)
260 | z = model.q_sample(init_latent, t_enc, noise=noise)
261 |
262 | pred_noise = model.apply_model(z, t_enc, emb_opt)
263 |
264 | loss = criteria(pred_noise, noise)
265 | loss.backward()
266 | pbar.set_postfix({"loss": loss.item()})
267 | history.append(loss.item())
268 | opt.step()
269 |
270 | model.eval()
271 | plt.plot(history)
272 | plt.show()
273 | torch.manual_seed(args.seed)
274 | img = quick_sample(emb_opt, scale, start_code)
275 | img.save(os.path.join(save_dir, 'B_end_optText_optDM.png'))
276 | # Should look like the original image
277 |
278 | if args.save_checkpoint:
279 | ckpt = {
280 | "state_dict": model.state_dict(),
281 | }
282 | ckpt_path = os.path.join(save_dir, 'optDM.ckpt')
283 | print(f'Saving optDM to {ckpt_path}')
284 | torch.save(ckpt, ckpt_path)
285 |
286 | # ======================= (C) Generation ===================================
287 | print('########### Step 3 - Generate images ###########')
288 | # Interpolate the embedding
289 | for idx, alpha in enumerate(args.alpha_list):
290 | print(f'alpha={alpha}')
291 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt
292 | torch.manual_seed(args.seed)
293 | img = quick_sample(new_emb, scale, start_code)
294 | img.save(
295 | os.path.join(
296 | save_dir,
297 | f'0C_interText_optDM_{idx}_alpha={round(alpha,3)}.png'))
298 |
299 | print('Done')
--------------------------------------------------------------------------------
/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/editing/collaborative_edit.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | from omegaconf import OmegaConf
8 | from PIL import Image
9 | from torchvision.utils import make_grid
10 |
11 | from ldm.models.diffusion.ddim_confidence import DDIMConfidenceSampler
12 | from ldm.util import instantiate_from_config
13 | """
14 | Inference script for multi-modal-driven face editing at 256x256 resolution
15 | """
16 |
17 |
18 | def parse_args():
19 |
20 | parser = argparse.ArgumentParser(description="")
21 |
22 | # uni-modal editing results
23 | parser.add_argument(
24 | "--imagic_text_folder",
25 | type=str,
26 | default=
27 | "outputs/text_edit/27044.jpg/He is a teen. The face is covered with short pointed beard.",
28 | help="path to Imagic text-based editing results")
29 | parser.add_argument(
30 | "--imagic_mask_folder",
31 | type=str,
32 | default=
33 | "outputs/mask_edit/27044.jpg/27044_0_remove_smile_and_rings.png",
34 | help="path to Imagic mask-based editing results")
35 |
36 | # paths
37 | parser.add_argument(
38 | "--config_path",
39 | type=str,
40 | default="configs/256_codiff_mask_text.yaml",
41 | help="path to model config")
42 | parser.add_argument(
43 | "--ckpt_path",
44 | type=str,
45 | default="pretrained/256_codiff_mask_text.ckpt",
46 | help="path to model checkpoint")
47 | parser.add_argument(
48 | "--save_folder",
49 | type=str,
50 | default="outputs/collaborative_edit",
51 | help="folder to save editing outputs")
52 |
53 | # batch size and ddim steps
54 | parser.add_argument(
55 | "--batch_size",
56 | type=int,
57 | default=1,
58 | help="number of images to generate")
59 | parser.add_argument(
60 | "--ddim_steps",
61 | type=int,
62 | default=50,
63 | help=
64 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)"
65 | )
66 | parser.add_argument(
67 | "--seed",
68 | type=int,
69 | default=2,
70 | )
71 |
72 | # whether save intermediate outputs
73 | parser.add_argument(
74 | "--save_z",
75 | type=bool,
76 | default=False,
77 | help=
78 | "whether visualize the VAE latent codes and save them in the output folder",
79 | )
80 | parser.add_argument(
81 | "--return_influence_function",
82 | type=bool,
83 | default=False,
84 | help=
85 | "whether visualize the Influence Functions and save them in the output folder",
86 | )
87 | parser.add_argument(
88 | "--display_x_inter",
89 | type=bool,
90 | default=False,
91 | help=
92 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder",
93 | )
94 |
95 | args = parser.parse_args()
96 | return args
97 |
98 |
99 | def main():
100 |
101 | args = parse_args()
102 |
103 | # ========== set output directory ==========
104 | os.makedirs(args.save_folder, exist_ok=True)
105 |
106 | # ========== init model ==========
107 | config = OmegaConf.load(args.config_path)
108 | model_config = config['model']
109 | model_config['params']['seg_mask_ldm_ckpt_path'] = os.path.join(
110 | args.imagic_mask_folder, 'optDM.ckpt')
111 | model_config['params']['text_ldm_ckpt_path'] = os.path.join(
112 | args.imagic_text_folder, 'optDM.ckpt')
113 | model = instantiate_from_config(model_config)
114 | model.init_from_ckpt(args.ckpt_path)
115 | mask_optDM_ckpt = torch.load(
116 | os.path.join(args.imagic_mask_folder, 'optDM.ckpt'))
117 | text_optDM_ckpt = torch.load(
118 | os.path.join(args.imagic_text_folder, 'optDM.ckpt'))
119 |
120 | print('Updating text and mask model to the finetuned version ......')
121 | state_dict = model.state_dict()
122 | for name in state_dict.keys():
123 | if 'model.seg_mask_unet.' in name:
124 | name_end = name[20:]
125 | state_dict[name] = mask_optDM_ckpt['state_dict'][
126 | f'model.{name_end}']
127 | elif 'model.text_unet.' in name:
128 | name_end = name[16:]
129 | state_dict[name] = text_optDM_ckpt['state_dict'][
130 | f'model.{name_end}']
131 | model.load_state_dict(state_dict)
132 |
133 | print('Pushing model to CUDA ......')
134 | global_model = model.cuda()
135 | global_model.eval()
136 |
137 | seed = args.seed
138 |
139 | for alpha_idx, alpha in enumerate([
140 | 0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1, 1.125, 1.25,
141 | 1.375, 1.5, 1.625, 1.75, 1.875, 2.0, 2.5
142 | ]):
143 |
144 | print(f'alpha={alpha}')
145 |
146 | seed = int(seed)
147 | torch.manual_seed(seed)
148 | mask_start_code = torch.load(
149 | os.path.join(args.imagic_mask_folder, 'start_code.pt'))
150 | text_start_code = torch.load(
151 | os.path.join(args.imagic_text_folder, 'start_code.pt'))
152 | start_code = mask_start_code
153 |
154 | # prepare directories
155 | save_sub_folder = os.path.join(args.save_folder, f'seed={seed}')
156 | os.makedirs(save_sub_folder, exist_ok=True)
157 |
158 | seed = int(seed)
159 | model = copy.deepcopy(global_model)
160 |
161 | # ========== inference ==========
162 | with torch.no_grad():
163 |
164 | with model.ema_scope("Plotting"):
165 |
166 | condition = {}
167 |
168 | mask_alpha = alpha
169 | mask_emb_tgt = torch.load(
170 | os.path.join(args.imagic_mask_folder, 'emb_tgt.pt'))
171 | mask_emb_opt = torch.load(
172 | os.path.join(args.imagic_mask_folder, 'emb_opt.pt'))
173 | mask_new_emb = mask_alpha * mask_emb_tgt + (
174 | 1 - mask_alpha) * mask_emb_opt
175 | condition['seg_mask'] = mask_new_emb.repeat(
176 | args.batch_size, 1, 1)
177 |
178 | text_alpha = alpha
179 | text_emb_tgt = torch.load(
180 | os.path.join(args.imagic_text_folder, 'emb_tgt.pt'))
181 | text_emb_opt = torch.load(
182 | os.path.join(args.imagic_text_folder, 'emb_opt.pt'))
183 | text_new_emb = text_alpha * text_emb_tgt + (
184 | 1 - text_alpha) * text_emb_opt
185 | condition['text'] = text_new_emb.repeat(args.batch_size, 1, 1)
186 |
187 | torch.manual_seed(seed)
188 |
189 | ddim_sampler = DDIMConfidenceSampler(
190 | model=model,
191 | return_confidence_map=args.return_influence_function)
192 |
193 | torch.manual_seed(seed)
194 |
195 | z_0_batch, intermediates = ddim_sampler.sample(
196 | S=args.ddim_steps,
197 | batch_size=args.batch_size,
198 | shape=(3, 64, 64),
199 | conditioning=condition,
200 | verbose=False,
201 | start_code=start_code,
202 | eta=1.0,
203 | log_every_t=1)
204 |
205 | # decode VAE latent z_0 to image x_0
206 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256]
207 |
208 | # ========== save outputs ==========
209 | for idx in range(args.batch_size):
210 |
211 | # save synthesized image x_0
212 | save_x_0_path = os.path.join(
213 | save_sub_folder,
214 | f'{str(idx).zfill(6)}_x_0_{alpha_idx}_alpha={round(alpha, 3)}.png'
215 | )
216 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256]
217 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy()
218 | x_0 = (x_0 + 1.0) * 127.5
219 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255
220 | x_0 = x_0.astype(np.uint8)
221 | x_0 = Image.fromarray(x_0[0])
222 | x_0.save(save_x_0_path)
223 |
224 | # save intermediate x_t and pred_x_0
225 | if args.display_x_inter:
226 | for cond_name in ['x_inter', 'pred_x0']:
227 | save_conf_path = os.path.join(
228 | save_sub_folder,
229 | f'{str(idx).zfill(6)}_{cond_name}.png')
230 | conf = intermediates[f'{cond_name}']
231 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64
232 | conf = conf[:, idx, :, :, :] # 50x3x64x64
233 | print('decoding x_inter ......')
234 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256]
235 | conf = make_grid(
236 | conf,
237 | nrow=10) # 10 images per row # [3, 256x3, 256x10]
238 | conf = conf.permute(1, 2,
239 | 0).to('cpu').numpy() # cxhxh -> hxhxc
240 | conf = (conf + 1.0) * 127.5
241 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
242 | conf = conf.astype(np.uint8)
243 | conf = Image.fromarray(conf)
244 | conf.save(save_conf_path)
245 |
246 | # save influence functions
247 | if args.return_influence_function:
248 | for cond_name in ['seg_mask', 'text']:
249 | save_conf_path = os.path.join(
250 | save_sub_folder,
251 | f'{str(idx).zfill(6)}_{cond_name}_influence_function.png'
252 | )
253 | conf = intermediates[f'{cond_name}_confidence_map']
254 | conf = torch.stack(conf, dim=0) # 50x8x1x64x64
255 | conf = conf[:, idx, :, :, :] # 50x1x64x64
256 | conf = torch.cat(
257 | [conf, conf, conf],
258 | dim=1) # manually create 3 channels # [50, 3, 64, 64]
259 | conf = make_grid(
260 | conf, nrow=10) # 10 images per row # [3, 332, 662]
261 | conf = conf.permute(1, 2,
262 | 0).to('cpu').numpy() # cxhxh -> hxhxc
263 | conf = conf * 255 # manually tuned denormalization: [0,1] -> [0,255]
264 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255
265 | conf = conf.astype(np.uint8)
266 | conf = Image.fromarray(conf)
267 | conf.save(save_conf_path)
268 |
269 | # save latent z_0
270 | if args.save_z:
271 | save_z_0_path = os.path.join(save_sub_folder,
272 | f'{str(idx).zfill(6)}_z_0.png')
273 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64]
274 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy()
275 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization
276 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255
277 | z_0 = z_0.astype(np.uint8)
278 | z_0 = Image.fromarray(z_0[0])
279 | z_0.save(save_z_0_path)
280 |
281 |
282 | if __name__ == "__main__":
283 | main()
--------------------------------------------------------------------------------
/ldm/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """DDIM sampling using one uni-modal pre-trained diffusion models"""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class DDIMSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | self.register_buffer('betas', to_torch(self.model.betas))
32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34 |
35 | # calculations for diffusion q(x_t | x_{t-1}) and others
36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41 |
42 | # ddim sampling parameters
43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44 | ddim_timesteps=self.ddim_timesteps,
45 | eta=ddim_eta,verbose=verbose)
46 | self.register_buffer('ddim_sigmas', ddim_sigmas)
47 | self.register_buffer('ddim_alphas', ddim_alphas)
48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54 |
55 | @torch.no_grad()
56 | def sample(self,
57 | S,
58 | batch_size,
59 | shape,
60 | conditioning=None,
61 | callback=None,
62 | normals_sequence=None,
63 | img_callback=None,
64 | quantize_x0=False,
65 | eta=0.,
66 | mask=None,
67 | x0=None,
68 | temperature=1.,
69 | noise_dropout=0.,
70 | score_corrector=None,
71 | corrector_kwargs=None,
72 | verbose=True,
73 | x_T=None,
74 | log_every_t=100,
75 | unconditional_guidance_scale=1.,
76 | unconditional_conditioning=None,
77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78 | **kwargs
79 | ):
80 |
81 | if conditioning is not None:
82 | if isinstance(conditioning, dict):
83 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84 | if cbs != batch_size:
85 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86 | else:
87 | if conditioning.shape[0] != batch_size:
88 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89 |
90 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91 | # sampling
92 | C, H, W = shape
93 | size = (batch_size, C, H, W)
94 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
95 |
96 | samples, intermediates = self.ddim_sampling(conditioning, size,
97 | callback=callback,
98 | img_callback=img_callback,
99 | quantize_denoised=quantize_x0,
100 | mask=mask, x0=x0,
101 | ddim_use_original_steps=False,
102 | noise_dropout=noise_dropout,
103 | temperature=temperature,
104 | score_corrector=score_corrector,
105 | corrector_kwargs=corrector_kwargs,
106 | x_T=x_T,
107 | log_every_t=log_every_t,
108 | unconditional_guidance_scale=unconditional_guidance_scale,
109 | unconditional_conditioning=unconditional_conditioning,
110 | )
111 | return samples, intermediates
112 |
113 | @torch.no_grad()
114 | def ddim_sampling(self, cond, shape,
115 | x_T=None, ddim_use_original_steps=False,
116 | callback=None, timesteps=None, quantize_denoised=False,
117 | mask=None, x0=None, img_callback=None, log_every_t=100,
118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
120 | device = self.model.betas.device
121 | b = shape[0]
122 | if x_T is None:
123 | img = torch.randn(shape, device=device)
124 | else:
125 | img = x_T
126 |
127 | if timesteps is None:
128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
129 | elif timesteps is not None and not ddim_use_original_steps:
130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
131 | timesteps = self.ddim_timesteps[:subset_end]
132 |
133 | intermediates = {'x_inter': [], 'pred_x0': []}
134 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
135 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
136 | print(f"Running DDIM Sampling with {total_steps} timesteps")
137 |
138 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
139 |
140 | for i, step in enumerate(iterator):
141 | index = total_steps - i - 1
142 | ts = torch.full((b,), step, device=device, dtype=torch.long)
143 |
144 | if mask is not None:
145 | assert x0 is not None
146 | img_orig = self.model.q_sample(x0, ts)
147 | img = img_orig * mask + (1. - mask) * img
148 |
149 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
150 | quantize_denoised=quantize_denoised, temperature=temperature,
151 | noise_dropout=noise_dropout, score_corrector=score_corrector,
152 | corrector_kwargs=corrector_kwargs,
153 | unconditional_guidance_scale=unconditional_guidance_scale,
154 | unconditional_conditioning=unconditional_conditioning)
155 | img, pred_x0 = outs
156 |
157 | if callback: callback(i)
158 | if img_callback: img_callback(pred_x0, i)
159 |
160 | if index % log_every_t == 0 or index == total_steps - 1:
161 | intermediates['x_inter'].append(img)
162 | intermediates['pred_x0'].append(pred_x0)
163 |
164 | return img, intermediates
165 |
166 | @torch.no_grad()
167 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
168 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
169 | unconditional_guidance_scale=1., unconditional_conditioning=None):
170 | b, *_, device = *x.shape, x.device
171 |
172 |
173 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
174 | e_t = self.model.apply_model(x, t, c)
175 | else:
176 | x_in = torch.cat([x] * 2)
177 | t_in = torch.cat([t] * 2)
178 | c_in = torch.cat([unconditional_conditioning, c])
179 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
180 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
181 |
182 | if score_corrector is not None:
183 | assert self.model.parameterization == "eps"
184 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
185 |
186 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
187 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
188 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
189 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
190 | # select parameters corresponding to the currently considered timestep
191 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
192 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
193 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
194 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
195 |
196 | # current prediction for x_0
197 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
198 | if quantize_denoised:
199 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
200 | # direction pointing to x_t
201 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
202 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
203 | if noise_dropout > 0.:
204 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
205 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
206 |
207 | return x_prev, pred_x0
208 |
--------------------------------------------------------------------------------
/editing/imagic_edit_mask.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 | from pathlib import Path
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from omegaconf import OmegaConf
11 | from PIL import Image
12 | from torchvision import transforms
13 | from tqdm import tqdm
14 |
15 | from einops import rearrange
16 | from ldm.models.diffusion.ddim import DDIMSampler
17 | from ldm.util import instantiate_from_config
18 | """
19 | Reference code:
20 | https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
21 | """
22 |
23 | parser = argparse.ArgumentParser()
24 |
25 | # directories
26 | parser.add_argument('--config', type=str, default='configs/256_mask.yaml')
27 | parser.add_argument('--ckpt', type=str, default='pretrained/256_mask.ckpt')
28 | parser.add_argument('--save_folder', type=str, default='outputs/mask_edit')
29 | parser.add_argument(
30 | '--input_image_path',
31 | type=str,
32 | default='test_data/test_mask_edit/256_input_image/27044.jpg')
33 | parser.add_argument(
34 | '--mask_path',
35 | type=str,
36 | default=
37 | 'test_data/test_mask_edit/256_edited_masks/27044_0_remove_smile_and_rings.png'
38 | )
39 |
40 | # hyperparameters
41 | parser.add_argument('--seed', type=int, default=0)
42 | parser.add_argument('--stage1_lr', type=float, default=0.001)
43 | parser.add_argument('--stage1_num_iter', type=int, default=500)
44 | parser.add_argument('--stage2_lr', type=float, default=1e-6)
45 | parser.add_argument('--stage2_num_iter', type=int, default=1000)
46 | parser.add_argument(
47 | '--alpha_list',
48 | type=str,
49 | default='-1, -0.5, 0, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5')
50 | parser.add_argument('--set_random_seed', type=bool, default=False)
51 | parser.add_argument('--save_checkpoint', type=bool, default=True)
52 |
53 | args = parser.parse_args()
54 |
55 |
56 | def load_model_from_config(config, ckpt, device="cpu", verbose=False):
57 | """Loads a model from config and a ckpt
58 | if config is a path will use omegaconf to load
59 | """
60 | if isinstance(config, (str, Path)):
61 | config = OmegaConf.load(config)
62 |
63 | pl_sd = torch.load(ckpt, map_location="cpu")
64 | global_step = pl_sd["global_step"]
65 | sd = pl_sd["state_dict"]
66 | model = instantiate_from_config(config.model)
67 | m, u = model.load_state_dict(sd, strict=False)
68 | model.to(device)
69 | model.eval()
70 | model.cond_stage_model.device = device
71 | return model
72 |
73 |
74 | @torch.no_grad()
75 | def sample_model(model,
76 | sampler,
77 | c,
78 | h,
79 | w,
80 | ddim_steps,
81 | scale,
82 | ddim_eta,
83 | start_code=None,
84 | n_samples=1):
85 | """Sample the model"""
86 | uc = None
87 | if scale != 1.0:
88 | uc = model.get_learned_conditioning(n_samples * [""])
89 |
90 | # print(f'model.model.parameters(): {model.model.parameters()}')
91 | # for name, param in model.model.named_parameters():
92 | # if param.requires_grad:
93 | # print (name, param.data)
94 | # break
95 |
96 | # print(f'unconditional_guidance_scale: {scale}') # 1.0
97 | # print(f'unconditional_conditioning: {uc}') # None
98 | with model.ema_scope("Plotting"):
99 |
100 | shape = [3, 64, 64] # [4, h // 8, w // 8]
101 | samples_ddim, _ = sampler.sample(
102 | S=ddim_steps,
103 | conditioning=c,
104 | batch_size=n_samples,
105 | shape=shape,
106 | verbose=False,
107 | start_code=start_code,
108 | unconditional_guidance_scale=scale,
109 | unconditional_conditioning=uc,
110 | eta=ddim_eta,
111 | )
112 | return samples_ddim
113 |
114 |
115 | def load_img(path, target_size=256):
116 | """Load an image, resize and output -1..1"""
117 | image = Image.open(path).convert("RGB")
118 |
119 | tform = transforms.Compose([
120 | # transforms.Resize(target_size),
121 | # transforms.CenterCrop(target_size),
122 | transforms.ToTensor(),
123 | ])
124 | image = tform(image)
125 | return 2. * image - 1.
126 |
127 |
128 | def decode_to_im(samples, n_samples=1, nrow=1):
129 | """Decode a latent and return PIL image"""
130 | samples = model.decode_first_stage(samples)
131 | ims = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
132 | x_sample = 255. * rearrange(
133 | ims.cpu().numpy(),
134 | '(n1 n2) c h w -> (n1 h) (n2 w) c',
135 | n1=n_samples // nrow,
136 | n2=nrow)
137 | return Image.fromarray(x_sample.astype(np.uint8))
138 |
139 |
140 | if __name__ == '__main__':
141 |
142 | args.alpha_list = [float(i) for i in args.alpha_list.split(',')]
143 |
144 | device = "cuda" # "cuda:0"
145 |
146 | # Generation parameters
147 | scale = 1.0
148 | h = 256
149 | w = 256
150 | ddim_steps = 50
151 | ddim_eta = 1.0
152 |
153 | # initialize model
154 | global_model = load_model_from_config(args.config, args.ckpt, device)
155 |
156 | input_image = args.input_image_path
157 | image_name = input_image.split('/')[-1]
158 | mask_name = args.mask_path.split('/')[-1]
159 |
160 | torch.manual_seed(args.seed)
161 |
162 | model = copy.deepcopy(global_model)
163 | sampler = DDIMSampler(model)
164 |
165 | # prepare directories
166 | save_dir = os.path.join(args.save_folder, image_name, str(mask_name))
167 | print(f'save_dir = {save_dir}')
168 |
169 | os.makedirs(save_dir, exist_ok=True)
170 | print(
171 | f'================================================================================'
172 | )
173 | print(f'input_image: {input_image} | mask_name: {mask_name}')
174 |
175 | # read input image
176 | init_image = load_img(input_image).to(device).unsqueeze(
177 | 0) # [1, 3, 256, 256]
178 | gaussian_distribution = model.encode_first_stage(init_image)
179 | init_latent = model.get_first_stage_encoding(
180 | gaussian_distribution) # [1, 3, 64, 64]
181 | img = decode_to_im(init_latent)
182 | img.save(os.path.join(save_dir, 'input_image_reconstructed.png'))
183 |
184 | # obtain mask embedding
185 | # ========== prepare mask (one-hot): resize to 32x32 then one-hot ==========
186 | with open(args.mask_path, 'rb') as f:
187 | img = Image.open(f)
188 | resized_img = img.resize((32, 32), Image.NEAREST) # resize
189 | flattened_img = list(resized_img.getdata())
190 | flattened_img_tensor = torch.tensor(flattened_img) # flatten
191 | flattened_img_tensor_one_hot = F.one_hot(
192 | flattened_img_tensor, num_classes=19) # one hot
193 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose(
194 | 0, 1)
195 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze(
196 | flattened_img_tensor_one_hot_transpose,
197 | 0).cuda() # add batch dimension
198 |
199 | # ========== prepare mask (image) ==========
200 | mask = Image.open(args.mask_path)
201 | mask = mask.convert('RGB')
202 | mask = np.array(mask).astype(np.uint8) # three channel integer
203 | input_mask = mask
204 |
205 | # save seg_mask
206 | mask_ = Image.fromarray(input_mask)
207 | mask_.save(os.path.join(save_dir, mask_name))
208 |
209 | condition = flattened_img_tensor_one_hot_transpose
210 | emb_tgt = model.get_learned_conditioning(condition)
211 | emb_ = emb_tgt.clone()
212 | torch.save(emb_, os.path.join(save_dir, 'emb_tgt.pt'))
213 | emb = torch.load(os.path.join(save_dir, 'emb_tgt.pt')) # [1, 77, 640]
214 |
215 | # Sample the model with a fixed code to see what it looks like
216 | quick_sample = lambda x, s, code: decode_to_im(
217 | sample_model(
218 | model, sampler, x, h, w, ddim_steps, s, ddim_eta, start_code=code))
219 | # start_code = torch.randn_like(init_latent)
220 | start_code = torch.randn((1, 3, 64, 64), device=device)
221 | torch.save(start_code, os.path.join(save_dir,
222 | 'start_code.pt')) # [1, 3, 64, 64]
223 | torch.manual_seed(args.seed)
224 | img = quick_sample(emb_tgt, scale, start_code)
225 | img.save(os.path.join(save_dir, 'A_start_tgtText_origDM.png'))
226 |
227 | # ======================= (A) Text Embedding Optimization ===================================
228 | print('########### Step 1 - Optimise the embedding ###########')
229 | emb.requires_grad = True
230 | opt = torch.optim.Adam([emb], lr=args.stage1_lr)
231 | criteria = torch.nn.MSELoss()
232 | history = []
233 |
234 | pbar = tqdm(range(args.stage1_num_iter))
235 | for i in pbar:
236 | opt.zero_grad()
237 |
238 | if args.set_random_seed:
239 | torch.seed()
240 | noise = torch.randn_like(init_latent)
241 | t_enc = torch.randint(1000, (1, ), device=device)
242 | z = model.q_sample(init_latent, t_enc, noise=noise)
243 |
244 | pred_noise = model.apply_model(z, t_enc, emb)
245 |
246 | loss = criteria(pred_noise, noise)
247 | loss.backward()
248 | pbar.set_postfix({"loss": loss.item()})
249 | history.append(loss.item())
250 | opt.step()
251 |
252 | plt.plot(history)
253 | plt.show()
254 | torch.save(emb, os.path.join(save_dir, 'emb_opt.pt'))
255 | emb_opt = torch.load(os.path.join(save_dir, 'emb_opt.pt')) # [1, 77, 640]
256 |
257 | torch.manual_seed(args.seed)
258 | img = quick_sample(emb_opt, scale, start_code)
259 | img.save(os.path.join(save_dir, 'A_end_optText_origDM.png'))
260 |
261 | # Interpolate the embedding
262 | for idx, alpha in enumerate(args.alpha_list):
263 | print(f'alpha={alpha}')
264 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt
265 | torch.manual_seed(args.seed)
266 | img = quick_sample(new_emb, scale, start_code)
267 | img.save(
268 | os.path.join(
269 | save_dir,
270 | f'0A_interText_origDM_{idx}_alpha={round(alpha,3)}.png'))
271 |
272 | # ======================= (B) Model Fine-Tuning ===================================
273 | print('########### Step 2 - Fine tune the model ###########')
274 | emb_opt.requires_grad = False
275 | model.train()
276 |
277 | opt = torch.optim.Adam(model.model.parameters(), lr=args.stage2_lr)
278 | criteria = torch.nn.MSELoss()
279 | history = []
280 |
281 | pbar = tqdm(range(args.stage2_num_iter))
282 | for i in pbar:
283 | opt.zero_grad()
284 |
285 | if args.set_random_seed:
286 | torch.seed()
287 | noise = torch.randn_like(init_latent)
288 | t_enc = torch.randint(model.num_timesteps, (1, ), device=device)
289 | z = model.q_sample(init_latent, t_enc, noise=noise)
290 |
291 | pred_noise = model.apply_model(z, t_enc, emb_opt)
292 |
293 | loss = criteria(pred_noise, noise)
294 | loss.backward()
295 | pbar.set_postfix({"loss": loss.item()})
296 | history.append(loss.item())
297 | opt.step()
298 |
299 | model.eval()
300 | plt.plot(history)
301 | plt.show()
302 | torch.manual_seed(args.seed)
303 | img = quick_sample(emb_opt, scale, start_code)
304 | img.save(os.path.join(save_dir, 'B_end_optText_optDM.png'))
305 | # Should look like the original image
306 |
307 | if args.save_checkpoint:
308 | ckpt = {
309 | "state_dict": model.state_dict(),
310 | }
311 | ckpt_path = os.path.join(save_dir, 'optDM.ckpt')
312 | print(f'Saving optDM to {ckpt_path}')
313 | torch.save(ckpt, ckpt_path)
314 |
315 | # ======================= (C) Generation ===================================
316 | print('########### Step 3 - Generate images ###########')
317 | # Interpolate the embedding
318 | for idx, alpha in enumerate(args.alpha_list):
319 | print(f'alpha={alpha}')
320 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt
321 | torch.manual_seed(args.seed)
322 | img = quick_sample(new_emb, scale, start_code)
323 | img.save(
324 | os.path.join(
325 | save_dir,
326 | f'0C_interText_optDM_{idx}_alpha={round(alpha,3)}.png'))
327 |
328 | print('Done')
--------------------------------------------------------------------------------
/ldm/models/diffusion/compose_modules.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from functools import partial
3 | import math
4 | from typing import Iterable
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from ldm.modules.diffusionmodules.util import (
12 | checkpoint,
13 | conv_nd,
14 | linear,
15 | avg_pool_nd,
16 | zero_module,
17 | normalization,
18 | timestep_embedding,
19 | )
20 | from ldm.modules.attention import SpatialTransformer
21 |
22 | from ldm.modules.diffusionmodules.openaimodel import *
23 | from omegaconf import OmegaConf
24 |
25 | from ldm.models.diffusion.ddpm import DiffusionWrapper
26 | import random
27 |
28 | class ComposeUNet(nn.Module):
29 | """
30 | One diffusion denoising step (using dynamic diffusers)
31 | """
32 |
33 | def __init__(self,
34 | return_confidence_map=False,
35 | confidence_conditioning_key='crossattn',
36 | confidence_map_predictor_config=None,
37 | seg_mask_scale_factor = None,
38 | seg_mask_schedule = None,
39 | softmax_twice = False,
40 | boost_factor = 1.0,
41 | manual_prob = 1.0,
42 | return_each_branch = False,
43 | confidence_input = 'unet_output',
44 | conditions= ['seg_mask', 'text']
45 | ):
46 | super().__init__()
47 |
48 | self.conditions = conditions
49 | self.return_confidence_map = return_confidence_map
50 |
51 | # define dynamic diffusers
52 | if 'seg_mask' in self.conditions:
53 | self.seg_mask_confidence_predictor = DiffusionWrapper(confidence_map_predictor_config, confidence_conditioning_key)
54 | if 'text' in self.conditions:
55 | self.text_confidence_predictor = DiffusionWrapper(confidence_map_predictor_config, confidence_conditioning_key)
56 | if 'sketch' in self.conditions:
57 | self.sketch_confidence_predictor = DiffusionWrapper(confidence_map_predictor_config, confidence_conditioning_key)
58 |
59 | self.seg_mask_scale_factor = seg_mask_scale_factor #/ (seg_mask_scale_factor + text_scale_factor)
60 | self.seg_mask_schedule = seg_mask_schedule
61 | self.softmax_twice = softmax_twice
62 | self.boost_factor = boost_factor
63 | self.manual_prob = manual_prob
64 | self.return_each_branch = return_each_branch
65 |
66 | self.confidence_input = confidence_input
67 |
68 | def set_seg_mask_schedule(self, t):
69 | t = t[0]
70 | if self.seg_mask_schedule == None:
71 | schedule_scale = 1.0
72 | elif self.seg_mask_schedule == 'linear_decay':
73 | schedule_scale = t / 1000
74 | elif self.seg_mask_schedule == 'cosine_decay':
75 | pi = torch.acos(torch.zeros(1)).item() * 2
76 | schedule_scale = (torch.cos( torch.tensor((1-(t/1000)) * (pi)))+1)/2
77 | else:
78 | raise NotImplementedError
79 | return schedule_scale
80 |
81 |
82 | def forward(self, x, t, cond):
83 | """
84 | One diffusion denoising step (using dynamic diffusers)
85 |
86 | input:
87 | - x: noisy image x_t
88 | - t: timestep
89 | - cond = {'seg_mask': tensor, 'text': tensor, '...': ...}
90 | output:
91 | x_t-1
92 | """
93 |
94 | # compute individual branch's output using pretrained diffusion models
95 | if 'seg_mask' in self.conditions:
96 | seg_mask_unet_output = self.seg_mask_unet(x=x, t=t, c_crossattn=[cond['seg_mask']]) # [B, 3, 64, 64]
97 | if 'text' in self.conditions:
98 | text_unet_output = self.text_unet(x=x, t=t, c_crossattn=[cond['text']]) # [B, 3, 64, 64]
99 |
100 | if 'sketch' in self.conditions:
101 | sketch_unet_output = self.sketch_unet(x=x, t=t, c_crossattn=[cond['sketch']]) # [B, 3, 64, 64]
102 |
103 | # compute influence function for each branch using a dynamic diffuser for each branch
104 | if self.confidence_input == 'unet_output':
105 | if 'seg_mask' in self.conditions:
106 | seg_mask_confidence_map = self.seg_mask_confidence_predictor(x=seg_mask_unet_output, t=t, c_crossattn=[cond['seg_mask']]) # [B, 1, 64, 64]
107 | if 'text' in self.conditions:
108 | text_confidence_map = self.text_confidence_predictor(x=text_unet_output, t=t, c_crossattn=[cond['text']]) # [B, 1, 64, 64]
109 | if 'sketch' in self.conditions:
110 | sketch_confidence_map = self.sketch_confidence_predictor(x=sketch_unet_output, t=t, c_crossattn=[cond['sketch']]) # [B, 1, 64, 64]
111 | print('sketch forward')
112 |
113 | elif self.confidence_input == 'x_t':
114 | if 'seg_mask' in self.conditions:
115 | seg_mask_confidence_map = self.seg_mask_confidence_predictor(x=x, t=t, c_crossattn=[cond['seg_mask']]) # [B, 1, 64, 64]
116 | if 'text' in self.conditions:
117 | text_confidence_map = self.text_confidence_predictor(x=x, t=t, c_crossattn=[cond['text']]) # [B, 1, 64, 64]
118 | if 'sketch' in self.conditions:
119 | sketch_confidence_map = self.sketch_confidence_predictor(x=x, t=t, c_crossattn=[cond['sketch']]) # [B, 1, 64, 64]
120 |
121 | else:
122 | raise NotImplementedError
123 |
124 |
125 | # Use softmax to normalize the influence functions across all branches
126 | if ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' not in self.conditions):
127 | concat_map = torch.cat([seg_mask_confidence_map, text_confidence_map], dim=1) # first mask, then text # [B, 2, 64, 64]
128 | softmax_map = F.softmax(input=concat_map, dim=1) # [B, 2, 64, 64]
129 | seg_mask_confidence_map = softmax_map[:,0,:,:].unsqueeze(1) # [B, 1, 64, 64]
130 | text_confidence_map = softmax_map[:,1,:,:].unsqueeze(1) # [B, 1, 64, 64]
131 | elif ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' in self.conditions):
132 | concat_map = torch.cat([seg_mask_confidence_map, text_confidence_map, sketch_confidence_map], dim=1) # first mask, then text # [B, 3, 64, 64]
133 | softmax_map = F.softmax(input=concat_map, dim=1) # [B, 3, 64, 64]
134 | seg_mask_confidence_map = softmax_map[:,0,:,:].unsqueeze(1) # [B, 1, 64, 64]
135 | text_confidence_map = softmax_map[:,1,:,:].unsqueeze(1) # [B, 1, 64, 64]
136 | sketch_confidence_map = softmax_map[:,2,:,:].unsqueeze(1) # [B, 1, 64, 64]
137 | else:
138 | raise NotImplementedError
139 |
140 | if random.random() <= self.manual_prob:
141 |
142 | if self.seg_mask_schedule is not None:
143 | seg_mask_schedule_scale = self.set_seg_mask_schedule(t)
144 | if 'sketch' not in self.conditions:
145 | seg_mask_confidence_map = seg_mask_confidence_map * seg_mask_schedule_scale
146 | text_confidence_map = 1 - seg_mask_confidence_map
147 | else:
148 | seg_mask_confidence_map = seg_mask_confidence_map * seg_mask_schedule_scale
149 | sketch_confidence_map = sketch_confidence_map * seg_mask_schedule_scale
150 | text_confidence_map = 1 - seg_mask_confidence_map - sketch_confidence_map
151 |
152 | if self.seg_mask_scale_factor is not None:
153 | if 'sketch' not in self.conditions:
154 | seg_mask_confidence_map = seg_mask_confidence_map * self.seg_mask_scale_factor
155 | sum_map = text_confidence_map + seg_mask_confidence_map
156 | seg_mask_confidence_map = seg_mask_confidence_map / sum_map
157 | text_confidence_map = text_confidence_map / sum_map
158 | else:
159 | seg_mask_confidence_map = seg_mask_confidence_map * self.seg_mask_scale_factor
160 | sketch_confidence_map = sketch_confidence_map * self.seg_mask_scale_factor
161 | sum_map = text_confidence_map + seg_mask_confidence_map + sketch_confidence_map
162 | seg_mask_confidence_map = seg_mask_confidence_map / sum_map
163 | text_confidence_map = text_confidence_map / sum_map
164 | sketch_confidence_map = sketch_confidence_map / sum_map
165 |
166 | if self.softmax_twice:
167 | assert ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' not in self.conditions), "softmax_twice is only implemented for two-modal controls"
168 | print(f'softmax_twice self.boost_factor={self.boost_factor}')
169 | concat_map = torch.cat([seg_mask_confidence_map, text_confidence_map], dim=1) * self.boost_factor # first mask, then text # [B, 2, 64, 64]
170 | softmax_map = F.softmax(input=concat_map, dim=1) # [B, 2, 64, 64]
171 | seg_mask_confidence_map = softmax_map[:,0,:,:].unsqueeze(1) # [B, 1, 64, 64]
172 | text_confidence_map = softmax_map[:,1,:,:].unsqueeze(1) # [B, 1, 64, 64]
173 |
174 |
175 | # Compute weighted sum of all branch'es output
176 | if ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' not in self.conditions):
177 | seg_mask_weighted = seg_mask_unet_output * seg_mask_confidence_map # [B, 3, 64, 64]
178 | text_weighted = text_unet_output * text_confidence_map # [B, 3, 64, 64]
179 | output = text_weighted + seg_mask_weighted # [B, 3, 64, 64]
180 |
181 | if self.return_confidence_map:
182 | if self.return_each_branch:
183 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map, 'text_unet_output': text_unet_output, 'seg_mask_unet_output': seg_mask_unet_output}
184 | else:
185 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map}
186 | elif self.return_each_branch:
187 | return {'output': output, 'text_unet_output': text_unet_output, 'seg_mask_unet_output': seg_mask_unet_output}
188 | else:
189 | return output
190 | elif ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' in self.conditions):
191 | seg_mask_weighted = seg_mask_unet_output * seg_mask_confidence_map # [B, 3, 64, 64]
192 | text_weighted = text_unet_output * text_confidence_map # [B, 3, 64, 64]
193 | sketch_weighted = sketch_unet_output * sketch_confidence_map # [B, 3, 64, 64]
194 | output = text_weighted + seg_mask_weighted + sketch_weighted # [B, 3, 64, 64]
195 |
196 | if self.return_confidence_map:
197 | if self.return_each_branch:
198 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map, 'sketch_confidence_map': sketch_confidence_map,'seg_mask_unet_output': seg_mask_unet_output, 'text_unet_output': text_unet_output, 'sketch_unet_output': sketch_unet_output }
199 | else:
200 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map, 'sketch_confidence_map': sketch_confidence_map}
201 | elif self.return_each_branch:
202 | return {'output': output, 'text_unet_output': text_unet_output, 'seg_mask_unet_output': seg_mask_unet_output, 'sketch_unet_output': sketch_unet_output}
203 | else:
204 | return output
205 |
206 |
207 | class ComposeCondStageModel(nn.Module):
208 | """
209 | Condition Encoder of Multi-Modal Conditions
210 | """
211 |
212 | def __init__(self,conditions= ['seg_mask', 'text']):
213 | super().__init__()
214 | self.conditions = conditions
215 |
216 |
217 | def forward(self, input):
218 |
219 | composed_cond = {}
220 |
221 | if 'seg_mask' in self.conditions:
222 | seg_mask_output = self.seg_mask_cond_stage_model(input['seg_mask'])
223 | composed_cond['seg_mask'] = seg_mask_output
224 | if 'text' in self.conditions:
225 | text_output = self.text_cond_stage_model(input['text'])
226 | composed_cond['text'] = text_output
227 | if 'sketch' in self.conditions:
228 | sketch_output = self.sketch_cond_stage_model(input['sketch'])
229 | composed_cond['sketch'] = sketch_output
230 |
231 | return composed_cond
232 |
233 | def encode(self, input):
234 | return self(input)
235 |
--------------------------------------------------------------------------------