├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── lsun.py │ └── celebahq.py ├── models │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ └── compose_modules.py ├── modules │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── util.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── image_degradation │ │ ├── utils │ │ │ └── test.png │ │ └── __init__.py │ ├── ema.py │ └── attention.py ├── lr_scheduler.py └── util.py ├── assets ├── fig_teaser.jpg └── fig_framework.jpg ├── test_data ├── 256_masks │ └── 29980.png ├── 512_masks │ ├── 27007.png │ └── 29980.png └── test_mask_edit │ ├── 256_input_image │ └── 27044.jpg │ └── 256_edited_masks │ └── 27044_0_remove_smile_and_rings.png ├── freeu ├── assets │ ├── mask2face_27007.jpeg │ ├── mask2face_29980.jpeg │ ├── text2face_female.jpeg │ └── text2face_male.jpeg ├── README.md ├── text2image_freeu.py └── mask2image_freeu.py ├── .gitignore ├── setup.py ├── environment.yaml ├── LICENSE ├── configs ├── 512_vae.yaml ├── 256_vae.yaml ├── 512_text.yaml ├── 512_codiff_mask_text.yaml ├── 256_text.yaml ├── 512_mask.yaml ├── 256_codiff_mask_text.yaml └── 256_mask.yaml ├── text2image.py ├── mask2image.py ├── generate.py └── editing ├── imagic_edit_text.py ├── collaborative_edit.py └── imagic_edit_mask.py /ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /assets/fig_teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/assets/fig_teaser.jpg -------------------------------------------------------------------------------- /assets/fig_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/assets/fig_framework.jpg -------------------------------------------------------------------------------- /test_data/256_masks/29980.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/256_masks/29980.png -------------------------------------------------------------------------------- /test_data/512_masks/27007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/512_masks/27007.png -------------------------------------------------------------------------------- /test_data/512_masks/29980.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/512_masks/29980.png -------------------------------------------------------------------------------- /freeu/assets/mask2face_27007.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/mask2face_27007.jpeg -------------------------------------------------------------------------------- /freeu/assets/mask2face_29980.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/mask2face_29980.jpeg -------------------------------------------------------------------------------- /freeu/assets/text2face_female.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/text2face_female.jpeg -------------------------------------------------------------------------------- /freeu/assets/text2face_male.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/freeu/assets/text2face_male.jpeg -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /test_data/test_mask_edit/256_input_image/27044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/test_mask_edit/256_input_image/27044.jpg -------------------------------------------------------------------------------- /test_data/test_mask_edit/256_edited_masks/27044_0_remove_smile_and_rings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziqihuangg/Collaborative-Diffusion/HEAD/test_data/test_mask_edit/256_edited_masks/27044_0_remove_smile_and_rings.png -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # postscripts 2 | *.ckpt 3 | *.zip 4 | *.pyc 5 | *.pt 6 | *.pth 7 | 8 | # directories 9 | pretrained 10 | dataset 11 | outputs 12 | src 13 | trash 14 | stashed_for_git 15 | private_scripts 16 | ldm/data/trash 17 | 18 | latent_diffusion.egg-info 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: codiff 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - torch-fidelity==0.3.0 24 | - transformers==4.3.1 25 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 26 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 27 | - -e . -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /configs/512_vae.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 512 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4] 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 2 30 | wrap: True 31 | train: 32 | target: ldm.data.celebahq.CelebAConditionalDataset 33 | params: 34 | phase: train 35 | im_preprocessor_config: 36 | target: ldm.data.celebahq.DalleTransformerPreprocessor 37 | params: 38 | size: 512 39 | phase: train 40 | test_dataset_size: 3000 41 | conditions: [] 42 | validation: 43 | target: ldm.data.celebahq.CelebAConditionalDataset 44 | params: 45 | phase: test 46 | im_preprocessor_config: 47 | target: ldm.data.celebahq.DalleTransformerPreprocessor 48 | params: 49 | size: 512 50 | phase: val 51 | test_dataset_size: 3000 52 | conditions: [] 53 | 54 | lightning: 55 | callbacks: 56 | image_logger: 57 | target: main.ImageLogger 58 | params: 59 | batch_frequency: 1000 60 | max_images: 8 61 | increase_log_steps: True 62 | 63 | trainer: 64 | benchmark: True 65 | accumulate_grad_batches: 2 66 | -------------------------------------------------------------------------------- /configs/256_vae.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 8 31 | wrap: True 32 | train: 33 | target: ldm.data.celebahq.CelebAConditionalDataset 34 | params: 35 | phase: train 36 | im_preprocessor_config: 37 | target: ldm.data.celebahq.DalleTransformerPreprocessor 38 | params: 39 | size: 256 40 | phase: train 41 | test_dataset_size: 3000 42 | conditions: [] 43 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 44 | validation: 45 | target: ldm.data.celebahq.CelebAConditionalDataset 46 | params: 47 | phase: test 48 | im_preprocessor_config: 49 | target: ldm.data.celebahq.DalleTransformerPreprocessor 50 | params: 51 | size: 256 52 | phase: val 53 | test_dataset_size: 3000 54 | conditions: [] 55 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 56 | 57 | lightning: 58 | callbacks: 59 | image_logger: 60 | target: main.ImageLogger 61 | params: 62 | batch_frequency: 1000 63 | max_images: 8 64 | increase_log_steps: True 65 | 66 | trainer: 67 | benchmark: True 68 | accumulate_grad_batches: 2 69 | -------------------------------------------------------------------------------- /freeu/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Face Generation + FreeU 3 | We now integrated [FreeU](https://chenyangsi.top/FreeU/) [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ChenyangSi/FreeU) into the LDMs to further boost synthesis quality. 4 | We adapted FreeU's [code](https://github.com/ChenyangSi/FreeU) for our face generation diffusion models. 5 | For more details about FreeU, please refer to the [paper](https://arxiv.org/abs/2309.11497). 6 | 7 | ## Mask-to-Face Generation: 8 | 9 | 10 | 11 | 1. without FreeU 12 | ```bash 13 | python freeu/mask2image_freeu.py \ 14 | --mask_path "test_data/512_masks/27007.png" 15 | ``` 16 | 2. with FreeU 17 | ```bash 18 | python freeu/mask2image_freeu.py \ 19 | --mask_path "test_data/512_masks/27007.png" \ 20 | --enable_freeu \ 21 | --b1 1.1 \ 22 | --b2 1.2 \ 23 | --s1 1 \ 24 | --s2 1 25 | ``` 26 | ## Text-to-Face Generation: 27 | 28 | 29 | 30 | 31 | 1. without FreeU 32 | ```bash 33 | python freeu/text2image_freeu.py \ 34 | --input_text "This man has beard of medium length. He is in his thirties." 35 | ``` 36 | 2. with FreeU 37 | ```bash 38 | python freeu/text2image_freeu.py \ 39 | --input_text "This man has beard of medium length. He is in his thirties." \ 40 | --enable_freeu \ 41 | --b1 1.1 \ 42 | --b2 1.2 \ 43 | --s1 1 \ 44 | --s2 1 45 | ``` 46 | another example: 47 | 1. without FreeU 48 | ```bash 49 | python freeu/text2image_freeu.py \ 50 | --input_text "This woman is in her forties." 51 | ``` 52 | 2. with FreeU 53 | ```bash 54 | python freeu/text2image_freeu.py \ 55 | --input_text "This woman is in her forties." \ 56 | --enable_freeu \ 57 | --b1 1.1 \ 58 | --b2 1.2 \ 59 | --s1 1 \ 60 | --s2 1 61 | ``` 62 | -------------------------------------------------------------------------------- /configs/512_text.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.141 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 64 24 | in_channels: 3 25 | out_channels: 3 26 | model_channels: 192 27 | attention_resolutions: 28 | - 8 29 | - 4 30 | - 2 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 3 36 | - 5 37 | num_heads: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 640 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 3 48 | monitor: val/rec_loss 49 | ckpt_path: pretrained/512_vae.ckpt 50 | ddconfig: 51 | double_z: true 52 | z_channels: 3 53 | resolution: 512 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [1, 2, 4, 4] 58 | num_res_blocks: 2 59 | attn_resolutions: [ ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.BERTEmbedder 66 | params: 67 | n_embed: 640 68 | n_layer: 32 69 | 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 8 74 | wrap: True 75 | train: 76 | target: ldm.data.celebahq.CelebAConditionalDataset 77 | params: 78 | phase: train 79 | im_preprocessor_config: 80 | target: ldm.data.celebahq.DalleTransformerPreprocessor 81 | params: 82 | size: 512 83 | phase: train 84 | test_dataset_size: 3000 85 | conditions: 86 | - 'text' 87 | validation: 88 | target: ldm.data.celebahq.CelebAConditionalDataset 89 | params: 90 | phase: test 91 | im_preprocessor_config: 92 | target: ldm.data.celebahq.DalleTransformerPreprocessor 93 | params: 94 | size: 512 95 | phase: val 96 | test_dataset_size: 3000 97 | conditions: 98 | - 'text' 99 | 100 | 101 | lightning: 102 | callbacks: 103 | image_logger: 104 | target: main.ImageLogger 105 | params: 106 | batch_frequency: 1000 107 | max_images: 8 108 | increase_log_steps: False 109 | 110 | trainer: 111 | benchmark: True -------------------------------------------------------------------------------- /configs/512_codiff_mask_text.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm_compose.LatentDiffusionCompose 4 | params: 5 | 6 | linear_start: 0.0015 7 | linear_end: 0.0195 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: image 12 | cond_stage_key: conditions 13 | image_size: 64 14 | channels: 3 15 | cond_stage_trainable: True # so that don't need to change LatentDiffusion def forward's code 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple 18 | scale_factor: 0.141 19 | use_ema: False 20 | 21 | seg_mask_ldm_config_path: configs/512_mask.yaml 22 | seg_mask_ldm_ckpt_path: pretrained/512_mask.ckpt 23 | text_ldm_config_path: configs/512_text.yaml 24 | text_ldm_ckpt_path: pretrained/512_text.ckpt 25 | 26 | compose_unet_config: 27 | target: ldm.models.diffusion.compose_modules.ComposeUNet 28 | params: 29 | confidence_conditioning_key: crossattn 30 | confidence_input: x_t 31 | confidence_map_predictor_config: 32 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 33 | params: 34 | image_size: 64 35 | in_channels: 3 36 | out_channels: 1 37 | model_channels: 32 38 | attention_resolutions: 39 | - 8 40 | - 4 41 | - 2 42 | num_res_blocks: 2 43 | channel_mult: 44 | - 1 45 | - 2 46 | - 3 47 | - 5 48 | num_heads: 32 49 | use_spatial_transformer: true 50 | transformer_depth: 1 51 | context_dim: 640 52 | use_checkpoint: true 53 | legacy: False 54 | 55 | 56 | compose_cond_stage_config: 57 | target: ldm.models.diffusion.compose_modules.ComposeCondStageModel 58 | params: {} 59 | 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 8 64 | wrap: True 65 | train: 66 | target: ldm.data.celebahq.CelebAConditionalDataset 67 | params: 68 | phase: train 69 | im_preprocessor_config: 70 | target: ldm.data.celebahq.DalleTransformerPreprocessor 71 | params: 72 | size: 512 73 | phase: train 74 | test_dataset_size: 3000 75 | conditions: 76 | - 'text' 77 | - 'seg_mask' 78 | validation: 79 | target: ldm.data.celebahq.CelebAConditionalDataset 80 | params: 81 | phase: test 82 | im_preprocessor_config: 83 | target: ldm.data.celebahq.DalleTransformerPreprocessor 84 | params: 85 | size: 512 86 | phase: val 87 | test_dataset_size: 3000 88 | conditions: 89 | - 'text' 90 | - 'seg_mask' 91 | 92 | 93 | lightning: 94 | callbacks: 95 | image_logger: 96 | target: main.ImageLogger 97 | params: 98 | batch_frequency: 1000 99 | max_images: 8 100 | increase_log_steps: False 101 | 102 | trainer: 103 | benchmark: True -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /configs/256_text.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.058 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 64 24 | in_channels: 3 25 | out_channels: 3 26 | model_channels: 192 27 | attention_resolutions: 28 | - 8 29 | - 4 30 | - 2 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 3 36 | - 5 37 | num_heads: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 640 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 3 48 | monitor: val/rec_loss 49 | ckpt_path: pretrained/256_vae.ckpt 50 | ddconfig: 51 | double_z: true 52 | z_channels: 3 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [1, 2, 4] 58 | num_res_blocks: 2 59 | attn_resolutions: [ ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.BERTEmbedder 66 | params: 67 | n_embed: 640 68 | n_layer: 32 69 | 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 8 74 | wrap: True 75 | train: 76 | target: ldm.data.celebahq.CelebAConditionalDataset 77 | params: 78 | phase: train 79 | im_preprocessor_config: 80 | target: ldm.data.celebahq.DalleTransformerPreprocessor 81 | params: 82 | size: 256 83 | phase: train 84 | test_dataset_size: 3000 85 | conditions: 86 | - 'text' 87 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 88 | validation: 89 | target: ldm.data.celebahq.CelebAConditionalDataset 90 | params: 91 | phase: test 92 | im_preprocessor_config: 93 | target: ldm.data.celebahq.DalleTransformerPreprocessor 94 | params: 95 | size: 256 96 | phase: val 97 | test_dataset_size: 3000 98 | conditions: 99 | - 'text' 100 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 101 | 102 | 103 | lightning: 104 | callbacks: 105 | image_logger: 106 | target: main.ImageLogger 107 | params: 108 | batch_frequency: 1000 109 | max_images: 8 110 | increase_log_steps: False 111 | 112 | trainer: 113 | benchmark: True -------------------------------------------------------------------------------- /configs/512_mask.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: seg_mask 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple 17 | scale_factor: 0.141 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 64 24 | in_channels: 3 25 | out_channels: 3 26 | model_channels: 192 27 | attention_resolutions: 28 | - 8 29 | - 4 30 | - 2 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 3 36 | - 5 37 | num_heads: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 640 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 3 48 | monitor: val/rec_loss 49 | ckpt_path: pretrained/512_vae.ckpt 50 | ddconfig: 51 | double_z: true 52 | z_channels: 3 53 | resolution: 512 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [1, 2, 4, 4] 58 | num_res_blocks: 2 59 | attn_resolutions: [ ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.SegMaskEncoder 66 | params: 67 | seg_mask_encoder_config: 68 | target: ldm.modules.encoders.modules.PassSegMaskEncoder 69 | params: {} 70 | mask_embed_dim: 1024 71 | context_dim: 640 72 | 73 | data: 74 | target: main.DataModuleFromConfig 75 | params: 76 | batch_size: 8 77 | wrap: True 78 | train: 79 | target: ldm.data.celebahq.CelebAConditionalDataset 80 | params: 81 | phase: train 82 | im_preprocessor_config: 83 | target: ldm.data.celebahq.DalleTransformerPreprocessor 84 | params: 85 | size: 512 86 | phase: train 87 | test_dataset_size: 3000 88 | conditions: 89 | - 'seg_mask' 90 | validation: 91 | target: ldm.data.celebahq.CelebAConditionalDataset 92 | params: 93 | phase: test 94 | im_preprocessor_config: 95 | target: ldm.data.celebahq.DalleTransformerPreprocessor 96 | params: 97 | size: 512 98 | phase: val 99 | test_dataset_size: 3000 100 | conditions: 101 | - 'seg_mask' 102 | 103 | 104 | lightning: 105 | callbacks: 106 | image_logger: 107 | target: main.ImageLogger 108 | params: 109 | batch_frequency: 1000 110 | max_images: 8 111 | increase_log_steps: False 112 | 113 | trainer: 114 | benchmark: True -------------------------------------------------------------------------------- /configs/256_codiff_mask_text.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm_compose.LatentDiffusionCompose 4 | params: 5 | 6 | linear_start: 0.0015 7 | linear_end: 0.0195 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: image 12 | cond_stage_key: conditions 13 | image_size: 64 14 | channels: 3 15 | cond_stage_trainable: True # so that don't need to change LatentDiffusion def forward's code 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple 18 | scale_factor: 0.058 19 | use_ema: False 20 | 21 | seg_mask_ldm_config_path: configs/256_mask.yaml 22 | seg_mask_ldm_ckpt_path: pretrained/256_mask.ckpt 23 | text_ldm_config_path: configs/256_text.yaml 24 | text_ldm_ckpt_path: pretrained/256_text.ckpt 25 | 26 | compose_unet_config: 27 | target: ldm.models.diffusion.compose_modules.ComposeUNet 28 | params: 29 | confidence_conditioning_key: crossattn 30 | confidence_input: x_t 31 | confidence_map_predictor_config: 32 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 33 | params: 34 | image_size: 64 35 | in_channels: 3 36 | out_channels: 1 37 | model_channels: 32 38 | attention_resolutions: 39 | - 8 40 | - 4 41 | - 2 42 | num_res_blocks: 2 43 | channel_mult: 44 | - 1 45 | - 2 46 | - 3 47 | - 5 48 | num_heads: 32 49 | use_spatial_transformer: true 50 | transformer_depth: 1 51 | context_dim: 640 52 | use_checkpoint: true 53 | legacy: False 54 | 55 | 56 | compose_cond_stage_config: 57 | target: ldm.models.diffusion.compose_modules.ComposeCondStageModel 58 | params: {} 59 | 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 8 64 | wrap: True 65 | train: 66 | target: ldm.data.celebahq.CelebAConditionalDataset 67 | params: 68 | phase: train 69 | im_preprocessor_config: 70 | target: ldm.data.celebahq.DalleTransformerPreprocessor 71 | params: 72 | size: 256 73 | phase: train 74 | test_dataset_size: 3000 75 | conditions: 76 | - 'text' 77 | - 'seg_mask' 78 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 79 | validation: 80 | target: ldm.data.celebahq.CelebAConditionalDataset 81 | params: 82 | phase: test 83 | im_preprocessor_config: 84 | target: ldm.data.celebahq.DalleTransformerPreprocessor 85 | params: 86 | size: 256 87 | phase: val 88 | test_dataset_size: 3000 89 | conditions: 90 | - 'text' 91 | - 'seg_mask' 92 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 93 | 94 | lightning: 95 | callbacks: 96 | image_logger: 97 | target: main.ImageLogger 98 | params: 99 | batch_frequency: 1000 100 | max_images: 8 101 | increase_log_steps: False 102 | 103 | trainer: 104 | benchmark: True -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /configs/256_mask.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: seg_mask 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.058 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 64 24 | in_channels: 3 25 | out_channels: 3 26 | model_channels: 192 27 | attention_resolutions: 28 | - 8 29 | - 4 30 | - 2 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 3 36 | - 5 37 | num_heads: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 640 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 3 48 | monitor: val/rec_loss 49 | ckpt_path: pretrained/256_vae.ckpt 50 | ddconfig: 51 | double_z: true 52 | z_channels: 3 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [1, 2, 4] 58 | num_res_blocks: 2 59 | attn_resolutions: [ ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.SegMaskEncoder 66 | params: 67 | seg_mask_encoder_config: 68 | target: ldm.modules.encoders.modules.PassSegMaskEncoder 69 | params: {} 70 | mask_embed_dim: 1024 71 | context_dim: 640 72 | 73 | 74 | data: 75 | target: main.DataModuleFromConfig 76 | params: 77 | batch_size: 8 78 | wrap: True 79 | train: 80 | target: ldm.data.celebahq.CelebAConditionalDataset 81 | params: 82 | phase: train 83 | im_preprocessor_config: 84 | target: ldm.data.celebahq.DalleTransformerPreprocessor 85 | params: 86 | size: 256 87 | phase: train 88 | test_dataset_size: 3000 89 | conditions: 90 | - 'seg_mask' 91 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 92 | validation: 93 | target: ldm.data.celebahq.CelebAConditionalDataset 94 | params: 95 | phase: test 96 | im_preprocessor_config: 97 | target: ldm.data.celebahq.DalleTransformerPreprocessor 98 | params: 99 | size: 256 100 | phase: val 101 | test_dataset_size: 3000 102 | conditions: 103 | - 'seg_mask' 104 | image_folder: 'datasets/image_256_downsampled_from_hq_1024' 105 | 106 | lightning: 107 | callbacks: 108 | image_logger: 109 | target: main.ImageLogger 110 | params: 111 | batch_frequency: 1000 112 | max_images: 8 113 | increase_log_steps: False 114 | 115 | trainer: 116 | benchmark: True -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss # = rec_loss + p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar # = rec_loss + p_loss 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | 84 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss # # = rec_loss + p_loss + 1e-06 * kl_loss = l1_loss + p_loss + 1e-06 * kl_loss 85 | 86 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 87 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 88 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 89 | "{}/d_weight".format(split): d_weight.detach(), 90 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 91 | "{}/g_loss".format(split): g_loss.detach().mean(), 92 | } 93 | return loss, log 94 | 95 | if optimizer_idx == 1: 96 | # second pass for discriminator update 97 | if cond is None: 98 | logits_real = self.discriminator(inputs.contiguous().detach()) 99 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 100 | else: 101 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 102 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 103 | 104 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 105 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 106 | 107 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 108 | "{}/logits_real".format(split): logits_real.detach().mean(), 109 | "{}/logits_fake".format(split): logits_fake.detach().mean() 110 | } 111 | return d_loss, log 112 | 113 | -------------------------------------------------------------------------------- /text2image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from omegaconf import OmegaConf 9 | from PIL import Image 10 | from torchvision.utils import make_grid 11 | 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.util import instantiate_from_config 14 | 15 | """ 16 | Inference script for text-to-face generation at 512x512 resolution 17 | """ 18 | 19 | 20 | def parse_args(): 21 | 22 | parser = argparse.ArgumentParser(description="") 23 | 24 | # conditions 25 | parser.add_argument( 26 | "--input_text", 27 | type=str, 28 | default="This man has beard of medium length. He is in his thirties.", 29 | help="text condition") 30 | 31 | # paths 32 | parser.add_argument( 33 | "--config_path", 34 | type=str, 35 | default="configs/512_text.yaml", 36 | help="path to model config") 37 | parser.add_argument( 38 | "--ckpt_path", 39 | type=str, 40 | default="pretrained/512_text.ckpt", 41 | help="path to model checkpoint") 42 | parser.add_argument( 43 | "--save_folder", 44 | type=str, 45 | default="outputs/512_text2image", 46 | help="folder to save synthesis outputs") 47 | 48 | # batch size and ddim steps 49 | parser.add_argument( 50 | "--batch_size", 51 | type=int, 52 | default=4, 53 | help="number of images to generate") 54 | parser.add_argument( 55 | "--ddim_steps", 56 | type=int, 57 | default="50", 58 | help= 59 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)" 60 | ) 61 | 62 | # whether save intermediate outputs 63 | parser.add_argument( 64 | "--save_z", 65 | type=bool, 66 | default=False, 67 | help= 68 | "whether visualize the VAE latent codes and save them in the output folder", 69 | ) 70 | parser.add_argument( 71 | "--return_influence_function", 72 | type=bool, 73 | default=False, 74 | help= 75 | "whether visualize the Influence Functions and save them in the output folder", 76 | ) 77 | parser.add_argument( 78 | "--display_x_inter", 79 | type=bool, 80 | default=False, 81 | help= 82 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder", 83 | ) 84 | 85 | args = parser.parse_args() 86 | return args 87 | 88 | 89 | def main(): 90 | 91 | args = parse_args() 92 | 93 | # ========== set up model ========== 94 | print(f'Set up model') 95 | config = OmegaConf.load(args.config_path) 96 | model_config = config['model'] 97 | model = instantiate_from_config(model_config) 98 | model.init_from_ckpt(args.ckpt_path) 99 | model = model.cuda() 100 | model.eval() 101 | 102 | # ========== set output directory ========== 103 | os.makedirs(args.save_folder, exist_ok=True) 104 | # save a copy of this python script being used 105 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__)) 106 | 107 | print( 108 | f'================================================================================' 109 | ) 110 | print(f'text: {args.input_text}') 111 | 112 | # prepare directories 113 | save_sub_folder = os.path.join(args.save_folder, str(args.input_text)) 114 | os.makedirs(save_sub_folder, exist_ok=True) 115 | 116 | # ========== inference ========== 117 | with torch.no_grad(): 118 | 119 | # encode condition 120 | condition = [] 121 | for i in range(args.batch_size): 122 | condition.append(args.input_text.lower()) 123 | 124 | with model.ema_scope("Plotting"): 125 | 126 | # encode condition 127 | condition = model.get_learned_conditioning( 128 | condition) # [1, 77, 640] 129 | print(f'condition.shape={condition.shape}') # [B, 77, 640] 130 | 131 | # DDIM sampling 132 | ddim_sampler = DDIMSampler(model) 133 | z_0_batch, intermediates = ddim_sampler.sample( 134 | S=args.ddim_steps, 135 | batch_size=args.batch_size, 136 | shape=(3, 64, 64), 137 | conditioning=condition, 138 | verbose=False, 139 | eta=1.0, 140 | log_every_t=1) 141 | 142 | # decode VAE latent z_0 to image x_0 143 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256] 144 | 145 | # ========== save outputs ========== 146 | for idx in range(args.batch_size): 147 | 148 | # ========== save synthesized image x_0 ========== 149 | save_x_0_path = os.path.join(save_sub_folder, 150 | f'{str(idx).zfill(6)}_x_0.png') 151 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256] 152 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy() 153 | x_0 = (x_0 + 1.0) * 127.5 154 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255 155 | x_0 = x_0.astype(np.uint8) 156 | x_0 = Image.fromarray(x_0[0]) 157 | x_0.save(save_x_0_path) 158 | 159 | # save intermediate x_t and pred_x_0 160 | if args.display_x_inter: 161 | for cond_name in ['x_inter', 'pred_x0']: 162 | save_conf_path = os.path.join( 163 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png') 164 | conf = intermediates[f'{cond_name}'] 165 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64 166 | conf = conf[:, idx, :, :, :] # 50x3x64x64 167 | print('decoding x_inter ......') 168 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256] 169 | conf = make_grid( 170 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10] 171 | conf = conf.permute(1, 2, 172 | 0).to('cpu').numpy() # cxhxh -> hxhxc 173 | conf = (conf + 1.0) * 127.5 174 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 175 | conf = conf.astype(np.uint8) 176 | conf = Image.fromarray(conf) 177 | conf.save(save_conf_path) 178 | 179 | # save latent z_0 180 | if args.save_z: 181 | save_z_0_path = os.path.join(save_sub_folder, 182 | f'{str(idx).zfill(6)}_z_0.png') 183 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64] 184 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy() 185 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization 186 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255 187 | z_0 = z_0.astype(np.uint8) 188 | z_0 = Image.fromarray(z_0[0]) 189 | z_0.save(save_z_0_path) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | def instantiate_from_config_vq_diffusion(config): 88 | """the VQ-Diffusion version""" 89 | if config is None: 90 | return None 91 | if not "target" in config: 92 | raise KeyError("Expected key `target` to instantiate.") 93 | module, cls = config["target"].rsplit(".", 1) 94 | print(f'instantiate_from_config --- module: {module}, cls: {cls}') # ziqi added 95 | cls = getattr(importlib.import_module(module, package=None), cls) 96 | return cls(**config.get("params", dict())) 97 | 98 | 99 | def get_obj_from_str(string, reload=False): 100 | module, cls = string.rsplit(".", 1) 101 | if reload: 102 | module_imp = importlib.import_module(module) 103 | importlib.reload(module_imp) 104 | return getattr(importlib.import_module(module, package=None), cls) 105 | 106 | 107 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 108 | # create dummy dataset instance 109 | 110 | # run prefetching 111 | if idx_to_fn: 112 | res = func(data, worker_id=idx) 113 | else: 114 | res = func(data) 115 | Q.put([idx, res]) 116 | Q.put("Done") 117 | 118 | 119 | def parallel_data_prefetch( 120 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 121 | ): 122 | # if target_data_type not in ["ndarray", "list"]: 123 | # raise ValueError( 124 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 125 | # ) 126 | if isinstance(data, np.ndarray) and target_data_type == "list": 127 | raise ValueError("list expected but function got ndarray.") 128 | elif isinstance(data, abc.Iterable): 129 | if isinstance(data, dict): 130 | print( 131 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 132 | ) 133 | data = list(data.values()) 134 | if target_data_type == "ndarray": 135 | data = np.asarray(data) 136 | else: 137 | data = list(data) 138 | else: 139 | raise TypeError( 140 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 141 | ) 142 | 143 | if cpu_intensive: 144 | Q = mp.Queue(1000) 145 | proc = mp.Process 146 | else: 147 | Q = Queue(1000) 148 | proc = Thread 149 | # spawn processes 150 | if target_data_type == "ndarray": 151 | arguments = [ 152 | [func, Q, part, i, use_worker_id] 153 | for i, part in enumerate(np.array_split(data, n_proc)) 154 | ] 155 | else: 156 | step = ( 157 | int(len(data) / n_proc + 1) 158 | if len(data) % n_proc != 0 159 | else int(len(data) / n_proc) 160 | ) 161 | arguments = [ 162 | [func, Q, part, i, use_worker_id] 163 | for i, part in enumerate( 164 | [data[i: i + step] for i in range(0, len(data), step)] 165 | ) 166 | ] 167 | processes = [] 168 | for i in range(n_proc): 169 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 170 | processes += [p] 171 | 172 | # start processes 173 | print(f"Start prefetching...") 174 | import time 175 | 176 | start = time.time() 177 | gather_res = [[] for _ in range(n_proc)] 178 | try: 179 | for p in processes: 180 | p.start() 181 | 182 | k = 0 183 | while k < n_proc: 184 | # get result 185 | res = Q.get() 186 | if res == "Done": 187 | k += 1 188 | else: 189 | gather_res[res[0]] = res[1] 190 | 191 | except Exception as e: 192 | print("Exception: ", e) 193 | for p in processes: 194 | p.terminate() 195 | 196 | raise e 197 | finally: 198 | for p in processes: 199 | p.join() 200 | print(f"Prefetching complete. [{time.time() - start} sec.]") 201 | 202 | if target_data_type == 'ndarray': 203 | if not isinstance(gather_res[0], np.ndarray): 204 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 205 | 206 | # order outputs 207 | return np.concatenate(gather_res, axis=0) 208 | elif target_data_type == 'list': 209 | out = [] 210 | for r in gather_res: 211 | out.extend(r) 212 | return out 213 | else: 214 | return gather_res 215 | -------------------------------------------------------------------------------- /ldm/data/celebahq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | import json 7 | import random 8 | from ldm.util import instantiate_from_config_vq_diffusion 9 | import albumentations 10 | from torchvision import transforms as trans 11 | 12 | def load_img(filepath): 13 | img = Image.open(filepath).convert('RGB') 14 | return img 15 | 16 | 17 | class DalleTransformerPreprocessor(object): 18 | def __init__(self, 19 | size=256, 20 | phase='train', 21 | additional_targets=None): 22 | 23 | self.size = size 24 | self.phase = phase 25 | # ddc: following dalle to use randomcrop 26 | self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)], 27 | additional_targets=additional_targets) 28 | self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)], 29 | additional_targets=additional_targets) 30 | 31 | 32 | def __call__(self, image, **kargs): 33 | """ 34 | image: PIL.Image 35 | """ 36 | if isinstance(image, np.ndarray): 37 | image = Image.fromarray(image.astype(np.uint8)) 38 | 39 | w, h = image.size 40 | s_min = min(h, w) 41 | 42 | if self.phase == 'train': 43 | off_h = int(random.uniform(3*(h-s_min)//8, max(3*(h-s_min)//8+1, 5*(h-s_min)//8))) 44 | off_w = int(random.uniform(3*(w-s_min)//8, max(3*(w-s_min)//8+1, 5*(w-s_min)//8))) 45 | 46 | image = image.crop((off_w, off_h, off_w + s_min, off_h + s_min)) 47 | 48 | # resize image 49 | t_max = min(s_min, round(9/8*self.size)) 50 | t_max = max(t_max, self.size) 51 | t = int(random.uniform(self.size, t_max+1)) 52 | image = image.resize((t, t)) 53 | image = np.array(image).astype(np.uint8) 54 | image = self.train_preprocessor(image=image) 55 | else: 56 | if w < h: 57 | w_ = self.size 58 | h_ = int(h * w_/w) 59 | else: 60 | h_ = self.size 61 | w_ = int(w * h_/h) 62 | image = image.resize((w_, h_)) 63 | image = np.array(image).astype(np.uint8) 64 | image = self.val_preprocessor(image=image) 65 | return image 66 | 67 | 68 | 69 | class CelebAConditionalDataset(Dataset): 70 | 71 | """ 72 | This Dataset can be used for: 73 | - image-only: setting 'conditions' = [] 74 | - image and multi-modal 'conditions': setting conditions as the list of modalities you need 75 | 76 | To toggle between 256 and 512 image resolution, simply change the 'image_folder' 77 | """ 78 | 79 | def __init__(self, 80 | phase = 'train', 81 | im_preprocessor_config=None, 82 | test_dataset_size=3000, 83 | conditions = ['seg_mask', 'text', 'sketch'], 84 | image_folder = 'datasets/image/image_512_downsampled_from_hq_1024', 85 | text_file = 'datasets/text/captions_hq_beard_and_age_2022-08-19.json', 86 | mask_folder = 'datasets/mask/CelebAMask-HQ-mask-color-palette_32_nearest_downsampled_from_hq_512_one_hot_2d_tensor', 87 | sketch_folder = 'datasets/sketch/sketch_1x1024_tensor', 88 | ): 89 | 90 | self.transform = instantiate_from_config_vq_diffusion(im_preprocessor_config) 91 | self.conditions = conditions 92 | print(f'conditions = {conditions}') 93 | 94 | self.image_folder = image_folder 95 | print(f'self.image_folder = {self.image_folder}') 96 | 97 | # conditions directory 98 | self.text_file = text_file 99 | print(f'self.text_file = {self.text_file}') 100 | print(f'start loading text') 101 | with open(self.text_file, 'r') as f: 102 | self.text_file_content = json.load(f) 103 | print(f'end loading text') 104 | if 'seg_mask' in self.conditions: 105 | self.mask_folder = mask_folder 106 | print(f'self.mask_folder = {self.mask_folder}') 107 | if 'sketch' in self.conditions: 108 | self.sketch_folder = sketch_folder 109 | print(f'self.sketch_folder = {self.sketch_folder}') 110 | 111 | # list of valid image names & train test split 112 | self.image_name_list = list(self.text_file_content.keys()) 113 | 114 | # train test split 115 | if phase == 'train': 116 | self.image_name_list = self.image_name_list[:-test_dataset_size] 117 | elif phase == 'test': 118 | self.image_name_list = self.image_name_list[-test_dataset_size:] 119 | else: 120 | raise NotImplementedError 121 | self.num = len(self.image_name_list) 122 | 123 | # verbose 124 | print(f'phase = {phase}') 125 | print(f'number of samples = {self.num}') 126 | print(f'self.image_name_list[:10] = {self.image_name_list[:10]}') 127 | print(f'self.image_name_list[-10:] = {self.image_name_list[-10:]}\n') 128 | 129 | 130 | def __len__(self): 131 | return self.num 132 | 133 | def __getitem__(self, index): 134 | 135 | # ---------- (1) get image ---------- 136 | image_name = self.image_name_list[index] 137 | image_path = os.path.join(self.image_folder, image_name) 138 | image = load_img(image_path) 139 | image = np.array(image).astype(np.uint8) 140 | image = self.transform(image = image)['image'] 141 | image = image.astype(np.float32)/127.5 - 1.0 142 | 143 | # record into data entry 144 | if len(self.conditions) == 1: 145 | data = { 146 | 'image': image, 147 | } 148 | else: 149 | data = { 150 | 'image': image, 151 | 'conditions': {} 152 | } 153 | 154 | # ---------- (2) get text ---------- 155 | if 'text' in self.conditions: 156 | text = self.text_file_content[image_name]["Beard_and_Age"].lower() 157 | # record into data entry 158 | if len(self.conditions) == 1: 159 | data['caption'] = text 160 | else: 161 | data['conditions']['text'] = text 162 | 163 | # ---------- (3) get mask ---------- 164 | if 'seg_mask' in self.conditions: 165 | mask_idx = image_name.split('.')[0] 166 | mask_name = f'{mask_idx}.pt' 167 | mask_path = os.path.join(self.mask_folder, mask_name) 168 | mask_one_hot_tensor = torch.load(mask_path) 169 | 170 | # record into data entry 171 | if len(self.conditions) == 1: 172 | data['seg_mask'] = mask_one_hot_tensor 173 | else: 174 | data['conditions']['seg_mask'] = mask_one_hot_tensor 175 | 176 | 177 | # ---------- (4) get sketch ---------- 178 | if 'sketch' in self.conditions: 179 | sketch_idx = image_name.split('.')[0] 180 | sketch_name = f'{sketch_idx}.pt' 181 | sketch_path = os.path.join(self.sketch_folder, sketch_name) 182 | sketch_one_hot_tensor = torch.load(sketch_path) 183 | 184 | # record into data entry 185 | if len(self.conditions) == 1: 186 | data['sketch'] = sketch_one_hot_tensor 187 | else: 188 | data['conditions']['sketch'] = sketch_one_hot_tensor 189 | 190 | 191 | return data 192 | -------------------------------------------------------------------------------- /mask2image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | from torchvision.utils import make_grid 12 | 13 | from ldm.models.diffusion.ddim import DDIMSampler 14 | from ldm.util import instantiate_from_config 15 | 16 | 17 | def parse_args(): 18 | 19 | parser = argparse.ArgumentParser(description="") 20 | 21 | # conditions 22 | parser.add_argument( 23 | "--mask_path", 24 | type=str, 25 | default="test_data/512_masks/27007.png", 26 | help="path to the segmentation mask") 27 | 28 | # paths 29 | parser.add_argument( 30 | "--config_path", 31 | type=str, 32 | default="configs/512_mask.yaml", 33 | help="path to model config") 34 | parser.add_argument( 35 | "--ckpt_path", 36 | type=str, 37 | default="pretrained/512_mask.ckpt", 38 | help="path to model checkpoint") 39 | parser.add_argument( 40 | "--save_folder", 41 | type=str, 42 | default="outputs/512_mask2image", 43 | help="folder to save synthesis outputs") 44 | 45 | # batch size and ddim steps 46 | parser.add_argument( 47 | "--batch_size", 48 | type=int, 49 | default=4, 50 | help="number of images to generate") 51 | parser.add_argument( 52 | "--ddim_steps", 53 | type=int, 54 | default="50", 55 | help= 56 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)" 57 | ) 58 | 59 | # whether save intermediate outputs 60 | parser.add_argument( 61 | "--save_z", 62 | type=bool, 63 | default=False, 64 | help= 65 | "whether visualize the VAE latent codes and save them in the output folder", 66 | ) 67 | parser.add_argument( 68 | "--return_influence_function", 69 | type=bool, 70 | default=False, 71 | help= 72 | "whether visualize the Influence Functions and save them in the output folder", 73 | ) 74 | parser.add_argument( 75 | "--display_x_inter", 76 | type=bool, 77 | default=False, 78 | help= 79 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder", 80 | ) 81 | parser.add_argument( 82 | "--save_mixed", 83 | type=bool, 84 | default=False, 85 | help= 86 | "whether overlay the segmentation mask on the synthesized image to visualize mask consistency", 87 | ) 88 | 89 | args = parser.parse_args() 90 | return args 91 | 92 | 93 | def main(): 94 | 95 | args = parse_args() 96 | 97 | # ========== set up model ========== 98 | print(f'Set up model') 99 | config = OmegaConf.load(args.config_path) 100 | model_config = config['model'] 101 | model = instantiate_from_config(model_config) 102 | model.init_from_ckpt(args.ckpt_path) 103 | model = model.cuda() 104 | model.eval() 105 | 106 | # ========== set output directory ========== 107 | os.makedirs(args.save_folder, exist_ok=True) 108 | # save a copy of this python script being used 109 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__)) 110 | 111 | # ========== prepare seg mask for model ========== 112 | with open(args.mask_path, 'rb') as f: 113 | img = Image.open(f) 114 | resized_img = img.resize((32, 32), Image.NEAREST) # resize 115 | flattened_img = list(resized_img.getdata()) 116 | flattened_img_tensor = torch.tensor(flattened_img) # flatten 117 | flattened_img_tensor_one_hot = F.one_hot( 118 | flattened_img_tensor, num_classes=19) # one hot 119 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose( 120 | 0, 1) 121 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze( 122 | flattened_img_tensor_one_hot_transpose, 123 | 0).cuda() # add batch dimension 124 | 125 | # ========== prepare mask for visualization ========== 126 | mask = Image.open(args.mask_path) 127 | mask = mask.convert('RGB') 128 | mask = np.array(mask).astype(np.uint8) # three channel integer 129 | input_mask = mask 130 | 131 | print( 132 | f'================================================================================' 133 | ) 134 | print(f'mask_path: {args.mask_path}') 135 | 136 | # prepare directories 137 | mask_name = args.mask_path.split('/')[-1] 138 | save_sub_folder = os.path.join(args.save_folder, mask_name) 139 | os.makedirs(save_sub_folder, exist_ok=True) 140 | 141 | # save seg_mask 142 | save_path_mask = os.path.join(save_sub_folder, mask_name) 143 | mask_ = Image.fromarray(input_mask) 144 | mask_.save(save_path_mask) 145 | 146 | # ========== inference ========== 147 | with torch.no_grad(): 148 | 149 | condition = flattened_img_tensor_one_hot_transpose 150 | 151 | with model.ema_scope("Plotting"): 152 | 153 | # encode condition 154 | condition = model.get_learned_conditioning( 155 | condition) # [1, 96, 640] 156 | condition = condition.repeat(args.batch_size, 1, 1) # [B, 96, 640] 157 | 158 | ddim_sampler = DDIMSampler(model) 159 | z_0_batch, intermediates = ddim_sampler.sample( 160 | S=args.ddim_steps, 161 | batch_size=args.batch_size, 162 | shape=(3, 64, 64), 163 | conditioning=condition, 164 | verbose=False, 165 | eta=1.0, 166 | log_every_t=1) 167 | 168 | # decode latent z_0 to image x_0 169 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256] 170 | 171 | for idx in range(args.batch_size): 172 | 173 | # ========== save synthesized image x_0 ========== 174 | save_x_0_path = os.path.join(save_sub_folder, 175 | f'{str(idx).zfill(6)}_x_0.png') 176 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256] 177 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy() 178 | x_0 = (x_0 + 1.0) * 127.5 179 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255 180 | x_0 = x_0.astype(np.uint8) 181 | x_0 = Image.fromarray(x_0[0]) 182 | x_0.save(save_x_0_path) 183 | 184 | # save intermediate x_t and pred_x_0 185 | if args.display_x_inter: 186 | for cond_name in ['x_inter', 'pred_x0']: 187 | save_conf_path = os.path.join( 188 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png') 189 | conf = intermediates[f'{cond_name}'] 190 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64 191 | conf = conf[:, idx, :, :, :] # 50x3x64x64 192 | print('decoding x_inter ......') 193 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256] 194 | conf = make_grid( 195 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10] 196 | conf = conf.permute(1, 2, 197 | 0).to('cpu').numpy() # cxhxh -> hxhxc 198 | conf = (conf + 1.0) * 127.5 199 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 200 | conf = conf.astype(np.uint8) 201 | conf = Image.fromarray(conf) 202 | conf.save(save_conf_path) 203 | 204 | # save latent z_0 205 | if args.save_z: 206 | save_z_0_path = os.path.join(save_sub_folder, 207 | f'{str(idx).zfill(6)}_z_0.png') 208 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64] 209 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy() 210 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization 211 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255 212 | z_0 = z_0.astype(np.uint8) 213 | z_0 = Image.fromarray(z_0[0]) 214 | z_0.save(save_z_0_path) 215 | 216 | # overlay the segmentation mask on the synthesized image to visualize mask consistency 217 | save_mixed_path = os.path.join(save_sub_folder, 218 | f'{str(idx).zfill(6)}_mixed.png') 219 | Image.blend(x_0, mask_, 0.3).save(save_mixed_path) 220 | 221 | 222 | if __name__ == "__main__": 223 | main() -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /freeu/text2image_freeu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from omegaconf import OmegaConf, open_dict 9 | from PIL import Image 10 | from torchvision.utils import make_grid 11 | 12 | from ldm.models.diffusion.ddim import DDIMSampler 13 | from ldm.util import instantiate_from_config 14 | 15 | """ 16 | Inference script for text-to-face generation at 512x512 resolution 17 | """ 18 | 19 | 20 | def parse_args(): 21 | 22 | parser = argparse.ArgumentParser(description="") 23 | 24 | # conditions 25 | parser.add_argument( 26 | "--input_text", 27 | type=str, 28 | default="This man has beard of medium length. He is in his thirties.", 29 | help="text condition") 30 | 31 | # paths 32 | parser.add_argument( 33 | "--config_path", 34 | type=str, 35 | default="configs/512_text.yaml", 36 | help="path to model config") 37 | parser.add_argument( 38 | "--ckpt_path", 39 | type=str, 40 | default="pretrained/512_text.ckpt", 41 | help="path to model checkpoint") 42 | parser.add_argument( 43 | "--save_folder", 44 | type=str, 45 | default="outputs/512_text2image", 46 | help="folder to save synthesis outputs") 47 | 48 | # batch size and ddim steps 49 | parser.add_argument( 50 | "--batch_size", 51 | type=int, 52 | default=4, 53 | help="number of images to generate") 54 | parser.add_argument( 55 | "--ddim_steps", 56 | type=int, 57 | default="50", 58 | help= 59 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)" 60 | ) 61 | 62 | # whether save intermediate outputs 63 | parser.add_argument( 64 | "--save_z", 65 | type=bool, 66 | default=False, 67 | help= 68 | "whether visualize the VAE latent codes and save them in the output folder", 69 | ) 70 | parser.add_argument( 71 | "--return_influence_function", 72 | type=bool, 73 | default=False, 74 | help= 75 | "whether visualize the Influence Functions and save them in the output folder", 76 | ) 77 | parser.add_argument( 78 | "--display_x_inter", 79 | type=bool, 80 | default=False, 81 | help= 82 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder", 83 | ) 84 | 85 | # FreeU Config 86 | parser.add_argument( 87 | "--seed", 88 | type=int, 89 | default=2, 90 | help= 91 | "fix random seed to compare with and without FreeU", 92 | ) 93 | parser.add_argument( 94 | "--enable_freeu", 95 | action='store_true', 96 | help= 97 | "whether enable FreeU", 98 | ) 99 | parser.add_argument( 100 | "--b1", 101 | type=float, 102 | default=1.1, 103 | help= 104 | "parameter of FreeU", 105 | ) 106 | parser.add_argument( 107 | "--b2", 108 | type=float, 109 | default=1.2, 110 | help= 111 | "parameter of FreeU", 112 | ) 113 | parser.add_argument( 114 | "--s1", 115 | type=float, 116 | default=1, 117 | help= 118 | "parameter of FreeU", 119 | ) 120 | parser.add_argument( 121 | "--s2", 122 | type=float, 123 | default=1, 124 | help= 125 | "parameter of FreeU", 126 | ) 127 | 128 | args = parser.parse_args() 129 | return args 130 | 131 | 132 | def main(): 133 | 134 | args = parse_args() 135 | torch.manual_seed(args.seed) 136 | 137 | # ========== set up model ========== 138 | print(f'Set up model') 139 | config = OmegaConf.load(args.config_path) 140 | model_config = config['model'] 141 | # ---------- FreeU code starts ---------- 142 | print('Setting up FreeU') 143 | OmegaConf.set_struct(model_config, True) 144 | with open_dict(model_config): 145 | model_config.params.unet_config.params.enable_freeu = args.enable_freeu 146 | model_config.params.unet_config.params.b1 = args.b1 147 | model_config.params.unet_config.params.b2 = args.b2 148 | model_config.params.unet_config.params.s1 = args.s1 149 | model_config.params.unet_config.params.s2 = args.s2 150 | if args.enable_freeu: 151 | args.save_folder = f'{args.save_folder}_with_freeu' 152 | else: 153 | args.save_folder = f'{args.save_folder}_without_freeu' 154 | print(f'args.enable_freeu = {args.enable_freeu}') 155 | print(f'args.save_folder = {args.save_folder}') 156 | # ---------- FreeU code ends ---------- 157 | model = instantiate_from_config(model_config) 158 | model.init_from_ckpt(args.ckpt_path) 159 | model = model.cuda() 160 | model.eval() 161 | 162 | # ========== set output directory ========== 163 | os.makedirs(args.save_folder, exist_ok=True) 164 | # save a copy of this python script being used 165 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__)) 166 | 167 | print( 168 | f'================================================================================' 169 | ) 170 | print(f'text: {args.input_text}') 171 | 172 | # prepare directories 173 | save_sub_folder = os.path.join(args.save_folder, str(args.input_text)) 174 | os.makedirs(save_sub_folder, exist_ok=True) 175 | 176 | # ========== inference ========== 177 | with torch.no_grad(): 178 | 179 | # encode condition 180 | condition = [] 181 | for i in range(args.batch_size): 182 | condition.append(args.input_text.lower()) 183 | 184 | with model.ema_scope("Plotting"): 185 | 186 | # encode condition 187 | condition = model.get_learned_conditioning( 188 | condition) # [1, 77, 640] 189 | print(f'condition.shape={condition.shape}') # [B, 77, 640] 190 | 191 | # DDIM sampling 192 | ddim_sampler = DDIMSampler(model) 193 | z_0_batch, intermediates = ddim_sampler.sample( 194 | S=args.ddim_steps, 195 | batch_size=args.batch_size, 196 | shape=(3, 64, 64), 197 | conditioning=condition, 198 | verbose=False, 199 | eta=1.0, 200 | log_every_t=1) 201 | 202 | # decode VAE latent z_0 to image x_0 203 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256] 204 | 205 | # ========== save outputs ========== 206 | for idx in range(args.batch_size): 207 | 208 | # ========== save synthesized image x_0 ========== 209 | save_x_0_path = os.path.join(save_sub_folder, 210 | f'{str(idx).zfill(6)}_x_0.png') 211 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256] 212 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy() 213 | x_0 = (x_0 + 1.0) * 127.5 214 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255 215 | x_0 = x_0.astype(np.uint8) 216 | x_0 = Image.fromarray(x_0[0]) 217 | x_0.save(save_x_0_path) 218 | 219 | # save intermediate x_t and pred_x_0 220 | if args.display_x_inter: 221 | for cond_name in ['x_inter', 'pred_x0']: 222 | save_conf_path = os.path.join( 223 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png') 224 | conf = intermediates[f'{cond_name}'] 225 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64 226 | conf = conf[:, idx, :, :, :] # 50x3x64x64 227 | print('decoding x_inter ......') 228 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256] 229 | conf = make_grid( 230 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10] 231 | conf = conf.permute(1, 2, 232 | 0).to('cpu').numpy() # cxhxh -> hxhxc 233 | conf = (conf + 1.0) * 127.5 234 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 235 | conf = conf.astype(np.uint8) 236 | conf = Image.fromarray(conf) 237 | conf.save(save_conf_path) 238 | 239 | # save latent z_0 240 | if args.save_z: 241 | save_z_0_path = os.path.join(save_sub_folder, 242 | f'{str(idx).zfill(6)}_z_0.png') 243 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64] 244 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy() 245 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization 246 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255 247 | z_0 = z_0.astype(np.uint8) 248 | z_0 = Image.fromarray(z_0[0]) 249 | z_0.save(save_z_0_path) 250 | 251 | 252 | if __name__ == "__main__": 253 | main() -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /freeu/mask2image_freeu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision 9 | from omegaconf import OmegaConf, open_dict 10 | from PIL import Image 11 | from torchvision.utils import make_grid 12 | 13 | from ldm.models.diffusion.ddim import DDIMSampler 14 | from ldm.util import instantiate_from_config 15 | 16 | 17 | def parse_args(): 18 | 19 | parser = argparse.ArgumentParser(description="") 20 | 21 | # conditions 22 | parser.add_argument( 23 | "--mask_path", 24 | type=str, 25 | default="test_data/512_masks/27007.png", 26 | help="path to the segmentation mask") 27 | 28 | # paths 29 | parser.add_argument( 30 | "--config_path", 31 | type=str, 32 | default="configs/512_mask.yaml", 33 | help="path to model config") 34 | parser.add_argument( 35 | "--ckpt_path", 36 | type=str, 37 | default="pretrained/512_mask.ckpt", 38 | help="path to model checkpoint") 39 | parser.add_argument( 40 | "--save_folder", 41 | type=str, 42 | default="outputs/512_mask2image", 43 | help="folder to save synthesis outputs") 44 | 45 | # batch size and ddim steps 46 | parser.add_argument( 47 | "--batch_size", 48 | type=int, 49 | default=4, 50 | help="number of images to generate") 51 | parser.add_argument( 52 | "--ddim_steps", 53 | type=int, 54 | default="50", 55 | help= 56 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)" 57 | ) 58 | 59 | # whether save intermediate outputs 60 | parser.add_argument( 61 | "--save_z", 62 | type=bool, 63 | default=False, 64 | help= 65 | "whether visualize the VAE latent codes and save them in the output folder", 66 | ) 67 | parser.add_argument( 68 | "--return_influence_function", 69 | type=bool, 70 | default=False, 71 | help= 72 | "whether visualize the Influence Functions and save them in the output folder", 73 | ) 74 | parser.add_argument( 75 | "--display_x_inter", 76 | type=bool, 77 | default=False, 78 | help= 79 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder", 80 | ) 81 | parser.add_argument( 82 | "--save_mixed", 83 | type=bool, 84 | default=False, 85 | help= 86 | "whether overlay the segmentation mask on the synthesized image to visualize mask consistency", 87 | ) 88 | 89 | # FreeU Config 90 | parser.add_argument( 91 | "--seed", 92 | type=int, 93 | default=2, 94 | help= 95 | "fix random seed to compare with and without FreeU", 96 | ) 97 | parser.add_argument( 98 | "--enable_freeu", 99 | action='store_true', 100 | help= 101 | "whether enable FreeU", 102 | ) 103 | parser.add_argument( 104 | "--b1", 105 | type=float, 106 | default=1.1, 107 | help= 108 | "parameter of FreeU", 109 | ) 110 | parser.add_argument( 111 | "--b2", 112 | type=float, 113 | default=1.2, 114 | help= 115 | "parameter of FreeU", 116 | ) 117 | parser.add_argument( 118 | "--s1", 119 | type=float, 120 | default=1, 121 | help= 122 | "parameter of FreeU", 123 | ) 124 | parser.add_argument( 125 | "--s2", 126 | type=float, 127 | default=1, 128 | help= 129 | "parameter of FreeU", 130 | ) 131 | 132 | args = parser.parse_args() 133 | return args 134 | 135 | 136 | def main(): 137 | 138 | args = parse_args() 139 | torch.manual_seed(args.seed) 140 | 141 | # ========== set up model ========== 142 | print(f'Set up model') 143 | config = OmegaConf.load(args.config_path) 144 | model_config = config['model'] 145 | # ---------- FreeU code starts ---------- 146 | print('Setting up FreeU') 147 | OmegaConf.set_struct(model_config, True) 148 | with open_dict(model_config): 149 | model_config.params.unet_config.params.enable_freeu = args.enable_freeu 150 | model_config.params.unet_config.params.b1 = args.b1 151 | model_config.params.unet_config.params.b2 = args.b2 152 | model_config.params.unet_config.params.s1 = args.s1 153 | model_config.params.unet_config.params.s2 = args.s2 154 | if args.enable_freeu: 155 | args.save_folder = f'{args.save_folder}_with_freeu' 156 | else: 157 | args.save_folder = f'{args.save_folder}_without_freeu' 158 | print(f'args.enable_freeu = {args.enable_freeu}') 159 | print(f'args.save_folder = {args.save_folder}') 160 | # ---------- FreeU code ends ---------- 161 | model = instantiate_from_config(model_config) 162 | model.init_from_ckpt(args.ckpt_path) 163 | model = model.cuda() 164 | model.eval() 165 | 166 | # ========== set output directory ========== 167 | os.makedirs(args.save_folder, exist_ok=True) 168 | # save a copy of this python script being used 169 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__)) 170 | 171 | # ========== prepare seg mask for model ========== 172 | with open(args.mask_path, 'rb') as f: 173 | img = Image.open(f) 174 | resized_img = img.resize((32, 32), Image.NEAREST) # resize 175 | flattened_img = list(resized_img.getdata()) 176 | flattened_img_tensor = torch.tensor(flattened_img) # flatten 177 | flattened_img_tensor_one_hot = F.one_hot( 178 | flattened_img_tensor, num_classes=19) # one hot 179 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose( 180 | 0, 1) 181 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze( 182 | flattened_img_tensor_one_hot_transpose, 183 | 0).cuda() # add batch dimension 184 | 185 | # ========== prepare mask for visualization ========== 186 | mask = Image.open(args.mask_path) 187 | mask = mask.convert('RGB') 188 | mask = np.array(mask).astype(np.uint8) # three channel integer 189 | input_mask = mask 190 | 191 | print( 192 | f'================================================================================' 193 | ) 194 | print(f'mask_path: {args.mask_path}') 195 | 196 | # prepare directories 197 | mask_name = args.mask_path.split('/')[-1] 198 | save_sub_folder = os.path.join(args.save_folder, mask_name) 199 | os.makedirs(save_sub_folder, exist_ok=True) 200 | 201 | # save seg_mask 202 | save_path_mask = os.path.join(save_sub_folder, mask_name) 203 | mask_ = Image.fromarray(input_mask) 204 | mask_.save(save_path_mask) 205 | 206 | # ========== inference ========== 207 | with torch.no_grad(): 208 | 209 | condition = flattened_img_tensor_one_hot_transpose 210 | 211 | with model.ema_scope("Plotting"): 212 | 213 | # encode condition 214 | condition = model.get_learned_conditioning( 215 | condition) # [1, 96, 640] 216 | condition = condition.repeat(args.batch_size, 1, 1) # [B, 96, 640] 217 | 218 | ddim_sampler = DDIMSampler(model) 219 | z_0_batch, intermediates = ddim_sampler.sample( 220 | S=args.ddim_steps, 221 | batch_size=args.batch_size, 222 | shape=(3, 64, 64), 223 | conditioning=condition, 224 | verbose=False, 225 | eta=1.0, 226 | log_every_t=1) 227 | 228 | # decode latent z_0 to image x_0 229 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256] 230 | 231 | for idx in range(args.batch_size): 232 | 233 | # ========== save synthesized image x_0 ========== 234 | save_x_0_path = os.path.join(save_sub_folder, 235 | f'{str(idx).zfill(6)}_x_0.png') 236 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256] 237 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy() 238 | x_0 = (x_0 + 1.0) * 127.5 239 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255 240 | x_0 = x_0.astype(np.uint8) 241 | x_0 = Image.fromarray(x_0[0]) 242 | x_0.save(save_x_0_path) 243 | 244 | # save intermediate x_t and pred_x_0 245 | if args.display_x_inter: 246 | for cond_name in ['x_inter', 'pred_x0']: 247 | save_conf_path = os.path.join( 248 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png') 249 | conf = intermediates[f'{cond_name}'] 250 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64 251 | conf = conf[:, idx, :, :, :] # 50x3x64x64 252 | print('decoding x_inter ......') 253 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256] 254 | conf = make_grid( 255 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10] 256 | conf = conf.permute(1, 2, 257 | 0).to('cpu').numpy() # cxhxh -> hxhxc 258 | conf = (conf + 1.0) * 127.5 259 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 260 | conf = conf.astype(np.uint8) 261 | conf = Image.fromarray(conf) 262 | conf.save(save_conf_path) 263 | 264 | # save latent z_0 265 | if args.save_z: 266 | save_z_0_path = os.path.join(save_sub_folder, 267 | f'{str(idx).zfill(6)}_z_0.png') 268 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64] 269 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy() 270 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization 271 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255 272 | z_0 = z_0.astype(np.uint8) 273 | z_0 = Image.fromarray(z_0[0]) 274 | z_0.save(save_z_0_path) 275 | 276 | # overlay the segmentation mask on the synthesized image to visualize mask consistency 277 | save_mixed_path = os.path.join(save_sub_folder, 278 | f'{str(idx).zfill(6)}_mixed.png') 279 | Image.blend(x_0, mask_, 0.3).save(save_mixed_path) 280 | 281 | 282 | if __name__ == "__main__": 283 | main() -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | from ldm.util import instantiate_from_config_vq_diffusion 11 | 12 | 13 | class AbstractEncoder(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def encode(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | 21 | 22 | class ClassEmbedder(nn.Module): 23 | def __init__(self, embed_dim, n_classes=1000, key='class'): 24 | super().__init__() 25 | self.key = key 26 | self.embedding = nn.Embedding(n_classes, embed_dim) 27 | 28 | def forward(self, batch, key=None): 29 | if key is None: 30 | key = self.key 31 | # this is for use in crossattn 32 | c = batch[key][:, None] 33 | c = self.embedding(c) 34 | return c 35 | 36 | 37 | class TransformerEmbedder(AbstractEncoder): 38 | """Some transformer encoder layers""" 39 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 40 | super().__init__() 41 | self.device = device 42 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 43 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 44 | 45 | def forward(self, tokens): 46 | tokens = tokens.to(self.device) # meh 47 | z = self.transformer(tokens, return_embeddings=True) 48 | return z 49 | 50 | def encode(self, x): 51 | return self(x) 52 | 53 | 54 | class BERTTokenizer(AbstractEncoder): 55 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 56 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 57 | super().__init__() 58 | from transformers import BertTokenizerFast # TODO: add to reuquirements 59 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 60 | self.device = device 61 | self.vq_interface = vq_interface 62 | self.max_length = max_length 63 | 64 | def forward(self, text): 65 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 66 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 67 | tokens = batch_encoding["input_ids"].to(self.device) 68 | return tokens 69 | 70 | @torch.no_grad() 71 | def encode(self, text): 72 | tokens = self(text) 73 | if not self.vq_interface: 74 | return tokens 75 | return None, None, [None, None, tokens] 76 | 77 | def decode(self, text): 78 | return text 79 | 80 | 81 | class BERTEmbedder(AbstractEncoder): 82 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 83 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 84 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 85 | super().__init__() 86 | self.use_tknz_fn = use_tokenizer 87 | if self.use_tknz_fn: 88 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 89 | self.device = device 90 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 91 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 92 | emb_dropout=embedding_dropout) 93 | 94 | def forward(self, text): 95 | if self.use_tknz_fn: 96 | tokens = self.tknz_fn(text)#.to(self.device) 97 | else: 98 | tokens = text 99 | z = self.transformer(tokens, return_embeddings=True) 100 | return z 101 | 102 | def encode(self, text): 103 | # output of length 77 104 | return self(text) # [batch_size, 77, BERTEmbedder.n_embed] # exp0023: [B, 77, 640] 105 | 106 | 107 | class SpatialRescaler(nn.Module): 108 | def __init__(self, 109 | n_stages=1, 110 | method='bilinear', 111 | multiplier=0.5, 112 | in_channels=3, 113 | out_channels=None, 114 | bias=False): 115 | super().__init__() 116 | self.n_stages = n_stages 117 | assert self.n_stages >= 0 118 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 119 | self.multiplier = multiplier 120 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 121 | self.remap_output = out_channels is not None 122 | if self.remap_output: 123 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 124 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 125 | 126 | def forward(self,x): 127 | for stage in range(self.n_stages): 128 | x = self.interpolator(x, scale_factor=self.multiplier) 129 | 130 | 131 | if self.remap_output: 132 | x = self.channel_mapper(x) 133 | return x 134 | 135 | def encode(self, x): 136 | return self(x) 137 | 138 | 139 | class FrozenCLIPTextEmbedder(nn.Module): 140 | """ 141 | Uses the CLIP transformer encoder for text. 142 | """ 143 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 144 | super().__init__() 145 | self.model, _ = clip.load(version, jit=False, device="cpu") 146 | self.device = device 147 | self.max_length = max_length 148 | self.n_repeat = n_repeat 149 | self.normalize = normalize 150 | 151 | def freeze(self): 152 | self.model = self.model.eval() 153 | for param in self.parameters(): 154 | param.requires_grad = False 155 | 156 | def forward(self, text): 157 | tokens = clip.tokenize(text).to(self.device) 158 | z = self.model.encode_text(tokens) 159 | if self.normalize: 160 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 161 | return z 162 | 163 | def encode(self, text): 164 | z = self(text) 165 | if z.ndim==2: 166 | z = z[:, None, :] 167 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 168 | return z 169 | 170 | 171 | class FrozenClipImageEmbedder(nn.Module): 172 | """ 173 | Uses the CLIP image encoder. 174 | """ 175 | def __init__( 176 | self, 177 | model, 178 | jit=False, 179 | device='cuda' if torch.cuda.is_available() else 'cpu', 180 | antialias=False, 181 | ): 182 | super().__init__() 183 | self.model, _ = clip.load(name=model, device=device, jit=jit) 184 | 185 | self.antialias = antialias 186 | 187 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 188 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 189 | 190 | def preprocess(self, x): 191 | # normalize to [0,1] 192 | x = kornia.geometry.resize(x, (224, 224), 193 | interpolation='bicubic',align_corners=True, 194 | antialias=self.antialias) 195 | x = (x + 1.) / 2. 196 | # renormalize according to clip 197 | x = kornia.enhance.normalize(x, self.mean, self.std) 198 | return x 199 | 200 | def forward(self, x): 201 | # x is assumed to be in range [-1,1] 202 | return self.model.encode_image(self.preprocess(x)) 203 | 204 | 205 | class PassSegMaskEncoder(nn.Module): 206 | 207 | 208 | def __init__(self): 209 | super().__init__() 210 | 211 | 212 | def forward(self, x): 213 | """convert :torch.cuda.LongTensor to torch.cuda.FloatTensor""" 214 | return x.to(torch.float) 215 | 216 | def encode(self, input): 217 | return self(input) 218 | 219 | 220 | class SegMaskAndTextEncoder(nn.Module): 221 | 222 | def __init__(self, seg_mask_encoder_config, text_encoder_config, mask_embed_dim=1024, text_embed_dim=640): 223 | super().__init__() 224 | 225 | self.seg_mask_encoder = instantiate_from_config_vq_diffusion(seg_mask_encoder_config) 226 | self.text_encoder = instantiate_from_config_vq_diffusion(text_encoder_config) 227 | 228 | self.mask_embed_dim = mask_embed_dim 229 | self.text_embed_dim = text_embed_dim 230 | 231 | self.linear = nn.Linear( 232 | in_features=self.mask_embed_dim, # 1024 233 | out_features=self.text_embed_dim, # 640 234 | bias=True) 235 | 236 | 237 | def forward(self, input): 238 | 239 | seg_mask = self.seg_mask_encoder(input['seg_mask']) # [B, 19, 1024] 240 | seg_mask = self.linear(seg_mask) # [B, 19, 640] 241 | 242 | text = self.text_encoder(input['text']) # [B, 77, 640] 243 | 244 | seg_mask_and_text = torch.cat([seg_mask, text], 1) # Bx(19+77)x640 245 | 246 | return seg_mask_and_text 247 | 248 | def encode(self, input): 249 | return self(input) 250 | 251 | 252 | 253 | class SegMaskEncoder(nn.Module): 254 | 255 | def __init__(self, seg_mask_encoder_config, mask_embed_dim=1024, context_dim=640): 256 | super().__init__() 257 | 258 | self.seg_mask_encoder = instantiate_from_config_vq_diffusion(seg_mask_encoder_config) 259 | 260 | self.mask_embed_dim = mask_embed_dim 261 | self.context_dim = context_dim 262 | 263 | self.linear = nn.Linear( 264 | in_features=self.mask_embed_dim, # 1024 265 | out_features=self.context_dim, # 640 266 | bias=True) 267 | 268 | 269 | def forward(self, input): 270 | 271 | 272 | if len(input.shape) == 4 and input.shape[1] == 1: 273 | # Bx1x19x1024 -> Bx19x1024 274 | # the '1' dimension comes from inner_function() in class DDPM 275 | input = input.view(input.shape[0], input.shape[2], input.shape[3]) 276 | assert len(input.shape) == 3 # [B, 19, 1024] 277 | seg_mask = self.seg_mask_encoder(input) # [B, 19, 1024] 278 | 279 | # match mask dimension to the desired condition dimension 280 | seg_mask = self.linear(seg_mask) # [B, 19, 640] 281 | 282 | return seg_mask # [B, 19, 640] 283 | 284 | def encode(self, input): 285 | return self(input) 286 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | from torchvision.utils import make_grid 13 | 14 | from ldm.models.diffusion.ddim_confidence import DDIMConfidenceSampler 15 | from ldm.util import instantiate_from_config 16 | """ 17 | Inference script for multi-modal-driven face generation at 512x512 resolution 18 | """ 19 | 20 | 21 | def parse_args(): 22 | 23 | parser = argparse.ArgumentParser(description="") 24 | 25 | # multi-modal conditions 26 | parser.add_argument( 27 | "--mask_path", 28 | type=str, 29 | default="test_data/512_masks/27007.png", 30 | help="path to the segmentation mask") 31 | parser.add_argument( 32 | "--input_text", 33 | type=str, 34 | default="This man has beard of medium length. He is in his thirties.", 35 | help="text condition") 36 | 37 | # paths 38 | parser.add_argument( 39 | "--config_path", 40 | type=str, 41 | default="configs/512_codiff_mask_text.yaml", 42 | help="path to model config") 43 | parser.add_argument( 44 | "--ckpt_path", 45 | type=str, 46 | default="pretrained/512_codiff_mask_text.ckpt", 47 | help="path to model checkpoint") 48 | parser.add_argument( 49 | "--save_folder", 50 | type=str, 51 | default="outputs/inference_512_codiff_mask_text", 52 | help="folder to save synthesis outputs") 53 | 54 | # batch size and ddim steps 55 | parser.add_argument( 56 | "--batch_size", 57 | type=int, 58 | default=4, 59 | help="number of images to generate") 60 | parser.add_argument( 61 | "--ddim_steps", 62 | type=int, 63 | default="50", 64 | help= 65 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)" 66 | ) 67 | 68 | # whether save intermediate outputs 69 | parser.add_argument( 70 | "--save_z", 71 | type=bool, 72 | default=False, 73 | help= 74 | "whether visualize the VAE latent codes and save them in the output folder", 75 | ) 76 | parser.add_argument( 77 | "--return_influence_function", 78 | type=bool, 79 | default=False, 80 | help= 81 | "whether visualize the Influence Functions and save them in the output folder", 82 | ) 83 | parser.add_argument( 84 | "--display_x_inter", 85 | type=bool, 86 | default=False, 87 | help= 88 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder", 89 | ) 90 | parser.add_argument( 91 | "--save_mixed", 92 | type=bool, 93 | default=False, 94 | help= 95 | "whether overlay the segmentation mask on the synthesized image to visualize mask consistency", 96 | ) 97 | 98 | args = parser.parse_args() 99 | return args 100 | 101 | 102 | def main(): 103 | 104 | args = parse_args() 105 | 106 | # ========== set up model ========== 107 | print(f'Set up model') 108 | config = OmegaConf.load(args.config_path) 109 | model_config = config['model'] 110 | model = instantiate_from_config(model_config) 111 | model.init_from_ckpt(args.ckpt_path) 112 | model = model.cuda() 113 | model.eval() 114 | 115 | # ========== set output directory ========== 116 | os.makedirs(args.save_folder, exist_ok=True) 117 | # save a copy of this python script being used 118 | # shutil.copyfile(__file__, os.path.join(args.save_folder, __file__)) 119 | 120 | # ========== prepare seg mask for model ========== 121 | with open(args.mask_path, 'rb') as f: 122 | img = Image.open(f) 123 | resized_img = img.resize((32, 32), Image.NEAREST) # resize 124 | flattened_img = list(resized_img.getdata()) 125 | flattened_img_tensor = torch.tensor(flattened_img) # flatten 126 | flattened_img_tensor_one_hot = F.one_hot( 127 | flattened_img_tensor, num_classes=19) # one hot 128 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose( 129 | 0, 1) 130 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze( 131 | flattened_img_tensor_one_hot_transpose, 132 | 0).cuda() # add batch dimension 133 | 134 | # ========== prepare mask for visualization ========== 135 | mask = Image.open(args.mask_path) 136 | mask = mask.convert('RGB') 137 | mask = np.array(mask).astype(np.uint8) # three channel integer 138 | input_mask = mask 139 | 140 | print( 141 | f'================================================================================' 142 | ) 143 | print(f'mask_path: {args.mask_path} | text: {args.input_text}') 144 | 145 | # prepare directories 146 | mask_name = args.mask_path.split('/')[-1] 147 | save_sub_folder = os.path.join(args.save_folder, mask_name, 148 | str(args.input_text)) 149 | os.makedirs(save_sub_folder, exist_ok=True) 150 | 151 | # save seg_mask 152 | save_path_mask = os.path.join(save_sub_folder, mask_name) 153 | mask_ = Image.fromarray(input_mask) 154 | mask_.save(save_path_mask) 155 | 156 | # ========== inference ========== 157 | with torch.no_grad(): 158 | 159 | condition = { 160 | 'seg_mask': flattened_img_tensor_one_hot_transpose, 161 | 'text': [args.input_text.lower()] 162 | } 163 | 164 | with model.ema_scope("Plotting"): 165 | 166 | # encode condition 167 | condition = model.get_learned_conditioning(condition) 168 | if isinstance(condition, dict): 169 | for key, value in condition.items(): 170 | condition[key] = condition[key].repeat( 171 | args.batch_size, 1, 1) 172 | else: 173 | condition = condition.repeat(args.batch_size, 1, 1) 174 | 175 | # define DDIM sampler with dynamic diffusers 176 | ddim_sampler = DDIMConfidenceSampler( 177 | model=model, 178 | return_confidence_map=args.return_influence_function) 179 | 180 | # DDIM sampling 181 | z_0_batch, intermediates = ddim_sampler.sample( 182 | S=args.ddim_steps, 183 | batch_size=args.batch_size, 184 | shape=(3, 64, 64), 185 | conditioning=condition, 186 | verbose=False, 187 | eta=1.0, 188 | log_every_t=1) 189 | 190 | # decode VAE latent z_0 to image x_0 191 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256] 192 | 193 | # ========== save outputs ========== 194 | for idx in range(args.batch_size): 195 | 196 | # save synthesized image x_0 197 | save_x_0_path = os.path.join(save_sub_folder, 198 | f'{str(idx).zfill(6)}_x_0.png') 199 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256] 200 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy() 201 | x_0 = (x_0 + 1.0) * 127.5 202 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255 203 | x_0 = x_0.astype(np.uint8) 204 | x_0 = Image.fromarray(x_0[0]) 205 | x_0.save(save_x_0_path) 206 | 207 | # save intermediate x_t and pred_x_0 208 | if args.display_x_inter: 209 | for cond_name in ['x_inter', 'pred_x0']: 210 | save_conf_path = os.path.join( 211 | save_sub_folder, f'{str(idx).zfill(6)}_{cond_name}.png') 212 | conf = intermediates[f'{cond_name}'] 213 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64 214 | conf = conf[:, idx, :, :, :] # 50x3x64x64 215 | print('decoding x_inter ......') 216 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256] 217 | conf = make_grid( 218 | conf, nrow=10) # 10 images per row # [3, 256x3, 256x10] 219 | conf = conf.permute(1, 2, 220 | 0).to('cpu').numpy() # cxhxh -> hxhxc 221 | conf = (conf + 1.0) * 127.5 222 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 223 | conf = conf.astype(np.uint8) 224 | conf = Image.fromarray(conf) 225 | conf.save(save_conf_path) 226 | 227 | # save influence functions 228 | if args.return_influence_function: 229 | for cond_name in ['seg_mask', 'text']: 230 | save_conf_path = os.path.join( 231 | save_sub_folder, 232 | f'{str(idx).zfill(6)}_{cond_name}_influence_function.png') 233 | conf = intermediates[f'{cond_name}_confidence_map'] 234 | conf = torch.stack(conf, dim=0) # 50x8x1x64x64 235 | conf = conf[:, idx, :, :, :] # 50x1x64x64 236 | conf = torch.cat( 237 | [conf, conf, conf], 238 | dim=1) # manually create 3 channels # [50, 3, 64, 64] 239 | conf = make_grid( 240 | conf, nrow=10) # 10 images per row # [3, 332, 662] 241 | conf = conf.permute(1, 2, 242 | 0).to('cpu').numpy() # cxhxh -> hxhxc 243 | conf = conf * 255 # manually tuned denormalization: [0,1] -> [0,255] 244 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 245 | conf = conf.astype(np.uint8) 246 | conf = Image.fromarray(conf) 247 | conf.save(save_conf_path) 248 | 249 | # save latent z_0 250 | if args.save_z: 251 | save_z_0_path = os.path.join(save_sub_folder, 252 | f'{str(idx).zfill(6)}_z_0.png') 253 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64] 254 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy() 255 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization 256 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255 257 | z_0 = z_0.astype(np.uint8) 258 | z_0 = Image.fromarray(z_0[0]) 259 | z_0.save(save_z_0_path) 260 | 261 | # overlay the segmentation mask on the synthesized image to visualize mask consistency 262 | if args.save_mixed: 263 | save_mixed_path = os.path.join(save_sub_folder, 264 | f'{str(idx).zfill(6)}_mixed.png') 265 | Image.blend(x_0, mask_, 0.3).save(save_mixed_path) 266 | 267 | 268 | if __name__ == "__main__": 269 | main() -------------------------------------------------------------------------------- /editing/imagic_edit_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | from torchvision import transforms 12 | from tqdm import tqdm 13 | 14 | from einops import rearrange 15 | from ldm.models.diffusion.ddim import DDIMSampler 16 | from ldm.util import instantiate_from_config 17 | """ 18 | Reference code: 19 | https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb 20 | """ 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | # directories 25 | parser.add_argument('--config', type=str, default='configs/256_text.yaml') 26 | parser.add_argument('--ckpt', type=str, default='pretrained/256_text.ckpt') 27 | parser.add_argument('--save_folder', type=str, default='outputs/text_edit') 28 | parser.add_argument( 29 | '--input_image_path', 30 | type=str, 31 | default='test_data/test_mask_edit/256_input_image/27044.jpg') 32 | parser.add_argument( 33 | '--text_prompt', 34 | type=str, 35 | default='He is a teen. The face is covered with short pointed beard.') 36 | 37 | # hyperparameters 38 | parser.add_argument('--seed', type=int, default=0) 39 | parser.add_argument('--stage1_lr', type=float, default=0.001) 40 | parser.add_argument('--stage1_num_iter', type=int, default=500) 41 | parser.add_argument('--stage2_lr', type=float, default=1e-6) 42 | parser.add_argument('--stage2_num_iter', type=int, default=1000) 43 | parser.add_argument( 44 | '--alpha_list', 45 | type=str, 46 | default='-1, -0.5, 0, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5') 47 | parser.add_argument('--set_random_seed', type=bool, default=False) 48 | parser.add_argument('--save_checkpoint', type=bool, default=True) 49 | 50 | args = parser.parse_args() 51 | 52 | 53 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 54 | """Loads a model from config and a ckpt 55 | if config is a path will use omegaconf to load 56 | """ 57 | if isinstance(config, (str, Path)): 58 | config = OmegaConf.load(config) 59 | 60 | pl_sd = torch.load(ckpt, map_location="cpu") 61 | global_step = pl_sd["global_step"] 62 | sd = pl_sd["state_dict"] 63 | model = instantiate_from_config(config.model) 64 | m, u = model.load_state_dict(sd, strict=False) 65 | model.to(device) 66 | model.eval() 67 | model.cond_stage_model.device = device 68 | return model 69 | 70 | 71 | @torch.no_grad() 72 | def sample_model(model, 73 | sampler, 74 | c, 75 | h, 76 | w, 77 | ddim_steps, 78 | scale, 79 | ddim_eta, 80 | start_code=None, 81 | n_samples=1): 82 | """Sample the model""" 83 | uc = None 84 | if scale != 1.0: 85 | uc = model.get_learned_conditioning(n_samples * [""]) 86 | 87 | # print(f'model.model.parameters(): {model.model.parameters()}') 88 | # for name, param in model.model.named_parameters(): 89 | # if param.requires_grad: 90 | # print (name, param.data) 91 | # break 92 | 93 | # print(f'unconditional_guidance_scale: {scale}') # 1.0 94 | # print(f'unconditional_conditioning: {uc}') # None 95 | with model.ema_scope("Plotting"): 96 | 97 | shape = [3, 64, 64] # [4, h // 8, w // 8] 98 | samples_ddim, _ = sampler.sample( 99 | S=ddim_steps, 100 | conditioning=c, 101 | batch_size=n_samples, 102 | shape=shape, 103 | verbose=False, 104 | start_code=start_code, 105 | unconditional_guidance_scale=scale, 106 | unconditional_conditioning=uc, 107 | eta=ddim_eta, 108 | ) 109 | return samples_ddim 110 | 111 | 112 | def load_img(path, target_size=256): 113 | """Load an image, resize and output -1..1""" 114 | image = Image.open(path).convert("RGB") 115 | 116 | tform = transforms.Compose([ 117 | # transforms.Resize(target_size), 118 | # transforms.CenterCrop(target_size), 119 | transforms.ToTensor(), 120 | ]) 121 | image = tform(image) 122 | return 2. * image - 1. 123 | 124 | 125 | def decode_to_im(samples, n_samples=1, nrow=1): 126 | """Decode a latent and return PIL image""" 127 | samples = model.decode_first_stage(samples) 128 | ims = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) 129 | x_sample = 255. * rearrange( 130 | ims.cpu().numpy(), 131 | '(n1 n2) c h w -> (n1 h) (n2 w) c', 132 | n1=n_samples // nrow, 133 | n2=nrow) 134 | return Image.fromarray(x_sample.astype(np.uint8)) 135 | 136 | 137 | if __name__ == '__main__': 138 | 139 | args.alpha_list = [float(i) for i in args.alpha_list.split(',')] 140 | 141 | device = "cuda" # "cuda:0" 142 | 143 | # Generation parameters 144 | scale = 1.0 145 | h = 256 146 | w = 256 147 | ddim_steps = 50 148 | ddim_eta = 1.0 149 | 150 | # initialize model 151 | global_model = load_model_from_config(args.config, args.ckpt, device) 152 | 153 | input_image = args.input_image_path 154 | image_name = input_image.split('/')[-1] 155 | 156 | prompt = args.text_prompt 157 | 158 | torch.manual_seed(args.seed) 159 | 160 | model = copy.deepcopy(global_model) 161 | sampler = DDIMSampler(model) 162 | 163 | # prepare directories 164 | save_dir = os.path.join(args.save_folder, image_name, str(prompt)) 165 | os.makedirs(save_dir, exist_ok=True) 166 | print( 167 | f'================================================================================' 168 | ) 169 | print(f'input_image: {input_image} | text: {prompt}') 170 | 171 | # read input image 172 | init_image = load_img(input_image).to(device).unsqueeze( 173 | 0) # [1, 3, 256, 256] 174 | gaussian_distribution = model.encode_first_stage(init_image) 175 | init_latent = model.get_first_stage_encoding( 176 | gaussian_distribution) # [1, 3, 64, 64] 177 | img = decode_to_im(init_latent) 178 | img.save(os.path.join(save_dir, 'input_image_reconstructed.png')) 179 | 180 | # obtain text embedding 181 | emb_tgt = model.get_learned_conditioning([prompt]) 182 | emb_ = emb_tgt.clone() 183 | torch.save(emb_, os.path.join(save_dir, 'emb_tgt.pt')) 184 | emb = torch.load(os.path.join(save_dir, 'emb_tgt.pt')) # [1, 77, 640] 185 | 186 | # Sample the model with a fixed code to see what it looks like 187 | quick_sample = lambda x, s, code: decode_to_im( 188 | sample_model( 189 | model, sampler, x, h, w, ddim_steps, s, ddim_eta, start_code=code)) 190 | # start_code = torch.randn_like(init_latent) 191 | start_code = torch.randn((1, 3, 64, 64), device=device) 192 | torch.save(start_code, os.path.join(save_dir, 193 | 'start_code.pt')) # [1, 3, 64, 64] 194 | torch.manual_seed(args.seed) 195 | img = quick_sample(emb_tgt, scale, start_code) 196 | img.save(os.path.join(save_dir, 'A_start_tgtText_origDM.png')) 197 | 198 | # ======================= (A) Text Embedding Optimization =================================== 199 | print('########### Step 1 - Optimise the embedding ###########') 200 | emb.requires_grad = True 201 | opt = torch.optim.Adam([emb], lr=args.stage1_lr) 202 | criteria = torch.nn.MSELoss() 203 | history = [] 204 | 205 | pbar = tqdm(range(args.stage1_num_iter)) 206 | for i in pbar: 207 | opt.zero_grad() 208 | 209 | if args.set_random_seed: 210 | torch.seed() 211 | noise = torch.randn_like(init_latent) 212 | t_enc = torch.randint(1000, (1, ), device=device) 213 | z = model.q_sample(init_latent, t_enc, noise=noise) 214 | 215 | pred_noise = model.apply_model(z, t_enc, emb) 216 | 217 | loss = criteria(pred_noise, noise) 218 | loss.backward() 219 | pbar.set_postfix({"loss": loss.item()}) 220 | history.append(loss.item()) 221 | opt.step() 222 | 223 | plt.plot(history) 224 | plt.show() 225 | torch.save(emb, os.path.join(save_dir, 'emb_opt.pt')) 226 | emb_opt = torch.load(os.path.join(save_dir, 'emb_opt.pt')) # [1, 77, 640] 227 | 228 | torch.manual_seed(args.seed) 229 | img = quick_sample(emb_opt, scale, start_code) 230 | img.save(os.path.join(save_dir, 'A_end_optText_origDM.png')) 231 | 232 | # Interpolate the embedding 233 | for idx, alpha in enumerate(args.alpha_list): 234 | print(f'alpha={alpha}') 235 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt 236 | torch.manual_seed(args.seed) 237 | img = quick_sample(new_emb, scale, start_code) 238 | img.save( 239 | os.path.join( 240 | save_dir, 241 | f'0A_interText_origDM_{idx}_alpha={round(alpha,3)}.png')) 242 | 243 | # ======================= (B) Model Fine-Tuning =================================== 244 | print('########### Step 2 - Fine tune the model ###########') 245 | emb_opt.requires_grad = False 246 | model.train() 247 | 248 | opt = torch.optim.Adam(model.model.parameters(), lr=args.stage2_lr) 249 | criteria = torch.nn.MSELoss() 250 | history = [] 251 | 252 | pbar = tqdm(range(args.stage2_num_iter)) 253 | for i in pbar: 254 | opt.zero_grad() 255 | 256 | if args.set_random_seed: 257 | torch.seed() 258 | noise = torch.randn_like(init_latent) 259 | t_enc = torch.randint(model.num_timesteps, (1, ), device=device) 260 | z = model.q_sample(init_latent, t_enc, noise=noise) 261 | 262 | pred_noise = model.apply_model(z, t_enc, emb_opt) 263 | 264 | loss = criteria(pred_noise, noise) 265 | loss.backward() 266 | pbar.set_postfix({"loss": loss.item()}) 267 | history.append(loss.item()) 268 | opt.step() 269 | 270 | model.eval() 271 | plt.plot(history) 272 | plt.show() 273 | torch.manual_seed(args.seed) 274 | img = quick_sample(emb_opt, scale, start_code) 275 | img.save(os.path.join(save_dir, 'B_end_optText_optDM.png')) 276 | # Should look like the original image 277 | 278 | if args.save_checkpoint: 279 | ckpt = { 280 | "state_dict": model.state_dict(), 281 | } 282 | ckpt_path = os.path.join(save_dir, 'optDM.ckpt') 283 | print(f'Saving optDM to {ckpt_path}') 284 | torch.save(ckpt, ckpt_path) 285 | 286 | # ======================= (C) Generation =================================== 287 | print('########### Step 3 - Generate images ###########') 288 | # Interpolate the embedding 289 | for idx, alpha in enumerate(args.alpha_list): 290 | print(f'alpha={alpha}') 291 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt 292 | torch.manual_seed(args.seed) 293 | img = quick_sample(new_emb, scale, start_code) 294 | img.save( 295 | os.path.join( 296 | save_dir, 297 | f'0C_interText_optDM_{idx}_alpha={round(alpha,3)}.png')) 298 | 299 | print('Done') -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /editing/collaborative_edit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from torchvision.utils import make_grid 10 | 11 | from ldm.models.diffusion.ddim_confidence import DDIMConfidenceSampler 12 | from ldm.util import instantiate_from_config 13 | """ 14 | Inference script for multi-modal-driven face editing at 256x256 resolution 15 | """ 16 | 17 | 18 | def parse_args(): 19 | 20 | parser = argparse.ArgumentParser(description="") 21 | 22 | # uni-modal editing results 23 | parser.add_argument( 24 | "--imagic_text_folder", 25 | type=str, 26 | default= 27 | "outputs/text_edit/27044.jpg/He is a teen. The face is covered with short pointed beard.", 28 | help="path to Imagic text-based editing results") 29 | parser.add_argument( 30 | "--imagic_mask_folder", 31 | type=str, 32 | default= 33 | "outputs/mask_edit/27044.jpg/27044_0_remove_smile_and_rings.png", 34 | help="path to Imagic mask-based editing results") 35 | 36 | # paths 37 | parser.add_argument( 38 | "--config_path", 39 | type=str, 40 | default="configs/256_codiff_mask_text.yaml", 41 | help="path to model config") 42 | parser.add_argument( 43 | "--ckpt_path", 44 | type=str, 45 | default="pretrained/256_codiff_mask_text.ckpt", 46 | help="path to model checkpoint") 47 | parser.add_argument( 48 | "--save_folder", 49 | type=str, 50 | default="outputs/collaborative_edit", 51 | help="folder to save editing outputs") 52 | 53 | # batch size and ddim steps 54 | parser.add_argument( 55 | "--batch_size", 56 | type=int, 57 | default=1, 58 | help="number of images to generate") 59 | parser.add_argument( 60 | "--ddim_steps", 61 | type=int, 62 | default=50, 63 | help= 64 | "number of ddim steps (between 20 to 1000, the larger the slower but better quality)" 65 | ) 66 | parser.add_argument( 67 | "--seed", 68 | type=int, 69 | default=2, 70 | ) 71 | 72 | # whether save intermediate outputs 73 | parser.add_argument( 74 | "--save_z", 75 | type=bool, 76 | default=False, 77 | help= 78 | "whether visualize the VAE latent codes and save them in the output folder", 79 | ) 80 | parser.add_argument( 81 | "--return_influence_function", 82 | type=bool, 83 | default=False, 84 | help= 85 | "whether visualize the Influence Functions and save them in the output folder", 86 | ) 87 | parser.add_argument( 88 | "--display_x_inter", 89 | type=bool, 90 | default=False, 91 | help= 92 | "whether display the intermediate DDIM outputs (x_t and pred_x_0) and save them in the output folder", 93 | ) 94 | 95 | args = parser.parse_args() 96 | return args 97 | 98 | 99 | def main(): 100 | 101 | args = parse_args() 102 | 103 | # ========== set output directory ========== 104 | os.makedirs(args.save_folder, exist_ok=True) 105 | 106 | # ========== init model ========== 107 | config = OmegaConf.load(args.config_path) 108 | model_config = config['model'] 109 | model_config['params']['seg_mask_ldm_ckpt_path'] = os.path.join( 110 | args.imagic_mask_folder, 'optDM.ckpt') 111 | model_config['params']['text_ldm_ckpt_path'] = os.path.join( 112 | args.imagic_text_folder, 'optDM.ckpt') 113 | model = instantiate_from_config(model_config) 114 | model.init_from_ckpt(args.ckpt_path) 115 | mask_optDM_ckpt = torch.load( 116 | os.path.join(args.imagic_mask_folder, 'optDM.ckpt')) 117 | text_optDM_ckpt = torch.load( 118 | os.path.join(args.imagic_text_folder, 'optDM.ckpt')) 119 | 120 | print('Updating text and mask model to the finetuned version ......') 121 | state_dict = model.state_dict() 122 | for name in state_dict.keys(): 123 | if 'model.seg_mask_unet.' in name: 124 | name_end = name[20:] 125 | state_dict[name] = mask_optDM_ckpt['state_dict'][ 126 | f'model.{name_end}'] 127 | elif 'model.text_unet.' in name: 128 | name_end = name[16:] 129 | state_dict[name] = text_optDM_ckpt['state_dict'][ 130 | f'model.{name_end}'] 131 | model.load_state_dict(state_dict) 132 | 133 | print('Pushing model to CUDA ......') 134 | global_model = model.cuda() 135 | global_model.eval() 136 | 137 | seed = args.seed 138 | 139 | for alpha_idx, alpha in enumerate([ 140 | 0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1, 1.125, 1.25, 141 | 1.375, 1.5, 1.625, 1.75, 1.875, 2.0, 2.5 142 | ]): 143 | 144 | print(f'alpha={alpha}') 145 | 146 | seed = int(seed) 147 | torch.manual_seed(seed) 148 | mask_start_code = torch.load( 149 | os.path.join(args.imagic_mask_folder, 'start_code.pt')) 150 | text_start_code = torch.load( 151 | os.path.join(args.imagic_text_folder, 'start_code.pt')) 152 | start_code = mask_start_code 153 | 154 | # prepare directories 155 | save_sub_folder = os.path.join(args.save_folder, f'seed={seed}') 156 | os.makedirs(save_sub_folder, exist_ok=True) 157 | 158 | seed = int(seed) 159 | model = copy.deepcopy(global_model) 160 | 161 | # ========== inference ========== 162 | with torch.no_grad(): 163 | 164 | with model.ema_scope("Plotting"): 165 | 166 | condition = {} 167 | 168 | mask_alpha = alpha 169 | mask_emb_tgt = torch.load( 170 | os.path.join(args.imagic_mask_folder, 'emb_tgt.pt')) 171 | mask_emb_opt = torch.load( 172 | os.path.join(args.imagic_mask_folder, 'emb_opt.pt')) 173 | mask_new_emb = mask_alpha * mask_emb_tgt + ( 174 | 1 - mask_alpha) * mask_emb_opt 175 | condition['seg_mask'] = mask_new_emb.repeat( 176 | args.batch_size, 1, 1) 177 | 178 | text_alpha = alpha 179 | text_emb_tgt = torch.load( 180 | os.path.join(args.imagic_text_folder, 'emb_tgt.pt')) 181 | text_emb_opt = torch.load( 182 | os.path.join(args.imagic_text_folder, 'emb_opt.pt')) 183 | text_new_emb = text_alpha * text_emb_tgt + ( 184 | 1 - text_alpha) * text_emb_opt 185 | condition['text'] = text_new_emb.repeat(args.batch_size, 1, 1) 186 | 187 | torch.manual_seed(seed) 188 | 189 | ddim_sampler = DDIMConfidenceSampler( 190 | model=model, 191 | return_confidence_map=args.return_influence_function) 192 | 193 | torch.manual_seed(seed) 194 | 195 | z_0_batch, intermediates = ddim_sampler.sample( 196 | S=args.ddim_steps, 197 | batch_size=args.batch_size, 198 | shape=(3, 64, 64), 199 | conditioning=condition, 200 | verbose=False, 201 | start_code=start_code, 202 | eta=1.0, 203 | log_every_t=1) 204 | 205 | # decode VAE latent z_0 to image x_0 206 | x_0_batch = model.decode_first_stage(z_0_batch) # [B, 3, 256, 256] 207 | 208 | # ========== save outputs ========== 209 | for idx in range(args.batch_size): 210 | 211 | # save synthesized image x_0 212 | save_x_0_path = os.path.join( 213 | save_sub_folder, 214 | f'{str(idx).zfill(6)}_x_0_{alpha_idx}_alpha={round(alpha, 3)}.png' 215 | ) 216 | x_0 = x_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 256, 256] 217 | x_0 = x_0.permute(0, 2, 3, 1).to('cpu').numpy() 218 | x_0 = (x_0 + 1.0) * 127.5 219 | np.clip(x_0, 0, 255, out=x_0) # clip to range 0 to 255 220 | x_0 = x_0.astype(np.uint8) 221 | x_0 = Image.fromarray(x_0[0]) 222 | x_0.save(save_x_0_path) 223 | 224 | # save intermediate x_t and pred_x_0 225 | if args.display_x_inter: 226 | for cond_name in ['x_inter', 'pred_x0']: 227 | save_conf_path = os.path.join( 228 | save_sub_folder, 229 | f'{str(idx).zfill(6)}_{cond_name}.png') 230 | conf = intermediates[f'{cond_name}'] 231 | conf = torch.stack(conf, dim=0) # 50x8x3x64x64 232 | conf = conf[:, idx, :, :, :] # 50x3x64x64 233 | print('decoding x_inter ......') 234 | conf = model.decode_first_stage(conf) # [50, 3, 256, 256] 235 | conf = make_grid( 236 | conf, 237 | nrow=10) # 10 images per row # [3, 256x3, 256x10] 238 | conf = conf.permute(1, 2, 239 | 0).to('cpu').numpy() # cxhxh -> hxhxc 240 | conf = (conf + 1.0) * 127.5 241 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 242 | conf = conf.astype(np.uint8) 243 | conf = Image.fromarray(conf) 244 | conf.save(save_conf_path) 245 | 246 | # save influence functions 247 | if args.return_influence_function: 248 | for cond_name in ['seg_mask', 'text']: 249 | save_conf_path = os.path.join( 250 | save_sub_folder, 251 | f'{str(idx).zfill(6)}_{cond_name}_influence_function.png' 252 | ) 253 | conf = intermediates[f'{cond_name}_confidence_map'] 254 | conf = torch.stack(conf, dim=0) # 50x8x1x64x64 255 | conf = conf[:, idx, :, :, :] # 50x1x64x64 256 | conf = torch.cat( 257 | [conf, conf, conf], 258 | dim=1) # manually create 3 channels # [50, 3, 64, 64] 259 | conf = make_grid( 260 | conf, nrow=10) # 10 images per row # [3, 332, 662] 261 | conf = conf.permute(1, 2, 262 | 0).to('cpu').numpy() # cxhxh -> hxhxc 263 | conf = conf * 255 # manually tuned denormalization: [0,1] -> [0,255] 264 | np.clip(conf, 0, 255, out=conf) # clip to range 0 to 255 265 | conf = conf.astype(np.uint8) 266 | conf = Image.fromarray(conf) 267 | conf.save(save_conf_path) 268 | 269 | # save latent z_0 270 | if args.save_z: 271 | save_z_0_path = os.path.join(save_sub_folder, 272 | f'{str(idx).zfill(6)}_z_0.png') 273 | z_0 = z_0_batch[idx, :, :, :].unsqueeze(0) # [1, 3, 64, 64] 274 | z_0 = z_0.permute(0, 2, 3, 1).to('cpu').numpy() 275 | z_0 = (z_0 + 40) * 4 # manually tuned denormalization 276 | np.clip(z_0, 0, 255, out=z_0) # clip to range 0 to 255 277 | z_0 = z_0.astype(np.uint8) 278 | z_0 = Image.fromarray(z_0[0]) 279 | z_0.save(save_z_0_path) 280 | 281 | 282 | if __name__ == "__main__": 283 | main() -------------------------------------------------------------------------------- /ldm/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | """DDIM sampling using one uni-modal pre-trained diffusion models""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class DDIMSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | self.register_buffer('betas', to_torch(self.model.betas)) 32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 34 | 35 | # calculations for diffusion q(x_t | x_{t-1}) and others 36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 41 | 42 | # ddim sampling parameters 43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 44 | ddim_timesteps=self.ddim_timesteps, 45 | eta=ddim_eta,verbose=verbose) 46 | self.register_buffer('ddim_sigmas', ddim_sigmas) 47 | self.register_buffer('ddim_alphas', ddim_alphas) 48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 54 | 55 | @torch.no_grad() 56 | def sample(self, 57 | S, 58 | batch_size, 59 | shape, 60 | conditioning=None, 61 | callback=None, 62 | normals_sequence=None, 63 | img_callback=None, 64 | quantize_x0=False, 65 | eta=0., 66 | mask=None, 67 | x0=None, 68 | temperature=1., 69 | noise_dropout=0., 70 | score_corrector=None, 71 | corrector_kwargs=None, 72 | verbose=True, 73 | x_T=None, 74 | log_every_t=100, 75 | unconditional_guidance_scale=1., 76 | unconditional_conditioning=None, 77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 78 | **kwargs 79 | ): 80 | 81 | if conditioning is not None: 82 | if isinstance(conditioning, dict): 83 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 84 | if cbs != batch_size: 85 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 86 | else: 87 | if conditioning.shape[0] != batch_size: 88 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 89 | 90 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 91 | # sampling 92 | C, H, W = shape 93 | size = (batch_size, C, H, W) 94 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 95 | 96 | samples, intermediates = self.ddim_sampling(conditioning, size, 97 | callback=callback, 98 | img_callback=img_callback, 99 | quantize_denoised=quantize_x0, 100 | mask=mask, x0=x0, 101 | ddim_use_original_steps=False, 102 | noise_dropout=noise_dropout, 103 | temperature=temperature, 104 | score_corrector=score_corrector, 105 | corrector_kwargs=corrector_kwargs, 106 | x_T=x_T, 107 | log_every_t=log_every_t, 108 | unconditional_guidance_scale=unconditional_guidance_scale, 109 | unconditional_conditioning=unconditional_conditioning, 110 | ) 111 | return samples, intermediates 112 | 113 | @torch.no_grad() 114 | def ddim_sampling(self, cond, shape, 115 | x_T=None, ddim_use_original_steps=False, 116 | callback=None, timesteps=None, quantize_denoised=False, 117 | mask=None, x0=None, img_callback=None, log_every_t=100, 118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 119 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 120 | device = self.model.betas.device 121 | b = shape[0] 122 | if x_T is None: 123 | img = torch.randn(shape, device=device) 124 | else: 125 | img = x_T 126 | 127 | if timesteps is None: 128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 129 | elif timesteps is not None and not ddim_use_original_steps: 130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 131 | timesteps = self.ddim_timesteps[:subset_end] 132 | 133 | intermediates = {'x_inter': [], 'pred_x0': []} 134 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 135 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 136 | print(f"Running DDIM Sampling with {total_steps} timesteps") 137 | 138 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 139 | 140 | for i, step in enumerate(iterator): 141 | index = total_steps - i - 1 142 | ts = torch.full((b,), step, device=device, dtype=torch.long) 143 | 144 | if mask is not None: 145 | assert x0 is not None 146 | img_orig = self.model.q_sample(x0, ts) 147 | img = img_orig * mask + (1. - mask) * img 148 | 149 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 150 | quantize_denoised=quantize_denoised, temperature=temperature, 151 | noise_dropout=noise_dropout, score_corrector=score_corrector, 152 | corrector_kwargs=corrector_kwargs, 153 | unconditional_guidance_scale=unconditional_guidance_scale, 154 | unconditional_conditioning=unconditional_conditioning) 155 | img, pred_x0 = outs 156 | 157 | if callback: callback(i) 158 | if img_callback: img_callback(pred_x0, i) 159 | 160 | if index % log_every_t == 0 or index == total_steps - 1: 161 | intermediates['x_inter'].append(img) 162 | intermediates['pred_x0'].append(pred_x0) 163 | 164 | return img, intermediates 165 | 166 | @torch.no_grad() 167 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 168 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 169 | unconditional_guidance_scale=1., unconditional_conditioning=None): 170 | b, *_, device = *x.shape, x.device 171 | 172 | 173 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 174 | e_t = self.model.apply_model(x, t, c) 175 | else: 176 | x_in = torch.cat([x] * 2) 177 | t_in = torch.cat([t] * 2) 178 | c_in = torch.cat([unconditional_conditioning, c]) 179 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 180 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 181 | 182 | if score_corrector is not None: 183 | assert self.model.parameterization == "eps" 184 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 185 | 186 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 187 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 188 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 189 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 190 | # select parameters corresponding to the currently considered timestep 191 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 192 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 193 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 194 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 195 | 196 | # current prediction for x_0 197 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 198 | if quantize_denoised: 199 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 200 | # direction pointing to x_t 201 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 202 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 203 | if noise_dropout > 0.: 204 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 205 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 206 | 207 | return x_prev, pred_x0 208 | -------------------------------------------------------------------------------- /editing/imagic_edit_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | 15 | from einops import rearrange 16 | from ldm.models.diffusion.ddim import DDIMSampler 17 | from ldm.util import instantiate_from_config 18 | """ 19 | Reference code: 20 | https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb 21 | """ 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | # directories 26 | parser.add_argument('--config', type=str, default='configs/256_mask.yaml') 27 | parser.add_argument('--ckpt', type=str, default='pretrained/256_mask.ckpt') 28 | parser.add_argument('--save_folder', type=str, default='outputs/mask_edit') 29 | parser.add_argument( 30 | '--input_image_path', 31 | type=str, 32 | default='test_data/test_mask_edit/256_input_image/27044.jpg') 33 | parser.add_argument( 34 | '--mask_path', 35 | type=str, 36 | default= 37 | 'test_data/test_mask_edit/256_edited_masks/27044_0_remove_smile_and_rings.png' 38 | ) 39 | 40 | # hyperparameters 41 | parser.add_argument('--seed', type=int, default=0) 42 | parser.add_argument('--stage1_lr', type=float, default=0.001) 43 | parser.add_argument('--stage1_num_iter', type=int, default=500) 44 | parser.add_argument('--stage2_lr', type=float, default=1e-6) 45 | parser.add_argument('--stage2_num_iter', type=int, default=1000) 46 | parser.add_argument( 47 | '--alpha_list', 48 | type=str, 49 | default='-1, -0.5, 0, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5') 50 | parser.add_argument('--set_random_seed', type=bool, default=False) 51 | parser.add_argument('--save_checkpoint', type=bool, default=True) 52 | 53 | args = parser.parse_args() 54 | 55 | 56 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 57 | """Loads a model from config and a ckpt 58 | if config is a path will use omegaconf to load 59 | """ 60 | if isinstance(config, (str, Path)): 61 | config = OmegaConf.load(config) 62 | 63 | pl_sd = torch.load(ckpt, map_location="cpu") 64 | global_step = pl_sd["global_step"] 65 | sd = pl_sd["state_dict"] 66 | model = instantiate_from_config(config.model) 67 | m, u = model.load_state_dict(sd, strict=False) 68 | model.to(device) 69 | model.eval() 70 | model.cond_stage_model.device = device 71 | return model 72 | 73 | 74 | @torch.no_grad() 75 | def sample_model(model, 76 | sampler, 77 | c, 78 | h, 79 | w, 80 | ddim_steps, 81 | scale, 82 | ddim_eta, 83 | start_code=None, 84 | n_samples=1): 85 | """Sample the model""" 86 | uc = None 87 | if scale != 1.0: 88 | uc = model.get_learned_conditioning(n_samples * [""]) 89 | 90 | # print(f'model.model.parameters(): {model.model.parameters()}') 91 | # for name, param in model.model.named_parameters(): 92 | # if param.requires_grad: 93 | # print (name, param.data) 94 | # break 95 | 96 | # print(f'unconditional_guidance_scale: {scale}') # 1.0 97 | # print(f'unconditional_conditioning: {uc}') # None 98 | with model.ema_scope("Plotting"): 99 | 100 | shape = [3, 64, 64] # [4, h // 8, w // 8] 101 | samples_ddim, _ = sampler.sample( 102 | S=ddim_steps, 103 | conditioning=c, 104 | batch_size=n_samples, 105 | shape=shape, 106 | verbose=False, 107 | start_code=start_code, 108 | unconditional_guidance_scale=scale, 109 | unconditional_conditioning=uc, 110 | eta=ddim_eta, 111 | ) 112 | return samples_ddim 113 | 114 | 115 | def load_img(path, target_size=256): 116 | """Load an image, resize and output -1..1""" 117 | image = Image.open(path).convert("RGB") 118 | 119 | tform = transforms.Compose([ 120 | # transforms.Resize(target_size), 121 | # transforms.CenterCrop(target_size), 122 | transforms.ToTensor(), 123 | ]) 124 | image = tform(image) 125 | return 2. * image - 1. 126 | 127 | 128 | def decode_to_im(samples, n_samples=1, nrow=1): 129 | """Decode a latent and return PIL image""" 130 | samples = model.decode_first_stage(samples) 131 | ims = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) 132 | x_sample = 255. * rearrange( 133 | ims.cpu().numpy(), 134 | '(n1 n2) c h w -> (n1 h) (n2 w) c', 135 | n1=n_samples // nrow, 136 | n2=nrow) 137 | return Image.fromarray(x_sample.astype(np.uint8)) 138 | 139 | 140 | if __name__ == '__main__': 141 | 142 | args.alpha_list = [float(i) for i in args.alpha_list.split(',')] 143 | 144 | device = "cuda" # "cuda:0" 145 | 146 | # Generation parameters 147 | scale = 1.0 148 | h = 256 149 | w = 256 150 | ddim_steps = 50 151 | ddim_eta = 1.0 152 | 153 | # initialize model 154 | global_model = load_model_from_config(args.config, args.ckpt, device) 155 | 156 | input_image = args.input_image_path 157 | image_name = input_image.split('/')[-1] 158 | mask_name = args.mask_path.split('/')[-1] 159 | 160 | torch.manual_seed(args.seed) 161 | 162 | model = copy.deepcopy(global_model) 163 | sampler = DDIMSampler(model) 164 | 165 | # prepare directories 166 | save_dir = os.path.join(args.save_folder, image_name, str(mask_name)) 167 | print(f'save_dir = {save_dir}') 168 | 169 | os.makedirs(save_dir, exist_ok=True) 170 | print( 171 | f'================================================================================' 172 | ) 173 | print(f'input_image: {input_image} | mask_name: {mask_name}') 174 | 175 | # read input image 176 | init_image = load_img(input_image).to(device).unsqueeze( 177 | 0) # [1, 3, 256, 256] 178 | gaussian_distribution = model.encode_first_stage(init_image) 179 | init_latent = model.get_first_stage_encoding( 180 | gaussian_distribution) # [1, 3, 64, 64] 181 | img = decode_to_im(init_latent) 182 | img.save(os.path.join(save_dir, 'input_image_reconstructed.png')) 183 | 184 | # obtain mask embedding 185 | # ========== prepare mask (one-hot): resize to 32x32 then one-hot ========== 186 | with open(args.mask_path, 'rb') as f: 187 | img = Image.open(f) 188 | resized_img = img.resize((32, 32), Image.NEAREST) # resize 189 | flattened_img = list(resized_img.getdata()) 190 | flattened_img_tensor = torch.tensor(flattened_img) # flatten 191 | flattened_img_tensor_one_hot = F.one_hot( 192 | flattened_img_tensor, num_classes=19) # one hot 193 | flattened_img_tensor_one_hot_transpose = flattened_img_tensor_one_hot.transpose( 194 | 0, 1) 195 | flattened_img_tensor_one_hot_transpose = torch.unsqueeze( 196 | flattened_img_tensor_one_hot_transpose, 197 | 0).cuda() # add batch dimension 198 | 199 | # ========== prepare mask (image) ========== 200 | mask = Image.open(args.mask_path) 201 | mask = mask.convert('RGB') 202 | mask = np.array(mask).astype(np.uint8) # three channel integer 203 | input_mask = mask 204 | 205 | # save seg_mask 206 | mask_ = Image.fromarray(input_mask) 207 | mask_.save(os.path.join(save_dir, mask_name)) 208 | 209 | condition = flattened_img_tensor_one_hot_transpose 210 | emb_tgt = model.get_learned_conditioning(condition) 211 | emb_ = emb_tgt.clone() 212 | torch.save(emb_, os.path.join(save_dir, 'emb_tgt.pt')) 213 | emb = torch.load(os.path.join(save_dir, 'emb_tgt.pt')) # [1, 77, 640] 214 | 215 | # Sample the model with a fixed code to see what it looks like 216 | quick_sample = lambda x, s, code: decode_to_im( 217 | sample_model( 218 | model, sampler, x, h, w, ddim_steps, s, ddim_eta, start_code=code)) 219 | # start_code = torch.randn_like(init_latent) 220 | start_code = torch.randn((1, 3, 64, 64), device=device) 221 | torch.save(start_code, os.path.join(save_dir, 222 | 'start_code.pt')) # [1, 3, 64, 64] 223 | torch.manual_seed(args.seed) 224 | img = quick_sample(emb_tgt, scale, start_code) 225 | img.save(os.path.join(save_dir, 'A_start_tgtText_origDM.png')) 226 | 227 | # ======================= (A) Text Embedding Optimization =================================== 228 | print('########### Step 1 - Optimise the embedding ###########') 229 | emb.requires_grad = True 230 | opt = torch.optim.Adam([emb], lr=args.stage1_lr) 231 | criteria = torch.nn.MSELoss() 232 | history = [] 233 | 234 | pbar = tqdm(range(args.stage1_num_iter)) 235 | for i in pbar: 236 | opt.zero_grad() 237 | 238 | if args.set_random_seed: 239 | torch.seed() 240 | noise = torch.randn_like(init_latent) 241 | t_enc = torch.randint(1000, (1, ), device=device) 242 | z = model.q_sample(init_latent, t_enc, noise=noise) 243 | 244 | pred_noise = model.apply_model(z, t_enc, emb) 245 | 246 | loss = criteria(pred_noise, noise) 247 | loss.backward() 248 | pbar.set_postfix({"loss": loss.item()}) 249 | history.append(loss.item()) 250 | opt.step() 251 | 252 | plt.plot(history) 253 | plt.show() 254 | torch.save(emb, os.path.join(save_dir, 'emb_opt.pt')) 255 | emb_opt = torch.load(os.path.join(save_dir, 'emb_opt.pt')) # [1, 77, 640] 256 | 257 | torch.manual_seed(args.seed) 258 | img = quick_sample(emb_opt, scale, start_code) 259 | img.save(os.path.join(save_dir, 'A_end_optText_origDM.png')) 260 | 261 | # Interpolate the embedding 262 | for idx, alpha in enumerate(args.alpha_list): 263 | print(f'alpha={alpha}') 264 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt 265 | torch.manual_seed(args.seed) 266 | img = quick_sample(new_emb, scale, start_code) 267 | img.save( 268 | os.path.join( 269 | save_dir, 270 | f'0A_interText_origDM_{idx}_alpha={round(alpha,3)}.png')) 271 | 272 | # ======================= (B) Model Fine-Tuning =================================== 273 | print('########### Step 2 - Fine tune the model ###########') 274 | emb_opt.requires_grad = False 275 | model.train() 276 | 277 | opt = torch.optim.Adam(model.model.parameters(), lr=args.stage2_lr) 278 | criteria = torch.nn.MSELoss() 279 | history = [] 280 | 281 | pbar = tqdm(range(args.stage2_num_iter)) 282 | for i in pbar: 283 | opt.zero_grad() 284 | 285 | if args.set_random_seed: 286 | torch.seed() 287 | noise = torch.randn_like(init_latent) 288 | t_enc = torch.randint(model.num_timesteps, (1, ), device=device) 289 | z = model.q_sample(init_latent, t_enc, noise=noise) 290 | 291 | pred_noise = model.apply_model(z, t_enc, emb_opt) 292 | 293 | loss = criteria(pred_noise, noise) 294 | loss.backward() 295 | pbar.set_postfix({"loss": loss.item()}) 296 | history.append(loss.item()) 297 | opt.step() 298 | 299 | model.eval() 300 | plt.plot(history) 301 | plt.show() 302 | torch.manual_seed(args.seed) 303 | img = quick_sample(emb_opt, scale, start_code) 304 | img.save(os.path.join(save_dir, 'B_end_optText_optDM.png')) 305 | # Should look like the original image 306 | 307 | if args.save_checkpoint: 308 | ckpt = { 309 | "state_dict": model.state_dict(), 310 | } 311 | ckpt_path = os.path.join(save_dir, 'optDM.ckpt') 312 | print(f'Saving optDM to {ckpt_path}') 313 | torch.save(ckpt, ckpt_path) 314 | 315 | # ======================= (C) Generation =================================== 316 | print('########### Step 3 - Generate images ###########') 317 | # Interpolate the embedding 318 | for idx, alpha in enumerate(args.alpha_list): 319 | print(f'alpha={alpha}') 320 | new_emb = alpha * emb_tgt + (1 - alpha) * emb_opt 321 | torch.manual_seed(args.seed) 322 | img = quick_sample(new_emb, scale, start_code) 323 | img.save( 324 | os.path.join( 325 | save_dir, 326 | f'0C_interText_optDM_{idx}_alpha={round(alpha,3)}.png')) 327 | 328 | print('Done') -------------------------------------------------------------------------------- /ldm/models/diffusion/compose_modules.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | import math 4 | from typing import Iterable 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from ldm.modules.diffusionmodules.util import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | from ldm.modules.attention import SpatialTransformer 21 | 22 | from ldm.modules.diffusionmodules.openaimodel import * 23 | from omegaconf import OmegaConf 24 | 25 | from ldm.models.diffusion.ddpm import DiffusionWrapper 26 | import random 27 | 28 | class ComposeUNet(nn.Module): 29 | """ 30 | One diffusion denoising step (using dynamic diffusers) 31 | """ 32 | 33 | def __init__(self, 34 | return_confidence_map=False, 35 | confidence_conditioning_key='crossattn', 36 | confidence_map_predictor_config=None, 37 | seg_mask_scale_factor = None, 38 | seg_mask_schedule = None, 39 | softmax_twice = False, 40 | boost_factor = 1.0, 41 | manual_prob = 1.0, 42 | return_each_branch = False, 43 | confidence_input = 'unet_output', 44 | conditions= ['seg_mask', 'text'] 45 | ): 46 | super().__init__() 47 | 48 | self.conditions = conditions 49 | self.return_confidence_map = return_confidence_map 50 | 51 | # define dynamic diffusers 52 | if 'seg_mask' in self.conditions: 53 | self.seg_mask_confidence_predictor = DiffusionWrapper(confidence_map_predictor_config, confidence_conditioning_key) 54 | if 'text' in self.conditions: 55 | self.text_confidence_predictor = DiffusionWrapper(confidence_map_predictor_config, confidence_conditioning_key) 56 | if 'sketch' in self.conditions: 57 | self.sketch_confidence_predictor = DiffusionWrapper(confidence_map_predictor_config, confidence_conditioning_key) 58 | 59 | self.seg_mask_scale_factor = seg_mask_scale_factor #/ (seg_mask_scale_factor + text_scale_factor) 60 | self.seg_mask_schedule = seg_mask_schedule 61 | self.softmax_twice = softmax_twice 62 | self.boost_factor = boost_factor 63 | self.manual_prob = manual_prob 64 | self.return_each_branch = return_each_branch 65 | 66 | self.confidence_input = confidence_input 67 | 68 | def set_seg_mask_schedule(self, t): 69 | t = t[0] 70 | if self.seg_mask_schedule == None: 71 | schedule_scale = 1.0 72 | elif self.seg_mask_schedule == 'linear_decay': 73 | schedule_scale = t / 1000 74 | elif self.seg_mask_schedule == 'cosine_decay': 75 | pi = torch.acos(torch.zeros(1)).item() * 2 76 | schedule_scale = (torch.cos( torch.tensor((1-(t/1000)) * (pi)))+1)/2 77 | else: 78 | raise NotImplementedError 79 | return schedule_scale 80 | 81 | 82 | def forward(self, x, t, cond): 83 | """ 84 | One diffusion denoising step (using dynamic diffusers) 85 | 86 | input: 87 | - x: noisy image x_t 88 | - t: timestep 89 | - cond = {'seg_mask': tensor, 'text': tensor, '...': ...} 90 | output: 91 | x_t-1 92 | """ 93 | 94 | # compute individual branch's output using pretrained diffusion models 95 | if 'seg_mask' in self.conditions: 96 | seg_mask_unet_output = self.seg_mask_unet(x=x, t=t, c_crossattn=[cond['seg_mask']]) # [B, 3, 64, 64] 97 | if 'text' in self.conditions: 98 | text_unet_output = self.text_unet(x=x, t=t, c_crossattn=[cond['text']]) # [B, 3, 64, 64] 99 | 100 | if 'sketch' in self.conditions: 101 | sketch_unet_output = self.sketch_unet(x=x, t=t, c_crossattn=[cond['sketch']]) # [B, 3, 64, 64] 102 | 103 | # compute influence function for each branch using a dynamic diffuser for each branch 104 | if self.confidence_input == 'unet_output': 105 | if 'seg_mask' in self.conditions: 106 | seg_mask_confidence_map = self.seg_mask_confidence_predictor(x=seg_mask_unet_output, t=t, c_crossattn=[cond['seg_mask']]) # [B, 1, 64, 64] 107 | if 'text' in self.conditions: 108 | text_confidence_map = self.text_confidence_predictor(x=text_unet_output, t=t, c_crossattn=[cond['text']]) # [B, 1, 64, 64] 109 | if 'sketch' in self.conditions: 110 | sketch_confidence_map = self.sketch_confidence_predictor(x=sketch_unet_output, t=t, c_crossattn=[cond['sketch']]) # [B, 1, 64, 64] 111 | print('sketch forward') 112 | 113 | elif self.confidence_input == 'x_t': 114 | if 'seg_mask' in self.conditions: 115 | seg_mask_confidence_map = self.seg_mask_confidence_predictor(x=x, t=t, c_crossattn=[cond['seg_mask']]) # [B, 1, 64, 64] 116 | if 'text' in self.conditions: 117 | text_confidence_map = self.text_confidence_predictor(x=x, t=t, c_crossattn=[cond['text']]) # [B, 1, 64, 64] 118 | if 'sketch' in self.conditions: 119 | sketch_confidence_map = self.sketch_confidence_predictor(x=x, t=t, c_crossattn=[cond['sketch']]) # [B, 1, 64, 64] 120 | 121 | else: 122 | raise NotImplementedError 123 | 124 | 125 | # Use softmax to normalize the influence functions across all branches 126 | if ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' not in self.conditions): 127 | concat_map = torch.cat([seg_mask_confidence_map, text_confidence_map], dim=1) # first mask, then text # [B, 2, 64, 64] 128 | softmax_map = F.softmax(input=concat_map, dim=1) # [B, 2, 64, 64] 129 | seg_mask_confidence_map = softmax_map[:,0,:,:].unsqueeze(1) # [B, 1, 64, 64] 130 | text_confidence_map = softmax_map[:,1,:,:].unsqueeze(1) # [B, 1, 64, 64] 131 | elif ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' in self.conditions): 132 | concat_map = torch.cat([seg_mask_confidence_map, text_confidence_map, sketch_confidence_map], dim=1) # first mask, then text # [B, 3, 64, 64] 133 | softmax_map = F.softmax(input=concat_map, dim=1) # [B, 3, 64, 64] 134 | seg_mask_confidence_map = softmax_map[:,0,:,:].unsqueeze(1) # [B, 1, 64, 64] 135 | text_confidence_map = softmax_map[:,1,:,:].unsqueeze(1) # [B, 1, 64, 64] 136 | sketch_confidence_map = softmax_map[:,2,:,:].unsqueeze(1) # [B, 1, 64, 64] 137 | else: 138 | raise NotImplementedError 139 | 140 | if random.random() <= self.manual_prob: 141 | 142 | if self.seg_mask_schedule is not None: 143 | seg_mask_schedule_scale = self.set_seg_mask_schedule(t) 144 | if 'sketch' not in self.conditions: 145 | seg_mask_confidence_map = seg_mask_confidence_map * seg_mask_schedule_scale 146 | text_confidence_map = 1 - seg_mask_confidence_map 147 | else: 148 | seg_mask_confidence_map = seg_mask_confidence_map * seg_mask_schedule_scale 149 | sketch_confidence_map = sketch_confidence_map * seg_mask_schedule_scale 150 | text_confidence_map = 1 - seg_mask_confidence_map - sketch_confidence_map 151 | 152 | if self.seg_mask_scale_factor is not None: 153 | if 'sketch' not in self.conditions: 154 | seg_mask_confidence_map = seg_mask_confidence_map * self.seg_mask_scale_factor 155 | sum_map = text_confidence_map + seg_mask_confidence_map 156 | seg_mask_confidence_map = seg_mask_confidence_map / sum_map 157 | text_confidence_map = text_confidence_map / sum_map 158 | else: 159 | seg_mask_confidence_map = seg_mask_confidence_map * self.seg_mask_scale_factor 160 | sketch_confidence_map = sketch_confidence_map * self.seg_mask_scale_factor 161 | sum_map = text_confidence_map + seg_mask_confidence_map + sketch_confidence_map 162 | seg_mask_confidence_map = seg_mask_confidence_map / sum_map 163 | text_confidence_map = text_confidence_map / sum_map 164 | sketch_confidence_map = sketch_confidence_map / sum_map 165 | 166 | if self.softmax_twice: 167 | assert ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' not in self.conditions), "softmax_twice is only implemented for two-modal controls" 168 | print(f'softmax_twice self.boost_factor={self.boost_factor}') 169 | concat_map = torch.cat([seg_mask_confidence_map, text_confidence_map], dim=1) * self.boost_factor # first mask, then text # [B, 2, 64, 64] 170 | softmax_map = F.softmax(input=concat_map, dim=1) # [B, 2, 64, 64] 171 | seg_mask_confidence_map = softmax_map[:,0,:,:].unsqueeze(1) # [B, 1, 64, 64] 172 | text_confidence_map = softmax_map[:,1,:,:].unsqueeze(1) # [B, 1, 64, 64] 173 | 174 | 175 | # Compute weighted sum of all branch'es output 176 | if ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' not in self.conditions): 177 | seg_mask_weighted = seg_mask_unet_output * seg_mask_confidence_map # [B, 3, 64, 64] 178 | text_weighted = text_unet_output * text_confidence_map # [B, 3, 64, 64] 179 | output = text_weighted + seg_mask_weighted # [B, 3, 64, 64] 180 | 181 | if self.return_confidence_map: 182 | if self.return_each_branch: 183 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map, 'text_unet_output': text_unet_output, 'seg_mask_unet_output': seg_mask_unet_output} 184 | else: 185 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map} 186 | elif self.return_each_branch: 187 | return {'output': output, 'text_unet_output': text_unet_output, 'seg_mask_unet_output': seg_mask_unet_output} 188 | else: 189 | return output 190 | elif ('seg_mask' in self.conditions) and ('text' in self.conditions) and ('sketch' in self.conditions): 191 | seg_mask_weighted = seg_mask_unet_output * seg_mask_confidence_map # [B, 3, 64, 64] 192 | text_weighted = text_unet_output * text_confidence_map # [B, 3, 64, 64] 193 | sketch_weighted = sketch_unet_output * sketch_confidence_map # [B, 3, 64, 64] 194 | output = text_weighted + seg_mask_weighted + sketch_weighted # [B, 3, 64, 64] 195 | 196 | if self.return_confidence_map: 197 | if self.return_each_branch: 198 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map, 'sketch_confidence_map': sketch_confidence_map,'seg_mask_unet_output': seg_mask_unet_output, 'text_unet_output': text_unet_output, 'sketch_unet_output': sketch_unet_output } 199 | else: 200 | return {'output': output, 'seg_mask_confidence_map': seg_mask_confidence_map, 'text_confidence_map': text_confidence_map, 'sketch_confidence_map': sketch_confidence_map} 201 | elif self.return_each_branch: 202 | return {'output': output, 'text_unet_output': text_unet_output, 'seg_mask_unet_output': seg_mask_unet_output, 'sketch_unet_output': sketch_unet_output} 203 | else: 204 | return output 205 | 206 | 207 | class ComposeCondStageModel(nn.Module): 208 | """ 209 | Condition Encoder of Multi-Modal Conditions 210 | """ 211 | 212 | def __init__(self,conditions= ['seg_mask', 'text']): 213 | super().__init__() 214 | self.conditions = conditions 215 | 216 | 217 | def forward(self, input): 218 | 219 | composed_cond = {} 220 | 221 | if 'seg_mask' in self.conditions: 222 | seg_mask_output = self.seg_mask_cond_stage_model(input['seg_mask']) 223 | composed_cond['seg_mask'] = seg_mask_output 224 | if 'text' in self.conditions: 225 | text_output = self.text_cond_stage_model(input['text']) 226 | composed_cond['text'] = text_output 227 | if 'sketch' in self.conditions: 228 | sketch_output = self.sketch_cond_stage_model(input['sketch']) 229 | composed_cond['sketch'] = sketch_output 230 | 231 | return composed_cond 232 | 233 | def encode(self, input): 234 | return self(input) 235 | --------------------------------------------------------------------------------