├── .gitignore ├── LICENSE ├── README.md ├── configs ├── ldm │ ├── coco_sg2im_ldm_Layout2I_vqgan_f8.yaml │ └── coco_stuff_ldm_T2I_vqgan_f8.yaml └── vqgan │ └── coco_vqgan_f8.yaml ├── environment.yaml ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ └── ddpm.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── setup.py ├── taming ├── data │ ├── annotated_objects_coco.py │ ├── annotated_objects_dataset.py │ ├── base.py │ ├── coco.py │ ├── conditional_builder │ │ ├── objects_bbox.py │ │ ├── objects_center_points.py │ │ └── utils.py │ ├── custom.py │ ├── helper_types.py │ ├── image_transforms.py │ ├── open_images_helper.py │ ├── sflckr.py │ └── utils.py ├── lr_scheduler.py ├── models │ ├── cond_transformer.py │ ├── dummy_cond_stage.py │ └── vqgan.py ├── modules │ ├── diffusionmodules │ │ └── model.py │ ├── discriminator │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── segmentation.py │ │ ├── soft_cross_entropy.py │ │ └── vqperceptual.py │ ├── misc │ │ ├── coord.py │ │ └── pos_embed.py │ ├── transformer │ │ ├── mingpt.py │ │ └── permuter.py │ ├── util.py │ └── vqvae │ │ ├── mapping.py │ │ └── quantize.py └── util.py └── tools ├── download_datasets.sh ├── download_models.sh ├── ldm ├── train_ldm_coco_Layout2I.sh └── train_ldm_coco_T2I.sh └── vqgan └── train_vqgan_coco.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # amlt 2 | .amltconfig 3 | 4 | # original 5 | *.swp 6 | *.pt 7 | *.pth 8 | *.ckpt 9 | *.txt 10 | *tfevents* 11 | *.json 12 | *.png 13 | **/__pycache__/** 14 | .dumbo.json 15 | checkpoints/ 16 | checkpoint/ 17 | .idea/* 18 | **/.ipynb_checkpoints/** 19 | run.sh 20 | output/ 21 | output_dir/ 22 | output_aml/ 23 | weights/ 24 | logs/* 25 | exp/* 26 | src/* 27 | 28 | ## Deep speed 29 | *.pyc 30 | *.idea/ 31 | *~ 32 | *.swp 33 | *.log 34 | *deepspeed/git_version_info_installed.py 35 | 36 | # Build + installation data 37 | *build/ 38 | *dist/ 39 | *.so 40 | *deepspeed.egg-info/ 41 | *build.txt 42 | 43 | # Website 44 | *docs/_site/ 45 | *docs/build 46 | *docs/code-docs/source/_build 47 | *docs/code-docs/_build 48 | *docs/code-docs/build 49 | *.sass-cache/ 50 | *.jekyll-cache/ 51 | *.jekyll-metadata 52 | 53 | # Testing data 54 | *tests/unit/saved_checkpoint/ 55 | 56 | # Dev/IDE data 57 | *.vscode 58 | *.theia 59 | 60 | *cache* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ChirsFan0312 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Latent Diffusion Model (LDM) for Layout-to-image generation 2 | 3 | --- 4 | This is the non-official repository of [LDM](https://arxiv.org/abs/2112.10752) for layout-to-image generation. Currently, the config and code in official LDM repo is incompleted. Thus, the repo aims to reproduce LDM on Layout-to-image generation task. If you find it useful, please cite their original paper [LDM](https://arxiv.org/abs/2112.10752). 5 | 6 | --- 7 | ## Machine environment 8 | - Ubuntu version: 18.04.5 LTS 9 | - CUDA version: 11.6 10 | - Testing GPU: Nvidia Tesla V100 11 | --- 12 | 13 | ## Requirements 14 | A [conda](https://conda.io/) environment named `ldm_layout` can be created and activated with: 15 | 16 | ```bash 17 | conda env create -f environment.yaml 18 | conda activate ldm_layout 19 | ``` 20 | --- 21 | 22 | ## Datasets setup 23 | We provide two approaches to set up the datasets: 24 | ### Auto-download 25 | To automatically download datasets and save it into the default path (`../`), please use following script: 26 | ```bash 27 | bash tools/download_datasets.sh 28 | ``` 29 | ### Manual setup 30 | 31 | #### Text-to-image generation 32 | - We use COCO 2014 splits for text-to-image task, which can be downloaded from [official COCO website](https://cocodataset.org/#download). 33 | - Please create a folder name `2014` and collect the downloaded data and annotations as follows. 34 | 35 |
COCO 2014 file structure 36 | 37 | ``` 38 | >2014 39 | ├── annotations 40 | │ └── captions_val2014.json 41 | │ └── ... 42 | └── val2014 43 | └── COCO_val2014_000000000073.jpg 44 | └── ... 45 | ``` 46 | 47 |
48 | 49 | 50 | #### Layout-to-image generation 51 | - We use COCO 2017 splits to test Frido on layout-to-image task, which can be downloaded from [official COCO website](https://cocodataset.org/#download). 52 | - Please create a folder name `2017` and collect the downloaded data and annotations as follows. 53 | 54 |
COCO 2017 file structure 55 | 56 | ``` 57 | >2017 58 | ├── annotations 59 | │ └── captions_val2017.json 60 | │ └── ... 61 | └── val2017 62 | └── 000000000872.jpg 63 | └── ... 64 | ``` 65 | 66 |
67 | 68 | 69 | #### File structure for dataset and code 70 | Please make sure that the file structure is the same as the following. Or, you might modify the config file to match the corresponding paths. 71 | 72 |
File structure 73 | 74 | ``` 75 | >datasets 76 | ├── coco 77 | │ └── 2014 78 | │ └── annotations 79 | │ └── val2014 80 | │ └── ... 81 | │ └── 2017 82 | │ └── annotations 83 | │ └── val2017 84 | │ └── ... 85 | >ldm_layout 86 | └── configs 87 | │ └── ldm 88 | │ └── ... 89 | └── exp 90 | │ └── ... 91 | └── ldm 92 | └── taming 93 | └── scripts 94 | └── tools 95 | └── ... 96 | ``` 97 | 98 |
99 | 100 | --- 101 | 102 | 103 | ## VQGAN models setup 104 | We provide script to download VQGAN-f8 in [LDM github](https://github.com/CompVis/latent-diffusion): 105 | 106 | To automatically download VQGAN-f8 and save it into the default path (`exp/`), please use following script: 107 | ```bash 108 | bash tools/download_models.sh 109 | ``` 110 | 111 | ## Train LDM for layout-to-image generation 112 | We now provide scripts for training LDM on text-to-image and layout-to-image. 113 | 114 | Once the datasets are properly set up, one may train LDM by the following commands. 115 | ### Text-to-image 116 | ```bash 117 | bash tools/ldm/train_ldm_coco_T2I.sh 118 | ``` 119 | - Default output folder will be `exp/ldm/T2I` 120 | ### Layout-to-image 121 | 122 | ```bash 123 | bash tools/ldm/train_ldm_coco_Layout2I.sh 124 | ``` 125 | - Default output folder will be `exp/ldm/Layout2I` 126 | 127 | ### Multi-GPU testing 128 | 129 | Change "--gpus" to identify the number of GPUs for training. 130 | 131 | For example, using 4 gpus 132 | ```bash 133 | 134 | python main.py --base configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml \ 135 | -t True --gpus 0,1,2,3 -log_dir ./exp/ldm/Layout2I \ 136 | -n coco_sg2im_ldm_Layout2I_vqgan_f8 --scale_lr False -tb True 137 | ``` 138 | 139 | --- 140 | 141 | ## Inference 142 | 143 | Change "-t" to identify training or testing phase. 144 | (Note that multi-gpu testing is supported.) 145 | 146 | For example, using 4 gpus for testing 147 | ```bash 148 | 149 | python main.py --base configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml \ 150 | -t False --gpus 0,1,2,3 -log_dir ./exp/ldm/Layout2I \ 151 | -n coco_sg2im_ldm_Layout2I_vqgan_f8 --scale_lr False -tb True 152 | ``` 153 | 154 | ## Acknowledgement 155 | We build LDM_layout codebase heavily on the codebase of [Latent Diffusion Model (LDM)](https://github.com/CompVis/latent-diffusion) and [VQGAN](https://github.com/CompVis/taming-transformers). We sincerely thank the authors for open-sourcing! 156 | 157 | ## Citation 158 | If you find this code useful for your research, please consider citing: 159 | ```bibtex 160 | @misc{rombach2021highresolution, 161 | title={High-Resolution Image Synthesis with Latent Diffusion Models}, 162 | author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, 163 | year={2021}, 164 | eprint={2112.10752}, 165 | archivePrefix={arXiv}, 166 | primaryClass={cs.CV} 167 | } 168 | 169 | @misc{https://doi.org/10.48550/arxiv.2204.11824, 170 | doi = {10.48550/ARXIV.2204.11824}, 171 | url = {https://arxiv.org/abs/2204.11824}, 172 | author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and Ommer, Björn}, 173 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 174 | title = {Retrieval-Augmented Diffusion Models}, 175 | publisher = {arXiv}, 176 | year = {2022}, 177 | copyright = {arXiv.org perpetual, non-exclusive license} 178 | } 179 | 180 | ``` 181 | -------------------------------------------------------------------------------- /configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.e-6 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | first_stage_key: image 6 | cond_stage_key: objects_bbox 7 | linear_start: 0.0015 8 | linear_end: 0.0205 9 | num_timesteps_cond: 1 10 | log_every_t: 20 11 | timesteps: 1000 12 | loss_type: l1 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: true 16 | conditioning_key: crossattn 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 32 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 4 35 | num_head_channels: 32 36 | use_spatial_transformer: true 37 | transformer_depth: 2 38 | context_dim: 512 39 | 40 | first_stage_config: 41 | target: taming.models.vqgan.VQModelInterface 42 | params: 43 | ckpt_path: exp/vqgan/vq-f8/model.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/ 44 | embed_dim: 4 45 | n_embed: 16384 46 | ddconfig: 47 | double_z: False 48 | z_channels: 4 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 32 61 | dropout: 0.0 62 | vitconfig: 63 | embed_size: 256 64 | lossconfig: 65 | target: taming.modules.losses.DummyLoss 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 512 71 | n_layer: 16 72 | vocab_size: 16384 73 | max_seq_len: 92 74 | use_tokenizer: False 75 | 76 | plot_sample: False 77 | plot_inpaint: False 78 | plot_denoise_rows: False 79 | plot_progressive_rows: False 80 | plot_diffusion_rows: False 81 | plot_quantize_denoised: True 82 | 83 | data: 84 | target: main.DataModuleFromConfig 85 | params: 86 | batch_size: 4 87 | train: 88 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 89 | params: 90 | data_path: ../datasets/coco/2017 # substitute with path to full dataset 91 | img_id_file: ../datasets/coco/2017/annotations/deprecated-challenge2017/train-ids.txt 92 | split: train 93 | keys: [image, objects_bbox, file_name, annotations] 94 | no_tokens: 1024 95 | target_image_size: 256 96 | min_object_area: 0.02 97 | min_objects_per_image: 3 98 | max_objects_per_image: 8 99 | crop_method: center 100 | random_flip: True 101 | use_group_parameter: true 102 | encode_crop: true 103 | validation: 104 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 105 | params: 106 | data_path: ../datasets/coco/2017 # substitute with path to full dataset 107 | img_id_file: ../datasets/coco/2017/annotations/deprecated-challenge2017/val-ids.txt 108 | split: validation 109 | keys: [image, objects_bbox, file_name, annotations] 110 | no_tokens: 1024 111 | target_image_size: 256 112 | min_object_area: 0.02 113 | min_objects_per_image: 3 114 | max_objects_per_image: 8 115 | crop_method: center 116 | random_flip: false 117 | use_group_parameter: true 118 | encode_crop: true 119 | test: 120 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 121 | params: 122 | data_path: ../datasets/coco/2017 # substitute with path to full dataset 123 | img_id_file: ../datasets/coco/2017/annotations/deprecated-challenge2017/val-ids.txt 124 | split: validation 125 | keys: [image, objects_bbox, file_name, annotations] 126 | no_tokens: 1024 127 | target_image_size: 256 128 | min_object_area: 0.02 129 | min_objects_per_image: 3 130 | max_objects_per_image: 8 131 | crop_method: center 132 | random_flip: false 133 | use_group_parameter: true 134 | encode_crop: true 135 | 136 | lightning: 137 | callbacks: 138 | image_logger: 139 | target: main.ImageLogger 140 | params: 141 | batch_frequency: 1000 142 | max_images: 99 143 | increase_log_steps: False 144 | 145 | trainer: 146 | benchmark: True 147 | max_epochs: 300 148 | 149 | -------------------------------------------------------------------------------- /configs/ldm/coco_stuff_ldm_T2I_vqgan_f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.e-6 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | first_stage_key: image 6 | cond_stage_key: caption 7 | linear_start: 0.0015 8 | linear_end: 0.0155 9 | num_timesteps_cond: 1 10 | log_every_t: 100 11 | timesteps: 1000 12 | loss_type: l1 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: true 16 | conditioning_key: crossattn 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 32 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_head_channels: 32 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 640 40 | 41 | first_stage_config: 42 | target: taming.models.vqgan.VQModelInterface 43 | params: 44 | ckpt_path: exp/vqgan/vq-f8/model.ckpt 45 | embed_dim: 4 46 | n_embed: 16384 47 | ddconfig: 48 | double_z: False 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | vitconfig: 64 | embed_size: 256 65 | lossconfig: 66 | target: taming.modules.losses.DummyLoss 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.BERTEmbedder 70 | params: 71 | n_embed: 640 72 | n_layer: 32 73 | 74 | plot_sample: False 75 | plot_inpaint: False 76 | plot_denoise_rows: False 77 | plot_progressive_rows: False 78 | plot_diffusion_rows: False 79 | plot_quantize_denoised: True 80 | 81 | data: 82 | target: main.DataModuleFromConfig 83 | params: 84 | batch_size: 4 85 | train: 86 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 87 | params: 88 | data_path: ../datasets/coco/2014 # substitute with path to full dataset 89 | caption_ann_path: ../datasets/coco/2014/annotations/captions_train2014.json 90 | use_stuff: False 91 | split: train 92 | keys: [image, caption, file_name, annotations] 93 | no_tokens: 1024 94 | target_image_size: 256 95 | min_object_area: 0.00001 96 | min_objects_per_image: 2 97 | max_objects_per_image: 30 98 | crop_method: random-1d 99 | random_flip: true 100 | use_group_parameter: true 101 | encode_crop: False 102 | validation: 103 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 104 | params: 105 | data_path: ../datasets/coco/2014 # substitute with path to full dataset 106 | caption_ann_path: ../datasets/coco/2014/annotations/captions_val2014.json 107 | use_stuff: False 108 | split: validation 109 | keys: [image, caption, file_name, annotations] 110 | no_tokens: 1024 111 | target_image_size: 256 112 | min_object_area: 0.00001 113 | min_objects_per_image: 2 114 | max_objects_per_image: 30 115 | crop_method: center 116 | random_flip: false 117 | use_group_parameter: true 118 | encode_crop: False 119 | test: 120 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 121 | params: 122 | data_path: ../datasets/coco/2014 # substitute with path to full dataset 123 | caption_ann_path: ../datasets/coco/2014/annotations/captions_val2014.json 124 | use_stuff: False 125 | split: validation 126 | keys: [image, objects, caption, file_name, annotations] 127 | no_tokens: 1024 128 | target_image_size: 256 129 | min_object_area: 0.00001 130 | min_objects_per_image: 2 131 | max_objects_per_image: 30 132 | crop_method: center 133 | random_flip: false 134 | use_group_parameter: true 135 | encode_crop: false 136 | 137 | lightning: 138 | callbacks: 139 | image_logger: 140 | target: main.ImageLogger 141 | params: 142 | batch_frequency: 1000 143 | max_images: 99 144 | increase_log_steps: False 145 | 146 | trainer: 147 | benchmark: True 148 | max_epochs: 300 149 | 150 | -------------------------------------------------------------------------------- /configs/vqgan/coco_vqgan_f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [32] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_num_layers: 2 29 | disc_start: 1 30 | disc_weight: 0.6 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 4 37 | num_workers: 24 38 | train: 39 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 40 | params: 41 | data_path: ../datasets/coco/2017 # substitute with path to full dataset 42 | split: train 43 | keys: [image, objects_bbox, file_name, annotations] 44 | no_tokens: 1024 45 | target_image_size: 256 46 | min_object_area: 0.00001 47 | min_objects_per_image: 2 48 | max_objects_per_image: 30 49 | crop_method: random-1d 50 | random_flip: true 51 | use_group_parameter: true 52 | encode_crop: true 53 | validation: 54 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 55 | params: 56 | data_path: ../datasets/coco/2017 # substitute with path to full dataset 57 | split: validation 58 | keys: [image, objects_bbox, file_name, annotations] 59 | no_tokens: 1024 60 | target_image_size: 256 61 | min_object_area: 0.00001 62 | min_objects_per_image: 2 63 | max_objects_per_image: 30 64 | crop_method: center 65 | random_flip: false 66 | use_group_parameter: true 67 | encode_crop: true 68 | test: 69 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 70 | params: 71 | data_path: ../datasets/coco/2017 # substitute with path to full dataset 72 | split: validation 73 | keys: [image, objects_bbox, file_name, annotations] 74 | no_tokens: 1024 75 | target_image_size: 256 76 | min_object_area: 0.0000001 77 | min_objects_per_image: 2 78 | max_objects_per_image: 30 79 | crop_method: center 80 | random_flip: false 81 | use_group_parameter: true 82 | encode_crop: true 83 | 84 | lightning: 85 | callbacks: 86 | image_logger: 87 | target: main.ImageLogger 88 | params: 89 | batch_frequency: 1000 90 | max_images: 99 91 | increase_log_steps: False 92 | 93 | lightning: 94 | trainer: 95 | max_epochs: 50 96 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm_layout 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - opencv-python==4.1.2.30 14 | - albumentations==0.4.3 15 | - einops==0.3.0 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - matplotlib==3.5.1 19 | - matplotlib-inline==0.1.3 20 | - more-itertools==8.12.0 21 | - numpy==1.19.2 22 | - omegaconf==2.0.0 23 | - opencv-python-headless==4.1.2.30 24 | - pandas==1.4.1 25 | - Pillow==9.0.1 26 | - pudb==2019.2 27 | - Pygments==2.11.2 28 | - Pympler==1.0.1 29 | - python-dateutil==2.8.2 30 | - torch-fidelity==0.3.0 31 | - pytorch-lightning==1.0.8 32 | - PyYAML==6.0 33 | - requests==2.27.1 34 | - requests-oauthlib==1.3.1 35 | - scikit-image==0.19.2 36 | - scipy==1.8.0 37 | - setuptools==58.0.4 38 | - six==1.16.0 39 | - streamlit==1.7.0 40 | - terminado==0.13.3 41 | - test-tube==0.7.5 42 | - timm==0.4.5 43 | - tokenizers==0.10.3 44 | - tornado==6.1 45 | - tqdm==4.63.0 46 | - transformers==4.3.1 47 | - typing-extensions==3.10.0.2 48 | - urllib3==1.26.9 49 | - -e . -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | 14 | def uniq(arr): 15 | return{el: True for el in arr}.keys() 16 | 17 | 18 | def default(val, d): 19 | if exists(val): 20 | return val 21 | return d() if isfunction(d) else d 22 | 23 | 24 | def max_neg_value(t): 25 | return -torch.finfo(t.dtype).max 26 | 27 | 28 | def init_(tensor): 29 | dim = tensor.shape[-1] 30 | std = 1 / math.sqrt(dim) 31 | tensor.uniform_(-std, std) 32 | return tensor 33 | 34 | 35 | # feedforward 36 | class GEGLU(nn.Module): 37 | def __init__(self, dim_in, dim_out): 38 | super().__init__() 39 | self.proj = nn.Linear(dim_in, dim_out * 2) 40 | 41 | def forward(self, x): 42 | x, gate = self.proj(x).chunk(2, dim=-1) 43 | return x * F.gelu(gate) 44 | 45 | 46 | class FeedForward(nn.Module): 47 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 48 | super().__init__() 49 | inner_dim = int(dim * mult) 50 | dim_out = default(dim_out, dim) 51 | project_in = nn.Sequential( 52 | nn.Linear(dim, inner_dim), 53 | nn.GELU() 54 | ) if not glu else GEGLU(dim, inner_dim) 55 | 56 | self.net = nn.Sequential( 57 | project_in, 58 | nn.Dropout(dropout), 59 | nn.Linear(inner_dim, dim_out) 60 | ) 61 | 62 | def forward(self, x): 63 | return self.net(x) 64 | 65 | 66 | def zero_module(module): 67 | """ 68 | Zero out the parameters of a module and return it. 69 | """ 70 | for p in module.parameters(): 71 | p.detach().zero_() 72 | return module 73 | 74 | 75 | def Normalize(in_channels): 76 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 77 | 78 | 79 | class LinearAttention(nn.Module): 80 | def __init__(self, dim, heads=4, dim_head=32): 81 | super().__init__() 82 | self.heads = heads 83 | hidden_dim = dim_head * heads 84 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 85 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 86 | 87 | def forward(self, x): 88 | b, c, h, w = x.shape 89 | qkv = self.to_qkv(x) 90 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 91 | k = k.softmax(dim=-1) 92 | context = torch.einsum('bhdn,bhen->bhde', k, v) 93 | out = torch.einsum('bhde,bhdn->bhen', context, q) 94 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 95 | return self.to_out(out) 96 | 97 | 98 | class SpatialSelfAttention(nn.Module): 99 | def __init__(self, in_channels): 100 | super().__init__() 101 | self.in_channels = in_channels 102 | 103 | self.norm = Normalize(in_channels) 104 | self.q = torch.nn.Conv2d(in_channels, 105 | in_channels, 106 | kernel_size=1, 107 | stride=1, 108 | padding=0) 109 | self.k = torch.nn.Conv2d(in_channels, 110 | in_channels, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0) 114 | self.v = torch.nn.Conv2d(in_channels, 115 | in_channels, 116 | kernel_size=1, 117 | stride=1, 118 | padding=0) 119 | self.proj_out = torch.nn.Conv2d(in_channels, 120 | in_channels, 121 | kernel_size=1, 122 | stride=1, 123 | padding=0) 124 | 125 | def forward(self, x): 126 | h_ = x 127 | h_ = self.norm(h_) 128 | q = self.q(h_) 129 | k = self.k(h_) 130 | v = self.v(h_) 131 | 132 | # compute attention 133 | b,c,h,w = q.shape 134 | q = rearrange(q, 'b c h w -> b (h w) c') 135 | k = rearrange(k, 'b c h w -> b c (h w)') 136 | w_ = torch.einsum('bij,bjk->bik', q, k) 137 | 138 | w_ = w_ * (int(c)**(-0.5)) 139 | w_ = torch.nn.functional.softmax(w_, dim=2) 140 | 141 | # attend to values 142 | v = rearrange(v, 'b c h w -> b c (h w)') 143 | w_ = rearrange(w_, 'b i j -> b j i') 144 | h_ = torch.einsum('bij,bjk->bik', v, w_) 145 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 146 | h_ = self.proj_out(h_) 147 | 148 | return x+h_ 149 | 150 | 151 | class CrossAttention(nn.Module): 152 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 153 | super().__init__() 154 | inner_dim = dim_head * heads 155 | context_dim = default(context_dim, query_dim) 156 | 157 | self.scale = dim_head ** -0.5 158 | self.heads = heads 159 | 160 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 161 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 162 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 163 | 164 | self.to_out = nn.Sequential( 165 | nn.Linear(inner_dim, query_dim), 166 | nn.Dropout(dropout) 167 | ) 168 | 169 | def forward(self, x, context=None, mask=None): 170 | h = self.heads 171 | 172 | q = self.to_q(x) 173 | context = default(context, x) 174 | k = self.to_k(context) 175 | v = self.to_v(context) 176 | 177 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 178 | 179 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 180 | 181 | if exists(mask): 182 | mask = rearrange(mask, 'b ... -> b (...)') 183 | max_neg_value = -torch.finfo(sim.dtype).max 184 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 185 | sim.masked_fill_(~mask, max_neg_value) 186 | 187 | # attention, what we cannot get enough of 188 | attn = sim.softmax(dim=-1) 189 | 190 | out = einsum('b i j, b j d -> b i d', attn, v) 191 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 192 | return self.to_out(out) 193 | 194 | 195 | class BasicTransformerBlock(nn.Module): 196 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, use_mscond=False): 197 | super().__init__() 198 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 199 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 200 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 201 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 202 | self.norm1 = nn.LayerNorm(dim) 203 | self.norm2 = nn.LayerNorm(dim) 204 | self.norm3 = nn.LayerNorm(dim) 205 | self.checkpoint = checkpoint 206 | 207 | if use_mscond: 208 | self.attn_prev = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=0.5) 209 | self.norm_prev = nn.LayerNorm(dim) 210 | self.attn_cross = CrossAttention(query_dim=dim, context_dim=dim, 211 | heads=n_heads, dim_head=d_head, dropout=0.2) 212 | self.norm_cross = nn.LayerNorm(dim) 213 | 214 | def forward(self, x, context=None, x_prev_stage=None): 215 | 216 | if x_prev_stage is None: 217 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 218 | else: 219 | return checkpoint(self._forward_w_prev, (x, context, x_prev_stage), self.parameters(), self.checkpoint) 220 | 221 | def _forward(self, x, context=None): 222 | x_length = x.shape[1] 223 | x_attn = self.attn1(self.norm1(x)) + x 224 | x_attn = self.attn2(self.norm2(x_attn), context=context) + x_attn 225 | x_attn = self.ff(self.norm3(x_attn)) + x_attn 226 | return x_attn 227 | 228 | def _forward_w_prev(self, x, context=None, x_prev_stage=None): 229 | 230 | x_length = x.shape[1] 231 | x_attn = self.attn1(self.norm1(x)) + x 232 | 233 | x_prev_stage = self.attn_prev(self.norm_prev(x_prev_stage)) + x_prev_stage 234 | x_attn = self.attn_cross(self.norm_cross(x_attn), context=x_prev_stage) + x_attn 235 | 236 | x_attn = self.attn2(self.norm2(x_attn), context=context) + x_attn 237 | x_attn = self.ff(self.norm3(x_attn)) + x_attn 238 | 239 | return x_attn 240 | 241 | 242 | class SpatialTransformer(nn.Module): 243 | """ 244 | Transformer block for image-like data. 245 | First, project the input (aka embedding) 246 | and reshape to b, t, d. 247 | Then apply standard transformer action. 248 | Finally, reshape to image 249 | """ 250 | def __init__(self, in_channels, channels_cond, n_heads, d_head, 251 | depth=1, dropout=0., context_dim=None, use_pos_embed=-1, use_mscond=False, mscond_dim=None): 252 | super().__init__() 253 | self.in_channels = in_channels 254 | self.use_pos_embed = use_pos_embed 255 | self.use_mscond = use_mscond 256 | 257 | inner_dim = n_heads * d_head 258 | self.norm = Normalize(in_channels) 259 | 260 | if use_pos_embed > 0: 261 | self.pos_embed = nn.Embedding(use_pos_embed, in_channels) 262 | 263 | self.proj_in = nn.Conv2d(in_channels, 264 | inner_dim, 265 | kernel_size=1, 266 | stride=1, 267 | padding=0) 268 | 269 | self.transformer_blocks = nn.ModuleList( 270 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_mscond=use_mscond) 271 | for d in range(depth)] 272 | ) 273 | 274 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 275 | in_channels, 276 | kernel_size=1, 277 | stride=1, 278 | padding=0)) 279 | 280 | if self.use_mscond: 281 | self.cond_proj_in = nn.Conv2d(mscond_dim, 282 | inner_dim, 283 | kernel_size=1, 284 | stride=1, 285 | padding=0) 286 | 287 | def forward(self, x, context=None, feat_cond=None): 288 | # note: if no context is given, cross-attention defaults to self-attention 289 | b, c, h, w = x.shape 290 | x_in = x 291 | x = self.norm(x) 292 | 293 | if feat_cond is not None and self.use_mscond: 294 | feat_cond = F.interpolate(feat_cond, size=x.size()[2:], mode='nearest') 295 | feat_cond = self.cond_proj_in(feat_cond) 296 | feat_cond = rearrange(feat_cond, 'b c h w -> b (h w) c') 297 | 298 | x = self.proj_in(x) 299 | x = rearrange(x, 'b c h w -> b (h w) c') 300 | 301 | if self.use_pos_embed > 0: 302 | pos_x = torch.arange(w) 303 | pos_y = torch.arange(h) 304 | grid_x, grid_y = torch.meshgrid(pos_x, pos_y) 305 | grid_x = grid_x.reshape(1, -1).repeat(b, 1).cuda() 306 | grid_y = grid_y.reshape(1, -1).repeat(b, 1).cuda() 307 | emb_pos_x = self.pos_embed(grid_x) 308 | emb_pos_y = self.pos_embed(grid_y) 309 | emb_pos = (emb_pos_x + emb_pos_y) / 2. 310 | x = x + emb_pos 311 | 312 | if feat_cond is not None and self.use_mscond: 313 | for block in self.transformer_blocks: 314 | x = block(x, context=context, x_prev_stage=feat_cond) 315 | else: 316 | for block in self.transformer_blocks: 317 | x = block(x, context=context) 318 | 319 | 320 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 321 | x = self.proj_out(x) 322 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay, torch.true_divide((1 + self.num_updates), (10 + self.num_updates))) 31 | # decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 32 | 33 | one_minus_decay = 1.0 - decay 34 | 35 | with torch.no_grad(): 36 | m_param = dict(model.named_parameters()) 37 | shadow_params = dict(self.named_buffers()) 38 | 39 | for key in m_param: 40 | if m_param[key].requires_grad: 41 | sname = self.m_name2s_name[key] 42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 44 | else: 45 | assert not key in self.m_name2s_name 46 | 47 | def copy_to(self, model): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | for key in m_param: 51 | if m_param[key].requires_grad: 52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def store(self, parameters): 57 | """ 58 | Save the current parameters for restoring later. 59 | Args: 60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 61 | temporarily stored. 62 | """ 63 | self.collected_params = [param.clone() for param in parameters] 64 | 65 | def restore(self, parameters): 66 | """ 67 | Restore the parameters stored with the `store` method. 68 | Useful to validate the model with EMA parameters without affecting the 69 | original optimization process. Store the parameters before the 70 | `copy_to` method. After validation (or model saving), use this to 71 | restore the former parameters. 72 | Args: 73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 74 | updated with the stored parameters. 75 | """ 76 | for c_param, param in zip(self.collected_params, parameters): 77 | param.data.copy_(c_param.data) 78 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidhalladay/ldm_layout/49d664158db0c9d51aa057494dd72b8669fe586c/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import kornia 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, multilabel=False, padding_idx=1023, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.multilabel = multilabel 26 | self.embedding = nn.Embedding(n_classes, embed_dim) 27 | 28 | def forward(self, batch, key=None): 29 | if key is None: 30 | key = self.key 31 | # this is for use in crossattn 32 | if self.multilabel: 33 | c = batch[key].cuda() 34 | c = self.embedding(c) 35 | c = c.max(-2)[0] 36 | else: 37 | c = batch[key][:, None].cuda() 38 | c = self.embedding(c) 39 | return c 40 | 41 | 42 | class TransformerEmbedder(AbstractEncoder): 43 | """Some transformer encoder layers""" 44 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 45 | super().__init__() 46 | self.device = device 47 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 48 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 49 | 50 | def forward(self, tokens): 51 | tokens = tokens.to(self.device) # meh 52 | z = self.transformer(tokens, return_embeddings=True) 53 | return z 54 | 55 | def encode(self, x): 56 | return self(x) 57 | 58 | 59 | class BERTTokenizer(AbstractEncoder): 60 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 61 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 62 | super().__init__() 63 | from transformers import BertTokenizerFast # TODO: add to reuquirements 64 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 65 | self.device = device 66 | self.vq_interface = vq_interface 67 | self.max_length = max_length 68 | 69 | def forward(self, text): 70 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 71 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 72 | tokens = batch_encoding["input_ids"].to(self.device) 73 | return tokens 74 | 75 | @torch.no_grad() 76 | def encode(self, text): 77 | tokens = self(text) 78 | if not self.vq_interface: 79 | return tokens 80 | return None, None, [None, None, tokens] 81 | 82 | def decode(self, text): 83 | return text 84 | 85 | class BERTEmbedder(AbstractEncoder): 86 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 87 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 88 | device="cuda",use_tokenizer=True, embedding_dropout=0.0, cond_key=''): 89 | super().__init__() 90 | self.use_tknz_fn = use_tokenizer 91 | self.cond_key = cond_key 92 | if self.use_tknz_fn: 93 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 94 | self.device = device 95 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 96 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 97 | emb_dropout=embedding_dropout) 98 | 99 | def forward(self, text, return_token=False): 100 | if self.use_tknz_fn: 101 | tokens = self.tknz_fn(text).to(self.device) 102 | else: 103 | if self.cond_key != '': 104 | text = text[self.cond_key].cuda() 105 | tokens = text.long() 106 | 107 | z = self.transformer(tokens, return_embeddings=True) 108 | if return_token: 109 | return z, tokens 110 | return z 111 | 112 | def encode(self, text): 113 | # output of length 77 114 | return self(text) 115 | 116 | class BERTEmbedderVQTInterface(BERTTokenizer): 117 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 118 | super().__init__(device=device, vq_interface=vq_interface, max_length=max_length) 119 | 120 | def encode(self, c): 121 | tokens = self(c) 122 | return c, None, [None,None,tokens] 123 | 124 | def decode(self, c): 125 | return c 126 | 127 | class SpatialRescaler(nn.Module): 128 | def __init__(self, 129 | n_stages=1, 130 | method='bilinear', 131 | multiplier=0.5, 132 | in_channels=3, 133 | out_channels=None, 134 | bias=False): 135 | super().__init__() 136 | self.n_stages = n_stages 137 | assert self.n_stages >= 0 138 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 139 | self.multiplier = multiplier 140 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 141 | self.remap_output = out_channels is not None 142 | if self.remap_output: 143 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 144 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 145 | 146 | def forward(self,x): 147 | for stage in range(self.n_stages): 148 | x = self.interpolator(x, scale_factor=self.multiplier) 149 | 150 | 151 | if self.remap_output: 152 | x = self.channel_mapper(x) 153 | return x 154 | 155 | def encode(self, x): 156 | return self(x) 157 | 158 | 159 | class FrozenCLIPEmbedder(AbstractEncoder): 160 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 161 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 162 | super().__init__() 163 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 164 | self.transformer = CLIPTextModel.from_pretrained(version) 165 | self.device = device 166 | self.max_length = max_length 167 | self.freeze() 168 | 169 | def freeze(self): 170 | self.transformer = self.transformer.eval() 171 | for param in self.parameters(): 172 | param.requires_grad = False 173 | 174 | def forward(self, text): 175 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 176 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 177 | tokens = batch_encoding["input_ids"].to(self.device) 178 | outputs = self.transformer(input_ids=tokens) 179 | 180 | z = outputs.last_hidden_state 181 | return z 182 | 183 | def encode(self, text): 184 | return self(text) 185 | 186 | 187 | class FrozenCLIPTextEmbedder(nn.Module): 188 | """ 189 | Uses the CLIP transformer encoder for text. 190 | """ 191 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 192 | super().__init__() 193 | self.model, _ = clip.load(version, jit=False, device="cpu") 194 | self.device = device 195 | self.max_length = max_length 196 | self.n_repeat = n_repeat 197 | self.normalize = normalize 198 | self.use_tknz_fn = True 199 | 200 | def freeze(self): 201 | self.model = self.model.eval() 202 | for param in self.parameters(): 203 | param.requires_grad = False 204 | 205 | def forward(self, text): 206 | tokens = clip.tokenize(text).to(self.device) 207 | z = self.model.encode_text(tokens) 208 | if self.normalize: 209 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 210 | return z 211 | 212 | def encode(self, text): 213 | z = self(text) 214 | if z.ndim==2: 215 | z = z[:, None, :] 216 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 217 | return z 218 | 219 | 220 | class FrozenClipImageEmbedder(nn.Module): 221 | """ 222 | Uses the CLIP image encoder. 223 | """ 224 | def __init__( 225 | self, 226 | model, 227 | jit=False, 228 | device='cuda' if torch.cuda.is_available() else 'cpu', 229 | antialias=False, 230 | ): 231 | super().__init__() 232 | self.model, _ = clip.load(name=model, device=device, jit=jit) 233 | 234 | self.antialias = antialias 235 | 236 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 237 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 238 | 239 | def preprocess(self, x): 240 | # normalize to [0,1] 241 | x = kornia.geometry.resize(x, (224, 224), 242 | interpolation='bicubic',align_corners=True, 243 | antialias=self.antialias) 244 | x = (x + 1.) / 2. 245 | # renormalize according to clip 246 | x = kornia.enhance.normalize(x, self.mean, self.std) 247 | return x 248 | 249 | def forward(self, x): 250 | # x is assumed to be in range [-1,1] 251 | return self.model.encode_image(self.preprocess(x)) 252 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if codebook_loss is None: 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | def log_txt_as_img(wh, xc, size=10): 11 | # wh a tuple of (width, height) 12 | # xc a list of captions to plot 13 | b = len(xc) 14 | txts = list() 15 | for bi in range(b): 16 | txt = Image.new("RGB", wh, color="white") 17 | draw = ImageDraw.Draw(txt) 18 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 19 | nc = int(40 * (wh[0] / 256)) 20 | if type(xc[bi]) is list: 21 | xc[bi] = '{}'.format(xc[bi])[1:-1] 22 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 23 | 24 | try: 25 | draw.text((0, 0), lines, fill="black", font=font) 26 | except UnicodeEncodeError: 27 | print("Cant encode string for logging. Skipping.") 28 | 29 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 30 | txts.append(txt) 31 | txts = np.stack(txts) 32 | txts = torch.tensor(txts) 33 | return txts 34 | 35 | 36 | def ismap(x): 37 | if not isinstance(x, torch.Tensor): 38 | return False 39 | return (len(x.shape) == 4) and (x.shape[1] > 3) 40 | 41 | 42 | def isimage(x): 43 | if not isinstance(x,torch.Tensor): 44 | return False 45 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 46 | 47 | 48 | def exists(x): 49 | return x is not None 50 | 51 | 52 | def default(val, d): 53 | if exists(val): 54 | return val 55 | return d() if isfunction(d) else d 56 | 57 | 58 | def mean_flat(tensor): 59 | """ 60 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 61 | Take the mean over all non-batch dimensions. 62 | """ 63 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 64 | 65 | 66 | def count_params(model, verbose=False): 67 | total_params = sum(p.numel() for p in model.parameters()) 68 | if verbose: 69 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 70 | return total_params 71 | 72 | 73 | def instantiate_from_config(config): 74 | if not "target" in config: 75 | if config == '__is_first_stage__': 76 | return None 77 | elif config == "__is_unconditional__": 78 | return None 79 | raise KeyError("Expected key `target` to instantiate.") 80 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 81 | 82 | # TODO: clean this 83 | def instantiate_from_config_main(config, *args, **kwargs): 84 | if not "target" in config: 85 | raise KeyError("Expected key `target` to instantiate.") 86 | return get_obj_from_str(config["target"])(*args, **config.get("params", dict()), **kwargs) 87 | 88 | 89 | def get_obj_from_str(string, reload=False): 90 | module, cls = string.rsplit(".", 1) 91 | if reload: 92 | module_imp = importlib.import_module(module) 93 | importlib.reload(module_imp) 94 | return getattr(importlib.import_module(module, package=None), cls) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='ldm_layout', 5 | version='1.0.0', 6 | packages=find_packages(), 7 | install_requires=[ 8 | 'torch', 9 | 'numpy', 10 | ], 11 | ) -------------------------------------------------------------------------------- /taming/data/annotated_objects_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List, Callable, Dict, Any, Union 3 | import warnings 4 | 5 | import PIL.Image as pil_image 6 | from torch import Tensor 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | 10 | from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder, ObjectsConditionalBuilder, CaptionsConditionalBuilder 11 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 12 | from taming.data.conditional_builder.utils import load_object_from_string 13 | from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType 14 | from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ 15 | Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor 16 | 17 | 18 | class AnnotatedObjectsDataset(Dataset): 19 | def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int, 20 | min_object_area: float, min_objects_per_image: int, max_objects_per_image: int, 21 | crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool, 22 | encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "", 23 | no_object_classes: Optional[int] = None): 24 | 25 | self.data_path = data_path 26 | self.split = split 27 | self.keys = keys 28 | self.target_image_size = target_image_size 29 | self.min_object_area = min_object_area 30 | self.min_objects_per_image = min_objects_per_image 31 | self.max_objects_per_image = max_objects_per_image 32 | self.crop_method = crop_method 33 | self.random_flip = random_flip 34 | self.no_tokens = no_tokens 35 | self.use_group_parameter = use_group_parameter 36 | self.encode_crop = encode_crop 37 | 38 | self.annotations = None 39 | self.image_descriptions = None 40 | self.categories = None 41 | self.category_ids = None 42 | self.category_number = None 43 | self.image_ids = None 44 | self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip) 45 | self.paths = self.build_paths(self.data_path) 46 | self._conditional_builders = None 47 | self.category_allow_list = None 48 | if category_allow_list_target: 49 | allow_list = load_object_from_string(category_allow_list_target) 50 | self.category_allow_list = {name for name, _ in allow_list} 51 | self.category_mapping = {} 52 | if category_mapping_target: 53 | self.category_mapping = load_object_from_string(category_mapping_target) 54 | self.no_object_classes = no_object_classes 55 | 56 | def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]: 57 | top_level = Path(top_level) 58 | sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()} 59 | for path in sub_paths.values(): 60 | if not path.exists(): 61 | raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.') 62 | return sub_paths 63 | 64 | @staticmethod 65 | def load_image_from_disk(path: Path) -> Image: 66 | return pil_image.open(path).convert('RGB') 67 | 68 | @staticmethod 69 | def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool): 70 | transform_functions = [] 71 | if crop_method == 'none': 72 | transform_functions.append(transforms.Resize((target_image_size, target_image_size))) 73 | # transform_functions.extend([ 74 | # transforms.Resize((target_image_size, target_image_size)), 75 | # CenterCropReturnCoordinates(target_image_size) 76 | # ]) 77 | elif crop_method == 'center': 78 | transform_functions.extend([ 79 | transforms.Resize(target_image_size), 80 | CenterCropReturnCoordinates(target_image_size) 81 | ]) 82 | elif crop_method == 'random-1d': 83 | transform_functions.extend([ 84 | transforms.Resize(target_image_size), 85 | RandomCrop1dReturnCoordinates(target_image_size) 86 | ]) 87 | elif crop_method == 'random-2d': 88 | transform_functions.extend([ 89 | Random2dCropReturnCoordinates(target_image_size), 90 | transforms.Resize(target_image_size) 91 | ]) 92 | elif crop_method is None: 93 | return None 94 | else: 95 | raise ValueError(f'Received invalid crop method [{crop_method}].') 96 | if random_flip: 97 | transform_functions.append(RandomHorizontalFlipReturn()) 98 | transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.)) 99 | return transform_functions 100 | 101 | def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor): 102 | crop_bbox = None 103 | flipped = None 104 | for t in self.transform_functions: 105 | if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)): 106 | crop_bbox, x = t(x) 107 | elif isinstance(t, RandomHorizontalFlipReturn): 108 | flipped, x = t(x) 109 | else: 110 | x = t(x) 111 | return crop_bbox, flipped, x 112 | 113 | @property 114 | def no_classes(self) -> int: 115 | return self.no_object_classes if self.no_object_classes else len(self.categories) 116 | 117 | @property 118 | def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: 119 | # cannot set this up in init because no_classes is only known after loading data in init of superclass 120 | if self._conditional_builders is None: 121 | self._conditional_builders = { 122 | 'objects_center_points': ObjectsCenterPointsConditionalBuilder( 123 | self.no_classes, 124 | self.max_objects_per_image, 125 | self.no_tokens, 126 | self.encode_crop, 127 | self.use_group_parameter, 128 | getattr(self, 'use_additional_parameters', False) 129 | ), 130 | 'objects_bbox': ObjectsBoundingBoxConditionalBuilder( 131 | self.no_classes, 132 | self.max_objects_per_image, 133 | self.no_tokens, 134 | self.encode_crop, 135 | self.use_group_parameter, 136 | getattr(self, 'use_additional_parameters', False) 137 | ), 138 | 'objects': ObjectsConditionalBuilder( 139 | self.no_classes, 140 | self.max_objects_per_image, 141 | self.no_tokens, 142 | self.encode_crop, 143 | self.use_group_parameter, 144 | getattr(self, 'use_additional_parameters', False) 145 | ), 146 | # 'captions': CaptionsConditionalBuilder( 147 | # self.no_classes, 148 | # self.max_objects_per_image, 149 | # self.no_tokens, 150 | # self.encode_crop, 151 | # self.use_group_parameter, 152 | # getattr(self, 'use_additional_parameters', False) 153 | # ), 154 | } 155 | return self._conditional_builders 156 | 157 | def filter_categories(self) -> None: 158 | if self.category_allow_list: 159 | self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list} 160 | if self.category_mapping: 161 | self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping} 162 | try: 163 | print('Filting appending categories') 164 | if self.category_allow_list: 165 | self.categories_append = {id_: cat for id_, cat in self.categories_append.items() if cat.name in self.category_allow_list} 166 | if self.category_mapping: 167 | self.categories_append = {id_: cat for id_, cat in self.categories_append.items() if cat.id not in self.category_mapping} 168 | except: 169 | pass 170 | 171 | def setup_category_id_and_number(self) -> None: 172 | self.category_ids = list(self.categories.keys()) 173 | self.category_ids.sort() 174 | if '/m/01s55n' in self.category_ids: 175 | self.category_ids.remove('/m/01s55n') 176 | self.category_ids.append('/m/01s55n') 177 | try: 178 | print('Adding appending categories into main one.') 179 | self.category_ids_append = list(self.categories_append.keys()) 180 | self.category_ids_append.sort() 181 | self.category_ids += self.category_ids_append 182 | self.categories = {**self.categories, **self.categories_append} 183 | except: 184 | pass 185 | self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)} 186 | if self.category_allow_list is not None and self.category_mapping is None \ 187 | and len(self.category_ids) != len(self.category_allow_list): 188 | warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. ' 189 | 'Make sure all names in category_allow_list exist.') 190 | 191 | def clean_up_annotations_and_image_descriptions(self) -> None: 192 | image_id_set = set(self.image_ids) 193 | self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set} 194 | self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set} 195 | 196 | @staticmethod 197 | def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float, 198 | min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]: 199 | filtered = {} 200 | for image_id, annotations in all_annotations.items(): 201 | annotations_with_min_area = [a for a in annotations if a.area > min_object_area] 202 | if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image: 203 | filtered[image_id] = annotations_with_min_area 204 | return filtered 205 | 206 | def __len__(self): 207 | return len(self.image_ids) 208 | 209 | def __getitem__(self, n: int) -> Dict[str, Any]: 210 | image_id = self.get_image_id(n) 211 | sample = self.get_image_description(image_id) 212 | sample['annotations'] = self.get_annotation(image_id) 213 | 214 | if 'image' in self.keys: 215 | sample['image_path'] = str(self.get_image_path(image_id)) 216 | sample['image'] = self.load_image_from_disk(sample['image_path']) 217 | sample['image'] = convert_pil_to_tensor(sample['image']) 218 | sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image']) 219 | sample['image'] = sample['image'].permute(1, 2, 0) 220 | 221 | for conditional, builder in self.conditional_builders.items(): 222 | if conditional in self.keys: 223 | sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped']) 224 | 225 | if self.keys: 226 | # only return specified keys 227 | sample = {key: sample[key] for key in self.keys} 228 | return sample 229 | 230 | def get_image_id(self, no: int) -> str: 231 | return self.image_ids[no] 232 | 233 | def get_annotation(self, image_id: str) -> str: 234 | return self.annotations[image_id] 235 | 236 | def get_textual_label_for_category_id(self, category_id: str) -> str: 237 | return self.categories[category_id].name 238 | 239 | def get_textual_label_for_category_no(self, category_no: int) -> str: 240 | return self.categories[self.get_category_id(category_no)].name 241 | 242 | def get_category_number(self, category_id: str) -> int: 243 | return self.category_number[category_id] 244 | 245 | def get_category_id(self, category_no: int) -> str: 246 | return self.category_ids[category_no] 247 | 248 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 249 | raise NotImplementedError() 250 | 251 | def get_path_structure(self): 252 | raise NotImplementedError 253 | 254 | def get_image_path(self, image_id: str) -> Path: 255 | raise NotImplementedError 256 | -------------------------------------------------------------------------------- /taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /taming/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | 9 | from taming.data.sflckr import SegmentationBase # for examples included in repo 10 | 11 | 12 | class Examples(SegmentationBase): 13 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 14 | super().__init__(data_csv="data/coco_examples.txt", 15 | data_root="data/coco_images", 16 | segmentation_root="data/coco_segmentations", 17 | size=size, random_crop=random_crop, 18 | interpolation=interpolation, 19 | n_labels=183, shift_segmentation=True) 20 | 21 | 22 | class CocoBase(Dataset): 23 | """needed for (image, caption, segmentation) pairs""" 24 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 25 | crop_size=None, force_no_crop=False, given_files=None): 26 | self.split = self.get_split() 27 | self.size = size 28 | if crop_size is None: 29 | self.crop_size = size 30 | else: 31 | self.crop_size = crop_size 32 | 33 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 34 | self.stuffthing = use_stuffthing # include thing in segmentation 35 | if self.onehot and not self.stuffthing: 36 | raise NotImplemented("One hot mode is only supported for the " 37 | "stuffthings version because labels are stored " 38 | "a bit different.") 39 | 40 | data_json = datajson 41 | with open(data_json) as json_file: 42 | self.json_data = json.load(json_file) 43 | self.img_id_to_captions = dict() 44 | self.img_id_to_filepath = dict() 45 | self.img_id_to_segmentation_filepath = dict() 46 | 47 | assert data_json.split("/")[-1] in ["captions_train2017.json", 48 | "captions_val2017.json"] 49 | if self.stuffthing: 50 | self.segmentation_prefix = ( 51 | "data/cocostuffthings/val2017" if 52 | data_json.endswith("captions_val2017.json") else 53 | "data/cocostuffthings/train2017") 54 | else: 55 | self.segmentation_prefix = ( 56 | "data/coco/annotations/stuff_val2017_pixelmaps" if 57 | data_json.endswith("captions_val2017.json") else 58 | "data/coco/annotations/stuff_train2017_pixelmaps") 59 | 60 | imagedirs = self.json_data["images"] 61 | self.labels = {"image_ids": list()} 62 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 63 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 64 | self.img_id_to_captions[imgdir["id"]] = list() 65 | pngfilename = imgdir["file_name"].replace("jpg", "png") 66 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 67 | self.segmentation_prefix, pngfilename) 68 | if given_files is not None: 69 | if pngfilename in given_files: 70 | self.labels["image_ids"].append(imgdir["id"]) 71 | else: 72 | self.labels["image_ids"].append(imgdir["id"]) 73 | 74 | capdirs = self.json_data["annotations"] 75 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 76 | # there are in average 5 captions per image 77 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 78 | 79 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 80 | if self.split=="validation": 81 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 82 | else: 83 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 84 | self.preprocessor = albumentations.Compose( 85 | [self.rescaler, self.cropper], 86 | additional_targets={"segmentation": "image"}) 87 | if force_no_crop: 88 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 89 | self.preprocessor = albumentations.Compose( 90 | [self.rescaler], 91 | additional_targets={"segmentation": "image"}) 92 | 93 | def __len__(self): 94 | return len(self.labels["image_ids"]) 95 | 96 | def preprocess_image(self, image_path, segmentation_path): 97 | image = Image.open(image_path) 98 | if not image.mode == "RGB": 99 | image = image.convert("RGB") 100 | image = np.array(image).astype(np.uint8) 101 | 102 | segmentation = Image.open(segmentation_path) 103 | if not self.onehot and not segmentation.mode == "RGB": 104 | segmentation = segmentation.convert("RGB") 105 | segmentation = np.array(segmentation).astype(np.uint8) 106 | if self.onehot: 107 | assert self.stuffthing 108 | # stored in caffe format: unlabeled==255. stuff and thing from 109 | # 0-181. to be compatible with the labels in 110 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 111 | # we shift stuffthing one to the right and put unlabeled in zero 112 | # as long as segmentation is uint8 shifting to right handles the 113 | # latter too 114 | assert segmentation.dtype == np.uint8 115 | segmentation = segmentation + 1 116 | 117 | processed = self.preprocessor(image=image, segmentation=segmentation) 118 | image, segmentation = processed["image"], processed["segmentation"] 119 | image = (image / 127.5 - 1.0).astype(np.float32) 120 | 121 | if self.onehot: 122 | assert segmentation.dtype == np.uint8 123 | # make it one hot 124 | n_labels = 183 125 | flatseg = np.ravel(segmentation) 126 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 127 | onehot[np.arange(flatseg.size), flatseg] = True 128 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 129 | segmentation = onehot 130 | else: 131 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 132 | return image, segmentation 133 | 134 | def __getitem__(self, i): 135 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 136 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 137 | image, segmentation = self.preprocess_image(img_path, seg_path) 138 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 139 | # randomly draw one of all available captions per image 140 | caption = captions[np.random.randint(0, len(captions))] 141 | example = {"image": image, 142 | "caption": [str(caption[0])], 143 | "segmentation": segmentation, 144 | "img_path": img_path, 145 | "seg_path": seg_path, 146 | "filename_": img_path.split(os.sep)[-1] 147 | } 148 | return example 149 | 150 | 151 | class CocoImagesAndCaptionsTrain(CocoBase): 152 | """returns a pair of (image, caption)""" 153 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 154 | super().__init__(size=size, 155 | dataroot="data/coco/train2017", 156 | datajson="data/coco/annotations/captions_train2017.json", 157 | onehot_segmentation=onehot_segmentation, 158 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 159 | 160 | def get_split(self): 161 | return "train" 162 | 163 | 164 | class CocoImagesAndCaptionsValidation(CocoBase): 165 | """returns a pair of (image, caption)""" 166 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 167 | given_files=None): 168 | super().__init__(size=size, 169 | dataroot="data/coco/val2017", 170 | datajson="data/coco/annotations/captions_val2017.json", 171 | onehot_segmentation=onehot_segmentation, 172 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 173 | given_files=given_files) 174 | 175 | def get_split(self): 176 | return "validation" 177 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | # font = ImageFont.truetype( 47 | # "arial.ttf", 48 | # size=get_plot_font_size(font_size, figure_size) 49 | # ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK) #, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | 62 | 63 | class ObjectsConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 64 | @property 65 | def object_descriptor_length(self) -> int: 66 | return 1 67 | 68 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 69 | object_triples = [ 70 | (self.object_representation(ann),) 71 | for ann in annotations 72 | ] 73 | 74 | empty_triple = (self.none,) 75 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 76 | return object_triples 77 | 78 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 79 | conditional_list = conditional.tolist() 80 | crop_coordinates = None 81 | # if self.encode_crop: 82 | # crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 83 | # conditional_list = conditional_list[:-2] 84 | object_triples = grouper(conditional_list, 1) 85 | assert conditional.shape[0] == self.embedding_dim 86 | return [ 87 | (object_triple[0]) 88 | for object_triple in object_triples if object_triple[0] != self.none 89 | ], crop_coordinates 90 | 91 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 92 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 93 | 94 | return 0 95 | # plot = pil_image.new('RGB', figure_size, WHITE) 96 | # draw = pil_img_draw.Draw(plot) 97 | # font = ImageFont.truetype( 98 | # "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 99 | # size=get_plot_font_size(font_size, figure_size) 100 | # ) 101 | # width, height = plot.size 102 | # description, crop_coordinates = self.inverse_build(conditional) 103 | # for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 104 | # annotation = self.representation_to_annotation(representation) 105 | # class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 106 | # bbox = absolute_bbox(bbox, width, height) 107 | # draw.rectangle(bbox, outline=color, width=line_width) 108 | # draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 109 | # if crop_coordinates is not None: 110 | # draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 111 | # return convert_pil_to_tensor(plot) / 127.5 - 1. 112 | 113 | class CaptionsConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 114 | @property 115 | def object_descriptor_length(self) -> int: 116 | return 1 117 | 118 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 119 | object_triples = [ 120 | (self.object_representation(ann),) 121 | for ann in annotations 122 | ] 123 | empty_triple = (self.none,) 124 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 125 | return object_triples 126 | 127 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 128 | conditional_list = conditional.tolist() 129 | crop_coordinates = None 130 | # if self.encode_crop: 131 | # crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 132 | # conditional_list = conditional_list[:-2] 133 | object_triples = grouper(conditional_list, 1) 134 | assert conditional.shape[0] == self.embedding_dim 135 | return [ 136 | (object_triple[0]) 137 | for object_triple in object_triples if object_triple[0] != self.none 138 | ], crop_coordinates 139 | 140 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 141 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 142 | 143 | return 0 144 | # plot = pil_image.new('RGB', figure_size, WHITE) 145 | # draw = pil_img_draw.Draw(plot) 146 | # font = ImageFont.truetype( 147 | # "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 148 | # size=get_plot_font_size(font_size, figure_size) 149 | # ) 150 | # width, height = plot.size 151 | # description, crop_coordinates = self.inverse_build(conditional) 152 | # for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 153 | # annotation = self.representation_to_annotation(representation) 154 | # class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 155 | # bbox = absolute_bbox(bbox, width, height) 156 | # draw.rectangle(bbox, outline=color, width=line_width) 157 | # draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 158 | # if crop_coordinates is not None: 159 | # draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 160 | # return convert_pil_to_tensor(plot) / 127.5 - 1. -------------------------------------------------------------------------------- /taming/data/conditional_builder/objects_center_points.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import warnings 4 | from itertools import cycle 5 | from typing import List, Optional, Tuple, Callable 6 | 7 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 8 | from more_itertools.recipes import grouper 9 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ 10 | additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ 11 | absolute_bbox, rescale_annotations 12 | from taming.data.helper_types import BoundingBox, Annotation 13 | from taming.data.image_transforms import convert_pil_to_tensor 14 | from torch import LongTensor, Tensor 15 | 16 | 17 | class ObjectsCenterPointsConditionalBuilder: 18 | def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, 19 | use_group_parameter: bool, use_additional_parameters: bool): 20 | self.no_object_classes = no_object_classes 21 | self.no_max_objects = no_max_objects 22 | self.no_tokens = no_tokens 23 | self.encode_crop = encode_crop 24 | self.no_sections = int(math.sqrt(self.no_tokens)) 25 | self.use_group_parameter = use_group_parameter 26 | self.use_additional_parameters = use_additional_parameters 27 | 28 | @property 29 | def none(self) -> int: 30 | return self.no_tokens - 1 31 | 32 | @property 33 | def object_descriptor_length(self) -> int: 34 | return 2 35 | 36 | @property 37 | def embedding_dim(self) -> int: 38 | extra_length = 2 if self.encode_crop else 0 39 | return self.no_max_objects * self.object_descriptor_length + extra_length 40 | 41 | def tokenize_coordinates(self, x: float, y: float) -> int: 42 | """ 43 | Express 2d coordinates with one number. 44 | Example: assume self.no_tokens = 16, then no_sections = 4: 45 | 0 0 0 0 46 | 0 0 # 0 47 | 0 0 0 0 48 | 0 0 0 x 49 | Then the # position corresponds to token 6, the x position to token 15. 50 | @param x: float in [0, 1] 51 | @param y: float in [0, 1] 52 | @return: discrete tokenized coordinate 53 | """ 54 | x_discrete = int(round(x * (self.no_sections - 1))) 55 | y_discrete = int(round(y * (self.no_sections - 1))) 56 | return y_discrete * self.no_sections + x_discrete 57 | 58 | def coordinates_from_token(self, token: int) -> (float, float): 59 | x = token % self.no_sections 60 | y = token // self.no_sections 61 | return x / (self.no_sections - 1), y / (self.no_sections - 1) 62 | 63 | def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: 64 | x0, y0 = self.coordinates_from_token(token1) 65 | x1, y1 = self.coordinates_from_token(token2) 66 | return x0, y0, x1 - x0, y1 - y0 67 | 68 | def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: 69 | return self.tokenize_coordinates(bbox[0], bbox[1]), \ 70 | self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) 71 | 72 | def inverse_build(self, conditional: LongTensor) \ 73 | -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: 74 | conditional_list = conditional.tolist() 75 | crop_coordinates = None 76 | if self.encode_crop: 77 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 78 | conditional_list = conditional_list[:-2] 79 | table_of_content = grouper(conditional_list, self.object_descriptor_length) 80 | assert conditional.shape[0] == self.embedding_dim 81 | return [ 82 | (object_tuple[0], self.coordinates_from_token(object_tuple[1])) 83 | for object_tuple in table_of_content if object_tuple[0] != self.none 84 | ], crop_coordinates 85 | 86 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 87 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 88 | plot = pil_image.new('RGB', figure_size, WHITE) 89 | draw = pil_img_draw.Draw(plot) 90 | circle_size = get_circle_size(figure_size) 91 | font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', 92 | size=get_plot_font_size(font_size, figure_size)) 93 | width, height = plot.size 94 | description, crop_coordinates = self.inverse_build(conditional) 95 | for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): 96 | x_abs, y_abs = x * width, y * height 97 | ann = self.representation_to_annotation(representation) 98 | label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) 99 | ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] 100 | draw.ellipse(ellipse_bbox, fill=color, width=0) 101 | draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) 102 | if crop_coordinates is not None: 103 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 104 | return convert_pil_to_tensor(plot) / 127.5 - 1. 105 | 106 | def object_representation(self, annotation: Annotation) -> int: 107 | modifier = 0 108 | if self.use_group_parameter: 109 | modifier |= 1 * (annotation.is_group_of is True) 110 | if self.use_additional_parameters: 111 | modifier |= 2 * (annotation.is_occluded is True) 112 | modifier |= 4 * (annotation.is_depiction is True) 113 | modifier |= 8 * (annotation.is_inside is True) 114 | return annotation.category_no + self.no_object_classes * modifier 115 | 116 | def representation_to_annotation(self, representation: int) -> Annotation: 117 | category_no = representation % self.no_object_classes 118 | modifier = representation // self.no_object_classes 119 | # noinspection PyTypeChecker 120 | return Annotation( 121 | area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, 122 | category_no=category_no, 123 | is_group_of=bool((modifier & 1) * self.use_group_parameter), 124 | is_occluded=bool((modifier & 2) * self.use_additional_parameters), 125 | is_depiction=bool((modifier & 4) * self.use_additional_parameters), 126 | is_inside=bool((modifier & 8) * self.use_additional_parameters) 127 | ) 128 | 129 | def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: 130 | return list(self.token_pair_from_bbox(crop_coordinates)) 131 | 132 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 133 | object_tuples = [ 134 | (self.object_representation(a), 135 | self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) 136 | for a in annotations 137 | ] 138 | empty_tuple = (self.none, self.none) 139 | object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) 140 | return object_tuples 141 | 142 | def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ 143 | -> LongTensor: 144 | if len(annotations) == 0: 145 | warnings.warn('Did not receive any annotations.') 146 | if len(annotations) > self.no_max_objects: 147 | warnings.warn('Received more annotations than allowed.') 148 | annotations = annotations[:self.no_max_objects] 149 | 150 | if not crop_coordinates: 151 | crop_coordinates = FULL_CROP 152 | 153 | random.shuffle(annotations) 154 | annotations = filter_annotations(annotations, crop_coordinates) 155 | if self.encode_crop: 156 | annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) 157 | if horizontal_flip: 158 | crop_coordinates = horizontally_flip_bbox(crop_coordinates) 159 | extra = self._crop_encoder(crop_coordinates) 160 | else: 161 | annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) 162 | extra = [] 163 | 164 | object_tuples = self._make_object_descriptors(annotations) 165 | 166 | # flatten 167 | flattened = [token for tuple_ in object_tuples for token in tuple_] + extra 168 | assert len(flattened) == self.embedding_dim 169 | assert all(0 <= value < self.no_tokens for value in flattened) 170 | return LongTensor(flattened) 171 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /taming/data/open_images_helper.py: -------------------------------------------------------------------------------- 1 | open_images_unify_categories_for_coco = { 2 | '/m/03bt1vf': '/m/01g317', 3 | '/m/04yx4': '/m/01g317', 4 | '/m/05r655': '/m/01g317', 5 | '/m/01bl7v': '/m/01g317', 6 | '/m/0cnyhnx': '/m/01xq0k1', 7 | '/m/01226z': '/m/018xm', 8 | '/m/05ctyq': '/m/018xm', 9 | '/m/058qzx': '/m/04ctx', 10 | '/m/06pcq': '/m/0l515', 11 | '/m/03m3pdh': '/m/02crq1', 12 | '/m/046dlr': '/m/01x3z', 13 | '/m/0h8mzrc': '/m/01x3z', 14 | } 15 | 16 | 17 | top_300_classes_plus_coco_compatibility = [ 18 | ('Man', 1060962), 19 | ('Clothing', 986610), 20 | ('Tree', 748162), 21 | ('Woman', 611896), 22 | ('Person', 610294), 23 | ('Human face', 442948), 24 | ('Girl', 175399), 25 | ('Building', 162147), 26 | ('Car', 159135), 27 | ('Plant', 155704), 28 | ('Human body', 137073), 29 | ('Flower', 133128), 30 | ('Window', 127485), 31 | ('Human arm', 118380), 32 | ('House', 114365), 33 | ('Wheel', 111684), 34 | ('Suit', 99054), 35 | ('Human hair', 98089), 36 | ('Human head', 92763), 37 | ('Chair', 88624), 38 | ('Boy', 79849), 39 | ('Table', 73699), 40 | ('Jeans', 57200), 41 | ('Tire', 55725), 42 | ('Skyscraper', 53321), 43 | ('Food', 52400), 44 | ('Footwear', 50335), 45 | ('Dress', 50236), 46 | ('Human leg', 47124), 47 | ('Toy', 46636), 48 | ('Tower', 45605), 49 | ('Boat', 43486), 50 | ('Land vehicle', 40541), 51 | ('Bicycle wheel', 34646), 52 | ('Palm tree', 33729), 53 | ('Fashion accessory', 32914), 54 | ('Glasses', 31940), 55 | ('Bicycle', 31409), 56 | ('Furniture', 30656), 57 | ('Sculpture', 29643), 58 | ('Bottle', 27558), 59 | ('Dog', 26980), 60 | ('Snack', 26796), 61 | ('Human hand', 26664), 62 | ('Bird', 25791), 63 | ('Book', 25415), 64 | ('Guitar', 24386), 65 | ('Jacket', 23998), 66 | ('Poster', 22192), 67 | ('Dessert', 21284), 68 | ('Baked goods', 20657), 69 | ('Drink', 19754), 70 | ('Flag', 18588), 71 | ('Houseplant', 18205), 72 | ('Tableware', 17613), 73 | ('Airplane', 17218), 74 | ('Door', 17195), 75 | ('Sports uniform', 17068), 76 | ('Shelf', 16865), 77 | ('Drum', 16612), 78 | ('Vehicle', 16542), 79 | ('Microphone', 15269), 80 | ('Street light', 14957), 81 | ('Cat', 14879), 82 | ('Fruit', 13684), 83 | ('Fast food', 13536), 84 | ('Animal', 12932), 85 | ('Vegetable', 12534), 86 | ('Train', 12358), 87 | ('Horse', 11948), 88 | ('Flowerpot', 11728), 89 | ('Motorcycle', 11621), 90 | ('Fish', 11517), 91 | ('Desk', 11405), 92 | ('Helmet', 10996), 93 | ('Truck', 10915), 94 | ('Bus', 10695), 95 | ('Hat', 10532), 96 | ('Auto part', 10488), 97 | ('Musical instrument', 10303), 98 | ('Sunglasses', 10207), 99 | ('Picture frame', 10096), 100 | ('Sports equipment', 10015), 101 | ('Shorts', 9999), 102 | ('Wine glass', 9632), 103 | ('Duck', 9242), 104 | ('Wine', 9032), 105 | ('Rose', 8781), 106 | ('Tie', 8693), 107 | ('Butterfly', 8436), 108 | ('Beer', 7978), 109 | ('Cabinetry', 7956), 110 | ('Laptop', 7907), 111 | ('Insect', 7497), 112 | ('Goggles', 7363), 113 | ('Shirt', 7098), 114 | ('Dairy Product', 7021), 115 | ('Marine invertebrates', 7014), 116 | ('Cattle', 7006), 117 | ('Trousers', 6903), 118 | ('Van', 6843), 119 | ('Billboard', 6777), 120 | ('Balloon', 6367), 121 | ('Human nose', 6103), 122 | ('Tent', 6073), 123 | ('Camera', 6014), 124 | ('Doll', 6002), 125 | ('Coat', 5951), 126 | ('Mobile phone', 5758), 127 | ('Swimwear', 5729), 128 | ('Strawberry', 5691), 129 | ('Stairs', 5643), 130 | ('Goose', 5599), 131 | ('Umbrella', 5536), 132 | ('Cake', 5508), 133 | ('Sun hat', 5475), 134 | ('Bench', 5310), 135 | ('Bookcase', 5163), 136 | ('Bee', 5140), 137 | ('Computer monitor', 5078), 138 | ('Hiking equipment', 4983), 139 | ('Office building', 4981), 140 | ('Coffee cup', 4748), 141 | ('Curtain', 4685), 142 | ('Plate', 4651), 143 | ('Box', 4621), 144 | ('Tomato', 4595), 145 | ('Coffee table', 4529), 146 | ('Office supplies', 4473), 147 | ('Maple', 4416), 148 | ('Muffin', 4365), 149 | ('Cocktail', 4234), 150 | ('Castle', 4197), 151 | ('Couch', 4134), 152 | ('Pumpkin', 3983), 153 | ('Computer keyboard', 3960), 154 | ('Human mouth', 3926), 155 | ('Christmas tree', 3893), 156 | ('Mushroom', 3883), 157 | ('Swimming pool', 3809), 158 | ('Pastry', 3799), 159 | ('Lavender (Plant)', 3769), 160 | ('Football helmet', 3732), 161 | ('Bread', 3648), 162 | ('Traffic sign', 3628), 163 | ('Common sunflower', 3597), 164 | ('Television', 3550), 165 | ('Bed', 3525), 166 | ('Cookie', 3485), 167 | ('Fountain', 3484), 168 | ('Paddle', 3447), 169 | ('Bicycle helmet', 3429), 170 | ('Porch', 3420), 171 | ('Deer', 3387), 172 | ('Fedora', 3339), 173 | ('Canoe', 3338), 174 | ('Carnivore', 3266), 175 | ('Bowl', 3202), 176 | ('Human eye', 3166), 177 | ('Ball', 3118), 178 | ('Pillow', 3077), 179 | ('Salad', 3061), 180 | ('Beetle', 3060), 181 | ('Orange', 3050), 182 | ('Drawer', 2958), 183 | ('Platter', 2937), 184 | ('Elephant', 2921), 185 | ('Seafood', 2921), 186 | ('Monkey', 2915), 187 | ('Countertop', 2879), 188 | ('Watercraft', 2831), 189 | ('Helicopter', 2805), 190 | ('Kitchen appliance', 2797), 191 | ('Personal flotation device', 2781), 192 | ('Swan', 2739), 193 | ('Lamp', 2711), 194 | ('Boot', 2695), 195 | ('Bronze sculpture', 2693), 196 | ('Chicken', 2677), 197 | ('Taxi', 2643), 198 | ('Juice', 2615), 199 | ('Cowboy hat', 2604), 200 | ('Apple', 2600), 201 | ('Tin can', 2590), 202 | ('Necklace', 2564), 203 | ('Ice cream', 2560), 204 | ('Human beard', 2539), 205 | ('Coin', 2536), 206 | ('Candle', 2515), 207 | ('Cart', 2512), 208 | ('High heels', 2441), 209 | ('Weapon', 2433), 210 | ('Handbag', 2406), 211 | ('Penguin', 2396), 212 | ('Rifle', 2352), 213 | ('Violin', 2336), 214 | ('Skull', 2304), 215 | ('Lantern', 2285), 216 | ('Scarf', 2269), 217 | ('Saucer', 2225), 218 | ('Sheep', 2215), 219 | ('Vase', 2189), 220 | ('Lily', 2180), 221 | ('Mug', 2154), 222 | ('Parrot', 2140), 223 | ('Human ear', 2137), 224 | ('Sandal', 2115), 225 | ('Lizard', 2100), 226 | ('Kitchen & dining room table', 2063), 227 | ('Spider', 1977), 228 | ('Coffee', 1974), 229 | ('Goat', 1926), 230 | ('Squirrel', 1922), 231 | ('Cello', 1913), 232 | ('Sushi', 1881), 233 | ('Tortoise', 1876), 234 | ('Pizza', 1870), 235 | ('Studio couch', 1864), 236 | ('Barrel', 1862), 237 | ('Cosmetics', 1841), 238 | ('Moths and butterflies', 1841), 239 | ('Convenience store', 1817), 240 | ('Watch', 1792), 241 | ('Home appliance', 1786), 242 | ('Harbor seal', 1780), 243 | ('Luggage and bags', 1756), 244 | ('Vehicle registration plate', 1754), 245 | ('Shrimp', 1751), 246 | ('Jellyfish', 1730), 247 | ('French fries', 1723), 248 | ('Egg (Food)', 1698), 249 | ('Football', 1697), 250 | ('Musical keyboard', 1683), 251 | ('Falcon', 1674), 252 | ('Candy', 1660), 253 | ('Medical equipment', 1654), 254 | ('Eagle', 1651), 255 | ('Dinosaur', 1634), 256 | ('Surfboard', 1630), 257 | ('Tank', 1628), 258 | ('Grape', 1624), 259 | ('Lion', 1624), 260 | ('Owl', 1622), 261 | ('Ski', 1613), 262 | ('Waste container', 1606), 263 | ('Frog', 1591), 264 | ('Sparrow', 1585), 265 | ('Rabbit', 1581), 266 | ('Pen', 1546), 267 | ('Sea lion', 1537), 268 | ('Spoon', 1521), 269 | ('Sink', 1512), 270 | ('Teddy bear', 1507), 271 | ('Bull', 1495), 272 | ('Sofa bed', 1490), 273 | ('Dragonfly', 1479), 274 | ('Brassiere', 1478), 275 | ('Chest of drawers', 1472), 276 | ('Aircraft', 1466), 277 | ('Human foot', 1463), 278 | ('Pig', 1455), 279 | ('Fork', 1454), 280 | ('Antelope', 1438), 281 | ('Tripod', 1427), 282 | ('Tool', 1424), 283 | ('Cheese', 1422), 284 | ('Lemon', 1397), 285 | ('Hamburger', 1393), 286 | ('Dolphin', 1390), 287 | ('Mirror', 1390), 288 | ('Marine mammal', 1387), 289 | ('Giraffe', 1385), 290 | ('Snake', 1368), 291 | ('Gondola', 1364), 292 | ('Wheelchair', 1360), 293 | ('Piano', 1358), 294 | ('Cupboard', 1348), 295 | ('Banana', 1345), 296 | ('Trumpet', 1335), 297 | ('Lighthouse', 1333), 298 | ('Invertebrate', 1317), 299 | ('Carrot', 1268), 300 | ('Sock', 1260), 301 | ('Tiger', 1241), 302 | ('Camel', 1224), 303 | ('Parachute', 1224), 304 | ('Bathroom accessory', 1223), 305 | ('Earrings', 1221), 306 | ('Headphones', 1218), 307 | ('Skirt', 1198), 308 | ('Skateboard', 1190), 309 | ('Sandwich', 1148), 310 | ('Saxophone', 1141), 311 | ('Goldfish', 1136), 312 | ('Stool', 1104), 313 | ('Traffic light', 1097), 314 | ('Shellfish', 1081), 315 | ('Backpack', 1079), 316 | ('Sea turtle', 1078), 317 | ('Cucumber', 1075), 318 | ('Tea', 1051), 319 | ('Toilet', 1047), 320 | ('Roller skates', 1040), 321 | ('Mule', 1039), 322 | ('Bust', 1031), 323 | ('Broccoli', 1030), 324 | ('Crab', 1020), 325 | ('Oyster', 1019), 326 | ('Cannon', 1012), 327 | ('Zebra', 1012), 328 | ('French horn', 1008), 329 | ('Grapefruit', 998), 330 | ('Whiteboard', 997), 331 | ('Zucchini', 997), 332 | ('Crocodile', 992), 333 | 334 | ('Clock', 960), 335 | ('Wall clock', 958), 336 | 337 | ('Doughnut', 869), 338 | ('Snail', 868), 339 | 340 | ('Baseball glove', 859), 341 | 342 | ('Panda', 830), 343 | ('Tennis racket', 830), 344 | 345 | ('Pear', 652), 346 | 347 | ('Bagel', 617), 348 | ('Oven', 616), 349 | ('Ladybug', 615), 350 | ('Shark', 615), 351 | ('Polar bear', 614), 352 | ('Ostrich', 609), 353 | 354 | ('Hot dog', 473), 355 | ('Microwave oven', 467), 356 | ('Fire hydrant', 20), 357 | ('Stop sign', 20), 358 | ('Parking meter', 20), 359 | ('Bear', 20), 360 | ('Flying disc', 20), 361 | ('Snowboard', 20), 362 | ('Tennis ball', 20), 363 | ('Kite', 20), 364 | ('Baseball bat', 20), 365 | ('Kitchen knife', 20), 366 | ('Knife', 20), 367 | ('Submarine sandwich', 20), 368 | ('Computer mouse', 20), 369 | ('Remote control', 20), 370 | ('Toaster', 20), 371 | ('Sink', 20), 372 | ('Refrigerator', 20), 373 | ('Alarm clock', 20), 374 | ('Wall clock', 20), 375 | ('Scissors', 20), 376 | ('Hair dryer', 20), 377 | ('Toothbrush', 20), 378 | ('Suitcase', 20) 379 | ] 380 | -------------------------------------------------------------------------------- /taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /taming/modules/losses/soft_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def SoftCrossEntropy(inputs, target, reduction='sum'): 5 | log_likelihood = -F.log_softmax(inputs, dim=1) 6 | batch = inputs.shape[0] 7 | if reduction == 'average': 8 | loss = torch.sum(torch.mul(log_likelihood, target)) / batch 9 | else: 10 | loss = torch.sum(torch.mul(log_likelihood, target)) 11 | return loss -------------------------------------------------------------------------------- /taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import cv2 5 | import torchvision 6 | 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 9 | 10 | 11 | class DummyLoss(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | 16 | def adopt_weight(weight, global_step, threshold=0, value=0.): 17 | if global_step < threshold: 18 | weight = value 19 | return weight 20 | 21 | 22 | def hinge_d_loss(logits_real, logits_fake): 23 | loss_real = torch.mean(F.relu(1. - logits_real)) 24 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 25 | d_loss = 0.5 * (loss_real + loss_fake) 26 | return d_loss 27 | 28 | 29 | def vanilla_d_loss(logits_real, logits_fake): 30 | d_loss = 0.5 * ( 31 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 32 | torch.mean(torch.nn.functional.softplus(logits_fake))) 33 | return d_loss 34 | 35 | 36 | class VQLPIPSWithDiscriminator(nn.Module): 37 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 38 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 39 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 40 | disc_ndf=64, disc_loss="hinge", aux_downscale=4.): 41 | super().__init__() 42 | assert disc_loss in ["hinge", "vanilla"] 43 | self.codebook_weight = codebook_weight 44 | self.pixel_weight = pixelloss_weight 45 | self.perceptual_loss = LPIPS().eval() 46 | self.perceptual_weight = perceptual_weight 47 | self.aux_downscale = aux_downscale 48 | 49 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 50 | n_layers=disc_num_layers, 51 | use_actnorm=use_actnorm, 52 | ndf=disc_ndf 53 | ).apply(weights_init) 54 | self.discriminator_iter_start = disc_start 55 | if disc_loss == "hinge": 56 | self.disc_loss = hinge_d_loss 57 | elif disc_loss == "vanilla": 58 | self.disc_loss = vanilla_d_loss 59 | else: 60 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 61 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 62 | self.disc_factor = disc_factor 63 | self.discriminator_weight = disc_weight 64 | self.disc_conditional = disc_conditional 65 | 66 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 67 | if last_layer is not None: 68 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 70 | else: 71 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 72 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 73 | 74 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 75 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 76 | d_weight = d_weight * self.discriminator_weight 77 | return d_weight 78 | 79 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 80 | global_step, last_layer=None, cond=None, split="train", xrec_aux=None): 81 | 82 | aux_downscale = self.aux_downscale 83 | 84 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 85 | if self.perceptual_weight > 0: 86 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 87 | rec_loss = rec_loss + self.perceptual_weight * p_loss 88 | else: 89 | p_loss = torch.tensor([0.0]) 90 | 91 | if xrec_aux is not None: 92 | # print(aux_downscale) 93 | inputs_aux = F.interpolate(inputs, scale_factor=1./aux_downscale) 94 | inputs_aux = F.interpolate(inputs_aux, scale_factor=aux_downscale, mode='bilinear') 95 | # inputs_cv = torchvision.utils.make_grid(inputs) 96 | # inputs_aux_cv = torchvision.utils.make_grid(inputs_aux) 97 | # torchvision.utils.save_image(inputs_cv, "input.png") 98 | # torchvision.utils.save_image(inputs_aux_cv, "input_aux.png") 99 | rec_aux_loss = torch.abs(inputs_aux.contiguous() - xrec_aux.contiguous()) 100 | rec_loss = rec_loss + 0.5 * rec_aux_loss 101 | else: 102 | rec_aux_loss = torch.tensor([0.0]) 103 | 104 | nll_loss = rec_loss 105 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 106 | nll_loss = torch.mean(nll_loss) 107 | 108 | # now the GAN part 109 | if optimizer_idx == 0: 110 | # generator update 111 | if cond is None: 112 | assert not self.disc_conditional 113 | logits_fake = self.discriminator(reconstructions.contiguous()) 114 | else: 115 | assert self.disc_conditional 116 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 117 | g_loss = -torch.mean(logits_fake) 118 | 119 | try: 120 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 121 | except RuntimeError: 122 | assert not self.training 123 | d_weight = torch.tensor(0.0) 124 | 125 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 126 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 127 | 128 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 129 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 130 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 131 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 132 | "{}/p_loss".format(split): p_loss.detach().mean(), 133 | "{}/rec_aux_loss".format(split): rec_aux_loss.detach().mean(), 134 | "{}/d_weight".format(split): d_weight.detach(), 135 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 136 | "{}/g_loss".format(split): g_loss.detach().mean(), 137 | } 138 | return loss, log 139 | 140 | if optimizer_idx == 1: 141 | # second pass for discriminator update 142 | if cond is None: 143 | logits_real = self.discriminator(inputs.contiguous().detach()) 144 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 145 | else: 146 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 147 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 148 | 149 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 150 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 151 | 152 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 153 | "{}/logits_real".format(split): logits_real.detach().mean(), 154 | "{}/logits_fake".format(split): logits_fake.detach().mean() 155 | } 156 | return d_loss, log 157 | -------------------------------------------------------------------------------- /taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /taming/modules/misc/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /taming/modules/transformer/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractPermuter(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | def forward(self, x, reverse=False): 10 | raise NotImplementedError 11 | 12 | 13 | class Identity(AbstractPermuter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, reverse=False): 18 | return x 19 | 20 | 21 | class Subsample(AbstractPermuter): 22 | def __init__(self, H, W): 23 | super().__init__() 24 | C = 1 25 | indices = np.arange(H*W).reshape(C,H,W) 26 | while min(H, W) > 1: 27 | indices = indices.reshape(C,H//2,2,W//2,2) 28 | indices = indices.transpose(0,2,4,1,3) 29 | indices = indices.reshape(C*4,H//2, W//2) 30 | H = H//2 31 | W = W//2 32 | C = C*4 33 | assert H == W == 1 34 | idx = torch.tensor(indices.ravel()) 35 | self.register_buffer('forward_shuffle_idx', 36 | nn.Parameter(idx, requires_grad=False)) 37 | self.register_buffer('backward_shuffle_idx', 38 | nn.Parameter(torch.argsort(idx), requires_grad=False)) 39 | 40 | def forward(self, x, reverse=False): 41 | if not reverse: 42 | return x[:, self.forward_shuffle_idx] 43 | else: 44 | return x[:, self.backward_shuffle_idx] 45 | 46 | 47 | def mortonify(i, j): 48 | """(i,j) index to linear morton code""" 49 | i = np.uint64(i) 50 | j = np.uint64(j) 51 | 52 | z = np.uint(0) 53 | 54 | for pos in range(32): 55 | z = (z | 56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | 57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) 58 | ) 59 | return z 60 | 61 | 62 | class ZCurve(AbstractPermuter): 63 | def __init__(self, H, W): 64 | super().__init__() 65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] 66 | idx = np.argsort(reverseidx) 67 | idx = torch.tensor(idx) 68 | reverseidx = torch.tensor(reverseidx) 69 | self.register_buffer('forward_shuffle_idx', 70 | idx) 71 | self.register_buffer('backward_shuffle_idx', 72 | reverseidx) 73 | 74 | def forward(self, x, reverse=False): 75 | if not reverse: 76 | return x[:, self.forward_shuffle_idx] 77 | else: 78 | return x[:, self.backward_shuffle_idx] 79 | 80 | 81 | class SpiralOut(AbstractPermuter): 82 | def __init__(self, H, W): 83 | super().__init__() 84 | assert H == W 85 | size = W 86 | indices = np.arange(size*size).reshape(size,size) 87 | 88 | i0 = size//2 89 | j0 = size//2-1 90 | 91 | i = i0 92 | j = j0 93 | 94 | idx = [indices[i0, j0]] 95 | step_mult = 0 96 | for c in range(1, size//2+1): 97 | step_mult += 1 98 | # steps left 99 | for k in range(step_mult): 100 | i = i - 1 101 | j = j 102 | idx.append(indices[i, j]) 103 | 104 | # step down 105 | for k in range(step_mult): 106 | i = i 107 | j = j + 1 108 | idx.append(indices[i, j]) 109 | 110 | step_mult += 1 111 | if c < size//2: 112 | # step right 113 | for k in range(step_mult): 114 | i = i + 1 115 | j = j 116 | idx.append(indices[i, j]) 117 | 118 | # step up 119 | for k in range(step_mult): 120 | i = i 121 | j = j - 1 122 | idx.append(indices[i, j]) 123 | else: 124 | # end reached 125 | for k in range(step_mult-1): 126 | i = i + 1 127 | idx.append(indices[i, j]) 128 | 129 | assert len(idx) == size*size 130 | idx = torch.tensor(idx) 131 | self.register_buffer('forward_shuffle_idx', idx) 132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 133 | 134 | def forward(self, x, reverse=False): 135 | if not reverse: 136 | return x[:, self.forward_shuffle_idx] 137 | else: 138 | return x[:, self.backward_shuffle_idx] 139 | 140 | 141 | class SpiralIn(AbstractPermuter): 142 | def __init__(self, H, W): 143 | super().__init__() 144 | assert H == W 145 | size = W 146 | indices = np.arange(size*size).reshape(size,size) 147 | 148 | i0 = size//2 149 | j0 = size//2-1 150 | 151 | i = i0 152 | j = j0 153 | 154 | idx = [indices[i0, j0]] 155 | step_mult = 0 156 | for c in range(1, size//2+1): 157 | step_mult += 1 158 | # steps left 159 | for k in range(step_mult): 160 | i = i - 1 161 | j = j 162 | idx.append(indices[i, j]) 163 | 164 | # step down 165 | for k in range(step_mult): 166 | i = i 167 | j = j + 1 168 | idx.append(indices[i, j]) 169 | 170 | step_mult += 1 171 | if c < size//2: 172 | # step right 173 | for k in range(step_mult): 174 | i = i + 1 175 | j = j 176 | idx.append(indices[i, j]) 177 | 178 | # step up 179 | for k in range(step_mult): 180 | i = i 181 | j = j - 1 182 | idx.append(indices[i, j]) 183 | else: 184 | # end reached 185 | for k in range(step_mult-1): 186 | i = i + 1 187 | idx.append(indices[i, j]) 188 | 189 | assert len(idx) == size*size 190 | idx = idx[::-1] 191 | idx = torch.tensor(idx) 192 | self.register_buffer('forward_shuffle_idx', idx) 193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 194 | 195 | def forward(self, x, reverse=False): 196 | if not reverse: 197 | return x[:, self.forward_shuffle_idx] 198 | else: 199 | return x[:, self.backward_shuffle_idx] 200 | 201 | 202 | class Random(nn.Module): 203 | def __init__(self, H, W): 204 | super().__init__() 205 | indices = np.random.RandomState(1).permutation(H*W) 206 | idx = torch.tensor(indices.ravel()) 207 | self.register_buffer('forward_shuffle_idx', idx) 208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 209 | 210 | def forward(self, x, reverse=False): 211 | if not reverse: 212 | return x[:, self.forward_shuffle_idx] 213 | else: 214 | return x[:, self.backward_shuffle_idx] 215 | 216 | 217 | class AlternateParsing(AbstractPermuter): 218 | def __init__(self, H, W): 219 | super().__init__() 220 | indices = np.arange(W*H).reshape(H,W) 221 | for i in range(1, H, 2): 222 | indices[i, :] = indices[i, ::-1] 223 | idx = indices.flatten() 224 | assert len(idx) == H*W 225 | idx = torch.tensor(idx) 226 | self.register_buffer('forward_shuffle_idx', idx) 227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 228 | 229 | def forward(self, x, reverse=False): 230 | if not reverse: 231 | return x[:, self.forward_shuffle_idx] 232 | else: 233 | return x[:, self.backward_shuffle_idx] 234 | 235 | 236 | if __name__ == "__main__": 237 | p0 = AlternateParsing(16, 16) 238 | print(p0.forward_shuffle_idx) 239 | print(p0.backward_shuffle_idx) 240 | 241 | x = torch.randint(0, 768, size=(11, 256)) 242 | y = p0(x) 243 | xre = p0(y, reverse=True) 244 | assert torch.equal(x, xre) 245 | 246 | p1 = SpiralOut(2, 2) 247 | print(p1.forward_shuffle_idx) 248 | print(p1.backward_shuffle_idx) 249 | -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /taming/modules/vqvae/mapping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import copy 9 | import datetime 10 | import os 11 | import random 12 | import time 13 | import timeit 14 | import warnings 15 | from collections import OrderedDict 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn.functional import interpolate 21 | from torch.nn.modules.sparse import Embedding 22 | 23 | class PixelNormLayer(nn.Module): 24 | def __init__(self, epsilon=1e-8): 25 | super().__init__() 26 | self.epsilon = epsilon 27 | 28 | def forward(self, x): 29 | return x * torch.rsqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon) 30 | 31 | 32 | class EqualizedLinear(nn.Module): 33 | """Linear layer with equalized learning rate and custom learning rate multiplier.""" 34 | 35 | def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True): 36 | super().__init__() 37 | he_std = gain * input_size ** (-0.5) # He init 38 | # Equalized learning rate and custom learning rate multiplier. 39 | if use_wscale: 40 | init_std = 1.0 / lrmul 41 | self.w_mul = he_std * lrmul 42 | else: 43 | init_std = he_std / lrmul 44 | self.w_mul = lrmul 45 | self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) 46 | if bias: 47 | self.bias = torch.nn.Parameter(torch.zeros(output_size)) 48 | self.b_mul = lrmul 49 | else: 50 | self.bias = None 51 | 52 | def forward(self, x): 53 | bias = self.bias 54 | if bias is not None: 55 | bias = bias * self.b_mul 56 | return F.linear(x, self.weight * self.w_mul, bias) 57 | 58 | class GMapping(nn.Module): 59 | 60 | def __init__(self, latent_size=512, dlatent_size=512, dlatent_broadcast=None, 61 | mapping_layers=8, mapping_fmaps=512, mapping_lrmul=0.01, mapping_nonlinearity='lrelu', 62 | use_wscale=True, normalize_latents=False, **kwargs): 63 | super().__init__() 64 | 65 | self.latent_size = latent_size 66 | self.mapping_fmaps = mapping_fmaps 67 | self.dlatent_size = dlatent_size 68 | self.dlatent_broadcast = dlatent_broadcast 69 | 70 | # Activation function. 71 | act, gain = {'relu': (torch.relu, np.sqrt(2)), 72 | 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[mapping_nonlinearity] 73 | 74 | # Embed labels and concatenate them with latents. 75 | # TODO 76 | 77 | layers = [] 78 | # Normalize latents. 79 | if normalize_latents: 80 | layers.append(('pixel_norm', PixelNormLayer())) 81 | 82 | # Mapping layers. (apply_bias?) 83 | layers.append(('dense0', EqualizedLinear(self.latent_size, self.mapping_fmaps, 84 | gain=gain, lrmul=mapping_lrmul, use_wscale=use_wscale))) 85 | layers.append(('dense0_act', act)) 86 | for layer_idx in range(1, mapping_layers): 87 | fmaps_in = self.mapping_fmaps 88 | fmaps_out = self.dlatent_size if layer_idx == mapping_layers - 1 else self.mapping_fmaps 89 | layers.append( 90 | ('dense{:d}'.format(layer_idx), 91 | EqualizedLinear(fmaps_in, fmaps_out, gain=gain, lrmul=mapping_lrmul, use_wscale=use_wscale))) 92 | layers.append(('dense{:d}_act'.format(layer_idx), act)) 93 | 94 | # Output. 95 | self.map = nn.Sequential(OrderedDict(layers)) 96 | 97 | def forward(self, x): 98 | # First input: Latent vectors (Z) [mini_batch, latent_size]. 99 | x = self.map(x) 100 | 101 | # Broadcast -> batch_size * dlatent_broadcast * dlatent_size 102 | if self.dlatent_broadcast is not None: 103 | x = x.unsqueeze(1).expand(-1, self.dlatent_broadcast, -1) 104 | return x 105 | 106 | if __name__=='__main__': 107 | 108 | m = GMapping(latent_size=256, dlatent_size=256) 109 | 110 | x = torch.randn(10, 16, 16, 256) 111 | o=m(x) 112 | print(o.shape) 113 | -------------------------------------------------------------------------------- /taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | 19 | def download(url, local_path, chunk_size=1024): 20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 21 | with requests.get(url, stream=True) as r: 22 | total_size = int(r.headers.get("content-length", 0)) 23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 24 | with open(local_path, "wb") as f: 25 | for data in r.iter_content(chunk_size=chunk_size): 26 | if data: 27 | f.write(data) 28 | pbar.update(chunk_size) 29 | 30 | 31 | def md5_hash(path): 32 | with open(path, "rb") as f: 33 | content = f.read() 34 | return hashlib.md5(content).hexdigest() 35 | 36 | 37 | def get_ckpt_path(name, root, check=False): 38 | assert name in URL_MAP 39 | path = os.path.join(root, CKPT_MAP[name]) 40 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 41 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 42 | download(URL_MAP[name], path) 43 | md5 = md5_hash(path) 44 | assert md5 == MD5_MAP[name], md5 45 | return path 46 | 47 | 48 | class KeyNotFoundError(Exception): 49 | def __init__(self, cause, keys=None, visited=None): 50 | self.cause = cause 51 | self.keys = keys 52 | self.visited = visited 53 | messages = list() 54 | if keys is not None: 55 | messages.append("Key not found: {}".format(keys)) 56 | if visited is not None: 57 | messages.append("Visited: {}".format(visited)) 58 | messages.append("Cause:\n{}".format(cause)) 59 | message = "\n".join(messages) 60 | super().__init__(message) 61 | 62 | 63 | def retrieve( 64 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 65 | ): 66 | """Given a nested list or dict return the desired value at key expanding 67 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 68 | is done in-place. 69 | 70 | Parameters 71 | ---------- 72 | list_or_dict : list or dict 73 | Possibly nested list or dictionary. 74 | key : str 75 | key/to/value, path like string describing all keys necessary to 76 | consider to get to the desired value. List indices can also be 77 | passed here. 78 | splitval : str 79 | String that defines the delimiter between keys of the 80 | different depth levels in `key`. 81 | default : obj 82 | Value returned if :attr:`key` is not found. 83 | expand : bool 84 | Whether to expand callable nodes on the path or not. 85 | 86 | Returns 87 | ------- 88 | The desired value or if :attr:`default` is not ``None`` and the 89 | :attr:`key` is not found returns ``default``. 90 | 91 | Raises 92 | ------ 93 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 94 | ``None``. 95 | """ 96 | 97 | keys = key.split(splitval) 98 | 99 | success = True 100 | try: 101 | visited = [] 102 | parent = None 103 | last_key = None 104 | for key in keys: 105 | if callable(list_or_dict): 106 | if not expand: 107 | raise KeyNotFoundError( 108 | ValueError( 109 | "Trying to get past callable node with expand=False." 110 | ), 111 | keys=keys, 112 | visited=visited, 113 | ) 114 | list_or_dict = list_or_dict() 115 | parent[last_key] = list_or_dict 116 | 117 | last_key = key 118 | parent = list_or_dict 119 | 120 | try: 121 | if isinstance(list_or_dict, dict): 122 | list_or_dict = list_or_dict[key] 123 | else: 124 | list_or_dict = list_or_dict[int(key)] 125 | except (KeyError, IndexError, ValueError) as e: 126 | raise KeyNotFoundError(e, keys=keys, visited=visited) 127 | 128 | visited += [key] 129 | # final expansion of retrieved value 130 | if expand and callable(list_or_dict): 131 | list_or_dict = list_or_dict() 132 | parent[last_key] = list_or_dict 133 | except KeyNotFoundError as e: 134 | if default is None: 135 | raise e 136 | else: 137 | list_or_dict = default 138 | success = False 139 | 140 | if not pass_success: 141 | return list_or_dict 142 | else: 143 | return list_or_dict, success 144 | 145 | 146 | if __name__ == "__main__": 147 | config = {"keya": "a", 148 | "keyb": "b", 149 | "keyc": 150 | {"cc1": 1, 151 | "cc2": 2, 152 | } 153 | } 154 | from omegaconf import OmegaConf 155 | config = OmegaConf.create(config) 156 | print(config) 157 | retrieve(config, "keya") 158 | 159 | -------------------------------------------------------------------------------- /tools/download_datasets.sh: -------------------------------------------------------------------------------- 1 | set -e # exit script if error 2 | 3 | echo "Please install wget first!" 4 | echo "Auto set up datasets in \"../datasets\"" 5 | 6 | echo "Create dataset folders and subfolders" 7 | mkdir datasets 8 | mkdir datasets/coco 9 | mkdir datasets/coco/2014 10 | mkdir datasets/coco/2017 11 | 12 | echo "Download coco 2014 datasets (valid split) in \"./datasets/coco/2014\"." 13 | wget http://images.cocodataset.org/zips/val2014.zip 14 | wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip 15 | 16 | mv val2014.zip datasets/coco/2014 17 | mv annotations_trainval2014.zip datasets/coco/2014 18 | unzip datasets/coco/2014/val2014.zip -d datasets/coco/2014 19 | unzip datasets/coco/2014/annotations_trainval2014.zip -d datasets/coco/2014 20 | 21 | echo "Download coco 2017 datasets (valid split) in \"./datasets/coco/2017\"." 22 | wget http://images.cocodataset.org/zips/val2017.zip 23 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 24 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip 25 | 26 | mv val2017.zip datasets/coco/2017 27 | mv annotations_trainval2017.zip datasets/coco/2017 28 | mv stuff_annotations_trainval2017.zip datasets/coco/2017 29 | unzip datasets/coco/2017/val2017.zip -d datasets/coco/2017 30 | unzip datasets/coco/2017/annotations_trainval2017.zip -d datasets/coco/2017 31 | unzip datasets/coco/2017/stuff_annotations_trainval2017.zip -d datasets/coco/2017 32 | 33 | echo "Remove cache files..." 34 | rm -rf datasets/coco/2014/val2014.zip 35 | rm -rf datasets/coco/2014/annotations_trainval2014.zip 36 | rm -rf datasets/coco/2017/val2017.zip 37 | rm -rf datasets/coco/2017/annotations_trainval2017.zip 38 | rm -rf datasets/coco/2017/stuff_annotations_trainval2017.zip 39 | 40 | echo "Move datasets to \"../datasets\"" 41 | mv datasets ../ -------------------------------------------------------------------------------- /tools/download_models.sh: -------------------------------------------------------------------------------- 1 | wget https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 2 | 3 | mkdir exp 4 | mkdir exp/vqgan 5 | mkdir exp/vqgan/vq-f8 6 | 7 | mv ./vq-f8.zip exp/vqgan/vq-f8/ 8 | 9 | cd exp/vqgan/vq-f8/ 10 | unzip vq-f8.zip 11 | rm vq-f8.zip 12 | cd ../../ -------------------------------------------------------------------------------- /tools/ldm/train_ldm_coco_Layout2I.sh: -------------------------------------------------------------------------------- 1 | 2 | python main.py --base configs/ldm/coco_sg2im_ldm_Layout2I_vqgan_f8.yaml \ 3 | -t True --gpus 1 -log_dir ./exp/ldm/Layout2I \ 4 | -n coco_sg2im_ldm_Layout2I_vqgan_f8 --scale_lr False -tb True 5 | -------------------------------------------------------------------------------- /tools/ldm/train_ldm_coco_T2I.sh: -------------------------------------------------------------------------------- 1 | 2 | python main.py --base configs/ldm/coco_stuff_ldm_T2I_vqgan_f8.yaml \ 3 | -t True --gpus 1 -log_dir ./exp/ldm/T2I \ 4 | -n coco_stuff_ldm_T2I_vqgan_f8 --scale_lr False -tb True 5 | -------------------------------------------------------------------------------- /tools/vqgan/train_vqgan_coco.sh: -------------------------------------------------------------------------------- 1 | python main.py -t True --base configs/vqgan/coco_vqgan_f8.yaml \ 2 | --gpus 1 -log_dir ./exp/vqgan -n coco_vqgan_f8 --------------------------------------------------------------------------------