├── .gitignore ├── DATA ├── README.MD ├── gatedSA_first_conv.jpeg ├── mydata_to_tsv.py ├── process_grounding.py └── tsv.py ├── LICENSE ├── README.md ├── SD_input_conv_weight_bias.pth ├── color150.mat ├── configs ├── GoldG+SBU+CC3M+O365_box_text.yaml ├── GoldG+SBU+CC3M+O365_box_text_image.yaml ├── ade_sem.yaml ├── cc3m_canny.yaml ├── cc3m_depth.yaml ├── cc3m_hed.yaml ├── coco2017K.yaml ├── diode_normal.yaml ├── flickr_text.yaml └── flickr_text_image.yaml ├── convert_ckpt.py ├── dataset ├── __init__.py ├── base_dataset.py ├── base_dataset_kp.py ├── catalog.py ├── concat_dataset.py ├── dataset_canny.py ├── dataset_depth.py ├── dataset_hed.py ├── dataset_kp.py ├── dataset_normal.py ├── dataset_sem.py ├── tsv.py ├── tsv_dataset.py └── utils.py ├── demo ├── .gitignore ├── DejaVuSansMono.ttf ├── README.md ├── __init__.py ├── app.py ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ ├── catalog.py │ ├── cd_dataset.py │ ├── concat_dataset.py │ ├── grounding_dataset.py │ ├── layout_dataset.py │ ├── tsv.py │ ├── tsv_dataset.py │ └── utils.py ├── environment.yaml ├── environment_cpu_mps.yaml ├── gligen │ ├── __init__.py │ ├── create_meta.py │ ├── distributed.py │ ├── evaluator.py │ ├── image_projection_matrix │ ├── ldm │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── imagenet.py │ │ │ ├── imagenet_clsidx_to_label.txt │ │ │ ├── imagenet_train_hr_indices.p │ │ │ ├── imagenet_val_hr_indices.p │ │ │ ├── index_synset.yaml │ │ │ └── lsun.py │ │ ├── lr_scheduler.py │ │ ├── models │ │ │ ├── autoencoder.py │ │ │ └── diffusion │ │ │ │ ├── __init__.py │ │ │ │ ├── classifier.py │ │ │ │ ├── ddim.py │ │ │ │ ├── ddpm.py │ │ │ │ ├── ldm.py │ │ │ │ └── plms.py │ │ ├── modules │ │ │ ├── attention.py │ │ │ ├── diffusionmodules │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ ├── openaimodel.py │ │ │ │ ├── positionnet.py │ │ │ │ ├── positionnet_with_image.py │ │ │ │ └── util.py │ │ │ ├── distributions │ │ │ │ ├── __init__.py │ │ │ │ └── distributions.py │ │ │ ├── ema.py │ │ │ ├── encoders │ │ │ │ ├── __init__.py │ │ │ │ ├── modules.py │ │ │ │ └── modules_backup.py │ │ │ ├── image_degradation │ │ │ │ ├── __init__.py │ │ │ │ ├── bsrgan.py │ │ │ │ ├── bsrgan_light.py │ │ │ │ └── utils_image.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── contperceptual.py │ │ │ │ └── vqperceptual.py │ │ │ └── x_transformer.py │ │ └── util.py │ ├── projection_matrix │ ├── task_grounded_generation.py │ └── trainer.py └── images │ ├── arg_corgis.jpeg │ ├── blank.png │ ├── flower_beach.jpg │ ├── red_bird.jpg │ ├── style_cloudpurple.png │ ├── style_gold.png │ └── teddy.jpg ├── distributed.py ├── docs ├── gligen_controlnet.jpeg ├── gligen_vs_controlnet.MD └── unet.jpeg ├── env_docker └── Dockerfile ├── figures └── teaser_v4.png ├── gligen_inference.py ├── grounding_input ├── __init__.py ├── canny_grounding_downsampler_input.py ├── canny_grounding_tokinzer_input.py ├── depth_grounding_downsampler_input.py ├── depth_grounding_tokinzer_input.py ├── hed_grounding_downsampler_input.py ├── hed_grounding_tokinzer_input.py ├── keypoint_grounding_tokinzer_input.py ├── normal_grounding_downsampler_input.py ├── normal_grounding_tokinzer_input.py ├── sem_grounding_downsampler_input.py ├── sem_grounding_tokinzer_input.py ├── text_grounding_tokinzer_input.py └── text_image_grounding_tokinzer_input.py ├── inference_images ├── beach.jpg ├── bigben.jpg ├── canny_robot.png ├── clock.png ├── dalle2_museum.jpg ├── depth_bird.png ├── hed_man_eat.png ├── normal_tree_building.jpg ├── placeholder.png ├── readme.txt ├── sem_ade_living_room.png └── style_golden.jpg ├── inpaint_mask_func.py ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ ├── imagenet_clsidx_to_label.txt │ ├── imagenet_train_hr_indices.p │ ├── imagenet_val_hr_indices.p │ ├── index_synset.yaml │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── ldm.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── canny_grounding_downsampler.py │ │ ├── canny_grounding_net.py │ │ ├── convnext.py │ │ ├── depth_grounding_downsampler.py │ │ ├── depth_grounding_net.py │ │ ├── grounding_net_example.py │ │ ├── hed_grounding_downsampler.py │ │ ├── hed_grounding_net.py │ │ ├── keypoint_grounding_net.py │ │ ├── model.py │ │ ├── normal_grounding_downsampler.py │ │ ├── normal_grounding_net.py │ │ ├── openaimodel.py │ │ ├── pseudo_example.py │ │ ├── resnet.py │ │ ├── sem_grounding_downsampler.py │ │ ├── sem_grounding_net.py │ │ ├── text_grounding_net.py │ │ ├── text_image_grounding_net.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── modules.py │ │ └── modules_backup.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── projection_matrix ├── trainer.py └── tsv_split_merge.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ project files 2 | .idea 3 | *.iml 4 | out 5 | gen 6 | 7 | ### Vim template 8 | [._]*.s[a-w][a-z] 9 | [._]s[a-w][a-z] 10 | *.un~ 11 | Session.vim 12 | .netrwhist 13 | *~ 14 | 15 | ### IPythonNotebook template 16 | # Temporary data 17 | .ipynb_checkpoints/ 18 | 19 | ### Python template 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | env/ 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | #lib/ 38 | #lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *,cover 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | *.ipynb 80 | *.params 81 | # *.json 82 | .vscode/ 83 | *.code-workspace/ 84 | 85 | lib/pycocotools/_mask.c 86 | lib/nms/cpu_nms.c 87 | 88 | -------------------------------------------------------------------------------- /DATA/gatedSA_first_conv.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/DATA/gatedSA_first_conv.jpeg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yuheng Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SD_input_conv_weight_bias.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/SD_input_conv_weight_bias.pth -------------------------------------------------------------------------------- /color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/color150.mat -------------------------------------------------------------------------------- /configs/GoldG+SBU+CC3M+O365_box_text.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | grounding_tokenizer: 26 | target: ldm.modules.diffusionmodules.text_grounding_net.PositionNet 27 | params: 28 | in_dim: 768 # this is pre-processing feature dim from CLIP Text encoder; penultimate feature 29 | out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension 30 | 31 | 32 | autoencoder: 33 | target: ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | scale_factor: 0.18215 36 | embed_dim: 4 37 | ddconfig: 38 | double_z: true 39 | z_channels: 4 40 | resolution: 256 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: 45 | - 1 46 | - 2 47 | - 4 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | 53 | 54 | text_encoder: 55 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 56 | 57 | 58 | 59 | 60 | train_dataset_names: 61 | VGGrounding: 62 | which_layer_text: before 63 | image_size: 512 64 | max_boxes_per_data: 30 65 | prob_use_caption: 0.5 66 | random_crop: False 67 | random_flip: True 68 | FlickrGrounding: 69 | which_layer_text: before 70 | image_size: 512 71 | max_boxes_per_data: 30 72 | prob_use_caption: 0.5 73 | random_crop: False 74 | random_flip: True 75 | SBUGrounding: 76 | which_layer_text: before 77 | image_size: 512 78 | max_boxes_per_data: 30 79 | prob_use_caption: 0.5 80 | random_crop: False 81 | random_flip: True 82 | CC3MGrounding: 83 | which_layer_text: before 84 | image_size: 512 85 | max_boxes_per_data: 30 86 | prob_use_caption: 0.5 87 | random_crop: False 88 | random_flip: True 89 | Obj365Detection: 90 | which_layer_text: before 91 | image_size: 512 92 | max_boxes_per_data: 30 93 | prob_use_caption: 0.5 94 | random_crop: False 95 | random_flip: True 96 | 97 | 98 | 99 | 100 | grounding_tokenizer_input: 101 | target: grounding_input.text_grounding_tokinzer_input.GroundingNetInput 102 | -------------------------------------------------------------------------------- /configs/GoldG+SBU+CC3M+O365_box_text_image.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA 23 | use_checkpoint: True 24 | 25 | grounding_tokenizer: 26 | target: ldm.modules.diffusionmodules.text_image_grounding_net.PositionNet 27 | params: 28 | in_dim: 768 # this is pre-processing feature dim from CLIP Text encoder; penultimate feature 29 | out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension 30 | 31 | 32 | autoencoder: 33 | target: ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | scale_factor: 0.18215 36 | embed_dim: 4 37 | ddconfig: 38 | double_z: true 39 | z_channels: 4 40 | resolution: 256 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: 45 | - 1 46 | - 2 47 | - 4 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | 53 | 54 | text_encoder: 55 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 56 | 57 | 58 | 59 | train_dataset_names: 60 | VGGrounding: 61 | which_layer_text: before 62 | which_layer_image: after_reproject 63 | image_size: 512 64 | max_boxes_per_data: 30 65 | prob_use_caption: 0.5 66 | random_drop_embedding: both 67 | random_crop: False 68 | random_flip: True 69 | FlickrGrounding: 70 | which_layer_text: before 71 | which_layer_image: after_reproject 72 | image_size: 512 73 | max_boxes_per_data: 30 74 | prob_use_caption: 0.5 75 | random_drop_embedding: both 76 | random_crop: False 77 | random_flip: True 78 | SBUGrounding: 79 | which_layer_text: before 80 | which_layer_image: after_reproject 81 | image_size: 512 82 | max_boxes_per_data: 30 83 | prob_use_caption: 0.5 84 | random_drop_embedding: both 85 | random_crop: False 86 | random_flip: True 87 | CC3MGrounding: 88 | which_layer_text: before 89 | which_layer_image: after_reproject 90 | image_size: 512 91 | max_boxes_per_data: 30 92 | prob_use_caption: 0.5 93 | random_drop_embedding: both 94 | random_crop: False 95 | random_flip: True 96 | Obj365Detection: 97 | which_layer_text: before 98 | which_layer_image: after_reproject 99 | image_size: 512 100 | max_boxes_per_data: 30 101 | prob_use_caption: 0.5 102 | random_drop_embedding: both 103 | random_crop: False 104 | random_flip: True 105 | 106 | 107 | grounding_tokenizer_input: 108 | target: grounding_input.text_image_grounding_tokinzer_input.GroundingNetInput 109 | -------------------------------------------------------------------------------- /configs/ade_sem.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | 26 | grounding_downsampler: 27 | target: ldm.modules.diffusionmodules.sem_grounding_downsampler.GroundingDownsampler 28 | params: 29 | in_dim: 152 30 | resize_input: 256 31 | out_dim: 8 32 | 33 | grounding_tokenizer: 34 | target: ldm.modules.diffusionmodules.sem_grounding_net.PositionNet 35 | params: 36 | in_dim: 152 37 | resize_input: 256 38 | out_dim: 768 39 | 40 | 41 | autoencoder: 42 | target: ldm.models.autoencoder.AutoencoderKL 43 | params: 44 | scale_factor: 0.18215 45 | embed_dim: 4 46 | ddconfig: 47 | double_z: true 48 | z_channels: 4 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: [] 60 | dropout: 0.0 61 | 62 | 63 | text_encoder: 64 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 65 | 66 | 67 | 68 | 69 | train_dataset_names: 70 | ADESemantic: 71 | prob_use_caption: 1 72 | image_size: 512 73 | random_flip: True 74 | 75 | 76 | grounding_tokenizer_input: 77 | target: grounding_input.sem_grounding_tokinzer_input.GroundingNetInput 78 | 79 | 80 | grounding_downsampler_input: 81 | target: grounding_input.sem_grounding_downsampler_input.GroundingDSInput 82 | -------------------------------------------------------------------------------- /configs/cc3m_canny.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | 26 | grounding_downsampler: 27 | target: ldm.modules.diffusionmodules.canny_grounding_downsampler.GroundingDownsampler 28 | params: 29 | resize_input: 256 30 | out_dim: 8 31 | 32 | grounding_tokenizer: 33 | target: ldm.modules.diffusionmodules.canny_grounding_net.PositionNet 34 | params: 35 | resize_input: 256 36 | out_dim: 768 37 | 38 | 39 | autoencoder: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | scale_factor: 0.18215 43 | embed_dim: 4 44 | ddconfig: 45 | double_z: true 46 | z_channels: 4 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | - 4 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | 60 | 61 | text_encoder: 62 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 63 | 64 | 65 | 66 | 67 | train_dataset_names: 68 | CC3MGroundingCanny: 69 | prob_use_caption: 1 70 | image_size: 512 71 | random_flip: True 72 | 73 | 74 | grounding_tokenizer_input: 75 | target: grounding_input.canny_grounding_tokinzer_input.GroundingNetInput 76 | 77 | 78 | grounding_downsampler_input: 79 | target: grounding_input.canny_grounding_downsampler_input.GroundingDSInput 80 | -------------------------------------------------------------------------------- /configs/cc3m_depth.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | 26 | grounding_downsampler: 27 | target: ldm.modules.diffusionmodules.depth_grounding_downsampler.GroundingDownsampler 28 | params: 29 | resize_input: 256 30 | out_dim: 8 31 | 32 | grounding_tokenizer: 33 | target: ldm.modules.diffusionmodules.depth_grounding_net.PositionNet 34 | params: 35 | resize_input: 256 36 | out_dim: 768 37 | 38 | 39 | autoencoder: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | scale_factor: 0.18215 43 | embed_dim: 4 44 | ddconfig: 45 | double_z: true 46 | z_channels: 4 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | - 4 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | 60 | 61 | text_encoder: 62 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 63 | 64 | 65 | 66 | 67 | train_dataset_names: 68 | CC3MGroundingDepth: 69 | prob_use_caption: 1 70 | image_size: 512 71 | random_flip: True 72 | 73 | 74 | grounding_tokenizer_input: 75 | target: grounding_input.depth_grounding_tokinzer_input.GroundingNetInput 76 | 77 | 78 | grounding_downsampler_input: 79 | target: grounding_input.depth_grounding_downsampler_input.GroundingDSInput 80 | -------------------------------------------------------------------------------- /configs/cc3m_hed.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | 26 | grounding_downsampler: 27 | target: ldm.modules.diffusionmodules.hed_grounding_downsampler.GroundingDownsampler 28 | params: 29 | out_dim: 1 30 | 31 | grounding_tokenizer: 32 | target: ldm.modules.diffusionmodules.hed_grounding_net.PositionNet 33 | params: 34 | resize_input: 256 35 | out_dim: 768 36 | 37 | 38 | autoencoder: 39 | target: ldm.models.autoencoder.AutoencoderKL 40 | params: 41 | scale_factor: 0.18215 42 | embed_dim: 4 43 | ddconfig: 44 | double_z: true 45 | z_channels: 4 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | - 4 55 | num_res_blocks: 2 56 | attn_resolutions: [] 57 | dropout: 0.0 58 | 59 | 60 | text_encoder: 61 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 62 | 63 | 64 | 65 | 66 | train_dataset_names: 67 | CC3MGroundingHed: 68 | prob_use_caption: 1 69 | image_size: 512 70 | random_flip: True 71 | 72 | 73 | grounding_tokenizer_input: 74 | target: grounding_input.hed_grounding_tokinzer_input.GroundingNetInput 75 | 76 | 77 | grounding_downsampler_input: 78 | target: grounding_input.hed_grounding_downsampler_input.GroundingDSInput 79 | -------------------------------------------------------------------------------- /configs/coco2017K.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA 23 | use_checkpoint: True 24 | 25 | grounding_tokenizer: 26 | target: ldm.modules.diffusionmodules.keypoint_grounding_net.PositionNet 27 | params: 28 | max_persons_per_image: 8 # must same as the one in dataset 29 | out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension 30 | 31 | 32 | autoencoder: 33 | target: ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | scale_factor: 0.18215 36 | embed_dim: 4 37 | ddconfig: 38 | double_z: true 39 | z_channels: 4 40 | resolution: 256 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: 45 | - 1 46 | - 2 47 | - 4 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | 53 | 54 | text_encoder: 55 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 56 | 57 | 58 | 59 | 60 | train_dataset_names: 61 | COCO2017Keypoint: 62 | image_size: 512 63 | prob_real_caption: 1 64 | max_persons_per_image: 8 # This must be same as the one in Model 65 | random_flip: True 66 | 67 | 68 | grounding_tokenizer_input: 69 | target: grounding_input.keypoint_grounding_tokinzer_input.GroundingNetInput 70 | -------------------------------------------------------------------------------- /configs/diode_normal.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | 26 | grounding_downsampler: 27 | target: ldm.modules.diffusionmodules.normal_grounding_downsampler.GroundingDownsampler 28 | params: 29 | resize_input: 256 30 | out_dim: 8 31 | 32 | grounding_tokenizer: 33 | target: ldm.modules.diffusionmodules.normal_grounding_net.PositionNet 34 | params: 35 | resize_input: 256 36 | out_dim: 768 37 | 38 | 39 | autoencoder: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | scale_factor: 0.18215 43 | embed_dim: 4 44 | ddconfig: 45 | double_z: true 46 | z_channels: 4 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | - 4 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | 60 | 61 | text_encoder: 62 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 63 | 64 | 65 | 66 | 67 | train_dataset_names: 68 | DIODENormal: 69 | prob_use_caption: 1 70 | image_size: 512 71 | random_flip: True 72 | 73 | 74 | grounding_tokenizer_input: 75 | target: grounding_input.normal_grounding_tokinzer_input.GroundingNetInput 76 | 77 | 78 | grounding_downsampler_input: 79 | target: grounding_input.normal_grounding_downsampler_input.GroundingDSInput 80 | -------------------------------------------------------------------------------- /configs/flickr_text.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA. We have ablate this, self-attention is better than cross-attention, thus please set this as gatedSA usually 23 | use_checkpoint: True 24 | 25 | grounding_tokenizer: 26 | target: ldm.modules.diffusionmodules.text_grounding_net.PositionNet 27 | params: 28 | in_dim: 768 # this is pre-processing feature dim from CLIP Text encoder; penultimate feature 29 | out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension 30 | 31 | 32 | autoencoder: 33 | target: ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | scale_factor: 0.18215 36 | embed_dim: 4 37 | ddconfig: 38 | double_z: true 39 | z_channels: 4 40 | resolution: 256 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: 45 | - 1 46 | - 2 47 | - 4 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | 53 | 54 | text_encoder: 55 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 56 | 57 | 58 | 59 | 60 | train_dataset_names: 61 | FlickrGrounding: 62 | which_layer_text: before 63 | image_size: 512 64 | max_boxes_per_data: 30 65 | prob_use_caption: 0.5 66 | random_crop: False 67 | random_flip: True 68 | 69 | 70 | grounding_tokenizer_input: 71 | target: grounding_input.text_grounding_tokinzer_input.GroundingNetInput 72 | -------------------------------------------------------------------------------- /configs/flickr_text_image.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | target: ldm.models.diffusion.ldm.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | timesteps: 1000 7 | 8 | 9 | model: 10 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 11 | params: 12 | image_size: 64 # unused in the unet, but will be used when create xT 13 | in_channels: 4 14 | out_channels: 4 15 | model_channels: 320 16 | attention_resolutions: [ 4, 2, 1 ] 17 | num_res_blocks: 2 18 | channel_mult: [ 1, 2, 4, 4 ] 19 | num_heads: 8 20 | transformer_depth: 1 21 | context_dim: 768 22 | fuser_type: gatedSA # gatedCA or gatedSA 23 | use_checkpoint: True 24 | 25 | grounding_tokenizer: 26 | target: ldm.modules.diffusionmodules.text_image_grounding_net.PositionNet 27 | params: 28 | in_dim: 768 # this is pre-processing feature dim from CLIP Text encoder; penultimate feature 29 | out_dim: 768 # Not constrained to this, as one linear project is appiled at each Gated layer to match visual dimension 30 | 31 | 32 | autoencoder: 33 | target: ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | scale_factor: 0.18215 36 | embed_dim: 4 37 | ddconfig: 38 | double_z: true 39 | z_channels: 4 40 | resolution: 256 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: 45 | - 1 46 | - 2 47 | - 4 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | 53 | 54 | text_encoder: 55 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 56 | 57 | 58 | 59 | # This is second stage training. (Resume base model from O365+GoldG) 60 | # randomly drop caption for all dataset (made caption used for O365) 61 | train_dataset_names: 62 | FlickrGrounding: 63 | which_layer_text: before 64 | which_layer_image: after_reproject 65 | image_size: 512 66 | max_boxes_per_data: 30 67 | prob_use_caption: 0.5 68 | random_drop_embedding: both 69 | random_crop: False 70 | random_flip: True 71 | 72 | 73 | grounding_tokenizer_input: 74 | target: grounding_input.text_image_grounding_tokinzer_input.GroundingNetInput 75 | -------------------------------------------------------------------------------- /convert_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | 5 | def add_additional_channels(state_dict, num_additional_channels): 6 | "state_dict should be just from unet model, not the entire SD or GLIGEN" 7 | 8 | if num_additional_channels != 0: 9 | 10 | new_conv_weight = torch.zeros(320, 4+num_additional_channels, 3, 3 ) 11 | 12 | for key,value in state_dict.items(): 13 | if key == "input_blocks.0.0.weight": 14 | old_conv_weight = value 15 | new_conv_weight[:,0:4,:,:] = old_conv_weight 16 | state_dict[key] = new_conv_weight 17 | 18 | 19 | 20 | 21 | 22 | 23 | if __name__ == "__main__": 24 | # The following code will add additional 5 channels (for inpainting) to a GLIGEN ckpt 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--ckpt_path", type=str, default=None, help="") 28 | parser.add_argument("--new_ckpt_path", type=str, default=None, help="") 29 | args = parser.parse_args() 30 | 31 | 32 | new_conv_weight = torch.zeros(320, 4+4+1, 3, 3 ) 33 | 34 | ckpt = torch.load(args.ckpt_path, map_location="cpu") 35 | 36 | for key,value in ckpt["model"].items(): 37 | if key == "input_blocks.0.0.weight": 38 | old_conv_weight = value 39 | new_conv_weight[:,0:4,:,:] = old_conv_weight 40 | ckpt["model"]["input_blocks.0.0.weight"] = new_conv_weight 41 | 42 | save = {"model":ckpt["model"]} 43 | torch.save(save, args.new_ckpt_path) 44 | 45 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/catalog.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class DatasetCatalog: 4 | def __init__(self, ROOT): 5 | 6 | 7 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 8 | 9 | 10 | self.VGGrounding = { 11 | "target": "dataset.tsv_dataset.TSVDataset", 12 | "train_params": dict( 13 | tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'), 14 | ), 15 | } 16 | 17 | 18 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 19 | 20 | 21 | self.FlickrGrounding = { 22 | "target": "dataset.tsv_dataset.TSVDataset", 23 | "train_params":dict( 24 | tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'), 25 | ), 26 | } 27 | 28 | 29 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 30 | 31 | self.SBUGrounding = { 32 | "target": "dataset.tsv_dataset.TSVDataset", 33 | "train_params":dict( 34 | tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'), 35 | ), 36 | } 37 | 38 | 39 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 40 | 41 | 42 | self.CC3MGrounding = { 43 | "target": "dataset.tsv_dataset.TSVDataset", 44 | "train_params":dict( 45 | tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'), 46 | ), 47 | } 48 | 49 | 50 | 51 | 52 | 53 | self.CC3MGroundingHed = { 54 | "target": "dataset.dataset_hed.HedDataset", 55 | "train_params":dict( 56 | tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'), 57 | hed_tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv_hed/train-00.tsv'), 58 | ), 59 | } 60 | 61 | 62 | self.CC3MGroundingCanny = { 63 | "target": "dataset.dataset_canny.CannyDataset", 64 | "train_params":dict( 65 | tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'), 66 | canny_tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv_canny/train-00.tsv'), 67 | ), 68 | } 69 | 70 | 71 | self.CC3MGroundingDepth = { 72 | "target": "dataset.dataset_depth.DepthDataset", 73 | "train_params":dict( 74 | tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'), 75 | depth_tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv_depth/train-00.tsv'), 76 | ), 77 | } 78 | 79 | 80 | 81 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 82 | 83 | 84 | self.CC12MGrounding = { 85 | "target": "dataset.tsv_dataset.TSVDataset", 86 | "train_params":dict( 87 | tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'), 88 | ), 89 | } 90 | 91 | 92 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 93 | 94 | self.Obj365Detection = { 95 | "target": "dataset.tsv_dataset.TSVDataset", 96 | "train_params":dict( 97 | tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'), 98 | ), 99 | } 100 | 101 | 102 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 103 | 104 | self.COCO2017Keypoint = { 105 | "target": "dataset.dataset_kp.KeypointDataset", 106 | "train_params":dict( 107 | image_root = os.path.join(ROOT,'COCO/images'), 108 | keypoints_json_path = os.path.join(ROOT,'COCO/annotations2017/person_keypoints_train2017.json'), 109 | caption_json_path = os.path.join(ROOT,'COCO/annotations2017/captions_train2017.json'), 110 | ), 111 | } 112 | 113 | 114 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 115 | 116 | self.DIODENormal = { 117 | "target": "dataset.dataset_normal.NormalDataset", 118 | "train_params":dict( 119 | image_rootdir = os.path.join(ROOT,'normal/image_train'), 120 | normal_rootdir = os.path.join(ROOT,'normal/normal_train'), 121 | caption_path = os.path.join(ROOT,'normal/diode_cation.json'), 122 | ), 123 | } 124 | 125 | 126 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 127 | 128 | self.ADESemantic = { 129 | "target": "dataset.dataset_sem.SemanticDataset", 130 | "train_params":dict( 131 | image_rootdir = os.path.join(ROOT,'ADE/ADEChallengeData2016/images/training'), 132 | sem_rootdir = os.path.join(ROOT,'ADE/ADEChallengeData2016/annotations/training'), 133 | caption_path = os.path.join(ROOT,'ADE/ade_train_images_cation.json'), 134 | ), 135 | } 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /dataset/concat_dataset.py: -------------------------------------------------------------------------------- 1 | from .catalog import DatasetCatalog 2 | from ldm.util import instantiate_from_config 3 | import torch 4 | 5 | 6 | 7 | 8 | class ConCatDataset(): 9 | def __init__(self, dataset_name_list, ROOT, train=True, repeats=None): 10 | self.datasets = [] 11 | cul_previous_dataset_length = 0 12 | offset_map = [] 13 | which_dataset = [] 14 | 15 | if repeats is None: 16 | repeats = [1] * len(dataset_name_list) 17 | else: 18 | assert len(repeats) == len(dataset_name_list) 19 | 20 | 21 | Catalog = DatasetCatalog(ROOT) 22 | for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()): 23 | repeat = repeats[dataset_idx] 24 | 25 | dataset_dict = getattr(Catalog, dataset_name) 26 | 27 | target = dataset_dict['target'] 28 | params = dataset_dict['train_params'] if train else dataset_dict['val_params'] 29 | if yaml_params is not None: 30 | params.update(yaml_params) 31 | dataset = instantiate_from_config( dict(target=target, params=params) ) 32 | 33 | self.datasets.append(dataset) 34 | for _ in range(repeat): 35 | offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length ) 36 | which_dataset.append( torch.ones(len(dataset))*dataset_idx ) 37 | cul_previous_dataset_length += len(dataset) 38 | offset_map = torch.cat(offset_map, dim=0).long() 39 | self.total_length = cul_previous_dataset_length 40 | 41 | self.mapping = torch.arange(self.total_length) - offset_map 42 | self.which_dataset = torch.cat(which_dataset, dim=0).long() 43 | 44 | 45 | def total_images(self): 46 | count = 0 47 | for dataset in self.datasets: 48 | print(dataset.total_images()) 49 | count += dataset.total_images() 50 | return count 51 | 52 | 53 | 54 | def __getitem__(self, idx): 55 | dataset = self.datasets[ self.which_dataset[idx] ] 56 | return dataset[ self.mapping[idx] ] 57 | 58 | 59 | def __len__(self): 60 | return self.total_length 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /dataset/dataset_canny.py: -------------------------------------------------------------------------------- 1 | from tkinter.messagebox import NO 2 | import torch 3 | import json 4 | from PIL import Image, ImageDraw, ImageOps 5 | import torchvision.transforms as transforms 6 | from io import BytesIO 7 | import random 8 | import torchvision.transforms.functional as TF 9 | 10 | from .tsv import TSVFile 11 | 12 | from io import BytesIO 13 | import base64 14 | import numpy as np 15 | 16 | 17 | def decode_base64_to_pillow(image_b64): 18 | return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB') 19 | 20 | def decode_tensor_from_string(arr_str, use_tensor=True): 21 | arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32') 22 | if use_tensor: 23 | arr = torch.from_numpy(arr) 24 | return arr 25 | 26 | 27 | def decode_item(item): 28 | "This is for decoding TSV for box data" 29 | item = json.loads(item) 30 | item['image'] = decode_base64_to_pillow(item['image']) 31 | 32 | for anno in item['annos']: 33 | anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before']) 34 | anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before']) 35 | anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after']) 36 | anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after']) 37 | return item 38 | 39 | 40 | def decode_item_canny(item): 41 | "This is for decoding TSV for canny data" 42 | item = json.loads(item) 43 | item['canny_edge'] = decode_base64_to_pillow(item['canny_edge']) 44 | return item 45 | 46 | 47 | 48 | class CannyDataset(): 49 | def __init__(self, tsv_path, canny_tsv_path, prob_use_caption=1, image_size=512, random_flip=False): 50 | 51 | self.tsv_path = tsv_path 52 | self.canny_tsv_path = canny_tsv_path 53 | self.prob_use_caption = prob_use_caption 54 | self.image_size = image_size 55 | self.random_flip = random_flip 56 | 57 | # Load tsv data 58 | self.tsv_file = TSVFile(self.tsv_path) 59 | self.canny_tsv_file = TSVFile(self.canny_tsv_path) 60 | 61 | self.pil_to_tensor = transforms.PILToTensor() 62 | 63 | 64 | def total_images(self): 65 | return len(self) 66 | 67 | 68 | def get_item_from_tsv(self, index): 69 | _, item = self.tsv_file[index] 70 | item = decode_item(item) 71 | return item 72 | 73 | 74 | def get_item_from_canny_tsv(self, index): 75 | _, item = self.canny_tsv_file[index] 76 | item = decode_item_canny(item) 77 | return item 78 | 79 | 80 | 81 | def __getitem__(self, index): 82 | 83 | raw_item = self.get_item_from_tsv(index) 84 | raw_item_canny = self.get_item_from_canny_tsv(index) 85 | 86 | assert raw_item['data_id'] == raw_item_canny['data_id'] 87 | 88 | out = {} 89 | 90 | out['id'] = raw_item['data_id'] 91 | image = raw_item['image'] 92 | canny_edge = raw_item_canny['canny_edge'] 93 | 94 | # - - - - - center_crop, resize and random_flip - - - - - - # 95 | assert image.size == canny_edge.size 96 | 97 | crop_size = min(image.size) 98 | image = TF.center_crop(image, crop_size) 99 | image = image.resize( (self.image_size, self.image_size) ) 100 | 101 | canny_edge = TF.center_crop(canny_edge, crop_size) 102 | canny_edge = canny_edge.resize( (self.image_size, self.image_size) ) 103 | 104 | 105 | if self.random_flip and random.random()<0.5: 106 | image = ImageOps.mirror(image) 107 | canny_edge = ImageOps.mirror(canny_edge) 108 | 109 | out['image'] = ( self.pil_to_tensor(image).float()/255 - 0.5 ) / 0.5 110 | out['canny_edge'] = ( self.pil_to_tensor(canny_edge).float()/255 - 0.5 ) / 0.5 111 | out['mask'] = torch.tensor(1.0) 112 | 113 | # -------------------- caption ------------------- # 114 | if random.uniform(0, 1) < self.prob_use_caption: 115 | out["caption"] = raw_item["caption"] 116 | else: 117 | out["caption"] = "" 118 | 119 | return out 120 | 121 | 122 | def __len__(self): 123 | return len(self.tsv_file) 124 | 125 | 126 | -------------------------------------------------------------------------------- /dataset/dataset_depth.py: -------------------------------------------------------------------------------- 1 | from tkinter.messagebox import NO 2 | import torch 3 | import json 4 | from PIL import Image, ImageDraw, ImageOps 5 | import torchvision.transforms as transforms 6 | from io import BytesIO 7 | import random 8 | import torchvision.transforms.functional as TF 9 | 10 | from .tsv import TSVFile 11 | 12 | from io import BytesIO 13 | import base64 14 | import numpy as np 15 | 16 | 17 | def decode_base64_to_pillow(image_b64): 18 | return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB') 19 | 20 | def decode_tensor_from_string(arr_str, use_tensor=True): 21 | arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32') 22 | if use_tensor: 23 | arr = torch.from_numpy(arr) 24 | return arr 25 | 26 | 27 | def decode_item(item): 28 | "This is for decoding TSV for box data" 29 | item = json.loads(item) 30 | item['image'] = decode_base64_to_pillow(item['image']) 31 | 32 | for anno in item['annos']: 33 | anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before']) 34 | anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before']) 35 | anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after']) 36 | anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after']) 37 | return item 38 | 39 | 40 | def decode_item_depth(item): 41 | "This is for decoding TSV for depth data" 42 | item = json.loads(item) 43 | item['depth'] = decode_base64_to_pillow(item['depth']) 44 | return item 45 | 46 | 47 | 48 | class DepthDataset(): 49 | def __init__(self, tsv_path, depth_tsv_path, prob_use_caption=1, image_size=512, random_flip=False): 50 | 51 | self.tsv_path = tsv_path 52 | self.depth_tsv_path = depth_tsv_path 53 | self.prob_use_caption = prob_use_caption 54 | self.image_size = image_size 55 | self.random_flip = random_flip 56 | 57 | # Load tsv data 58 | self.tsv_file = TSVFile(self.tsv_path) 59 | self.depth_tsv_file = TSVFile(self.depth_tsv_path) 60 | 61 | self.pil_to_tensor = transforms.PILToTensor() 62 | 63 | 64 | def total_images(self): 65 | return len(self) 66 | 67 | 68 | def get_item_from_tsv(self, index): 69 | _, item = self.tsv_file[index] 70 | item = decode_item(item) 71 | return item 72 | 73 | 74 | def get_item_from_depth_tsv(self, index): 75 | _, item = self.depth_tsv_file[index] 76 | item = decode_item_depth(item) 77 | return item 78 | 79 | 80 | 81 | def __getitem__(self, index): 82 | 83 | raw_item = self.get_item_from_tsv(index) 84 | raw_item_depth = self.get_item_from_depth_tsv(index) 85 | 86 | assert raw_item['data_id'] == raw_item_depth['data_id'] 87 | 88 | out = {} 89 | 90 | out['id'] = raw_item['data_id'] 91 | image = raw_item['image'] 92 | depth = raw_item_depth['depth'] 93 | 94 | # - - - - - center_crop, resize and random_flip - - - - - - # 95 | assert image.size == depth.size 96 | 97 | crop_size = min(image.size) 98 | image = TF.center_crop(image, crop_size) 99 | image = image.resize( (self.image_size, self.image_size) ) 100 | 101 | depth = TF.center_crop(depth, crop_size) 102 | depth = depth.resize( (self.image_size, self.image_size) ) 103 | 104 | 105 | if self.random_flip and random.random()<0.5: 106 | image = ImageOps.mirror(image) 107 | depth = ImageOps.mirror(depth) 108 | 109 | out['image'] = ( self.pil_to_tensor(image).float()/255 - 0.5 ) / 0.5 110 | out['depth'] = ( self.pil_to_tensor(depth).float()/255 - 0.5 ) / 0.5 111 | out['mask'] = torch.tensor(1.0) 112 | 113 | # -------------------- caption ------------------- # 114 | if random.uniform(0, 1) < self.prob_use_caption: 115 | out["caption"] = raw_item["caption"] 116 | else: 117 | out["caption"] = "" 118 | 119 | return out 120 | 121 | 122 | def __len__(self): 123 | return len(self.tsv_file) 124 | 125 | 126 | -------------------------------------------------------------------------------- /dataset/dataset_hed.py: -------------------------------------------------------------------------------- 1 | from tkinter.messagebox import NO 2 | import torch 3 | import json 4 | from PIL import Image, ImageDraw, ImageOps 5 | import torchvision.transforms as transforms 6 | from io import BytesIO 7 | import random 8 | import torchvision.transforms.functional as TF 9 | 10 | from .tsv import TSVFile 11 | 12 | from io import BytesIO 13 | import base64 14 | import numpy as np 15 | 16 | 17 | def decode_base64_to_pillow(image_b64): 18 | return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB') 19 | 20 | def decode_tensor_from_string(arr_str, use_tensor=True): 21 | arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32') 22 | if use_tensor: 23 | arr = torch.from_numpy(arr) 24 | return arr 25 | 26 | 27 | def decode_item(item): 28 | "This is for decoding TSV for box data" 29 | item = json.loads(item) 30 | item['image'] = decode_base64_to_pillow(item['image']) 31 | 32 | for anno in item['annos']: 33 | anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before']) 34 | anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before']) 35 | anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after']) 36 | anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after']) 37 | return item 38 | 39 | 40 | def decode_item_hed(item): 41 | "This is for decoding TSV for hed data" 42 | item = json.loads(item) 43 | item['hed_edge'] = decode_base64_to_pillow(item['hed_edge']) 44 | return item 45 | 46 | 47 | 48 | class HedDataset(): 49 | def __init__(self, tsv_path, hed_tsv_path, prob_use_caption=1, image_size=512, random_flip=False): 50 | 51 | self.tsv_path = tsv_path 52 | self.hed_tsv_path = hed_tsv_path 53 | self.prob_use_caption = prob_use_caption 54 | self.image_size = image_size 55 | self.random_flip = random_flip 56 | 57 | # Load tsv data 58 | self.tsv_file = TSVFile(self.tsv_path) 59 | self.hed_tsv_file = TSVFile(self.hed_tsv_path) 60 | 61 | self.pil_to_tensor = transforms.PILToTensor() 62 | 63 | 64 | def total_images(self): 65 | return len(self) 66 | 67 | 68 | def get_item_from_tsv(self, index): 69 | _, item = self.tsv_file[index] 70 | item = decode_item(item) 71 | return item 72 | 73 | 74 | def get_item_from_hed_tsv(self, index): 75 | _, item = self.hed_tsv_file[index] 76 | item = decode_item_hed(item) 77 | return item 78 | 79 | 80 | 81 | def __getitem__(self, index): 82 | 83 | raw_item = self.get_item_from_tsv(index) 84 | raw_item_hed = self.get_item_from_hed_tsv(index) 85 | 86 | assert raw_item['data_id'] == raw_item_hed['data_id'] 87 | 88 | out = {} 89 | 90 | out['id'] = raw_item['data_id'] 91 | image = raw_item['image'] 92 | hed_edge = raw_item_hed['hed_edge'] 93 | 94 | # - - - - - center_crop, resize and random_flip - - - - - - # 95 | assert image.size == hed_edge.size 96 | 97 | crop_size = min(image.size) 98 | image = TF.center_crop(image, crop_size) 99 | image = image.resize( (self.image_size, self.image_size) ) 100 | 101 | hed_edge = TF.center_crop(hed_edge, crop_size) 102 | hed_edge = hed_edge.resize( (self.image_size, self.image_size) ) 103 | 104 | 105 | if self.random_flip and random.random()<0.5: 106 | image = ImageOps.mirror(image) 107 | hed_edge = ImageOps.mirror(hed_edge) 108 | 109 | out['image'] = ( self.pil_to_tensor(image).float()/255 - 0.5 ) / 0.5 110 | out['hed_edge'] = ( self.pil_to_tensor(hed_edge).float()/255 - 0.5 ) / 0.5 111 | out['mask'] = torch.tensor(1.0) 112 | 113 | # -------------------- caption ------------------- # 114 | if random.uniform(0, 1) < self.prob_use_caption: 115 | out["caption"] = raw_item["caption"] 116 | else: 117 | out["caption"] = "" 118 | 119 | return out 120 | 121 | 122 | def __len__(self): 123 | return len(self.tsv_file) 124 | 125 | 126 | -------------------------------------------------------------------------------- /dataset/dataset_normal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from PIL import Image, ImageDraw, ImageOps 5 | import torchvision.transforms as transforms 6 | import random 7 | import torchvision.transforms.functional as TF 8 | import numpy as np 9 | 10 | 11 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg"]): 12 | out = [] 13 | for r, d, f in os.walk(rootdir): 14 | for file in f: 15 | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): 16 | out.append(os.path.join(r, file)) 17 | return out 18 | 19 | 20 | def exist_in(short_str, list_of_string): 21 | for string in list_of_string: 22 | if short_str in string: 23 | return True 24 | return False 25 | 26 | 27 | def clean_files(image_files, normal_files): 28 | """ 29 | Not sure why some images do not have normal map annotations, thus delete these images from list. 30 | 31 | The implementation here is inefficient ..... 32 | """ 33 | new_image_files = [] 34 | 35 | for image_file in image_files: 36 | image_file_basename = os.path.basename(image_file).split('.')[0] 37 | if exist_in(image_file_basename,normal_files): 38 | new_image_files.append(image_file) 39 | image_files = new_image_files 40 | 41 | 42 | # a sanity check 43 | for image_file, normal_file in zip(image_files, normal_files): 44 | image_file_basename = os.path.basename(image_file).split('.')[0] 45 | normal_file_basename = os.path.basename(normal_file).split('.')[0] 46 | assert image_file_basename == normal_file_basename[:-7] 47 | 48 | return image_files, normal_files 49 | 50 | 51 | 52 | 53 | class NormalDataset(): 54 | def __init__(self, image_rootdir, normal_rootdir, caption_path, prob_use_caption=1, image_size=512, random_flip=False): 55 | self.image_rootdir = image_rootdir 56 | self.normal_rootdir = normal_rootdir 57 | self.caption_path = caption_path 58 | self.prob_use_caption = prob_use_caption 59 | self.image_size = image_size 60 | self.random_flip = random_flip 61 | 62 | 63 | # Image and normal files 64 | image_files = recursively_read(rootdir=image_rootdir, must_contain="", exts=['png']) 65 | image_files.sort() 66 | normal_files = recursively_read(rootdir=normal_rootdir, must_contain="normal", exts=['npy']) 67 | normal_files.sort() 68 | 69 | image_files, normal_files = clean_files(image_files, normal_files) 70 | self.image_files = image_files 71 | self.normal_files = normal_files 72 | 73 | # Open caption json 74 | with open(caption_path, 'r') as f: 75 | self.image_filename_to_caption_mapping = json.load(f) 76 | 77 | 78 | self.pil_to_tensor = transforms.PILToTensor() 79 | 80 | 81 | def total_images(self): 82 | return len(self) 83 | 84 | 85 | def __getitem__(self, index): 86 | 87 | image_path = self.image_files[index] 88 | 89 | out = {} 90 | 91 | out['id'] = index 92 | image = Image.open(image_path).convert("RGB") 93 | 94 | normal = np.load( self.normal_files[index] ) # -1 to 1 numpy array 95 | normal = ((normal*0.5+0.5)*255).astype("uint8") 96 | normal = Image.fromarray(normal) # first convet normal map from array to image. So we can do crop etc easily 97 | assert image.size == normal.size 98 | 99 | 100 | # - - - - - center_crop, resize and random_flip - - - - - - # 101 | 102 | crop_size = min(image.size) 103 | image = TF.center_crop(image, crop_size) 104 | image = image.resize( (self.image_size, self.image_size) ) 105 | 106 | normal = TF.center_crop(normal, crop_size) 107 | normal = normal.resize( (self.image_size, self.image_size) ) 108 | 109 | 110 | if self.random_flip and random.random()<0.5: 111 | image = ImageOps.mirror(image) 112 | normal = ImageOps.mirror(normal) 113 | 114 | out['image'] = ( self.pil_to_tensor(image).float()/255 - 0.5 ) / 0.5 115 | out['normal'] = ( self.pil_to_tensor(normal).float()/255 - 0.5 ) / 0.5 # -1,1 is the actual range from numpy array annotation 116 | out['mask'] = torch.tensor(1.0) 117 | 118 | # -------------------- caption ------------------- # 119 | if random.uniform(0, 1) < self.prob_use_caption: 120 | out["caption"] = self.image_filename_to_caption_mapping[ os.path.basename(image_path) ] 121 | else: 122 | out["caption"] = "" 123 | 124 | return out 125 | 126 | 127 | def __len__(self): 128 | return len(self.image_files) 129 | 130 | 131 | -------------------------------------------------------------------------------- /dataset/dataset_sem.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from PIL import Image, ImageDraw, ImageOps 5 | import torchvision.transforms as transforms 6 | import random 7 | import torchvision.transforms.functional as TF 8 | import numpy as np 9 | 10 | 11 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg"]): 12 | out = [] 13 | for r, d, f in os.walk(rootdir): 14 | for file in f: 15 | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): 16 | out.append(os.path.join(r, file)) 17 | return out 18 | 19 | 20 | def exist_in(short_str, list_of_string): 21 | for string in list_of_string: 22 | if short_str in string: 23 | return True 24 | return False 25 | 26 | 27 | def clean_files(image_files, normal_files): 28 | """ 29 | Not sure why some images do not have normal map annotations, thus delete these images from list. 30 | 31 | The implementation here is inefficient ..... 32 | """ 33 | new_image_files = [] 34 | 35 | for image_file in image_files: 36 | image_file_basename = os.path.basename(image_file).split('.')[0] 37 | if exist_in(image_file_basename,normal_files): 38 | new_image_files.append(image_file) 39 | image_files = new_image_files 40 | 41 | 42 | # a sanity check 43 | for image_file, normal_file in zip(image_files, normal_files): 44 | image_file_basename = os.path.basename(image_file).split('.')[0] 45 | normal_file_basename = os.path.basename(normal_file).split('.')[0] 46 | assert image_file_basename == normal_file_basename[:-7] 47 | 48 | return image_files, normal_files 49 | 50 | 51 | 52 | 53 | class SemanticDataset(): 54 | def __init__(self, image_rootdir, sem_rootdir, caption_path, prob_use_caption=1, image_size=512, random_flip=False): 55 | self.image_rootdir = image_rootdir 56 | self.sem_rootdir = sem_rootdir 57 | self.caption_path = caption_path 58 | self.prob_use_caption = prob_use_caption 59 | self.image_size = image_size 60 | self.random_flip = random_flip 61 | 62 | 63 | # Image and normal files 64 | image_files = recursively_read(rootdir=image_rootdir, must_contain="", exts=['jpg']) 65 | image_files.sort() 66 | sem_files = recursively_read(rootdir=sem_rootdir, must_contain="", exts=['png']) 67 | sem_files.sort() 68 | 69 | 70 | self.image_files = image_files 71 | self.sem_files = sem_files 72 | 73 | # Open caption json 74 | with open(caption_path, 'r') as f: 75 | self.image_filename_to_caption_mapping = json.load(f) 76 | 77 | 78 | assert len(self.image_files) == len(self.sem_files) == len(self.image_filename_to_caption_mapping) 79 | self.pil_to_tensor = transforms.PILToTensor() 80 | 81 | 82 | def total_images(self): 83 | return len(self) 84 | 85 | 86 | def __getitem__(self, index): 87 | 88 | image_path = self.image_files[index] 89 | 90 | out = {} 91 | 92 | out['id'] = index 93 | image = Image.open(image_path).convert("RGB") 94 | sem = Image.open( self.sem_files[index] ).convert("L") # semantic class index 0,1,2,3,4 in uint8 representation 95 | 96 | assert image.size == sem.size 97 | 98 | 99 | # - - - - - center_crop, resize and random_flip - - - - - - # 100 | 101 | crop_size = min(image.size) 102 | image = TF.center_crop(image, crop_size) 103 | image = image.resize( (self.image_size, self.image_size) ) 104 | 105 | sem = TF.center_crop(sem, crop_size) 106 | sem = sem.resize( (self.image_size, self.image_size), Image.NEAREST ) # acorrding to official, it is nearest by default, but I don't know why it can prodice new values if not specify explicitly 107 | 108 | if self.random_flip and random.random()<0.5: 109 | image = ImageOps.mirror(image) 110 | sem = ImageOps.mirror(sem) 111 | 112 | sem = self.pil_to_tensor(sem)[0,:,:] 113 | 114 | input_label = torch.zeros(152, self.image_size, self.image_size) 115 | sem = input_label.scatter_(0, sem.long().unsqueeze(0), 1.0) 116 | 117 | out['image'] = ( self.pil_to_tensor(image).float()/255 - 0.5 ) / 0.5 118 | out['sem'] = sem 119 | out['mask'] = torch.tensor(1.0) 120 | 121 | 122 | # -------------------- caption ------------------- # 123 | if random.uniform(0, 1) < self.prob_use_caption: 124 | out["caption"] = self.image_filename_to_caption_mapping[ os.path.basename(image_path) ] 125 | else: 126 | out["caption"] = "" 127 | 128 | return out 129 | 130 | 131 | def __len__(self): 132 | return len(self.image_files) 133 | 134 | 135 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import PIL 18 | import torch 19 | import torchvision.transforms as T 20 | 21 | 22 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 23 | IMAGENET_STD = [0.229, 0.224, 0.225] 24 | 25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 27 | 28 | 29 | def imagenet_preprocess(): 30 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 31 | 32 | 33 | def rescale(x): 34 | lo, hi = x.min(), x.max() 35 | return x.sub(lo).div(hi - lo) 36 | 37 | 38 | def imagenet_deprocess(rescale_image=True): 39 | transforms = [ 40 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), 41 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), 42 | ] 43 | if rescale_image: 44 | transforms.append(rescale) 45 | return T.Compose(transforms) 46 | 47 | 48 | def imagenet_deprocess_batch(imgs, rescale=True): 49 | """ 50 | Input: 51 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images 52 | 53 | Output: 54 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images 55 | in the range [0, 255] 56 | """ 57 | if isinstance(imgs, torch.autograd.Variable): 58 | imgs = imgs.data 59 | imgs = imgs.cpu().clone() 60 | deprocess_fn = imagenet_deprocess(rescale_image=rescale) 61 | imgs_de = [] 62 | for i in range(imgs.size(0)): 63 | img_de = deprocess_fn(imgs[i])[None] 64 | img_de = img_de.mul(255).clamp(0, 255).byte() 65 | imgs_de.append(img_de) 66 | imgs_de = torch.cat(imgs_de, dim=0) 67 | return imgs_de 68 | 69 | 70 | class Resize(object): 71 | def __init__(self, size, interp=PIL.Image.BILINEAR): 72 | if isinstance(size, tuple): 73 | H, W = size 74 | self.size = (W, H) 75 | else: 76 | self.size = (size, size) 77 | self.interp = interp 78 | 79 | def __call__(self, img): 80 | return img.resize(self.size, self.interp) 81 | 82 | 83 | def unpack_var(v): 84 | if isinstance(v, torch.autograd.Variable): 85 | return v.data 86 | return v 87 | 88 | 89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): 90 | triples = unpack_var(triples) 91 | obj_data = [unpack_var(o) for o in obj_data] 92 | obj_to_img = unpack_var(obj_to_img) 93 | triple_to_img = unpack_var(triple_to_img) 94 | 95 | triples_out = [] 96 | obj_data_out = [[] for _ in obj_data] 97 | obj_offset = 0 98 | N = obj_to_img.max() + 1 99 | for i in range(N): 100 | o_idxs = (obj_to_img == i).nonzero().view(-1) 101 | t_idxs = (triple_to_img == i).nonzero().view(-1) 102 | 103 | cur_triples = triples[t_idxs].clone() 104 | cur_triples[:, 0] -= obj_offset 105 | cur_triples[:, 2] -= obj_offset 106 | triples_out.append(cur_triples) 107 | 108 | for j, o_data in enumerate(obj_data): 109 | cur_o_data = None 110 | if o_data is not None: 111 | cur_o_data = o_data[o_idxs] 112 | obj_data_out[j].append(cur_o_data) 113 | 114 | obj_offset += o_idxs.size(0) 115 | 116 | return triples_out, obj_data_out 117 | -------------------------------------------------------------------------------- /demo/.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ project files 2 | .idea 3 | *.iml 4 | out 5 | gen 6 | 7 | ### Vim template 8 | [._]*.s[a-w][a-z] 9 | [._]s[a-w][a-z] 10 | *.un~ 11 | Session.vim 12 | .netrwhist 13 | *~ 14 | 15 | ### IPythonNotebook template 16 | # Temporary data 17 | .ipynb_checkpoints/ 18 | 19 | ### Python template 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | env/ 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | #lib/ 38 | #lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *,cover 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | *.ipynb 80 | *.params 81 | # *.json 82 | .vscode/ 83 | *.code-workspace/ 84 | 85 | lib/pycocotools/_mask.c 86 | lib/nms/cpu_nms.c 87 | 88 | OUTPUT 89 | OUTPUT/* 90 | models/* 91 | DATASET 92 | DATASET/* 93 | external/ 94 | MODELS 95 | MODELS/* 96 | gradio_cached_examples/* 97 | 98 | kill.sh 99 | 100 | draws/ 101 | #:wq 102 | #plot/figs 103 | 104 | *venv/* 105 | 106 | # images 107 | # images/* 108 | 109 | create_samples/ 110 | create_samples/* 111 | 112 | ckpts/* 113 | -------------------------------------------------------------------------------- /demo/DejaVuSansMono.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/DejaVuSansMono.ttf -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: GLIGen 3 | emoji: 👁 4 | colorFrom: red 5 | colorTo: green 6 | sdk: gradio 7 | sdk_version: 3.15.0 8 | app_file: app.py 9 | pinned: false 10 | --- 11 | 12 | # Gradio App Demo for GLIGEN 13 | 14 | ## :notes: Introduction 15 | 16 | This folder includes the source code of our [Gradio app demo](https://huggingface.co/spaces/gligen/demo) for GLIGEN. It automatically downloads and loads our checkpoints hosted on Huggingface. 17 | 18 | NOTE: You may notice slight implementation differences of the pipeline between this code base and main GLIGEN repo, although the functionality and the checkpoints are the same. We'll replace the implementation pipeline to Diffusers after we finish the integration. 19 | 20 | ## :toolbox: Installation 21 | 22 | To install GLIGEN demo with CUDA support, create an environment. 23 | 24 | ```Shell 25 | conda env create -f environment.yaml 26 | ``` 27 | 28 | In case you don't have a CUDA-enabled GPU, you can run it on a CPU - though, it will be very slow. 29 | For some speedup on Macbooks with M1 Apple Silicon, there is support with [MPS](https://pytorch.org/docs/stable/notes/mps.html) (much faster than CPU, slower than CUDA). To use Macbook GPUs, make sure that you install [conda miniforge for the arm64 architecture (recommended: mambaforge)](https://github.com/conda-forge/miniforge). 30 | 31 | ```Shell 32 | mamba env create -f environment_cpu_mps.yaml 33 | ``` 34 | 35 | ## :notebook: Usage 36 | 37 | Activate the environment with 38 | 39 | ```Shell 40 | conda activate gligen_demo 41 | ``` 42 | 43 | By default, it only loads the base text-box generation pipeline to save memory. You'll see error in the UI interface if attempting to run pipelines that are not loaded. Modify command line arguments to enable/disable specific pipelines. 44 | 45 | ```Shell 46 | python app.py \ 47 | --load-text-box-generation=True \ 48 | --load-text-box-inpainting=False \ 49 | --load-text-image-box-generation=False 50 | ``` 51 | 52 | ## :question: How do you draw bounding boxes using Gradio sketchpad? 53 | 54 | Gradio does not natively support drawing bounding boxes in its sketchpad. In this repo, we use a simple workaround where users draw their boxes using freeform brush, and the backend calculates the min/max point along x/y axis, and "guesses" a bounding box. The interpreted boxes are visualized on the side for better user experience. 55 | 56 | Hope that we'll have native support for drawing bounding boxes with Gradio soon! :partying_face: 57 | 58 | ## :snowflake: TODO 59 | 60 | - [ ] Use diffusers as the inference pipeline 61 | - [ ] Refactor code base 62 | 63 | ## :book: Citation 64 | 65 | ``` 66 | @article{li2023gligen, 67 | title={GLIGEN: Open-Set Grounded Text-to-Image Generation}, 68 | author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae}, 69 | journal={CVPR}, 70 | year={2023} 71 | } 72 | ``` 73 | 74 | ## Disclaimer 75 | 76 | The original GLIGEN was partly implemented and trained during an internship at Microsoft. This repo re-implements GLIGEN in PyTorch with university GPUs after the internship. Despite the minor implementation differences, this repo aims to reproduce the results and observations in the paper for research purposes. 77 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/__init__.py -------------------------------------------------------------------------------- /demo/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/dataset/__init__.py -------------------------------------------------------------------------------- /demo/dataset/catalog.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class DatasetCatalog: 4 | def __init__(self, ROOT, which_embedder): 5 | assert which_embedder in ['clip', 'bert'] 6 | 7 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 8 | 9 | 10 | self.VGGrounding = { 11 | "target": "dataset.tsv_dataset.TSVDataset", 12 | "train_params": dict( 13 | tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'), 14 | ) 15 | } 16 | 17 | 18 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 19 | 20 | 21 | self.FlickrGrounding = { 22 | "target": "dataset.tsv_dataset.TSVDataset", 23 | "train_params":dict( 24 | tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'), 25 | ) 26 | } 27 | 28 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 29 | 30 | self.SBUGrounding = { 31 | "target": "dataset.tsv_dataset.TSVDataset", 32 | "train_params":dict( 33 | tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'), 34 | ) 35 | } 36 | 37 | 38 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 39 | 40 | 41 | self.CC3MGrounding = { 42 | "target": "dataset.tsv_dataset.TSVDataset", 43 | "train_params":dict( 44 | tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'), 45 | ) 46 | } 47 | 48 | 49 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 50 | 51 | 52 | self.CC12MGrounding = { 53 | "target": "dataset.tsv_dataset.TSVDataset", 54 | "train_params":dict( 55 | tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'), 56 | ) 57 | } 58 | 59 | 60 | # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 61 | 62 | # temp = 'category_embedding_clip.pth' if which_embedder == 'clip' else 'category_embedding_bert.pth' 63 | # obj365_category_embedding_path = os.path.join(ROOT, 'OBJECTS365', temp) 64 | 65 | self.Obj365Detection = { 66 | "target": "dataset.tsv_dataset.TSVDataset", 67 | "train_params":dict( 68 | tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'), 69 | ), 70 | } 71 | 72 | 73 | -------------------------------------------------------------------------------- /demo/dataset/concat_dataset.py: -------------------------------------------------------------------------------- 1 | from .catalog import DatasetCatalog 2 | from ldm.util import instantiate_from_config 3 | import torch 4 | 5 | 6 | 7 | 8 | class ConCatDataset(): 9 | def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None): 10 | self.datasets = [] 11 | cul_previous_dataset_length = 0 12 | offset_map = [] 13 | which_dataset = [] 14 | 15 | if repeats is None: 16 | repeats = [1] * len(dataset_name_list) 17 | else: 18 | assert len(repeats) == len(dataset_name_list) 19 | 20 | 21 | Catalog = DatasetCatalog(ROOT, which_embedder) 22 | for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()): 23 | repeat = repeats[dataset_idx] 24 | 25 | dataset_dict = getattr(Catalog, dataset_name) 26 | 27 | target = dataset_dict['target'] 28 | params = dataset_dict['train_params'] if train else dataset_dict['val_params'] 29 | if yaml_params is not None: 30 | params.update(yaml_params) 31 | dataset = instantiate_from_config( dict(target=target, params=params) ) 32 | 33 | self.datasets.append(dataset) 34 | for _ in range(repeat): 35 | offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length ) 36 | which_dataset.append( torch.ones(len(dataset))*dataset_idx ) 37 | cul_previous_dataset_length += len(dataset) 38 | offset_map = torch.cat(offset_map, dim=0).long() 39 | self.total_length = cul_previous_dataset_length 40 | 41 | self.mapping = torch.arange(self.total_length) - offset_map 42 | self.which_dataset = torch.cat(which_dataset, dim=0).long() 43 | 44 | 45 | def total_images(self): 46 | count = 0 47 | for dataset in self.datasets: 48 | print(dataset.total_images()) 49 | count += dataset.total_images() 50 | return count 51 | 52 | 53 | 54 | def __getitem__(self, idx): 55 | dataset = self.datasets[ self.which_dataset[idx] ] 56 | return dataset[ self.mapping[idx] ] 57 | 58 | 59 | def __len__(self): 60 | return self.total_length 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /demo/dataset/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import PIL 18 | import torch 19 | import torchvision.transforms as T 20 | 21 | 22 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 23 | IMAGENET_STD = [0.229, 0.224, 0.225] 24 | 25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 27 | 28 | 29 | def imagenet_preprocess(): 30 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 31 | 32 | 33 | def rescale(x): 34 | lo, hi = x.min(), x.max() 35 | return x.sub(lo).div(hi - lo) 36 | 37 | 38 | def imagenet_deprocess(rescale_image=True): 39 | transforms = [ 40 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), 41 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), 42 | ] 43 | if rescale_image: 44 | transforms.append(rescale) 45 | return T.Compose(transforms) 46 | 47 | 48 | def imagenet_deprocess_batch(imgs, rescale=True): 49 | """ 50 | Input: 51 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images 52 | 53 | Output: 54 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images 55 | in the range [0, 255] 56 | """ 57 | if isinstance(imgs, torch.autograd.Variable): 58 | imgs = imgs.data 59 | imgs = imgs.cpu().clone() 60 | deprocess_fn = imagenet_deprocess(rescale_image=rescale) 61 | imgs_de = [] 62 | for i in range(imgs.size(0)): 63 | img_de = deprocess_fn(imgs[i])[None] 64 | img_de = img_de.mul(255).clamp(0, 255).byte() 65 | imgs_de.append(img_de) 66 | imgs_de = torch.cat(imgs_de, dim=0) 67 | return imgs_de 68 | 69 | 70 | class Resize(object): 71 | def __init__(self, size, interp=PIL.Image.BILINEAR): 72 | if isinstance(size, tuple): 73 | H, W = size 74 | self.size = (W, H) 75 | else: 76 | self.size = (size, size) 77 | self.interp = interp 78 | 79 | def __call__(self, img): 80 | return img.resize(self.size, self.interp) 81 | 82 | 83 | def unpack_var(v): 84 | if isinstance(v, torch.autograd.Variable): 85 | return v.data 86 | return v 87 | 88 | 89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): 90 | triples = unpack_var(triples) 91 | obj_data = [unpack_var(o) for o in obj_data] 92 | obj_to_img = unpack_var(obj_to_img) 93 | triple_to_img = unpack_var(triple_to_img) 94 | 95 | triples_out = [] 96 | obj_data_out = [[] for _ in obj_data] 97 | obj_offset = 0 98 | N = obj_to_img.max() + 1 99 | for i in range(N): 100 | o_idxs = (obj_to_img == i).nonzero().view(-1) 101 | t_idxs = (triple_to_img == i).nonzero().view(-1) 102 | 103 | cur_triples = triples[t_idxs].clone() 104 | cur_triples[:, 0] -= obj_offset 105 | cur_triples[:, 2] -= obj_offset 106 | triples_out.append(cur_triples) 107 | 108 | for j, o_data in enumerate(obj_data): 109 | cur_o_data = None 110 | if o_data is not None: 111 | cur_o_data = o_data[o_idxs] 112 | obj_data_out[j].append(cur_o_data) 113 | 114 | obj_offset += o_idxs.size(0) 115 | 116 | return triples_out, obj_data_out 117 | -------------------------------------------------------------------------------- /demo/environment.yaml: -------------------------------------------------------------------------------- 1 | name: gligen_demo 2 | channels: 3 | - xformers/label/dev 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.10.8 8 | - pip=22.2.2 9 | - cudatoolkit=11.3 10 | - pytorch=1.12.1 11 | - torchvision=0.13.1 12 | - numpy=1.23.1 13 | - xformers 14 | - pip: 15 | - omegaconf==2.1.1 16 | - albumentations==1.3.0 17 | - opencv-python 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - test-tube>=0.7.5 22 | - streamlit==1.12.1 23 | - einops==0.3.0 24 | - git+https://github.com/openai/CLIP.git 25 | - protobuf~=3.20.1 26 | - torchmetrics==0.6.0 27 | - transformers==4.19.2 28 | - kornia==0.6.0 29 | - gradio==3.16.0 -------------------------------------------------------------------------------- /demo/environment_cpu_mps.yaml: -------------------------------------------------------------------------------- 1 | name: gligen_demo 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10.8 7 | - pip=22.2.2 8 | - pytorch=1.13.1 9 | - torchvision=0.14.1 10 | - numpy=1.23.1 11 | - pip: 12 | - omegaconf==2.1.1 13 | - albumentations==1.3.0 14 | - opencv-python 15 | - imageio==2.9.0 16 | - imageio-ffmpeg==0.4.2 17 | - pytorch-lightning==1.4.2 18 | - test-tube>=0.7.5 19 | - streamlit==1.12.1 20 | - einops==0.3.0 21 | - git+https://github.com/openai/CLIP.git 22 | - protobuf~=3.20.1 23 | - torchmetrics==0.6.0 24 | - transformers==4.26.1 25 | - kornia==0.6.0 26 | - gradio==3.16.0 27 | -------------------------------------------------------------------------------- /demo/gligen/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | sys.path.append(os.path.dirname(__file__)) 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "ldm")) 5 | 6 | import gligen.evaluator as evaluator 7 | import gligen.trainer as trainer 8 | 9 | 10 | # import gligen.ldm as ldm -------------------------------------------------------------------------------- /demo/gligen/create_meta.py: -------------------------------------------------------------------------------- 1 | CKPTS = [ 2 | 3 | dict( 4 | path="/home/chunyl/azure_mount/yuhengdb/fine_tune_ldm/version5_branch6_output/GoldG+SBU+CC3M+CC12M+O365/second_stage_drop_both/tag01/checkpoint_00450001.pth", 5 | feature_type=['before','after_reproject'], 6 | save_folder_name="v5b6_drop_both", 7 | ), 8 | 9 | 10 | # dict( 11 | # path="/home/v-yuhengli/blobfuse/output/fine_tune_ldm/version5_branch6_output/GoldG+SBU+CC3M+CC12M+O365/second_stage_drop_none/tag00/checkpoint_00165001.pth", 12 | # feature_type=['before','after_reproject'], 13 | # save_folder_name="v5b6_drop_none", 14 | # ), 15 | 16 | 17 | 18 | 19 | 20 | ] 21 | 22 | 23 | 24 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | # if meta["has_image_mask"] == 0: 34 | # image_embeddings = text_embeddings 35 | # if meta["has_text_mask"] == 0: 36 | # text_embeddings = image_embeddings 37 | 38 | # out = { 39 | # "boxes" : boxes.unsqueeze(0).repeat(batch,1,1), 40 | # "masks" : masks.unsqueeze(0).repeat(batch,1), 41 | # "text_masks" : masks.unsqueeze(0).repeat(batch,1), 42 | # "image_masks" : masks.unsqueeze(0).repeat(batch,1), 43 | # "text_embeddings" : text_embeddings.unsqueeze(0).repeat(batch,1,1), 44 | # "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1) 45 | # } 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | META = [ 54 | 55 | 56 | dict( 57 | prompt = "a teddy bear sitting next to a red bird", 58 | phrases = ['a teddy bear', 'a red bird'], 59 | images = ['images/teddy.jpg', 'images/red_bird.jpg'], 60 | locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 61 | alpha_type = [1.0, 0, 0.0], 62 | has_text_mask = 1, 63 | has_image_mask = 0, 64 | save_folder_name="teddy_bird_1_1" 65 | ), 66 | 67 | 68 | # dict( 69 | # prompt = "a teddy bear sitting next to a bird", 70 | # phrases = ['a teddy bear', 'a bird'], 71 | # images = ['images/teddy.jpg', 'images/red_bird.jpg'], 72 | # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 73 | # alpha_type = [1.0, 0, 0.0], 74 | # has_text_mask = 1, 75 | # has_image_mask = 1, 76 | # save_folder_name="teddy_bird_1_1" 77 | # ), 78 | 79 | 80 | # dict( 81 | # prompt = "a teddy bear sitting next to a bird", 82 | # phrases = ['a teddy bear', 'a bird'], 83 | # images = ['images/teddy.jpg', 'images/red_bird.jpg'], 84 | # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 85 | # alpha_type = [0.5, 0, 0.5], 86 | # has_text_mask = 1, 87 | # has_image_mask = 0, 88 | # save_folder_name="teddy_bird_1_0" 89 | # ), 90 | 91 | # dict( 92 | # prompt = "", 93 | # phrases = ['a teddy bear', 'an umbrella'], 94 | # images = ['images/teddy.jpg', 'images/umbrella.png'], 95 | # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 96 | # alpha_type = [1.0, 0, 0.0], 97 | # has_text_mask = 1, 98 | # has_image_mask = 1, 99 | # save_folder_name="empty_teddy_umbrella_1_1" 100 | # ), 101 | 102 | # dict( 103 | # prompt = "hello kitty and bird hybrid", 104 | # phrases = ['a hello kitty', 'a hello kitty'], 105 | # images = ['images/red_bird.jpg', 'images/red_bird.jpg'], 106 | # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 107 | # has_text_mask = 1, 108 | # has_image_mask = 1, 109 | # save_folder_name="hello+bird_1_1" 110 | # ), 111 | 112 | # dict( 113 | # prompt = "hello kitty and teddy bear hybrid", 114 | # phrases = ['a hello kitty', 'a hello kitty'], 115 | # images = ['images/teddy.jpg', 'images/teddy.jpg'], 116 | # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 117 | # has_text_mask = 1, 118 | # has_image_mask = 1, 119 | # save_folder_name="hello+teddy_1_1" 120 | # ), 121 | 122 | # dict( 123 | # prompt = "bird and hello kitty hybrid", 124 | # phrases = ['a bird', 'a bird'], 125 | # images = ['images/hello.jpg', 'images/hello.jpg'], 126 | # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ], 127 | # alpha_type = [1.0, 0, 0.0], 128 | # has_text_mask = 1, 129 | # has_image_mask = 0.5, 130 | # save_folder_name="bird+hello_1_1" 131 | # ), 132 | 133 | 134 | 135 | # dict( 136 | # prompt = "a deer standing in front of a brick house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k", 137 | # phrases = ['a deer'], 138 | # images = ['images/sky.jpg'], 139 | # locations = [ [0.0,0.5,0.5,0.9] ], 140 | # alpha_type = [1, 0, 0], 141 | # has_text_mask = 1, 142 | # has_image_mask = 1, 143 | # save_folder_name="deer_sky" 144 | # ), 145 | 146 | 147 | # dict( 148 | # prompt = "A woman sitting in a restaurant with a slice of pizza in front of her", 149 | # phrases = ['dining table', 'pizza', 'person', 'wall', 'car', 'paper', 'chair', 'window', 'bottle', 'cup'], 150 | # images = ['images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg'], 151 | # locations = [ [0.0030, 0.3589, 1.0000, 1.0000], 152 | # [0.0779, 0.6744, 0.9768, 1.0000], 153 | # [0.2236, 0.0000, 0.7809, 0.4352], 154 | # [0.0000, 0.0000, 0.4313, 0.4505], 155 | # [0.6275, 0.1050, 0.9444, 0.2497], 156 | # [0.0000, 0.3859, 0.1250, 0.6922], 157 | # [0.7137, 0.2389, 0.8540, 0.4549], 158 | # [0.0000, 0.0000, 0.4667, 0.0630], 159 | # [0.3822, 0.4235, 0.4932, 0.6575], 160 | # [0.6616, 0.3617, 0.7880, 0.5165] ], 161 | # alpha_type = [0.0, 0, 1.0], 162 | # has_text_mask = 1, 163 | # has_image_mask = 0, 164 | # save_folder_name="pizza_1_0" 165 | # ), 166 | 167 | 168 | 169 | 170 | ] -------------------------------------------------------------------------------- /demo/gligen/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | from ldm.util import default_device 9 | 10 | device = default_device() 11 | 12 | def get_rank(): 13 | if not dist.is_available(): 14 | return 0 15 | 16 | if not dist.is_initialized(): 17 | return 0 18 | 19 | return dist.get_rank() 20 | 21 | 22 | def synchronize(): 23 | if not dist.is_available(): 24 | return 25 | if not dist.is_initialized(): 26 | return 27 | 28 | world_size = dist.get_world_size() 29 | if world_size == 1: 30 | return 31 | 32 | dist.barrier() 33 | 34 | 35 | def get_world_size(): 36 | if not dist.is_available(): 37 | return 1 38 | if not dist.is_initialized(): 39 | return 1 40 | return dist.get_world_size() 41 | 42 | 43 | def reduce_sum(tensor): 44 | if not dist.is_available(): 45 | return tensor 46 | 47 | if not dist.is_initialized(): 48 | return tensor 49 | 50 | tensor = tensor.clone() 51 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 52 | 53 | return tensor 54 | 55 | 56 | def gather_grad(params): 57 | world_size = get_world_size() 58 | 59 | if world_size == 1: 60 | return 61 | 62 | for param in params: 63 | if param.grad is not None: 64 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 65 | param.grad.data.div_(world_size) 66 | 67 | 68 | def all_gather(data): 69 | world_size = get_world_size() 70 | 71 | if world_size == 1: 72 | return [data] 73 | 74 | buffer = pickle.dumps(data) 75 | storage = torch.ByteStorage.from_buffer(buffer) 76 | tensor = torch.ByteTensor(storage).to(device) 77 | 78 | local_size = torch.IntTensor([tensor.numel()]).to(device) 79 | size_list = [torch.IntTensor([0]).to(device) for _ in range(world_size)] 80 | dist.all_gather(size_list, local_size) 81 | size_list = [int(size.item()) for size in size_list] 82 | max_size = max(size_list) 83 | 84 | tensor_list = [] 85 | for _ in size_list: 86 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device)) 87 | 88 | if local_size != max_size: 89 | padding = torch.ByteTensor(size=(max_size - local_size,)).to(device) 90 | tensor = torch.cat((tensor, padding), 0) 91 | 92 | dist.all_gather(tensor_list, tensor) 93 | 94 | data_list = [] 95 | 96 | for size, tensor in zip(size_list, tensor_list): 97 | buffer = tensor.cpu().numpy().tobytes()[:size] 98 | data_list.append(pickle.loads(buffer)) 99 | 100 | return data_list 101 | 102 | 103 | def reduce_loss_dict(loss_dict): 104 | world_size = get_world_size() 105 | 106 | if world_size < 2: 107 | return loss_dict 108 | 109 | with torch.no_grad(): 110 | keys = [] 111 | losses = [] 112 | 113 | for k in sorted(loss_dict.keys()): 114 | keys.append(k) 115 | losses.append(loss_dict[k]) 116 | 117 | losses = torch.stack(losses, 0) 118 | dist.reduce(losses, dst=0) 119 | 120 | if dist.get_rank() == 0: 121 | losses /= world_size 122 | 123 | reduced_losses = {k: v for k, v in zip(keys, losses)} 124 | 125 | return reduced_losses 126 | -------------------------------------------------------------------------------- /demo/gligen/image_projection_matrix: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/image_projection_matrix -------------------------------------------------------------------------------- /demo/gligen/ldm/__init__.py: -------------------------------------------------------------------------------- 1 | import gligen.evaluator as evaluator 2 | import gligen.trainer as trainer 3 | import gligen.ldm as ldm -------------------------------------------------------------------------------- /demo/gligen/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/data/__init__.py -------------------------------------------------------------------------------- /demo/gligen/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /demo/gligen/ldm/data/imagenet_train_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/data/imagenet_train_hr_indices.p -------------------------------------------------------------------------------- /demo/gligen/ldm/data/imagenet_val_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/data/imagenet_val_hr_indices.p -------------------------------------------------------------------------------- /demo/gligen/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /demo/gligen/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /demo/gligen/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #import pytorch_lightning as pl 4 | import torch.nn.functional as F 5 | from contextlib import contextmanager 6 | 7 | # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 8 | 9 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 10 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 11 | 12 | from ldm.util import instantiate_from_config 13 | 14 | 15 | 16 | 17 | class AutoencoderKL(nn.Module): 18 | def __init__(self, 19 | ddconfig, 20 | embed_dim, 21 | scale_factor=1 22 | ): 23 | super().__init__() 24 | self.encoder = Encoder(**ddconfig) 25 | self.decoder = Decoder(**ddconfig) 26 | assert ddconfig["double_z"] 27 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 28 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 29 | self.embed_dim = embed_dim 30 | self.scale_factor = scale_factor 31 | 32 | 33 | 34 | def encode(self, x): 35 | h = self.encoder(x) 36 | moments = self.quant_conv(h) 37 | posterior = DiagonalGaussianDistribution(moments) 38 | return posterior.sample() * self.scale_factor 39 | 40 | def decode(self, z): 41 | z = 1. / self.scale_factor * z 42 | z = self.post_quant_conv(z) 43 | dec = self.decoder(z) 44 | return dec 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /demo/gligen/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /demo/gligen/ldm/models/diffusion/ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | from ldm.modules.diffusionmodules.util import make_beta_schedule 6 | 7 | 8 | 9 | 10 | 11 | class DDPM(nn.Module): 12 | def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 13 | super().__init__() 14 | 15 | self.v_posterior = 0 16 | self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s) 17 | 18 | 19 | def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 20 | 21 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) 22 | alphas = 1. - betas 23 | alphas_cumprod = np.cumprod(alphas, axis=0) 24 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 25 | 26 | timesteps, = betas.shape 27 | self.num_timesteps = int(timesteps) 28 | self.linear_start = linear_start 29 | self.linear_end = linear_end 30 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 31 | 32 | to_torch = partial(torch.tensor, dtype=torch.float32) 33 | 34 | self.register_buffer('betas', to_torch(betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 44 | 45 | # calculations for posterior q(x_{t-1} | x_t, x_0) 46 | posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas 47 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 48 | 49 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 50 | 51 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 52 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 53 | self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 54 | self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /demo/gligen/ldm/models/diffusion/ldm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import tqdm 5 | from ldm.util import default 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor 7 | from .ddpm import DDPM 8 | 9 | 10 | 11 | class LatentDiffusion(DDPM): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | # hardcoded 15 | self.clip_denoised = False 16 | 17 | 18 | 19 | def q_sample(self, x_start, t, noise=None): 20 | noise = default(noise, lambda: torch.randn_like(x_start)) 21 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 22 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 23 | 24 | 25 | "Does not support DDPM sampling anymore. Only do DDIM or PLMS" 26 | 27 | # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = # 28 | 29 | # def predict_start_from_noise(self, x_t, t, noise): 30 | # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 31 | # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) 32 | 33 | # def q_posterior(self, x_start, x_t, t): 34 | # posterior_mean = ( 35 | # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + 36 | # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 37 | # ) 38 | # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) 39 | # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) 40 | # return posterior_mean, posterior_variance, posterior_log_variance_clipped 41 | 42 | 43 | # def p_mean_variance(self, model, x, c, t): 44 | 45 | # model_out = model(x, t, c) 46 | # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) 47 | 48 | # if self.clip_denoised: 49 | # x_recon.clamp_(-1., 1.) 50 | 51 | # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 52 | # return model_mean, posterior_variance, posterior_log_variance, x_recon 53 | 54 | 55 | # @torch.no_grad() 56 | # def p_sample(self, model, x, c, t): 57 | # b, *_, device = *x.shape, x.device 58 | # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, ) 59 | # noise = torch.randn_like(x) 60 | 61 | # # no noise when t == 0 62 | # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 63 | 64 | # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 65 | 66 | 67 | # @torch.no_grad() 68 | # def p_sample_loop(self, model, shape, c): 69 | # device = self.betas.device 70 | # b = shape[0] 71 | # img = torch.randn(shape, device=device) 72 | 73 | # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps) 74 | # for i in iterator: 75 | # ts = torch.full((b,), i, device=device, dtype=torch.long) 76 | # img, x0 = self.p_sample(model, img, c, ts) 77 | 78 | # return img 79 | 80 | 81 | # @torch.no_grad() 82 | # def sample(self, model, shape, c, uc=None, guidance_scale=None): 83 | # return self.p_sample_loop(model, shape, c) 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/diffusionmodules/positionnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ldm.modules.attention import BasicTransformerBlock 4 | from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | class PositionNet(nn.Module): 10 | def __init__(self, positive_len, out_dim, fourier_freqs=8): 11 | super().__init__() 12 | self.positive_len = positive_len 13 | self.out_dim = out_dim 14 | 15 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) 16 | self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy 17 | 18 | self.linears = nn.Sequential( 19 | nn.Linear( self.positive_len + self.position_dim, 512), 20 | nn.SiLU(), 21 | nn.Linear( 512, 512), 22 | nn.SiLU(), 23 | nn.Linear(512, out_dim), 24 | ) 25 | 26 | self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) 27 | self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) 28 | 29 | 30 | def forward(self, boxes, masks, positive_embeddings): 31 | B, N, _ = boxes.shape 32 | masks = masks.unsqueeze(-1) 33 | 34 | # embedding position (it may includes padding as placeholder) 35 | xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C 36 | 37 | # learnable null embedding 38 | positive_null = self.null_positive_feature.view(1,1,-1) 39 | xyxy_null = self.null_position_feature.view(1,1,-1) 40 | 41 | # replace padding with learnable null embedding 42 | positive_embeddings = positive_embeddings*masks + (1-masks)*positive_null 43 | xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null 44 | 45 | objs = self.linears( torch.cat([positive_embeddings, xyxy_embedding], dim=-1) ) 46 | assert objs.shape == torch.Size([B,N,self.out_dim]) 47 | return objs 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/diffusionmodules/positionnet_with_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ldm.modules.attention import BasicTransformerBlock 4 | from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | class PositionNet(nn.Module): 10 | def __init__(self, positive_len, out_dim, fourier_freqs=8): 11 | super().__init__() 12 | self.positive_len = positive_len 13 | self.out_dim = out_dim 14 | 15 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) 16 | self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy 17 | 18 | # -------------------------------------------------------------- # 19 | self.linears_text = nn.Sequential( 20 | nn.Linear( self.positive_len + self.position_dim, 512), 21 | nn.SiLU(), 22 | nn.Linear( 512, 512), 23 | nn.SiLU(), 24 | nn.Linear(512, out_dim), 25 | ) 26 | 27 | self.linears_image = nn.Sequential( 28 | nn.Linear( self.positive_len + self.position_dim, 512), 29 | nn.SiLU(), 30 | nn.Linear( 512, 512), 31 | nn.SiLU(), 32 | nn.Linear(512, out_dim), 33 | ) 34 | 35 | # -------------------------------------------------------------- # 36 | self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) 37 | self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) 38 | self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) 39 | 40 | 41 | def forward(self, boxes, masks, text_masks, image_masks, text_embeddings, image_embeddings): 42 | B, N, _ = boxes.shape 43 | masks = masks.unsqueeze(-1) # B*N*1 44 | text_masks = text_masks.unsqueeze(-1) # B*N*1 45 | image_masks = image_masks.unsqueeze(-1) # B*N*1 46 | 47 | # embedding position (it may includes padding as placeholder) 48 | xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C 49 | 50 | # learnable null embedding 51 | text_null = self.null_text_feature.view(1,1,-1) # 1*1*C 52 | image_null = self.null_image_feature.view(1,1,-1) # 1*1*C 53 | xyxy_null = self.null_position_feature.view(1,1,-1) # 1*1*C 54 | 55 | # replace padding with learnable null embedding 56 | text_embeddings = text_embeddings*text_masks + (1-text_masks)*text_null 57 | image_embeddings = image_embeddings*image_masks + (1-image_masks)*image_null 58 | xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null 59 | 60 | objs_text = self.linears_text( torch.cat([text_embeddings, xyxy_embedding], dim=-1) ) 61 | objs_image = self.linears_image( torch.cat([image_embeddings,xyxy_embedding], dim=-1) ) 62 | objs = torch.cat( [objs_text,objs_image], dim=1 ) 63 | 64 | assert objs.shape == torch.Size([B,N*2,self.out_dim]) 65 | return objs 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /demo/gligen/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /demo/gligen/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | def log_txt_as_img(wh, xc, size=10): 11 | # wh a tuple of (width, height) 12 | # xc a list of captions to plot 13 | b = len(xc) 14 | txts = list() 15 | for bi in range(b): 16 | txt = Image.new("RGB", wh, color="white") 17 | draw = ImageDraw.Draw(txt) 18 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 19 | nc = int(40 * (wh[0] / 256)) 20 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 21 | 22 | try: 23 | draw.text((0, 0), lines, fill="black", font=font) 24 | except UnicodeEncodeError: 25 | print("Cant encode string for logging. Skipping.") 26 | 27 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 28 | txts.append(txt) 29 | txts = np.stack(txts) 30 | txts = torch.tensor(txts) 31 | return txts 32 | 33 | 34 | def ismap(x): 35 | if not isinstance(x, torch.Tensor): 36 | return False 37 | return (len(x.shape) == 4) and (x.shape[1] > 3) 38 | 39 | 40 | def isimage(x): 41 | if not isinstance(x,torch.Tensor): 42 | return False 43 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 44 | 45 | 46 | def exists(x): 47 | return x is not None 48 | 49 | 50 | def default(val, d): 51 | if exists(val): 52 | return val 53 | return d() if isfunction(d) else d 54 | 55 | 56 | def mean_flat(tensor): 57 | """ 58 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 59 | Take the mean over all non-batch dimensions. 60 | """ 61 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 62 | 63 | 64 | def count_params(model, verbose=False): 65 | total_params = sum(p.numel() for p in model.parameters()) 66 | if verbose: 67 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 68 | return total_params 69 | 70 | 71 | def instantiate_from_config(config): 72 | if not "target" in config: 73 | if config == '__is_first_stage__': 74 | return None 75 | elif config == "__is_unconditional__": 76 | return None 77 | raise KeyError("Expected key `target` to instantiate.") 78 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 79 | 80 | 81 | def get_obj_from_str(string, reload=False): 82 | module, cls = string.rsplit(".", 1) 83 | if reload: 84 | module_imp = importlib.import_module(module) 85 | importlib.reload(module_imp) 86 | return getattr(importlib.import_module(module, package=None), cls) 87 | 88 | def default_device() -> str: 89 | if torch.cuda.is_available(): 90 | return "cuda" 91 | elif torch.backends.mps.is_available(): 92 | return "mps" 93 | else: 94 | return "cpu" -------------------------------------------------------------------------------- /demo/gligen/projection_matrix: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/gligen/projection_matrix -------------------------------------------------------------------------------- /demo/images/arg_corgis.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/arg_corgis.jpeg -------------------------------------------------------------------------------- /demo/images/blank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/blank.png -------------------------------------------------------------------------------- /demo/images/flower_beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/flower_beach.jpg -------------------------------------------------------------------------------- /demo/images/red_bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/red_bird.jpg -------------------------------------------------------------------------------- /demo/images/style_cloudpurple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/style_cloudpurple.png -------------------------------------------------------------------------------- /demo/images/style_gold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/style_gold.png -------------------------------------------------------------------------------- /demo/images/teddy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/demo/images/teddy.jpg -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | if not dist.is_initialized(): 23 | return 24 | 25 | world_size = dist.get_world_size() 26 | if world_size == 1: 27 | return 28 | 29 | dist.barrier() 30 | 31 | 32 | def get_world_size(): 33 | if not dist.is_available(): 34 | return 1 35 | if not dist.is_initialized(): 36 | return 1 37 | return dist.get_world_size() 38 | 39 | 40 | def reduce_sum(tensor): 41 | if not dist.is_available(): 42 | return tensor 43 | 44 | if not dist.is_initialized(): 45 | return tensor 46 | 47 | tensor = tensor.clone() 48 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 49 | 50 | return tensor 51 | 52 | 53 | def gather_grad(params): 54 | world_size = get_world_size() 55 | 56 | if world_size == 1: 57 | return 58 | 59 | for param in params: 60 | if param.grad is not None: 61 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 62 | param.grad.data.div_(world_size) 63 | 64 | 65 | def all_gather(data): 66 | world_size = get_world_size() 67 | 68 | if world_size == 1: 69 | return [data] 70 | 71 | buffer = pickle.dumps(data) 72 | storage = torch.ByteStorage.from_buffer(buffer) 73 | tensor = torch.ByteTensor(storage).to('cuda') 74 | 75 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 76 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 77 | dist.all_gather(size_list, local_size) 78 | size_list = [int(size.item()) for size in size_list] 79 | max_size = max(size_list) 80 | 81 | tensor_list = [] 82 | for _ in size_list: 83 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 84 | 85 | if local_size != max_size: 86 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 87 | tensor = torch.cat((tensor, padding), 0) 88 | 89 | dist.all_gather(tensor_list, tensor) 90 | 91 | data_list = [] 92 | 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | 100 | def reduce_loss_dict(loss_dict): 101 | world_size = get_world_size() 102 | 103 | if world_size < 2: 104 | return loss_dict 105 | 106 | with torch.no_grad(): 107 | keys = [] 108 | losses = [] 109 | 110 | for k in sorted(loss_dict.keys()): 111 | keys.append(k) 112 | losses.append(loss_dict[k]) 113 | 114 | losses = torch.stack(losses, 0) 115 | dist.reduce(losses, dst=0) 116 | 117 | if dist.get_rank() == 0: 118 | losses /= world_size 119 | 120 | reduced_losses = {k: v for k, v in zip(keys, losses)} 121 | 122 | return reduced_losses 123 | -------------------------------------------------------------------------------- /docs/gligen_controlnet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gligen/GLIGEN/f9dccb9c6cf48bad03c3666290a7dec8c5e58f3c/docs/gligen_controlnet.jpeg -------------------------------------------------------------------------------- /docs/gligen_vs_controlnet.MD: -------------------------------------------------------------------------------- 1 | ## GLIGEN vs ControlNet 2 | 3 | 4 | Both [GLIGEN](https://gligen.github.io/) and [ControlNet](https://github.com/lllyasviel/ControlNet) are controllable diffusion models. They only add new leanable parameters to adapt and modify intermediate features in existing diffusion models without changing original weights. 5 | 6 | What is the difference between GLIGEN and Controlnet? At a high-level they are similar, but differ in model design details. The following figure shows the architecture of a U-Net, which is commonly used as the normal diffusion model in this context. It consists of several encoders and decoders. Within each encoder, it sequentially includes a residual block, self-attention layer, and cross-attention layer. 7 |
8 |
9 |
18 |
19 |