├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── openai-clip-vit-large-patch14 │ ├── config.json │ ├── merges.txt │ ├── preprocessor_config.json │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json ├── configs └── controlcom.yaml ├── examples ├── background │ ├── 000000049931.png │ ├── 000000330273.png │ └── 000001219153.png ├── bbox │ ├── 000000049931.txt │ ├── 000000330273.txt │ └── 000001219153.txt ├── foreground │ ├── 000000049931.png │ ├── 000000330273.png │ └── 000001219153.png ├── foreground_mask │ ├── 000000049931.png │ ├── 000000330273.png │ └── 000001219153.png └── mask_bbox │ ├── 000000049931.png │ ├── 000000330273.png │ └── 000001219153.png ├── figures ├── architecture.png ├── controllability_necessity.jpg ├── controllable_results.jpg └── task.png ├── ldm ├── __pycache__ │ ├── lr_scheduler.cpython-38.pyc │ └── util.cpython-38.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── base.cpython-38.pyc │ │ ├── datamodule.cpython-38.pyc │ │ ├── open-images-control.cpython-38.pyc │ │ ├── open-images.cpython-38.pyc │ │ ├── open_images_control.cpython-38.pyc │ │ ├── open_images_pcache.cpython-38.pyc │ │ ├── oss_pcache.cpython-38.pyc │ │ └── preview │ │ │ └── Pytorch │ │ │ └── kernel_size.png │ ├── base.py │ ├── datamodule.py │ └── open_images_control.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ ├── autoencoder.cpython-38.pyc │ │ └── mask_generator.cpython-38.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── ddim.cpython-38.pyc │ │ ├── ddpm.cpython-38.pyc │ │ └── plms.cpython-38.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-38.pyc │ │ ├── attention_map.cpython-38.pyc │ │ ├── diffusion_metrics.cpython-38.pyc │ │ ├── ema.cpython-38.pyc │ │ ├── local_module.cpython-38.pyc │ │ ├── mask_blur.cpython-38.pyc │ │ ├── preview │ │ │ └── Pytorch │ │ │ │ ├── bbox.png │ │ │ │ ├── context.png │ │ │ │ ├── idx_bbox.png │ │ │ │ ├── indices.png │ │ │ │ ├── k.png │ │ │ │ ├── q.png │ │ │ │ └── x.png │ │ └── x_transformer.cpython-38.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ ├── preview │ │ │ │ └── Pytorch │ │ │ │ │ └── bbox.png │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── modules.cpython-38.pyc │ │ │ └── xf.cpython-38.pyc │ │ ├── modules.py │ │ └── xf.py │ ├── local_module.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── contperceptual.cpython-38.pyc │ │ │ └── mask_loss.cpython-38.pyc │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── requirements.txt ├── scripts └── inference.py ├── src └── taming-transformers │ ├── License.txt │ ├── README.md │ ├── configs │ ├── coco_cond_stage.yaml │ ├── coco_scene_images_transformer.yaml │ ├── custom_vqgan.yaml │ ├── drin_transformer.yaml │ ├── faceshq_transformer.yaml │ ├── faceshq_vqgan.yaml │ ├── imagenet_vqgan.yaml │ ├── imagenetdepth_vqgan.yaml │ ├── open_images_scene_images_transformer.yaml │ └── sflckr_cond_stage.yaml │ ├── environment.yaml │ ├── main.py │ ├── scripts │ ├── extract_depth.py │ ├── extract_segmentation.py │ ├── extract_submodel.py │ ├── make_samples.py │ ├── make_scene_samples.py │ ├── sample_conditional.py │ └── sample_fast.py │ ├── setup.py │ └── taming │ ├── __pycache__ │ └── util.cpython-38.pyc │ ├── data │ ├── ade20k.py │ ├── annotated_objects_coco.py │ ├── annotated_objects_dataset.py │ ├── annotated_objects_open_images.py │ ├── base.py │ ├── coco.py │ ├── conditional_builder │ │ ├── objects_bbox.py │ │ ├── objects_center_points.py │ │ └── utils.py │ ├── custom.py │ ├── faceshq.py │ ├── helper_types.py │ ├── image_transforms.py │ ├── imagenet.py │ ├── open_images_helper.py │ ├── sflckr.py │ └── utils.py │ ├── lr_scheduler.py │ ├── models │ ├── cond_transformer.py │ ├── dummy_cond_stage.py │ └── vqgan.py │ ├── modules │ ├── __pycache__ │ │ └── util.cpython-38.pyc │ ├── diffusionmodules │ │ └── model.py │ ├── discriminator │ │ ├── __pycache__ │ │ │ └── model.cpython-38.pyc │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── lpips.cpython-38.pyc │ │ │ └── vqperceptual.cpython-38.pyc │ │ ├── lpips.py │ │ ├── segmentation.py │ │ └── vqperceptual.py │ ├── misc │ │ └── coord.py │ ├── transformer │ │ ├── mingpt.py │ │ └── permuter.py │ ├── util.py │ └── vqvae │ │ ├── __pycache__ │ │ └── quantize.cpython-38.pyc │ │ └── quantize.py │ └── util.py └── test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | *.egg-info/ 3 | results/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BCMI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /checkpoints/openai-clip-vit-large-patch14/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "clip-vit-large-patch14/", 3 | "architectures": [ 4 | "CLIPModel" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 768, 10 | "text_config": { 11 | "_name_or_path": "", 12 | "add_cross_attention": false, 13 | "architectures": null, 14 | "attention_dropout": 0.0, 15 | "bad_words_ids": null, 16 | "bos_token_id": 0, 17 | "chunk_size_feed_forward": 0, 18 | "cross_attention_hidden_size": null, 19 | "decoder_start_token_id": null, 20 | "diversity_penalty": 0.0, 21 | "do_sample": false, 22 | "dropout": 0.0, 23 | "early_stopping": false, 24 | "encoder_no_repeat_ngram_size": 0, 25 | "eos_token_id": 2, 26 | "finetuning_task": null, 27 | "forced_bos_token_id": null, 28 | "forced_eos_token_id": null, 29 | "hidden_act": "quick_gelu", 30 | "hidden_size": 768, 31 | "id2label": { 32 | "0": "LABEL_0", 33 | "1": "LABEL_1" 34 | }, 35 | "initializer_factor": 1.0, 36 | "initializer_range": 0.02, 37 | "intermediate_size": 3072, 38 | "is_decoder": false, 39 | "is_encoder_decoder": false, 40 | "label2id": { 41 | "LABEL_0": 0, 42 | "LABEL_1": 1 43 | }, 44 | "layer_norm_eps": 1e-05, 45 | "length_penalty": 1.0, 46 | "max_length": 20, 47 | "max_position_embeddings": 77, 48 | "min_length": 0, 49 | "model_type": "clip_text_model", 50 | "no_repeat_ngram_size": 0, 51 | "num_attention_heads": 12, 52 | "num_beam_groups": 1, 53 | "num_beams": 1, 54 | "num_hidden_layers": 12, 55 | "num_return_sequences": 1, 56 | "output_attentions": false, 57 | "output_hidden_states": false, 58 | "output_scores": false, 59 | "pad_token_id": 1, 60 | "prefix": null, 61 | "problem_type": null, 62 | "projection_dim" : 768, 63 | "pruned_heads": {}, 64 | "remove_invalid_values": false, 65 | "repetition_penalty": 1.0, 66 | "return_dict": true, 67 | "return_dict_in_generate": false, 68 | "sep_token_id": null, 69 | "task_specific_params": null, 70 | "temperature": 1.0, 71 | "tie_encoder_decoder": false, 72 | "tie_word_embeddings": true, 73 | "tokenizer_class": null, 74 | "top_k": 50, 75 | "top_p": 1.0, 76 | "torch_dtype": null, 77 | "torchscript": false, 78 | "transformers_version": "4.16.0.dev0", 79 | "use_bfloat16": false, 80 | "vocab_size": 49408 81 | }, 82 | "text_config_dict": { 83 | "hidden_size": 768, 84 | "intermediate_size": 3072, 85 | "num_attention_heads": 12, 86 | "num_hidden_layers": 12, 87 | "projection_dim": 768 88 | }, 89 | "torch_dtype": "float32", 90 | "transformers_version": null, 91 | "vision_config": { 92 | "_name_or_path": "", 93 | "add_cross_attention": false, 94 | "architectures": null, 95 | "attention_dropout": 0.0, 96 | "bad_words_ids": null, 97 | "bos_token_id": null, 98 | "chunk_size_feed_forward": 0, 99 | "cross_attention_hidden_size": null, 100 | "decoder_start_token_id": null, 101 | "diversity_penalty": 0.0, 102 | "do_sample": false, 103 | "dropout": 0.0, 104 | "early_stopping": false, 105 | "encoder_no_repeat_ngram_size": 0, 106 | "eos_token_id": null, 107 | "finetuning_task": null, 108 | "forced_bos_token_id": null, 109 | "forced_eos_token_id": null, 110 | "hidden_act": "quick_gelu", 111 | "hidden_size": 1024, 112 | "id2label": { 113 | "0": "LABEL_0", 114 | "1": "LABEL_1" 115 | }, 116 | "image_size": 224, 117 | "initializer_factor": 1.0, 118 | "initializer_range": 0.02, 119 | "intermediate_size": 4096, 120 | "is_decoder": false, 121 | "is_encoder_decoder": false, 122 | "label2id": { 123 | "LABEL_0": 0, 124 | "LABEL_1": 1 125 | }, 126 | "layer_norm_eps": 1e-05, 127 | "length_penalty": 1.0, 128 | "max_length": 20, 129 | "min_length": 0, 130 | "model_type": "clip_vision_model", 131 | "no_repeat_ngram_size": 0, 132 | "num_attention_heads": 16, 133 | "num_beam_groups": 1, 134 | "num_beams": 1, 135 | "num_hidden_layers": 24, 136 | "num_return_sequences": 1, 137 | "output_attentions": false, 138 | "output_hidden_states": false, 139 | "output_scores": false, 140 | "pad_token_id": null, 141 | "patch_size": 14, 142 | "prefix": null, 143 | "problem_type": null, 144 | "projection_dim" : 768, 145 | "pruned_heads": {}, 146 | "remove_invalid_values": false, 147 | "repetition_penalty": 1.0, 148 | "return_dict": true, 149 | "return_dict_in_generate": false, 150 | "sep_token_id": null, 151 | "task_specific_params": null, 152 | "temperature": 1.0, 153 | "tie_encoder_decoder": false, 154 | "tie_word_embeddings": true, 155 | "tokenizer_class": null, 156 | "top_k": 50, 157 | "top_p": 1.0, 158 | "torch_dtype": null, 159 | "torchscript": false, 160 | "transformers_version": "4.16.0.dev0", 161 | "use_bfloat16": false 162 | }, 163 | "vision_config_dict": { 164 | "hidden_size": 1024, 165 | "intermediate_size": 4096, 166 | "num_attention_heads": 16, 167 | "num_hidden_layers": 24, 168 | "patch_size": 14, 169 | "projection_dim": 768 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /checkpoints/openai-clip-vit-large-patch14/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_normalize": true, 5 | "do_resize": true, 6 | "feature_extractor_type": "CLIPFeatureExtractor", 7 | "image_mean": [ 8 | 0.48145466, 9 | 0.4578275, 10 | 0.40821073 11 | ], 12 | "image_std": [ 13 | 0.26862954, 14 | 0.26130258, 15 | 0.27577711 16 | ], 17 | "resample": 3, 18 | "size": 224 19 | } 20 | -------------------------------------------------------------------------------- /checkpoints/openai-clip-vit-large-patch14/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": { 3 | "content": "<|endoftext|>", 4 | "single_word": false, 5 | "lstrip": false, 6 | "rstrip": false, 7 | "normalized": true, 8 | "__type": "AddedToken" 9 | }, 10 | "bos_token": { 11 | "content": "<|startoftext|>", 12 | "single_word": false, 13 | "lstrip": false, 14 | "rstrip": false, 15 | "normalized": true, 16 | "__type": "AddedToken" 17 | }, 18 | "eos_token": { 19 | "content": "<|endoftext|>", 20 | "single_word": false, 21 | "lstrip": false, 22 | "rstrip": false, 23 | "normalized": true, 24 | "__type": "AddedToken" 25 | }, 26 | "pad_token": "<|endoftext|>", 27 | "add_prefix_space": false, 28 | "errors": "replace", 29 | "do_lower_case": true, 30 | "name_or_path": "openai/clip-vit-base-patch32", 31 | "model_max_length": 77, 32 | "special_tokens_map_file": "./special_tokens_map.json", 33 | "tokenizer_class": "CLIPTokenizer" 34 | } 35 | -------------------------------------------------------------------------------- /configs/controlcom.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: composition 11 | cond_stage_key: image 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | finetune_full_unet: true 16 | augment_config: 17 | augment_types: 18 | - - 0 19 | - 0 20 | - - 1 21 | - 0 22 | - - 0 23 | - 1 24 | - - 1 25 | - 1 26 | sample_prob: 27 | - 0.15 28 | - 0.15 29 | - 0.35 30 | - 0.35 31 | sample_mode: random 32 | augment_box: false 33 | augment_background: true 34 | replace_background_prob: 1 35 | use_inpaint_background: false 36 | conditioning_key: crossattn 37 | monitor: val/loss_simple_ema 38 | u_cond_percent: 0.2 39 | scale_factor: 0.18215 40 | use_ema: true 41 | use_guidance: true 42 | local_uncond: same 43 | scheduler_config: 44 | target: ldm.lr_scheduler.LambdaLinearScheduler 45 | params: 46 | warm_up_steps: 47 | - 1000 48 | cycle_lengths: 49 | - 10000000000000 50 | f_start: 51 | - 1.0e-06 52 | f_max: 53 | - 1.0 54 | f_min: 55 | - 1.0 56 | unet_config: 57 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 58 | params: 59 | image_size: 32 60 | in_channels: 11 61 | out_channels: 4 62 | model_channels: 320 63 | attention_resolutions: 64 | - 4 65 | - 2 66 | - 1 67 | num_res_blocks: 2 68 | channel_mult: 69 | - 1 70 | - 2 71 | - 4 72 | - 4 73 | num_heads: 8 74 | use_spatial_transformer: true 75 | transformer_depth: 1 76 | context_dim: 768 77 | use_checkpoint: true 78 | legacy: false 79 | add_conv_in_front_of_unet: false 80 | local_encoder_config: 81 | conditioning_key: ldm.modules.local_module.LocalRefineBlock 82 | add_position_emb: false 83 | roi_size: 16 84 | context_dim: 1024 85 | resolutions: 86 | - 1 87 | - 2 88 | add_in_encoder: true 89 | add_in_decoder: true 90 | add_before_crossattn: false 91 | first_stage_config: 92 | target: ldm.models.autoencoder.AutoencoderKL 93 | params: 94 | embed_dim: 4 95 | monitor: val/rec_loss 96 | ddconfig: 97 | double_z: true 98 | z_channels: 4 99 | resolution: 256 100 | in_channels: 3 101 | out_ch: 3 102 | ch: 128 103 | ch_mult: 104 | - 1 105 | - 2 106 | - 4 107 | - 4 108 | num_res_blocks: 2 109 | attn_resolutions: [] 110 | dropout: 0.0 111 | lossconfig: 112 | target: torch.nn.Identity 113 | cond_stage_config: 114 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 115 | params: 116 | version: openai-clip-vit-large-patch14 117 | local_hidden_index: 12 118 | use_foreground_mask: false 119 | patchtoken_for_global: false -------------------------------------------------------------------------------- /examples/background/000000049931.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/background/000000049931.png -------------------------------------------------------------------------------- /examples/background/000000330273.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/background/000000330273.png -------------------------------------------------------------------------------- /examples/background/000001219153.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/background/000001219153.png -------------------------------------------------------------------------------- /examples/bbox/000000049931.txt: -------------------------------------------------------------------------------- 1 | 168 137 488 413 000000049931_GT.png 2 | -------------------------------------------------------------------------------- /examples/bbox/000000330273.txt: -------------------------------------------------------------------------------- 1 | 250 30 504 497 000000330273_GT.png 2 | -------------------------------------------------------------------------------- /examples/bbox/000001219153.txt: -------------------------------------------------------------------------------- 1 | 199 163 379 490 000001219153_GT.png 2 | -------------------------------------------------------------------------------- /examples/foreground/000000049931.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/foreground/000000049931.png -------------------------------------------------------------------------------- /examples/foreground/000000330273.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/foreground/000000330273.png -------------------------------------------------------------------------------- /examples/foreground/000001219153.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/foreground/000001219153.png -------------------------------------------------------------------------------- /examples/foreground_mask/000000049931.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/foreground_mask/000000049931.png -------------------------------------------------------------------------------- /examples/foreground_mask/000000330273.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/foreground_mask/000000330273.png -------------------------------------------------------------------------------- /examples/foreground_mask/000001219153.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/foreground_mask/000001219153.png -------------------------------------------------------------------------------- /examples/mask_bbox/000000049931.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/mask_bbox/000000049931.png -------------------------------------------------------------------------------- /examples/mask_bbox/000000330273.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/mask_bbox/000000330273.png -------------------------------------------------------------------------------- /examples/mask_bbox/000001219153.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/examples/mask_bbox/000001219153.png -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/figures/architecture.png -------------------------------------------------------------------------------- /figures/controllability_necessity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/figures/controllability_necessity.jpg -------------------------------------------------------------------------------- /figures/controllable_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/figures/controllable_results.jpg -------------------------------------------------------------------------------- /figures/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/figures/task.png -------------------------------------------------------------------------------- /ldm/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/datamodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/datamodule.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/open-images-control.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/open-images-control.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/open-images.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/open-images.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/open_images_control.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/open_images_control.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/open_images_pcache.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/open_images_pcache.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/oss_pcache.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/oss_pcache.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/preview/Pytorch/kernel_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/data/__pycache__/preview/Pytorch/kernel_size.png -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob, importlib, csv 2 | import numpy as np 3 | import time 4 | import torch 5 | import torchvision 6 | import pytorch_lightning as pl 7 | 8 | from packaging import version 9 | from omegaconf import OmegaConf 10 | from torch.utils.data import random_split, DataLoader, Dataset, Subset 11 | from functools import partial 12 | from PIL import Image 13 | 14 | from pytorch_lightning import seed_everything 15 | from pytorch_lightning.trainer import Trainer 16 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 17 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 18 | from pytorch_lightning.utilities import rank_zero_info 19 | 20 | from ldm.data.base import Txt2ImgIterableBaseDataset 21 | from ldm.util import instantiate_from_config 22 | import socket 23 | from pytorch_lightning.plugins.environments import ClusterEnvironment,SLURMEnvironment 24 | 25 | class WrappedDataset(Dataset): 26 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 27 | 28 | def __init__(self, dataset): 29 | self.data = dataset 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, idx): 35 | return self.data[idx] 36 | 37 | def worker_init_fn(_): 38 | worker_info = torch.utils.data.get_worker_info() 39 | 40 | dataset = worker_info.dataset 41 | worker_id = worker_info.id 42 | 43 | if isinstance(dataset, Txt2ImgIterableBaseDataset): 44 | split_size = dataset.num_records // worker_info.num_workers 45 | # reset num_records to the true number to retain reliable length information 46 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 47 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 48 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 49 | else: 50 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 51 | 52 | class DataModuleFromConfig(pl.LightningDataModule): 53 | def __init__(self, batch_size, num_workers=None, train=None, validation=None, test=None, predict=None, 54 | wrap=False, shuffle_test_loader=False, use_worker_init_fn=False, 55 | shuffle_val_dataloader=False): 56 | super().__init__() 57 | self.batch_size = batch_size 58 | self.dataset_configs = dict() 59 | self.num_workers = num_workers if num_workers is not None else batch_size 60 | self.use_worker_init_fn = use_worker_init_fn 61 | if train is not None: 62 | self.dataset_configs["train"] = train 63 | self.train_dataloader = self._train_dataloader 64 | if validation is not None: 65 | self.dataset_configs["validation"] = validation 66 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 67 | if test is not None: 68 | self.dataset_configs["test"] = test 69 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 70 | if predict is not None: 71 | self.dataset_configs["predict"] = predict 72 | self.predict_dataloader = self._predict_dataloader 73 | self.wrap = wrap 74 | 75 | def prepare_data(self): 76 | for data_cfg in self.dataset_configs.values(): 77 | instantiate_from_config(data_cfg) 78 | 79 | def setup(self, stage=None): 80 | self.datasets = dict( 81 | (k, instantiate_from_config(self.dataset_configs[k])) 82 | for k in self.dataset_configs) 83 | if self.wrap: 84 | for k in self.datasets: 85 | self.datasets[k] = WrappedDataset(self.datasets[k]) 86 | 87 | def _train_dataloader(self): 88 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 89 | if is_iterable_dataset or self.use_worker_init_fn: 90 | init_fn = worker_init_fn 91 | else: 92 | init_fn = None 93 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 94 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, 95 | worker_init_fn=init_fn) 96 | 97 | def _val_dataloader(self, shuffle=False): 98 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 99 | init_fn = worker_init_fn 100 | else: 101 | init_fn = None 102 | return DataLoader(self.datasets["validation"], 103 | batch_size=self.batch_size, 104 | num_workers=self.num_workers, 105 | worker_init_fn=init_fn, 106 | shuffle=shuffle) 107 | 108 | def _test_dataloader(self, shuffle=False): 109 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 110 | if is_iterable_dataset or self.use_worker_init_fn: 111 | init_fn = worker_init_fn 112 | else: 113 | init_fn = None 114 | 115 | # do not shuffle dataloader for iterable dataset 116 | shuffle = shuffle and (not is_iterable_dataset) 117 | 118 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 119 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) 120 | 121 | def _predict_dataloader(self, shuffle=False): 122 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 123 | init_fn = worker_init_fn 124 | else: 125 | init_fn = None 126 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 127 | num_workers=self.num_workers, worker_init_fn=init_fn) 128 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/__pycache__/mask_generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/__pycache__/mask_generator.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/plms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention_map.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/attention_map.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/diffusion_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/diffusion_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/local_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/local_module.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/mask_blur.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/mask_blur.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/bbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/bbox.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/context.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/context.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/idx_bbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/idx_bbox.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/indices.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/indices.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/k.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/q.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/preview/Pytorch/x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/preview/Pytorch/x.png -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/__pycache__/x_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/preview/Pytorch/bbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/diffusionmodules/__pycache__/preview/Pytorch/bbox.png -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/xf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/encoders/__pycache__/xf.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/xf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer implementation adapted from CLIP ViT: 3 | https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | def convert_module_to_f16(l): 13 | """ 14 | Convert primitive modules to float16. 15 | """ 16 | if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 17 | l.weight.data = l.weight.data.half() 18 | if l.bias is not None: 19 | l.bias.data = l.bias.data.half() 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """ 24 | Implementation that supports fp16 inputs but fp32 gains/biases. 25 | """ 26 | 27 | def forward(self, x: th.Tensor): 28 | return super().forward(x.float()).to(x.dtype) 29 | 30 | 31 | class MultiheadAttention(nn.Module): 32 | def __init__(self, n_ctx, width, heads): 33 | super().__init__() 34 | self.n_ctx = n_ctx 35 | self.width = width 36 | self.heads = heads 37 | self.c_qkv = nn.Linear(width, width * 3) 38 | self.c_proj = nn.Linear(width, width) 39 | self.attention = QKVMultiheadAttention(heads, n_ctx) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = self.attention(x) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class MLP(nn.Module): 49 | def __init__(self, width): 50 | super().__init__() 51 | self.width = width 52 | self.c_fc = nn.Linear(width, width * 4) 53 | self.c_proj = nn.Linear(width * 4, width) 54 | self.gelu = nn.GELU() 55 | 56 | def forward(self, x): 57 | return self.c_proj(self.gelu(self.c_fc(x))) 58 | 59 | 60 | class QKVMultiheadAttention(nn.Module): 61 | def __init__(self, n_heads: int, n_ctx: int): 62 | super().__init__() 63 | self.n_heads = n_heads 64 | self.n_ctx = n_ctx 65 | 66 | def forward(self, qkv): 67 | bs, n_ctx, width = qkv.shape 68 | attn_ch = width // self.n_heads // 3 69 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 70 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 71 | q, k, v = th.split(qkv, attn_ch, dim=-1) 72 | weight = th.einsum( 73 | "bthc,bshc->bhts", q * scale, k * scale 74 | ) # More stable with f16 than dividing afterwards 75 | wdtype = weight.dtype 76 | weight = th.softmax(weight.float(), dim=-1).type(wdtype) 77 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 78 | 79 | 80 | class ResidualAttentionBlock(nn.Module): 81 | def __init__( 82 | self, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | ): 87 | super().__init__() 88 | 89 | self.attn = MultiheadAttention( 90 | n_ctx, 91 | width, 92 | heads, 93 | ) 94 | self.ln_1 = LayerNorm(width) 95 | self.mlp = MLP(width) 96 | self.ln_2 = LayerNorm(width) 97 | 98 | def forward(self, x: th.Tensor): 99 | x = x + self.attn(self.ln_1(x)) 100 | x = x + self.mlp(self.ln_2(x)) 101 | return x 102 | 103 | 104 | class Transformer(nn.Module): 105 | def __init__( 106 | self, 107 | n_ctx: int, 108 | width: int, 109 | layers: int, 110 | heads: int, 111 | ): 112 | super().__init__() 113 | self.n_ctx = n_ctx 114 | self.width = width 115 | self.layers = layers 116 | self.resblocks = nn.ModuleList( 117 | [ 118 | ResidualAttentionBlock( 119 | n_ctx, 120 | width, 121 | heads, 122 | ) 123 | for _ in range(layers) 124 | ] 125 | ) 126 | 127 | def forward(self, x: th.Tensor): 128 | for block in self.resblocks: 129 | x = block(x) 130 | return x 131 | -------------------------------------------------------------------------------- /ldm/modules/local_module.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | from numpy import inner 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, einsum 7 | from einops import rearrange, repeat 8 | import os, sys 9 | from ldm.modules.diffusionmodules.util import checkpoint 10 | from ldm.modules.attention import CrossAttention, zero_module, Normalize 11 | from torchvision.ops import roi_align 12 | 13 | class FDN(nn.Module): 14 | # Spatially-Adaptive Normalization, homepage: https://nvlabs.github.io/SPADE/ 15 | # this code borrows from https://github.com/ShihaoZhaoZSH/Uni-ControlNet/blob/591036b78d13fd17b002ecd3be44d7c84473b47c/models/local_adapter.py#L31 16 | def __init__(self, norm_nc, label_nc): 17 | super().__init__() 18 | ks = 3 19 | pw = ks // 2 20 | self.param_free_norm = nn.GroupNorm(32, norm_nc, affine=False) 21 | self.conv_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw) 22 | self.conv_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw) 23 | 24 | def forward(self, x, context): 25 | normalized = self.param_free_norm(x) 26 | assert context.size()[2:] == x.size()[2:] 27 | gamma = self.conv_gamma(context) 28 | beta = self.conv_beta(context) 29 | out = normalized * gamma + beta 30 | return out 31 | 32 | 33 | class LocalRefineBlock(nn.Module): 34 | def __init__(self, in_channels, n_heads=1, d_head=320, 35 | depth=1, dropout=0., context_dim=1024, roi_size=9, 36 | add_positional_embedding=False, block_spade=False): 37 | super().__init__() 38 | n_heads, d_head = 1, in_channels 39 | self.in_channels = in_channels 40 | self.heads = n_heads 41 | inner_dim = n_heads * d_head 42 | self.add_positional_embedding = add_positional_embedding 43 | if self.add_positional_embedding: 44 | self.local_positional_embedding = nn.Parameter( 45 | torch.randn(roi_size ** 2, in_channels) / in_channels ** 0.5) 46 | self.local_norm = Normalize(in_channels) 47 | self.context_norm = Normalize(context_dim) 48 | self.scale = d_head ** -0.5 49 | self.roi_size = roi_size 50 | self.local_proj_in = nn.Conv2d(in_channels+2, 51 | inner_dim, 52 | kernel_size=1, 53 | stride=1, 54 | padding=0) 55 | self.context_conv = nn.Conv2d(context_dim, 56 | inner_dim, 57 | kernel_size=1, 58 | stride=1, 59 | padding=0) 60 | self.local_proj_out = nn.Conv2d( 61 | context_dim, 62 | in_channels, 63 | kernel_size=3, 64 | stride=1, 65 | padding=1) 66 | self.SPADE = FDN(in_channels, context_dim)#zero_module(FDN(in_channels, context_dim)) 67 | 68 | def forward(self, global_x, context, **kwargs): 69 | indicator, bbox, mask, mask_method = kwargs.get('indicator'), kwargs.get('bbox'), kwargs.get('mask'), kwargs.get('mask_method') 70 | context_map = rearrange(context, 'b (h w) c -> b c h w', h=16) 71 | b, c, h, w = global_x.shape 72 | indices = torch.arange(b).reshape((-1,1)).to(bbox.dtype) 73 | indices = indices.to(bbox.device) 74 | idx_bbox = torch.cat([indices, bbox], dim=1) # B,5 75 | x = roi_align(global_x, idx_bbox, output_size=self.roi_size) # B,C,roi_size,roi_size 76 | # do something on local feature 77 | if self.add_positional_embedding: 78 | x = x + self.local_positional_embedding[None,:,:].to(x.dtype) 79 | # cross-attention 80 | x = self.local_norm(x) 81 | ind_map = repeat(indicator, 'b n -> b n h w', h=x.shape[-2], w=x.shape[-1]) 82 | ind_map = ind_map.to(x.dtype) 83 | x = torch.cat([x, ind_map], dim=1) 84 | q = self.local_proj_in(x) 85 | q = rearrange(q, 'b c h w -> b (h w) c') 86 | # k = self.context_conv(torch.cat([self.context_norm(context_map), ind_map], dim=1)) 87 | k = self.context_conv(self.context_norm(context_map)) 88 | k = rearrange(k, 'b c h w -> b (h w) c') 89 | # v = self.local_proj_out(torch.cat([context_map, ind_map], dim=1)) 90 | v = self.local_proj_out(context_map) 91 | v = rearrange(v, 'b c h w -> b (h w) c') 92 | sim = torch.einsum('b i d, b j d -> b i j', q, k) 93 | attn = sim.softmax(dim=-1) # b,256,256 94 | x = torch.einsum('b i j, b j d -> b i d', attn, v) 95 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.roi_size, w=self.roi_size) 96 | # align conditional foreground feature map with roi feature with using their cross-attention map 97 | align_context = torch.einsum('b i j, b j d -> b i d', attn, context) 98 | align_context = rearrange(align_context, 'b (h w) c -> b c h w', h=self.roi_size, w=self.roi_size) 99 | # update local feature with Spatially-Adaptive Normalization 100 | if mask != None: 101 | # only performing SPADE in the foreground area 102 | flat_mask = rearrange(mask, 'b c h w -> b (h w) c') 103 | if mask_method == 'argmax': 104 | thresh = torch.max(attn, dim=-1)[0].unsqueeze(-1) 105 | attn = torch.where(attn >= thresh, torch.ones_like(attn), torch.zeros_like(attn)) 106 | align_mask = torch.einsum('b i j, b j d -> b i d', attn, flat_mask) 107 | align_mask = torch.clamp(align_mask, max=1.0, min=0.0) 108 | align_mask = rearrange(align_mask, 'b (h w) c -> b c h w', h=self.roi_size, w=self.roi_size) 109 | x = torch.where(align_mask > 0.5, self.SPADE(x, align_context), x) 110 | else: 111 | align_mask = None 112 | x = self.SPADE(x, align_context) 113 | # paste the updated region feature into original global feature 114 | bbox_int = (bbox * h).int() 115 | bbox_int[:,2:] = torch.maximum(bbox_int[:,2:], bbox_int[:,:2] + 1) 116 | for i in range(b): 117 | x1,y1,x2,y2 = bbox_int[i] 118 | local_res = F.interpolate(x[i:i+1], (y2-y1,x2-x1)) 119 | local_x0 = global_x[i:i+1,:,y1:y2,x1:x2] 120 | # update foreground region feature by residual learning 121 | global_x[i:i+1,:,y1:y2,x1:x2] = local_res + local_x0 122 | 123 | if align_mask != None: 124 | return global_x, align_mask 125 | else: 126 | return global_x, attn 127 | 128 | if __name__ == '__main__': 129 | local_att = LocalRefineBlock(320, 1, 320, context_dim=1024, roi_size=16) 130 | H = W = 64 131 | feature = torch.randn((1, 1, H, W)).float() 132 | feature = feature.repeat(3, 320, 1, 1).float() 133 | bbox = torch.tensor([[0.,0.,0.3,0.3], 134 | [0.1,0.1,0.5,0.5], 135 | [0.2,0.2,0.4,0.4]]).reshape((-1,4)).float() 136 | indicator = torch.randint(0, 2, (3, 2)) 137 | context = torch.randn(3, 256, 1024) 138 | out = local_att(feature, context, bbox=bbox, indicator=indicator) 139 | if isinstance(out, tuple): 140 | print([o.shape for o in out]) 141 | else: 142 | print(out.shape) -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/losses/__pycache__/contperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/losses/__pycache__/mask_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/ldm/modules/losses/__pycache__/mask_loss.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | import cv2 16 | 17 | 18 | def log_txt_as_img(wh, xc, size=10): 19 | # wh a tuple of (width, height) 20 | # xc a list of captions to plot 21 | b = len(xc) 22 | txts = list() 23 | for bi in range(b): 24 | txt = Image.new("RGB", wh, color="white") 25 | draw = ImageDraw.Draw(txt) 26 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 27 | nc = int(40 * (wh[0] / 256)) 28 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 29 | 30 | try: 31 | draw.text((0, 0), lines, fill="black", font=font) 32 | except UnicodeEncodeError: 33 | print("Cant encode string for logging. Skipping.") 34 | 35 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 36 | txts.append(txt) 37 | txts = np.stack(txts) 38 | txts = torch.tensor(txts) 39 | return txts 40 | 41 | 42 | def ismap(x): 43 | if not isinstance(x, torch.Tensor): 44 | return False 45 | return (len(x.shape) == 4) and (x.shape[1] > 3) 46 | 47 | 48 | def isimage(x): 49 | if not isinstance(x, torch.Tensor): 50 | return False 51 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 52 | 53 | 54 | def exists(x): 55 | return x is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | def mean_flat(tensor): 65 | """ 66 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 67 | Take the mean over all non-batch dimensions. 68 | """ 69 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 70 | 71 | 72 | def count_params(model, verbose=False): 73 | total_params = sum(p.numel() for p in model.parameters()) 74 | if verbose: 75 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 76 | return total_params 77 | 78 | 79 | def instantiate_from_config(config): 80 | if not "target" in config: 81 | if config == '__is_first_stage__': 82 | return None 83 | elif config == "__is_unconditional__": 84 | return None 85 | raise KeyError("Expected key `target` to instantiate.") 86 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 87 | 88 | 89 | def get_obj_from_str(string, reload=False): 90 | module, cls = string.rsplit(".", 1) 91 | if reload: 92 | module_imp = importlib.import_module(module) 93 | importlib.reload(module_imp) 94 | return getattr(importlib.import_module(module, package=None), cls) 95 | 96 | 97 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 98 | # create dummy dataset instance 99 | 100 | # run prefetching 101 | if idx_to_fn: 102 | res = func(data, worker_id=idx) 103 | else: 104 | res = func(data) 105 | Q.put([idx, res]) 106 | Q.put("Done") 107 | 108 | 109 | def parallel_data_prefetch( 110 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 111 | ): 112 | # if target_data_type not in ["ndarray", "list"]: 113 | # raise ValueError( 114 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 115 | # ) 116 | if isinstance(data, np.ndarray) and target_data_type == "list": 117 | raise ValueError("list expected but function got ndarray.") 118 | elif isinstance(data, abc.Iterable): 119 | if isinstance(data, dict): 120 | print( 121 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 122 | ) 123 | data = list(data.values()) 124 | if target_data_type == "ndarray": 125 | data = np.asarray(data) 126 | else: 127 | data = list(data) 128 | else: 129 | raise TypeError( 130 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 131 | ) 132 | 133 | if cpu_intensive: 134 | Q = mp.Queue(1000) 135 | proc = mp.Process 136 | else: 137 | Q = Queue(1000) 138 | proc = Thread 139 | # spawn processes 140 | if target_data_type == "ndarray": 141 | arguments = [ 142 | [func, Q, part, i, use_worker_id] 143 | for i, part in enumerate(np.array_split(data, n_proc)) 144 | ] 145 | else: 146 | step = ( 147 | int(len(data) / n_proc + 1) 148 | if len(data) % n_proc != 0 149 | else int(len(data) / n_proc) 150 | ) 151 | arguments = [ 152 | [func, Q, part, i, use_worker_id] 153 | for i, part in enumerate( 154 | [data[i: i + step] for i in range(0, len(data), step)] 155 | ) 156 | ] 157 | processes = [] 158 | for i in range(n_proc): 159 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 160 | processes += [p] 161 | 162 | # start processes 163 | print(f"Start prefetching...") 164 | import time 165 | 166 | start = time.time() 167 | gather_res = [[] for _ in range(n_proc)] 168 | try: 169 | for p in processes: 170 | p.start() 171 | 172 | k = 0 173 | while k < n_proc: 174 | # get result 175 | res = Q.get() 176 | if res == "Done": 177 | k += 1 178 | else: 179 | gather_res[res[0]] = res[1] 180 | 181 | except Exception as e: 182 | print("Exception: ", e) 183 | for p in processes: 184 | p.terminate() 185 | 186 | raise e 187 | finally: 188 | for p in processes: 189 | p.join() 190 | print(f"Prefetching complete. [{time.time() - start} sec.]") 191 | 192 | if target_data_type == 'ndarray': 193 | if not isinstance(gather_res[0], np.ndarray): 194 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 195 | 196 | # order outputs 197 | return np.concatenate(gather_res, axis=0) 198 | elif target_data_type == 'list': 199 | out = [] 200 | for r in gather_res: 201 | out.extend(r) 202 | return out 203 | else: 204 | return gather_res 205 | 206 | def clip2sd(x): 207 | # clip input tensor to stable diffusion tensor 208 | MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1,-1,1,1).to(x.device) 209 | STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1,-1,1,1).to(x.device) 210 | denorm = x * STD + MEAN 211 | sd_x = denorm * 2 - 1 212 | return sd_x 213 | 214 | def numpy_to_pil(images): 215 | """ 216 | Convert a numpy image or a batch of images to a PIL image. 217 | """ 218 | if images.ndim == 3: 219 | images = images[None, ...] 220 | images = (images * 255).round().astype("uint8") 221 | pil_images = [Image.fromarray(image) for image in images] 222 | return pil_images 223 | 224 | def tensor2numpy(image, normalized=False, image_size=(512, 512)): 225 | image = Resize(image_size)(image) 226 | if not normalized: 227 | image = (image + 1.0) / 2.0 # -1,1 -> 0,1; b,c,h,w 228 | image = torch.clamp(image, 0., 1.) 229 | if image.dim() == 3: 230 | image = image.unsqueeze(0) 231 | image = image.permute(0, 2, 3, 1) 232 | image = image.numpy() 233 | image = (image * 255).astype(np.uint8) 234 | return image 235 | 236 | def draw_bbox_on_background(image_nps, norm_bbox, color=(255,215,0), thickness=3): 237 | dst_list = [] 238 | for i in range(image_nps.shape[0]): 239 | img = image_nps[i].copy() 240 | h,w,_ = img.shape 241 | x1 = int(norm_bbox[0,0] * w) 242 | y1 = int(norm_bbox[0,1] * h) 243 | x2 = int(norm_bbox[0,2] * w) 244 | y2 = int(norm_bbox[0,3] * h) 245 | dst = cv2.rectangle(img, (x1, y1), (x2, y2), color=color, thickness=thickness) 246 | dst_list.append(dst) 247 | dst_nps = np.stack(dst_list, axis=0) 248 | return dst_nps 249 | 250 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | academictorrents==2.3.3 2 | albumentations==1.3.0 3 | einops==0.3.0 4 | imageio==2.9.0 5 | more_itertools==10.2.0 6 | natsort==8.4.0 7 | numpy==1.21.6 8 | omegaconf==2.3.0 9 | opencv_python==4.1.2.30 10 | opencv_python_headless==4.7.0.72 11 | packaging==23.2 12 | Pillow==10.1.0 13 | pudb==2019.2 14 | pytorch_lightning==1.9.0 15 | PyYAML==6.0.1 16 | Requests==2.31.0 17 | scipy==1.9.1 18 | setuptools==65.6.3 19 | skimage==0.0 20 | streamlit==1.20.0 21 | torch==1.10.1 22 | torchmetrics==1.2.0 23 | torchvision==0.11.2 24 | tqdm==4.65.0 25 | transformers==4.27.4 26 | typing_extensions==4.10.0 27 | -------------------------------------------------------------------------------- /src/taming-transformers/License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ 20 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/coco_cond_stage.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQSegmentationModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | image_key: "segmentation" 8 | n_labels: 183 9 | ddconfig: 10 | double_z: false 11 | z_channels: 256 12 | resolution: 256 13 | in_channels: 183 14 | out_ch: 183 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 1 19 | - 2 20 | - 2 21 | - 4 22 | num_res_blocks: 2 23 | attn_resolutions: 24 | - 16 25 | dropout: 0.0 26 | 27 | lossconfig: 28 | target: taming.modules.losses.segmentation.BCELossWithQuant 29 | params: 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 12 36 | train: 37 | target: taming.data.coco.CocoImagesAndCaptionsTrain 38 | params: 39 | size: 296 40 | crop_size: 256 41 | onehot_segmentation: true 42 | use_stuffthing: true 43 | validation: 44 | target: taming.data.coco.CocoImagesAndCaptionsValidation 45 | params: 46 | size: 256 47 | crop_size: 256 48 | onehot_segmentation: true 49 | use_stuffthing: true 50 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/coco_scene_images_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: objects_bbox 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 8192 10 | block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim) 11 | n_layer: 40 12 | n_head: 16 13 | n_embd: 1408 14 | embd_pdrop: 0.1 15 | resid_pdrop: 0.1 16 | attn_pdrop: 0.1 17 | first_stage_config: 18 | target: taming.models.vqgan.VQModel 19 | params: 20 | ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/ 21 | embed_dim: 256 22 | n_embed: 8192 23 | ddconfig: 24 | double_z: false 25 | z_channels: 256 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 1 33 | - 2 34 | - 2 35 | - 4 36 | num_res_blocks: 2 37 | attn_resolutions: 38 | - 16 39 | dropout: 0.0 40 | lossconfig: 41 | target: taming.modules.losses.DummyLoss 42 | cond_stage_config: 43 | target: taming.models.dummy_cond_stage.DummyCondStage 44 | params: 45 | conditional_key: objects_bbox 46 | 47 | data: 48 | target: main.DataModuleFromConfig 49 | params: 50 | batch_size: 6 51 | train: 52 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 53 | params: 54 | data_path: data/coco_annotations_100 # substitute with path to full dataset 55 | split: train 56 | keys: [image, objects_bbox, file_name, annotations] 57 | no_tokens: 8192 58 | target_image_size: 256 59 | min_object_area: 0.00001 60 | min_objects_per_image: 2 61 | max_objects_per_image: 30 62 | crop_method: random-1d 63 | random_flip: true 64 | use_group_parameter: true 65 | encode_crop: true 66 | validation: 67 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 68 | params: 69 | data_path: data/coco_annotations_100 # substitute with path to full dataset 70 | split: validation 71 | keys: [image, objects_bbox, file_name, annotations] 72 | no_tokens: 8192 73 | target_image_size: 256 74 | min_object_area: 0.00001 75 | min_objects_per_image: 2 76 | max_objects_per_image: 30 77 | crop_method: center 78 | random_flip: false 79 | use_group_parameter: true 80 | encode_crop: true 81 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/custom_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 10000 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 5 32 | num_workers: 8 33 | train: 34 | target: taming.data.custom.CustomTrain 35 | params: 36 | training_images_list_file: some/training.txt 37 | size: 256 38 | validation: 39 | target: taming.data.custom.CustomTest 40 | params: 41 | test_images_list_file: some/test.txt 42 | size: 256 43 | 44 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/drin_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: depth 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 1024 10 | block_size: 512 11 | n_layer: 24 12 | n_head: 16 13 | n_embd: 1024 14 | first_stage_config: 15 | target: taming.models.vqgan.VQModel 16 | params: 17 | ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt 18 | embed_dim: 256 19 | n_embed: 1024 20 | ddconfig: 21 | double_z: false 22 | z_channels: 256 23 | resolution: 256 24 | in_channels: 3 25 | out_ch: 3 26 | ch: 128 27 | ch_mult: 28 | - 1 29 | - 1 30 | - 2 31 | - 2 32 | - 4 33 | num_res_blocks: 2 34 | attn_resolutions: 35 | - 16 36 | dropout: 0.0 37 | lossconfig: 38 | target: taming.modules.losses.DummyLoss 39 | cond_stage_config: 40 | target: taming.models.vqgan.VQModel 41 | params: 42 | ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt 43 | embed_dim: 256 44 | n_embed: 1024 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 1 50 | out_ch: 1 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | lossconfig: 63 | target: taming.modules.losses.DummyLoss 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 2 69 | num_workers: 8 70 | train: 71 | target: taming.data.imagenet.RINTrainWithDepth 72 | params: 73 | size: 256 74 | validation: 75 | target: taming.data.imagenet.RINValidationWithDepth 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/faceshq_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: coord 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 1024 10 | block_size: 512 11 | n_layer: 24 12 | n_head: 16 13 | n_embd: 1024 14 | first_stage_config: 15 | target: taming.models.vqgan.VQModel 16 | params: 17 | ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt 18 | embed_dim: 256 19 | n_embed: 1024 20 | ddconfig: 21 | double_z: false 22 | z_channels: 256 23 | resolution: 256 24 | in_channels: 3 25 | out_ch: 3 26 | ch: 128 27 | ch_mult: 28 | - 1 29 | - 1 30 | - 2 31 | - 2 32 | - 4 33 | num_res_blocks: 2 34 | attn_resolutions: 35 | - 16 36 | dropout: 0.0 37 | lossconfig: 38 | target: taming.modules.losses.DummyLoss 39 | cond_stage_config: 40 | target: taming.modules.misc.coord.CoordStage 41 | params: 42 | n_embed: 1024 43 | down_factor: 16 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 2 49 | num_workers: 8 50 | train: 51 | target: taming.data.faceshq.FacesHQTrain 52 | params: 53 | size: 256 54 | crop_size: 256 55 | coord: True 56 | validation: 57 | target: taming.data.faceshq.FacesHQValidation 58 | params: 59 | size: 256 60 | crop_size: 256 61 | coord: True 62 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/faceshq_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 30001 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 3 32 | num_workers: 8 33 | train: 34 | target: taming.data.faceshq.FacesHQTrain 35 | params: 36 | size: 256 37 | crop_size: 256 38 | validation: 39 | target: taming.data.faceshq.FacesHQValidation 40 | params: 41 | size: 256 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/imagenet_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 250001 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 12 32 | num_workers: 24 33 | train: 34 | target: taming.data.imagenet.ImageNetTrain 35 | params: 36 | config: 37 | size: 256 38 | validation: 39 | target: taming.data.imagenet.ImageNetValidation 40 | params: 41 | config: 42 | size: 256 43 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/imagenetdepth_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | image_key: depth 8 | ddconfig: 9 | double_z: False 10 | z_channels: 256 11 | resolution: 256 12 | in_channels: 1 13 | out_ch: 1 14 | ch: 128 15 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 16 | num_res_blocks: 2 17 | attn_resolutions: [16] 18 | dropout: 0.0 19 | 20 | lossconfig: 21 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 22 | params: 23 | disc_conditional: False 24 | disc_in_channels: 1 25 | disc_start: 50001 26 | disc_weight: 0.75 27 | codebook_weight: 1.0 28 | 29 | data: 30 | target: main.DataModuleFromConfig 31 | params: 32 | batch_size: 3 33 | num_workers: 8 34 | train: 35 | target: taming.data.imagenet.ImageNetTrainWithDepth 36 | params: 37 | size: 256 38 | validation: 39 | target: taming.data.imagenet.ImageNetValidationWithDepth 40 | params: 41 | size: 256 42 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/open_images_scene_images_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: objects_bbox 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 8192 10 | block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim) 11 | n_layer: 36 12 | n_head: 16 13 | n_embd: 1536 14 | embd_pdrop: 0.1 15 | resid_pdrop: 0.1 16 | attn_pdrop: 0.1 17 | first_stage_config: 18 | target: taming.models.vqgan.VQModel 19 | params: 20 | ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/ 21 | embed_dim: 256 22 | n_embed: 8192 23 | ddconfig: 24 | double_z: false 25 | z_channels: 256 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 1 33 | - 2 34 | - 2 35 | - 4 36 | num_res_blocks: 2 37 | attn_resolutions: 38 | - 16 39 | dropout: 0.0 40 | lossconfig: 41 | target: taming.modules.losses.DummyLoss 42 | cond_stage_config: 43 | target: taming.models.dummy_cond_stage.DummyCondStage 44 | params: 45 | conditional_key: objects_bbox 46 | 47 | data: 48 | target: main.DataModuleFromConfig 49 | params: 50 | batch_size: 6 51 | train: 52 | target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages 53 | params: 54 | data_path: data/open_images_annotations_100 # substitute with path to full dataset 55 | split: train 56 | keys: [image, objects_bbox, file_name, annotations] 57 | no_tokens: 8192 58 | target_image_size: 256 59 | category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility 60 | category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco 61 | min_object_area: 0.0001 62 | min_objects_per_image: 2 63 | max_objects_per_image: 30 64 | crop_method: random-2d 65 | random_flip: true 66 | use_group_parameter: true 67 | use_additional_parameters: true 68 | encode_crop: true 69 | validation: 70 | target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages 71 | params: 72 | data_path: data/open_images_annotations_100 # substitute with path to full dataset 73 | split: validation 74 | keys: [image, objects_bbox, file_name, annotations] 75 | no_tokens: 8192 76 | target_image_size: 256 77 | category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility 78 | category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco 79 | min_object_area: 0.0001 80 | min_objects_per_image: 2 81 | max_objects_per_image: 30 82 | crop_method: center 83 | random_flip: false 84 | use_group_parameter: true 85 | use_additional_parameters: true 86 | encode_crop: true 87 | -------------------------------------------------------------------------------- /src/taming-transformers/configs/sflckr_cond_stage.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQSegmentationModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | image_key: "segmentation" 8 | n_labels: 182 9 | ddconfig: 10 | double_z: false 11 | z_channels: 256 12 | resolution: 256 13 | in_channels: 182 14 | out_ch: 182 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 1 19 | - 2 20 | - 2 21 | - 4 22 | num_res_blocks: 2 23 | attn_resolutions: 24 | - 16 25 | dropout: 0.0 26 | 27 | lossconfig: 28 | target: taming.modules.losses.segmentation.BCELossWithQuant 29 | params: 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: cutlit.DataModuleFromConfig 34 | params: 35 | batch_size: 12 36 | train: 37 | target: taming.data.sflckr.Examples # adjust 38 | params: 39 | size: 256 40 | validation: 41 | target: taming.data.sflckr.Examples # adjust 42 | params: 43 | size: 256 44 | -------------------------------------------------------------------------------- /src/taming-transformers/environment.yaml: -------------------------------------------------------------------------------- 1 | name: taming 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.0.8 19 | - omegaconf==2.0.0 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - more-itertools>=8.0.0 24 | - transformers==4.3.1 25 | - -e . 26 | -------------------------------------------------------------------------------- /src/taming-transformers/scripts/extract_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import trange 5 | from PIL import Image 6 | 7 | 8 | def get_state(gpu): 9 | import torch 10 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") 11 | if gpu: 12 | midas.cuda() 13 | midas.eval() 14 | 15 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 16 | transform = midas_transforms.default_transform 17 | 18 | state = {"model": midas, 19 | "transform": transform} 20 | return state 21 | 22 | 23 | def depth_to_rgba(x): 24 | assert x.dtype == np.float32 25 | assert len(x.shape) == 2 26 | y = x.copy() 27 | y.dtype = np.uint8 28 | y = y.reshape(x.shape+(4,)) 29 | return np.ascontiguousarray(y) 30 | 31 | 32 | def rgba_to_depth(x): 33 | assert x.dtype == np.uint8 34 | assert len(x.shape) == 3 and x.shape[2] == 4 35 | y = x.copy() 36 | y.dtype = np.float32 37 | y = y.reshape(x.shape[:2]) 38 | return np.ascontiguousarray(y) 39 | 40 | 41 | def run(x, state): 42 | model = state["model"] 43 | transform = state["transform"] 44 | hw = x.shape[:2] 45 | with torch.no_grad(): 46 | prediction = model(transform((x + 1.0) * 127.5).cuda()) 47 | prediction = torch.nn.functional.interpolate( 48 | prediction.unsqueeze(1), 49 | size=hw, 50 | mode="bicubic", 51 | align_corners=False, 52 | ).squeeze() 53 | output = prediction.cpu().numpy() 54 | return output 55 | 56 | 57 | def get_filename(relpath, level=-2): 58 | # save class folder structure and filename: 59 | fn = relpath.split(os.sep)[level:] 60 | folder = fn[-2] 61 | file = fn[-1].split('.')[0] 62 | return folder, file 63 | 64 | 65 | def save_depth(dataset, path, debug=False): 66 | os.makedirs(path) 67 | N = len(dset) 68 | if debug: 69 | N = 10 70 | state = get_state(gpu=True) 71 | for idx in trange(N, desc="Data"): 72 | ex = dataset[idx] 73 | image, relpath = ex["image"], ex["relpath"] 74 | folder, filename = get_filename(relpath) 75 | # prepare 76 | folderabspath = os.path.join(path, folder) 77 | os.makedirs(folderabspath, exist_ok=True) 78 | savepath = os.path.join(folderabspath, filename) 79 | # run model 80 | xout = run(image, state) 81 | I = depth_to_rgba(xout) 82 | Image.fromarray(I).save("{}.png".format(savepath)) 83 | 84 | 85 | if __name__ == "__main__": 86 | from taming.data.imagenet import ImageNetTrain, ImageNetValidation 87 | out = "data/imagenet_depth" 88 | if not os.path.exists(out): 89 | print("Please create a folder or symlink '{}' to extract depth data ".format(out) + 90 | "(be prepared that the output size will be larger than ImageNet itself).") 91 | exit(1) 92 | 93 | # go 94 | dset = ImageNetValidation() 95 | abspath = os.path.join(out, "val") 96 | if os.path.exists(abspath): 97 | print("{} exists - not doing anything.".format(abspath)) 98 | else: 99 | print("preparing {}".format(abspath)) 100 | save_depth(dset, abspath) 101 | print("done with validation split") 102 | 103 | dset = ImageNetTrain() 104 | abspath = os.path.join(out, "train") 105 | if os.path.exists(abspath): 106 | print("{} exists - not doing anything.".format(abspath)) 107 | else: 108 | print("preparing {}".format(abspath)) 109 | save_depth(dset, abspath) 110 | print("done with train split") 111 | 112 | print("done done.") 113 | -------------------------------------------------------------------------------- /src/taming-transformers/scripts/extract_segmentation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import numpy as np 3 | import scipy 4 | import torch 5 | import torch.nn as nn 6 | from scipy import ndimage 7 | from tqdm import tqdm, trange 8 | from PIL import Image 9 | import torch.hub 10 | import torchvision 11 | import torch.nn.functional as F 12 | 13 | # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from 14 | # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth 15 | # and put the path here 16 | CKPT_PATH = "TODO" 17 | 18 | rescale = lambda x: (x + 1.) / 2. 19 | 20 | def rescale_bgr(x): 21 | x = (x+1)*127.5 22 | x = torch.flip(x, dims=[0]) 23 | return x 24 | 25 | 26 | class COCOStuffSegmenter(nn.Module): 27 | def __init__(self, config): 28 | super().__init__() 29 | self.config = config 30 | self.n_labels = 182 31 | model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) 32 | ckpt_path = CKPT_PATH 33 | model.load_state_dict(torch.load(ckpt_path)) 34 | self.model = model 35 | 36 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 37 | self.image_transform = torchvision.transforms.Compose([ 38 | torchvision.transforms.Lambda(lambda image: torch.stack( 39 | [normalize(rescale_bgr(x)) for x in image])) 40 | ]) 41 | 42 | def forward(self, x, upsample=None): 43 | x = self._pre_process(x) 44 | x = self.model(x) 45 | if upsample is not None: 46 | x = torch.nn.functional.upsample_bilinear(x, size=upsample) 47 | return x 48 | 49 | def _pre_process(self, x): 50 | x = self.image_transform(x) 51 | return x 52 | 53 | @property 54 | def mean(self): 55 | # bgr 56 | return [104.008, 116.669, 122.675] 57 | 58 | @property 59 | def std(self): 60 | return [1.0, 1.0, 1.0] 61 | 62 | @property 63 | def input_size(self): 64 | return [3, 224, 224] 65 | 66 | 67 | def run_model(img, model): 68 | model = model.eval() 69 | with torch.no_grad(): 70 | segmentation = model(img, upsample=(img.shape[2], img.shape[3])) 71 | segmentation = torch.argmax(segmentation, dim=1, keepdim=True) 72 | return segmentation.detach().cpu() 73 | 74 | 75 | def get_input(batch, k): 76 | x = batch[k] 77 | if len(x.shape) == 3: 78 | x = x[..., None] 79 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 80 | return x.float() 81 | 82 | 83 | def save_segmentation(segmentation, path): 84 | # --> class label to uint8, save as png 85 | os.makedirs(os.path.dirname(path), exist_ok=True) 86 | assert len(segmentation.shape)==4 87 | assert segmentation.shape[0]==1 88 | for seg in segmentation: 89 | seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) 90 | seg = Image.fromarray(seg) 91 | seg.save(path) 92 | 93 | 94 | def iterate_dataset(dataloader, destpath, model): 95 | os.makedirs(destpath, exist_ok=True) 96 | num_processed = 0 97 | for i, batch in tqdm(enumerate(dataloader), desc="Data"): 98 | try: 99 | img = get_input(batch, "image") 100 | img = img.cuda() 101 | seg = run_model(img, model) 102 | 103 | path = batch["relative_file_path_"][0] 104 | path = os.path.splitext(path)[0] 105 | 106 | path = os.path.join(destpath, path + ".png") 107 | save_segmentation(seg, path) 108 | num_processed += 1 109 | except Exception as e: 110 | print(e) 111 | print("but anyhow..") 112 | 113 | print("Processed {} files. Bye.".format(num_processed)) 114 | 115 | 116 | from taming.data.sflckr import Examples 117 | from torch.utils.data import DataLoader 118 | 119 | if __name__ == "__main__": 120 | dest = sys.argv[1] 121 | batchsize = 1 122 | print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) 123 | 124 | model = COCOStuffSegmenter({}).cuda() 125 | print("Instantiated model.") 126 | 127 | dataset = Examples() 128 | dloader = DataLoader(dataset, batch_size=batchsize) 129 | iterate_dataset(dataloader=dloader, destpath=dest, model=model) 130 | print("done.") 131 | -------------------------------------------------------------------------------- /src/taming-transformers/scripts/extract_submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | inpath = sys.argv[1] 6 | outpath = sys.argv[2] 7 | submodel = "cond_stage_model" 8 | if len(sys.argv) > 3: 9 | submodel = sys.argv[3] 10 | 11 | print("Extracting {} from {} to {}.".format(submodel, inpath, outpath)) 12 | 13 | sd = torch.load(inpath, map_location="cpu") 14 | new_sd = {"state_dict": dict((k.split(".", 1)[-1],v) 15 | for k,v in sd["state_dict"].items() 16 | if k.startswith("cond_stage_model"))} 17 | torch.save(new_sd, outpath) 18 | -------------------------------------------------------------------------------- /src/taming-transformers/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='taming-transformers', 5 | version='0.0.1', 6 | description='Taming Transformers for High-Resolution Image Synthesis', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/annotated_objects_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | from typing import Iterable, Dict, List, Callable, Any 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 10 | from taming.data.helper_types import Annotation, ImageDescription, Category 11 | 12 | COCO_PATH_STRUCTURE = { 13 | 'train': { 14 | 'top_level': '', 15 | 'instances_annotations': 'annotations/instances_train2017.json', 16 | 'stuff_annotations': 'annotations/stuff_train2017.json', 17 | 'files': 'train2017' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'instances_annotations': 'annotations/instances_val2017.json', 22 | 'stuff_annotations': 'annotations/stuff_val2017.json', 23 | 'files': 'val2017' 24 | } 25 | } 26 | 27 | 28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: 29 | return { 30 | str(img['id']): ImageDescription( 31 | id=img['id'], 32 | license=img.get('license'), 33 | file_name=img['file_name'], 34 | coco_url=img['coco_url'], 35 | original_size=(img['width'], img['height']), 36 | date_captured=img.get('date_captured'), 37 | flickr_url=img.get('flickr_url') 38 | ) 39 | for img in description_json 40 | } 41 | 42 | 43 | def load_categories(category_json: Iterable) -> Dict[str, Category]: 44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) 45 | for cat in category_json if cat['name'] != 'other'} 46 | 47 | 48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], 49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: 50 | annotations = defaultdict(list) 51 | total = sum(len(a) for a in annotations_json) 52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): 53 | image_id = str(ann['image_id']) 54 | if image_id not in image_descriptions: 55 | raise ValueError(f'image_id [{image_id}] has no image description.') 56 | category_id = ann['category_id'] 57 | try: 58 | category_no = category_no_for_id(str(category_id)) 59 | except KeyError: 60 | continue 61 | 62 | width, height = image_descriptions[image_id].original_size 63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) 64 | 65 | annotations[image_id].append( 66 | Annotation( 67 | id=ann['id'], 68 | area=bbox[2]*bbox[3], # use bbox area 69 | is_group_of=ann['iscrowd'], 70 | image_id=ann['image_id'], 71 | bbox=bbox, 72 | category_id=str(category_id), 73 | category_no=category_no 74 | ) 75 | ) 76 | return dict(annotations) 77 | 78 | 79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset): 80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): 81 | """ 82 | @param data_path: is the path to the following folder structure: 83 | coco/ 84 | ├── annotations 85 | │ ├── instances_train2017.json 86 | │ ├── instances_val2017.json 87 | │ ├── stuff_train2017.json 88 | │ └── stuff_val2017.json 89 | ├── train2017 90 | │ ├── 000000000009.jpg 91 | │ ├── 000000000025.jpg 92 | │ └── ... 93 | ├── val2017 94 | │ ├── 000000000139.jpg 95 | │ ├── 000000000285.jpg 96 | │ └── ... 97 | @param: split: one of 'train' or 'validation' 98 | @param: desired image size (give square images) 99 | """ 100 | super().__init__(**kwargs) 101 | self.use_things = use_things 102 | self.use_stuff = use_stuff 103 | 104 | with open(self.paths['instances_annotations']) as f: 105 | inst_data_json = json.load(f) 106 | with open(self.paths['stuff_annotations']) as f: 107 | stuff_data_json = json.load(f) 108 | 109 | category_jsons = [] 110 | annotation_jsons = [] 111 | if self.use_things: 112 | category_jsons.append(inst_data_json['categories']) 113 | annotation_jsons.append(inst_data_json['annotations']) 114 | if self.use_stuff: 115 | category_jsons.append(stuff_data_json['categories']) 116 | annotation_jsons.append(stuff_data_json['annotations']) 117 | 118 | self.categories = load_categories(chain(*category_jsons)) 119 | self.filter_categories() 120 | self.setup_category_id_and_number() 121 | 122 | self.image_descriptions = load_image_descriptions(inst_data_json['images']) 123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) 124 | self.annotations = self.filter_object_number(annotations, self.min_object_area, 125 | self.min_objects_per_image, self.max_objects_per_image) 126 | self.image_ids = list(self.annotations.keys()) 127 | self.clean_up_annotations_and_image_descriptions() 128 | 129 | def get_path_structure(self) -> Dict[str, str]: 130 | if self.split not in COCO_PATH_STRUCTURE: 131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]') 132 | return COCO_PATH_STRUCTURE[self.split] 133 | 134 | def get_image_path(self, image_id: str) -> Path: 135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) 136 | 137 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 138 | # noinspection PyProtectedMember 139 | return self.image_descriptions[image_id]._asdict() 140 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/annotated_objects_open_images.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from csv import DictReader, reader as TupleReader 3 | from pathlib import Path 4 | from typing import Dict, List, Any 5 | import warnings 6 | 7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 8 | from taming.data.helper_types import Annotation, Category 9 | from tqdm import tqdm 10 | 11 | OPEN_IMAGES_STRUCTURE = { 12 | 'train': { 13 | 'top_level': '', 14 | 'class_descriptions': 'class-descriptions-boxable.csv', 15 | 'annotations': 'oidv6-train-annotations-bbox.csv', 16 | 'file_list': 'train-images-boxable.csv', 17 | 'files': 'train' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'class_descriptions': 'class-descriptions-boxable.csv', 22 | 'annotations': 'validation-annotations-bbox.csv', 23 | 'file_list': 'validation-images.csv', 24 | 'files': 'validation' 25 | }, 26 | 'test': { 27 | 'top_level': '', 28 | 'class_descriptions': 'class-descriptions-boxable.csv', 29 | 'annotations': 'test-annotations-bbox.csv', 30 | 'file_list': 'test-images.csv', 31 | 'files': 'test' 32 | } 33 | } 34 | 35 | 36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], 37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: 38 | annotations: Dict[str, List[Annotation]] = defaultdict(list) 39 | with open(descriptor_path) as file: 40 | reader = DictReader(file) 41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): 42 | width = float(row['XMax']) - float(row['XMin']) 43 | height = float(row['YMax']) - float(row['YMin']) 44 | area = width * height 45 | category_id = row['LabelName'] 46 | if category_id in category_mapping: 47 | category_id = category_mapping[category_id] 48 | if area >= min_object_area and category_id in category_no_for_id: 49 | annotations[row['ImageID']].append( 50 | Annotation( 51 | id=i, 52 | image_id=row['ImageID'], 53 | source=row['Source'], 54 | category_id=category_id, 55 | category_no=category_no_for_id[category_id], 56 | confidence=float(row['Confidence']), 57 | bbox=(float(row['XMin']), float(row['YMin']), width, height), 58 | area=area, 59 | is_occluded=bool(int(row['IsOccluded'])), 60 | is_truncated=bool(int(row['IsTruncated'])), 61 | is_group_of=bool(int(row['IsGroupOf'])), 62 | is_depiction=bool(int(row['IsDepiction'])), 63 | is_inside=bool(int(row['IsInside'])) 64 | ) 65 | ) 66 | if 'train' in str(descriptor_path) and i < 14000000: 67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') 68 | return dict(annotations) 69 | 70 | 71 | def load_image_ids(csv_path: Path) -> List[str]: 72 | with open(csv_path) as file: 73 | reader = DictReader(file) 74 | return [row['image_name'] for row in reader] 75 | 76 | 77 | def load_categories(csv_path: Path) -> Dict[str, Category]: 78 | with open(csv_path) as file: 79 | reader = TupleReader(file) 80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} 81 | 82 | 83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): 84 | def __init__(self, use_additional_parameters: bool, **kwargs): 85 | """ 86 | @param data_path: is the path to the following folder structure: 87 | open_images/ 88 | │ oidv6-train-annotations-bbox.csv 89 | ├── class-descriptions-boxable.csv 90 | ├── oidv6-train-annotations-bbox.csv 91 | ├── test 92 | │ ├── 000026e7ee790996.jpg 93 | │ ├── 000062a39995e348.jpg 94 | │ └── ... 95 | ├── test-annotations-bbox.csv 96 | ├── test-images.csv 97 | ├── train 98 | │ ├── 000002b66c9c498e.jpg 99 | │ ├── 000002b97e5471a0.jpg 100 | │ └── ... 101 | ├── train-images-boxable.csv 102 | ├── validation 103 | │ ├── 0001eeaf4aed83f9.jpg 104 | │ ├── 0004886b7d043cfd.jpg 105 | │ └── ... 106 | ├── validation-annotations-bbox.csv 107 | └── validation-images.csv 108 | @param: split: one of 'train', 'validation' or 'test' 109 | @param: desired image size (returns square images) 110 | """ 111 | 112 | super().__init__(**kwargs) 113 | self.use_additional_parameters = use_additional_parameters 114 | 115 | self.categories = load_categories(self.paths['class_descriptions']) 116 | self.filter_categories() 117 | self.setup_category_id_and_number() 118 | 119 | self.image_descriptions = {} 120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, 121 | self.category_number) 122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, 123 | self.max_objects_per_image) 124 | self.image_ids = list(self.annotations.keys()) 125 | self.clean_up_annotations_and_image_descriptions() 126 | 127 | def get_path_structure(self) -> Dict[str, str]: 128 | if self.split not in OPEN_IMAGES_STRUCTURE: 129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') 130 | return OPEN_IMAGES_STRUCTURE[self.split] 131 | 132 | def get_image_path(self, image_id: str) -> Path: 133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') 134 | 135 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 136 | image_path = self.get_image_path(image_id) 137 | return {'file_path': str(image_path), 'file_name': image_path.name} 138 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | 9 | from taming.data.sflckr import SegmentationBase # for examples included in repo 10 | 11 | 12 | class Examples(SegmentationBase): 13 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 14 | super().__init__(data_csv="data/coco_examples.txt", 15 | data_root="data/coco_images", 16 | segmentation_root="data/coco_segmentations", 17 | size=size, random_crop=random_crop, 18 | interpolation=interpolation, 19 | n_labels=183, shift_segmentation=True) 20 | 21 | 22 | class CocoBase(Dataset): 23 | """needed for (image, caption, segmentation) pairs""" 24 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 25 | crop_size=None, force_no_crop=False, given_files=None): 26 | self.split = self.get_split() 27 | self.size = size 28 | if crop_size is None: 29 | self.crop_size = size 30 | else: 31 | self.crop_size = crop_size 32 | 33 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 34 | self.stuffthing = use_stuffthing # include thing in segmentation 35 | if self.onehot and not self.stuffthing: 36 | raise NotImplemented("One hot mode is only supported for the " 37 | "stuffthings version because labels are stored " 38 | "a bit different.") 39 | 40 | data_json = datajson 41 | with open(data_json) as json_file: 42 | self.json_data = json.load(json_file) 43 | self.img_id_to_captions = dict() 44 | self.img_id_to_filepath = dict() 45 | self.img_id_to_segmentation_filepath = dict() 46 | 47 | assert data_json.split("/")[-1] in ["captions_train2017.json", 48 | "captions_val2017.json"] 49 | if self.stuffthing: 50 | self.segmentation_prefix = ( 51 | "data/cocostuffthings/val2017" if 52 | data_json.endswith("captions_val2017.json") else 53 | "data/cocostuffthings/train2017") 54 | else: 55 | self.segmentation_prefix = ( 56 | "data/coco/annotations/stuff_val2017_pixelmaps" if 57 | data_json.endswith("captions_val2017.json") else 58 | "data/coco/annotations/stuff_train2017_pixelmaps") 59 | 60 | imagedirs = self.json_data["images"] 61 | self.labels = {"image_ids": list()} 62 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 63 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 64 | self.img_id_to_captions[imgdir["id"]] = list() 65 | pngfilename = imgdir["file_name"].replace("jpg", "png") 66 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 67 | self.segmentation_prefix, pngfilename) 68 | if given_files is not None: 69 | if pngfilename in given_files: 70 | self.labels["image_ids"].append(imgdir["id"]) 71 | else: 72 | self.labels["image_ids"].append(imgdir["id"]) 73 | 74 | capdirs = self.json_data["annotations"] 75 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 76 | # there are in average 5 captions per image 77 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 78 | 79 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 80 | if self.split=="validation": 81 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 82 | else: 83 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 84 | self.preprocessor = albumentations.Compose( 85 | [self.rescaler, self.cropper], 86 | additional_targets={"segmentation": "image"}) 87 | if force_no_crop: 88 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 89 | self.preprocessor = albumentations.Compose( 90 | [self.rescaler], 91 | additional_targets={"segmentation": "image"}) 92 | 93 | def __len__(self): 94 | return len(self.labels["image_ids"]) 95 | 96 | def preprocess_image(self, image_path, segmentation_path): 97 | image = Image.open(image_path) 98 | if not image.mode == "RGB": 99 | image = image.convert("RGB") 100 | image = np.array(image).astype(np.uint8) 101 | 102 | segmentation = Image.open(segmentation_path) 103 | if not self.onehot and not segmentation.mode == "RGB": 104 | segmentation = segmentation.convert("RGB") 105 | segmentation = np.array(segmentation).astype(np.uint8) 106 | if self.onehot: 107 | assert self.stuffthing 108 | # stored in caffe format: unlabeled==255. stuff and thing from 109 | # 0-181. to be compatible with the labels in 110 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 111 | # we shift stuffthing one to the right and put unlabeled in zero 112 | # as long as segmentation is uint8 shifting to right handles the 113 | # latter too 114 | assert segmentation.dtype == np.uint8 115 | segmentation = segmentation + 1 116 | 117 | processed = self.preprocessor(image=image, segmentation=segmentation) 118 | image, segmentation = processed["image"], processed["segmentation"] 119 | image = (image / 127.5 - 1.0).astype(np.float32) 120 | 121 | if self.onehot: 122 | assert segmentation.dtype == np.uint8 123 | # make it one hot 124 | n_labels = 183 125 | flatseg = np.ravel(segmentation) 126 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 127 | onehot[np.arange(flatseg.size), flatseg] = True 128 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 129 | segmentation = onehot 130 | else: 131 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 132 | return image, segmentation 133 | 134 | def __getitem__(self, i): 135 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 136 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 137 | image, segmentation = self.preprocess_image(img_path, seg_path) 138 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 139 | # randomly draw one of all available captions per image 140 | caption = captions[np.random.randint(0, len(captions))] 141 | example = {"image": image, 142 | "caption": [str(caption[0])], 143 | "segmentation": segmentation, 144 | "img_path": img_path, 145 | "seg_path": seg_path, 146 | "filename_": img_path.split(os.sep)[-1] 147 | } 148 | return example 149 | 150 | 151 | class CocoImagesAndCaptionsTrain(CocoBase): 152 | """returns a pair of (image, caption)""" 153 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 154 | super().__init__(size=size, 155 | dataroot="data/coco/train2017", 156 | datajson="data/coco/annotations/captions_train2017.json", 157 | onehot_segmentation=onehot_segmentation, 158 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 159 | 160 | def get_split(self): 161 | return "train" 162 | 163 | 164 | class CocoImagesAndCaptionsValidation(CocoBase): 165 | """returns a pair of (image, caption)""" 166 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 167 | given_files=None): 168 | super().__init__(size=size, 169 | dataroot="data/coco/val2017", 170 | datajson="data/coco/annotations/captions_val2017.json", 171 | onehot_segmentation=onehot_segmentation, 172 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 173 | given_files=given_files) 174 | 175 | def get_split(self): 176 | return "validation" 177 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/conditional_builder/objects_center_points.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import warnings 4 | from itertools import cycle 5 | from typing import List, Optional, Tuple, Callable 6 | 7 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 8 | from more_itertools.recipes import grouper 9 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ 10 | additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ 11 | absolute_bbox, rescale_annotations 12 | from taming.data.helper_types import BoundingBox, Annotation 13 | from taming.data.image_transforms import convert_pil_to_tensor 14 | from torch import LongTensor, Tensor 15 | 16 | 17 | class ObjectsCenterPointsConditionalBuilder: 18 | def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, 19 | use_group_parameter: bool, use_additional_parameters: bool): 20 | self.no_object_classes = no_object_classes 21 | self.no_max_objects = no_max_objects 22 | self.no_tokens = no_tokens 23 | self.encode_crop = encode_crop 24 | self.no_sections = int(math.sqrt(self.no_tokens)) 25 | self.use_group_parameter = use_group_parameter 26 | self.use_additional_parameters = use_additional_parameters 27 | 28 | @property 29 | def none(self) -> int: 30 | return self.no_tokens - 1 31 | 32 | @property 33 | def object_descriptor_length(self) -> int: 34 | return 2 35 | 36 | @property 37 | def embedding_dim(self) -> int: 38 | extra_length = 2 if self.encode_crop else 0 39 | return self.no_max_objects * self.object_descriptor_length + extra_length 40 | 41 | def tokenize_coordinates(self, x: float, y: float) -> int: 42 | """ 43 | Express 2d coordinates with one number. 44 | Example: assume self.no_tokens = 16, then no_sections = 4: 45 | 0 0 0 0 46 | 0 0 # 0 47 | 0 0 0 0 48 | 0 0 0 x 49 | Then the # position corresponds to token 6, the x position to token 15. 50 | @param x: float in [0, 1] 51 | @param y: float in [0, 1] 52 | @return: discrete tokenized coordinate 53 | """ 54 | x_discrete = int(round(x * (self.no_sections - 1))) 55 | y_discrete = int(round(y * (self.no_sections - 1))) 56 | return y_discrete * self.no_sections + x_discrete 57 | 58 | def coordinates_from_token(self, token: int) -> (float, float): 59 | x = token % self.no_sections 60 | y = token // self.no_sections 61 | return x / (self.no_sections - 1), y / (self.no_sections - 1) 62 | 63 | def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: 64 | x0, y0 = self.coordinates_from_token(token1) 65 | x1, y1 = self.coordinates_from_token(token2) 66 | return x0, y0, x1 - x0, y1 - y0 67 | 68 | def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: 69 | return self.tokenize_coordinates(bbox[0], bbox[1]), \ 70 | self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) 71 | 72 | def inverse_build(self, conditional: LongTensor) \ 73 | -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: 74 | conditional_list = conditional.tolist() 75 | crop_coordinates = None 76 | if self.encode_crop: 77 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 78 | conditional_list = conditional_list[:-2] 79 | table_of_content = grouper(conditional_list, self.object_descriptor_length) 80 | assert conditional.shape[0] == self.embedding_dim 81 | return [ 82 | (object_tuple[0], self.coordinates_from_token(object_tuple[1])) 83 | for object_tuple in table_of_content if object_tuple[0] != self.none 84 | ], crop_coordinates 85 | 86 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 87 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 88 | plot = pil_image.new('RGB', figure_size, WHITE) 89 | draw = pil_img_draw.Draw(plot) 90 | circle_size = get_circle_size(figure_size) 91 | font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', 92 | size=get_plot_font_size(font_size, figure_size)) 93 | width, height = plot.size 94 | description, crop_coordinates = self.inverse_build(conditional) 95 | for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): 96 | x_abs, y_abs = x * width, y * height 97 | ann = self.representation_to_annotation(representation) 98 | label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) 99 | ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] 100 | draw.ellipse(ellipse_bbox, fill=color, width=0) 101 | draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) 102 | if crop_coordinates is not None: 103 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 104 | return convert_pil_to_tensor(plot) / 127.5 - 1. 105 | 106 | def object_representation(self, annotation: Annotation) -> int: 107 | modifier = 0 108 | if self.use_group_parameter: 109 | modifier |= 1 * (annotation.is_group_of is True) 110 | if self.use_additional_parameters: 111 | modifier |= 2 * (annotation.is_occluded is True) 112 | modifier |= 4 * (annotation.is_depiction is True) 113 | modifier |= 8 * (annotation.is_inside is True) 114 | return annotation.category_no + self.no_object_classes * modifier 115 | 116 | def representation_to_annotation(self, representation: int) -> Annotation: 117 | category_no = representation % self.no_object_classes 118 | modifier = representation // self.no_object_classes 119 | # noinspection PyTypeChecker 120 | return Annotation( 121 | area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, 122 | category_no=category_no, 123 | is_group_of=bool((modifier & 1) * self.use_group_parameter), 124 | is_occluded=bool((modifier & 2) * self.use_additional_parameters), 125 | is_depiction=bool((modifier & 4) * self.use_additional_parameters), 126 | is_inside=bool((modifier & 8) * self.use_additional_parameters) 127 | ) 128 | 129 | def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: 130 | return list(self.token_pair_from_bbox(crop_coordinates)) 131 | 132 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 133 | object_tuples = [ 134 | (self.object_representation(a), 135 | self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) 136 | for a in annotations 137 | ] 138 | empty_tuple = (self.none, self.none) 139 | object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) 140 | return object_tuples 141 | 142 | def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ 143 | -> LongTensor: 144 | if len(annotations) == 0: 145 | warnings.warn('Did not receive any annotations.') 146 | if len(annotations) > self.no_max_objects: 147 | warnings.warn('Received more annotations than allowed.') 148 | annotations = annotations[:self.no_max_objects] 149 | 150 | if not crop_coordinates: 151 | crop_coordinates = FULL_CROP 152 | 153 | random.shuffle(annotations) 154 | annotations = filter_annotations(annotations, crop_coordinates) 155 | if self.encode_crop: 156 | annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) 157 | if horizontal_flip: 158 | crop_coordinates = horizontally_flip_bbox(crop_coordinates) 159 | extra = self._crop_encoder(crop_coordinates) 160 | else: 161 | annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) 162 | extra = [] 163 | 164 | object_tuples = self._make_object_descriptors(annotations) 165 | flattened = [token for tuple_ in object_tuples for token in tuple_] + extra 166 | assert len(flattened) == self.embedding_dim 167 | assert all(0 <= value < self.no_tokens for value in flattened) 168 | return LongTensor(flattened) 169 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/modules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/discriminator/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/modules/discriminator/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/modules/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/__pycache__/lpips.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/modules/losses/__pycache__/lpips.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/transformer/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractPermuter(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | def forward(self, x, reverse=False): 10 | raise NotImplementedError 11 | 12 | 13 | class Identity(AbstractPermuter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, reverse=False): 18 | return x 19 | 20 | 21 | class Subsample(AbstractPermuter): 22 | def __init__(self, H, W): 23 | super().__init__() 24 | C = 1 25 | indices = np.arange(H*W).reshape(C,H,W) 26 | while min(H, W) > 1: 27 | indices = indices.reshape(C,H//2,2,W//2,2) 28 | indices = indices.transpose(0,2,4,1,3) 29 | indices = indices.reshape(C*4,H//2, W//2) 30 | H = H//2 31 | W = W//2 32 | C = C*4 33 | assert H == W == 1 34 | idx = torch.tensor(indices.ravel()) 35 | self.register_buffer('forward_shuffle_idx', 36 | nn.Parameter(idx, requires_grad=False)) 37 | self.register_buffer('backward_shuffle_idx', 38 | nn.Parameter(torch.argsort(idx), requires_grad=False)) 39 | 40 | def forward(self, x, reverse=False): 41 | if not reverse: 42 | return x[:, self.forward_shuffle_idx] 43 | else: 44 | return x[:, self.backward_shuffle_idx] 45 | 46 | 47 | def mortonify(i, j): 48 | """(i,j) index to linear morton code""" 49 | i = np.uint64(i) 50 | j = np.uint64(j) 51 | 52 | z = np.uint(0) 53 | 54 | for pos in range(32): 55 | z = (z | 56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | 57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) 58 | ) 59 | return z 60 | 61 | 62 | class ZCurve(AbstractPermuter): 63 | def __init__(self, H, W): 64 | super().__init__() 65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] 66 | idx = np.argsort(reverseidx) 67 | idx = torch.tensor(idx) 68 | reverseidx = torch.tensor(reverseidx) 69 | self.register_buffer('forward_shuffle_idx', 70 | idx) 71 | self.register_buffer('backward_shuffle_idx', 72 | reverseidx) 73 | 74 | def forward(self, x, reverse=False): 75 | if not reverse: 76 | return x[:, self.forward_shuffle_idx] 77 | else: 78 | return x[:, self.backward_shuffle_idx] 79 | 80 | 81 | class SpiralOut(AbstractPermuter): 82 | def __init__(self, H, W): 83 | super().__init__() 84 | assert H == W 85 | size = W 86 | indices = np.arange(size*size).reshape(size,size) 87 | 88 | i0 = size//2 89 | j0 = size//2-1 90 | 91 | i = i0 92 | j = j0 93 | 94 | idx = [indices[i0, j0]] 95 | step_mult = 0 96 | for c in range(1, size//2+1): 97 | step_mult += 1 98 | # steps left 99 | for k in range(step_mult): 100 | i = i - 1 101 | j = j 102 | idx.append(indices[i, j]) 103 | 104 | # step down 105 | for k in range(step_mult): 106 | i = i 107 | j = j + 1 108 | idx.append(indices[i, j]) 109 | 110 | step_mult += 1 111 | if c < size//2: 112 | # step right 113 | for k in range(step_mult): 114 | i = i + 1 115 | j = j 116 | idx.append(indices[i, j]) 117 | 118 | # step up 119 | for k in range(step_mult): 120 | i = i 121 | j = j - 1 122 | idx.append(indices[i, j]) 123 | else: 124 | # end reached 125 | for k in range(step_mult-1): 126 | i = i + 1 127 | idx.append(indices[i, j]) 128 | 129 | assert len(idx) == size*size 130 | idx = torch.tensor(idx) 131 | self.register_buffer('forward_shuffle_idx', idx) 132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 133 | 134 | def forward(self, x, reverse=False): 135 | if not reverse: 136 | return x[:, self.forward_shuffle_idx] 137 | else: 138 | return x[:, self.backward_shuffle_idx] 139 | 140 | 141 | class SpiralIn(AbstractPermuter): 142 | def __init__(self, H, W): 143 | super().__init__() 144 | assert H == W 145 | size = W 146 | indices = np.arange(size*size).reshape(size,size) 147 | 148 | i0 = size//2 149 | j0 = size//2-1 150 | 151 | i = i0 152 | j = j0 153 | 154 | idx = [indices[i0, j0]] 155 | step_mult = 0 156 | for c in range(1, size//2+1): 157 | step_mult += 1 158 | # steps left 159 | for k in range(step_mult): 160 | i = i - 1 161 | j = j 162 | idx.append(indices[i, j]) 163 | 164 | # step down 165 | for k in range(step_mult): 166 | i = i 167 | j = j + 1 168 | idx.append(indices[i, j]) 169 | 170 | step_mult += 1 171 | if c < size//2: 172 | # step right 173 | for k in range(step_mult): 174 | i = i + 1 175 | j = j 176 | idx.append(indices[i, j]) 177 | 178 | # step up 179 | for k in range(step_mult): 180 | i = i 181 | j = j - 1 182 | idx.append(indices[i, j]) 183 | else: 184 | # end reached 185 | for k in range(step_mult-1): 186 | i = i + 1 187 | idx.append(indices[i, j]) 188 | 189 | assert len(idx) == size*size 190 | idx = idx[::-1] 191 | idx = torch.tensor(idx) 192 | self.register_buffer('forward_shuffle_idx', idx) 193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 194 | 195 | def forward(self, x, reverse=False): 196 | if not reverse: 197 | return x[:, self.forward_shuffle_idx] 198 | else: 199 | return x[:, self.backward_shuffle_idx] 200 | 201 | 202 | class Random(nn.Module): 203 | def __init__(self, H, W): 204 | super().__init__() 205 | indices = np.random.RandomState(1).permutation(H*W) 206 | idx = torch.tensor(indices.ravel()) 207 | self.register_buffer('forward_shuffle_idx', idx) 208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 209 | 210 | def forward(self, x, reverse=False): 211 | if not reverse: 212 | return x[:, self.forward_shuffle_idx] 213 | else: 214 | return x[:, self.backward_shuffle_idx] 215 | 216 | 217 | class AlternateParsing(AbstractPermuter): 218 | def __init__(self, H, W): 219 | super().__init__() 220 | indices = np.arange(W*H).reshape(H,W) 221 | for i in range(1, H, 2): 222 | indices[i, :] = indices[i, ::-1] 223 | idx = indices.flatten() 224 | assert len(idx) == H*W 225 | idx = torch.tensor(idx) 226 | self.register_buffer('forward_shuffle_idx', idx) 227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 228 | 229 | def forward(self, x, reverse=False): 230 | if not reverse: 231 | return x[:, self.forward_shuffle_idx] 232 | else: 233 | return x[:, self.backward_shuffle_idx] 234 | 235 | 236 | if __name__ == "__main__": 237 | p0 = AlternateParsing(16, 16) 238 | print(p0.forward_shuffle_idx) 239 | print(p0.backward_shuffle_idx) 240 | 241 | x = torch.randint(0, 768, size=(11, 256)) 242 | y = p0(x) 243 | xre = p0(y, reverse=True) 244 | assert torch.equal(x, xre) 245 | 246 | p1 = SpiralOut(2, 2) 247 | print(p1.forward_shuffle_idx) 248 | print(p1.backward_shuffle_idx) 249 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /src/taming-transformers/taming/modules/vqvae/__pycache__/quantize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ControlCom-Image-Composition/61b6a29f19da811e2687a073a159ad955317343b/src/taming-transformers/taming/modules/vqvae/__pycache__/quantize.cpython-38.pyc -------------------------------------------------------------------------------- /src/taming-transformers/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python scripts/inference.py \ 4 | --task harmonization \ 5 | --outdir results \ 6 | --testdir examples \ 7 | --num_samples 1 \ 8 | --sample_steps 50 \ 9 | --gpu 0 10 | 11 | python scripts/inference.py \ 12 | --task composition \ 13 | --outdir results \ 14 | --testdir examples \ 15 | --num_samples 1 \ 16 | --sample_steps 25 \ 17 | --plms \ 18 | --gpu 0 --------------------------------------------------------------------------------