├── LICENSE ├── README.md ├── assets ├── example-01.png ├── example-02.png ├── example-03.png ├── example-04.png ├── example-05.png ├── example-06.png ├── seed.jpg └── teaser.png ├── ckpt └── download.sh ├── configs └── stable-diffusion │ ├── v2-inference copy.yaml │ ├── v2-inference-v.yaml │ ├── v2-inference.yaml │ ├── v2-inpainting-inference.yaml │ ├── v2-midas-inference.yaml │ └── x4-upscaling.yaml ├── environment.yaml ├── examples ├── a photo of a black sloth plushie │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a pixel cartoon │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a polygonal illustration duck toy │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a purple plushie toy │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a red clock │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a robot monster toy │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a sand statue plushie │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a sculpture golden retriever │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── a photo of a silver backpack │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png └── a photo of a wooden sculpture Chow Chow │ ├── scene.png │ ├── selected_mask.png │ ├── sub_mask.png │ └── subject.png ├── ldm ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── modules.cpython-38.pyc │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── ptp_scripts ├── ptp_scripts.py └── ptp_utils_ori.py ├── run.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 PENGZHI LI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tuning-Free Image Customization with Image and Text Guidance (ECCV 2024) 2 | 3 | [![Paper](https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=b31b1b)](https://arxiv.org/abs/2403.12658) 4 | [![Project Page](https://img.shields.io/badge/Project-Website-5B7493?logo=googlechrome&logoColor=5B7493)](https://zrealli.github.io/TIGIC) 5 | 6 | 7 | 8 | This repository contains the official implementation of the following paper: 9 | > **Tuning-Free Image Customization with Image and Text Guidance**
10 | 11 |
12 |

13 | 14 |

15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 | ## :open_book: Overview 23 | Current image customization methods with diffusion models have limitations, such as unintended changes, reliance on either reference images or text, and time-consuming fine-tuning. We introduce a tuning-free framework **TIGIC** for text-image-guided image customization, allowing precise, region-specific edits within seconds. Our approach preserves the reference image's semantic features while modifying details based on text descriptions, using an innovative attention blending strategy in the UNet decoder. This method is applicable to image synthesis, design, and creative photography. 24 | 25 | 26 | 27 | ## :hammer: Quick Start 28 | 29 | ``` 30 | git clone https://github.com/zrealli/TIGIC.git 31 | cd TIGIC 32 | ``` 33 | ### 1. Prepare Environment 34 | To set up our environment, please follow these instructions: 35 | ``` 36 | conda env create -f environment.yaml 37 | conda activate TIGIC 38 | ``` 39 | Please note that our project requires 24GB of memory to run. 40 | 41 | ### 2. Download Checkpoints 42 | Next, download the [Stable Diffusion weights](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt) and put it to ./ckpt. 43 | ``` 44 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt 45 | ``` 46 | 47 | ### 4. Inference with TIGIC 48 | After downloading the base model, to execute user inference, use the following command: 49 | 50 | 51 | ``` 52 | python run.py --ckpt ./ckpt/v2-1_512-ema-pruned.ckpt --dpm_steps 20 --outdir ./outputs --root ./examples --scale 5 --seed 42 53 | ``` 54 | The generation results can be found in the ./outputs directory. Here we show some examples below. 55 | 56 | 57 | ## :framed_picture: Generation Results 58 | From left to right, the images are the **background, subject, collage, and the result** generated by TIGIC. 59 | 60 | *'a photo of a pixel cartoon'* 61 |
62 |

63 | 64 |

65 |
66 | 67 | *'a photo of a black sloth plushie'* 68 |
69 |

70 | 71 |

72 |
73 | 74 | *'a photo of a polygonal illustration duck toy'* 75 |
76 |

77 | 78 |

79 |
80 | 81 | *'a photo of a purple plushie toy'* 82 |
83 |

84 | 85 |

86 |
87 | 88 | *'a photo of a sculpture golden retriever'* 89 |
90 |

91 | 92 |

93 |
94 | 95 | *'a photo of a red clock'* 96 |
97 |

98 | 99 |

100 |
101 | 102 | :fountain_pen: Please note that if the generated results contain significant artifacts, adjust the random seed `--seed` to obtain the desired outcome. We demonstrate that different random seeds can produce varying image quality in the demo. 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | *'a photo of a black sloth plushie'* 113 |
114 |

115 | 116 |

117 |
118 | 119 | Please refer to our [project page](https://zrealli.github.io/TIGIC) for more visual comparisons. 120 | 121 | ### 🤗 Gradio Demo 122 | We will soon release a Gradio demo version, which will include an integrated automatic foreground subject segmentation module. 123 | 124 | ## :four_leaf_clover: Acknowledgments 125 | This project is distributed under the MIT License. Our work builds upon the foundation laid by others. We thank the contributions of the following projects that our code is based on [TF-ICON](https://github.com/Shilin-LU/TF-ICON) and [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt). 126 | 127 | ## :fountain_pen: Citation 128 | 129 | If you find our repo useful for your research, please consider citing our paper: 130 | 131 | ```bibtex 132 | @inproceedings{li2024tuning, 133 | title={Tuning-Free Image Customization with Image and Text Guidance}, 134 | author={Li, Pengzhi and Nie, Qiang and Chen, Ying and Jiang, Xi and Wu, Kai and Lin, Yuhuan and Liu, Yong and Peng, Jinlong and Wang, Chengjie and Zheng, Feng}, 135 | booktitle={European Conference on Computer Vision}, 136 | year={2024} 137 | } 138 | 139 | ``` 140 | 141 | -------------------------------------------------------------------------------- /assets/example-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-01.png -------------------------------------------------------------------------------- /assets/example-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-02.png -------------------------------------------------------------------------------- /assets/example-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-03.png -------------------------------------------------------------------------------- /assets/example-04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-04.png -------------------------------------------------------------------------------- /assets/example-05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-05.png -------------------------------------------------------------------------------- /assets/example-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-06.png -------------------------------------------------------------------------------- /assets/seed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/seed.jpg -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/teaser.png -------------------------------------------------------------------------------- /ckpt/download.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference copy.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False # we set this to false because this is an inference only config 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | use_checkpoint: True 24 | use_fp16: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference-v.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False # we set this to false because this is an inference only config 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | use_fp16: True 26 | image_size: 32 # unused 27 | in_channels: 4 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: [ 4, 2, 1 ] 31 | num_res_blocks: 2 32 | channel_mult: [ 1, 2, 4, 4 ] 33 | num_head_channels: 64 # need to fix for flash-attn 34 | use_spatial_transformer: True 35 | use_linear_in_transformer: True 36 | transformer_depth: 1 37 | context_dim: 1024 38 | legacy: False 39 | 40 | first_stage_config: 41 | target: ldm.models.autoencoder.AutoencoderKL 42 | params: 43 | embed_dim: 4 44 | monitor: val/rec_loss 45 | ddconfig: 46 | #attn_type: "vanilla-xformers" 47 | double_z: true 48 | z_channels: 4 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: [] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 66 | params: 67 | freeze: True 68 | layer: "penultimate" 69 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False # we set this to false because this is an inference only config 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | use_checkpoint: True 24 | use_fp16: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | image_size: 32 # unused 26 | in_channels: 9 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | 69 | 70 | data: 71 | target: ldm.data.laion.WebDataModuleFromConfig 72 | params: 73 | tar_base: null # for concat as in LAION-A 74 | p_unsafe_threshold: 0.1 75 | filter_word_list: "data/filters.yaml" 76 | max_pwatermark: 0.45 77 | batch_size: 8 78 | num_workers: 6 79 | multinode: True 80 | min_size: 512 81 | train: 82 | shards: 83 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" 84 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" 85 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" 86 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" 87 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" 88 | shuffle: 10000 89 | image_key: jpg 90 | image_transforms: 91 | - target: torchvision.transforms.Resize 92 | params: 93 | size: 512 94 | interpolation: 3 95 | - target: torchvision.transforms.RandomCrop 96 | params: 97 | size: 512 98 | postprocess: 99 | target: ldm.data.laion.AddMask 100 | params: 101 | mode: "512train-large" 102 | p_drop: 0.25 103 | # NOTE use enough shards to avoid empty validation loops in workers 104 | validation: 105 | shards: 106 | - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " 107 | shuffle: 0 108 | image_key: jpg 109 | image_transforms: 110 | - target: torchvision.transforms.Resize 111 | params: 112 | size: 512 113 | interpolation: 3 114 | - target: torchvision.transforms.CenterCrop 115 | params: 116 | size: 512 117 | postprocess: 118 | target: ldm.data.laion.AddMask 119 | params: 120 | mode: "512train-large" 121 | p_drop: 0.25 122 | 123 | lightning: 124 | find_unused_parameters: True 125 | modelcheckpoint: 126 | params: 127 | every_n_train_steps: 5000 128 | 129 | callbacks: 130 | metrics_over_trainsteps_checkpoint: 131 | params: 132 | every_n_train_steps: 10000 133 | 134 | image_logger: 135 | target: main.ImageLogger 136 | params: 137 | enable_autocast: False 138 | disabled: False 139 | batch_frequency: 1000 140 | max_images: 4 141 | increase_log_steps: False 142 | log_first_step: False 143 | log_images_kwargs: 144 | use_ema_scope: False 145 | inpaint: False 146 | plot_progressive_rows: False 147 | plot_diffusion_rows: False 148 | N: 4 149 | unconditional_guidance_scale: 5.0 150 | unconditional_guidance_label: [""] 151 | ddim_steps: 50 # todo check these out for depth2img, 152 | ddim_eta: 0.0 # todo check these out for depth2img, 153 | 154 | trainer: 155 | benchmark: True 156 | val_check_interval: 5000000 157 | num_sanity_val_steps: 0 158 | accumulate_grad_batches: 1 159 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-midas-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-07 3 | target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | depth_stage_config: 22 | target: ldm.modules.midas.api.MiDaSInference 23 | params: 24 | model_type: "dpt_hybrid" 25 | 26 | unet_config: 27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 5 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | #attn_type: "vanilla-xformers" 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [ ] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 70 | params: 71 | freeze: True 72 | layer: "penultimate" 73 | 74 | 75 | -------------------------------------------------------------------------------- /configs/stable-diffusion/x4-upscaling.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion 4 | params: 5 | parameterization: "v" 6 | low_scale_key: "lr" 7 | linear_start: 0.0001 8 | linear_end: 0.02 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: "jpg" 13 | cond_stage_key: "txt" 14 | image_size: 128 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: "hybrid-adm" 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.08333 20 | use_ema: False 21 | 22 | low_scale_config: 23 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation 24 | params: 25 | noise_schedule_config: # image space 26 | linear_start: 0.0001 27 | linear_end: 0.02 28 | max_noise_level: 350 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | use_checkpoint: True 34 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) 35 | image_size: 128 36 | in_channels: 7 37 | out_channels: 4 38 | model_channels: 256 39 | attention_resolutions: [ 2,4,8] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 2, 4] 42 | disable_self_attentions: [True, True, True, False] 43 | disable_middle_self_attn: False 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 1024 48 | legacy: False 49 | use_linear_in_transformer: True 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | ddconfig: 56 | # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) 57 | double_z: True 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 64 | num_res_blocks: 2 65 | attn_resolutions: [ ] 66 | dropout: 0.0 67 | 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 73 | params: 74 | freeze: True 75 | layer: "penultimate" 76 | 77 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: TIGIC 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - albumentations==1.3.0 14 | - opencv-python==4.6.0.66 15 | - imageio==2.9.0 16 | - imageio-ffmpeg==0.4.2 17 | - pytorch-lightning==1.4.2 18 | - omegaconf==2.1.1 19 | - test-tube>=0.7.5 20 | - streamlit==1.12.1 21 | - einops==0.3.0 22 | - transformers==4.19.2 23 | - webdataset==0.2.5 24 | - kornia==0.6 25 | - open_clip_torch==2.0.2 26 | - invisible-watermark>=0.1.5 27 | - streamlit-drawable-canvas==0.8.0 28 | - torchmetrics==0.6.0 29 | - diffusers==0.12.1 30 | - ipykernel 31 | - matplotlib 32 | - -e . 33 | 34 | 35 | -------------------------------------------------------------------------------- /examples/a photo of a black sloth plushie/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/scene.png -------------------------------------------------------------------------------- /examples/a photo of a black sloth plushie/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a black sloth plushie/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a black sloth plushie/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/subject.png -------------------------------------------------------------------------------- /examples/a photo of a pixel cartoon/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/scene.png -------------------------------------------------------------------------------- /examples/a photo of a pixel cartoon/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a pixel cartoon/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a pixel cartoon/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/subject.png -------------------------------------------------------------------------------- /examples/a photo of a polygonal illustration duck toy/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/scene.png -------------------------------------------------------------------------------- /examples/a photo of a polygonal illustration duck toy/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a polygonal illustration duck toy/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a polygonal illustration duck toy/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/subject.png -------------------------------------------------------------------------------- /examples/a photo of a purple plushie toy/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/scene.png -------------------------------------------------------------------------------- /examples/a photo of a purple plushie toy/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a purple plushie toy/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a purple plushie toy/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/subject.png -------------------------------------------------------------------------------- /examples/a photo of a red clock/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/scene.png -------------------------------------------------------------------------------- /examples/a photo of a red clock/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a red clock/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a red clock/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/subject.png -------------------------------------------------------------------------------- /examples/a photo of a robot monster toy/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/scene.png -------------------------------------------------------------------------------- /examples/a photo of a robot monster toy/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a robot monster toy/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a robot monster toy/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/subject.png -------------------------------------------------------------------------------- /examples/a photo of a sand statue plushie/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/scene.png -------------------------------------------------------------------------------- /examples/a photo of a sand statue plushie/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a sand statue plushie/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a sand statue plushie/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/subject.png -------------------------------------------------------------------------------- /examples/a photo of a sculpture golden retriever/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/scene.png -------------------------------------------------------------------------------- /examples/a photo of a sculpture golden retriever/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a sculpture golden retriever/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a sculpture golden retriever/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/subject.png -------------------------------------------------------------------------------- /examples/a photo of a silver backpack/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/scene.png -------------------------------------------------------------------------------- /examples/a photo of a silver backpack/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a silver backpack/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a silver backpack/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/subject.png -------------------------------------------------------------------------------- /examples/a photo of a wooden sculpture Chow Chow/scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/scene.png -------------------------------------------------------------------------------- /examples/a photo of a wooden sculpture Chow Chow/selected_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/selected_mask.png -------------------------------------------------------------------------------- /examples/a photo of a wooden sculpture Chow Chow/sub_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/sub_mask.png -------------------------------------------------------------------------------- /examples/a photo of a wooden sculpture Chow Chow/subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/subject.png -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | import ptp_scripts.ptp_scripts as ptp 4 | import sys 5 | sys.path.append('..') 6 | import ptp_scripts.ptp_utils_ori as ptp_utils_ori 7 | 8 | from ldm.models.diffusion.dpm_solver.dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 9 | 10 | from tqdm import tqdm 11 | 12 | MODEL_TYPES = { 13 | "eps": "noise", 14 | "v": "v" 15 | } 16 | 17 | 18 | class DPMSolverSampler(object): 19 | def __init__(self, model, **kwargs): 20 | super().__init__() 21 | self.model = model 22 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 23 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 24 | 25 | def register_buffer(self, name, attr): 26 | if type(attr) == torch.Tensor: 27 | if attr.device != self.model.device: 28 | attr = attr.to(self.model.device) 29 | setattr(self, name, attr) 30 | 31 | @torch.no_grad() 32 | def sample(self, 33 | steps, 34 | batch_size, 35 | shape, 36 | conditioning=None, 37 | conditioning_edit=None, 38 | inv_emb=None, 39 | inv_emb_edit=None, 40 | callback=None, 41 | normals_sequence=None, 42 | img_callback=None, 43 | quantize_x0=False, 44 | eta=0., 45 | mask=None, 46 | x0=None, 47 | temperature=1., 48 | noise_dropout=0., 49 | score_corrector=None, 50 | corrector_kwargs=None, 51 | verbose=True, 52 | x_T=None, 53 | log_every_t=100, 54 | unconditional_guidance_scale=1., 55 | unconditional_conditioning=None, 56 | unconditional_conditioning_edit=None, 57 | t_start=None, 58 | t_end=None, 59 | DPMencode=False, 60 | order=2, 61 | width=None, 62 | height=None, 63 | ref=False, 64 | param=None, 65 | tau_a=0.5, 66 | tau_b=0.8, 67 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 68 | **kwargs 69 | ): 70 | if conditioning is not None: 71 | if isinstance(conditioning, dict): 72 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 73 | if cbs != batch_size: 74 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 75 | else: 76 | if conditioning.shape[0] != batch_size: 77 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 78 | 79 | # sampling 80 | C, H, W = shape 81 | size = (batch_size, C, H, W) 82 | 83 | 84 | device = self.model.betas.device 85 | if x_T is None: 86 | x = torch.randn(size, device=device) 87 | else: 88 | x = x_T 89 | 90 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 91 | 92 | 93 | 94 | if DPMencode: 95 | # x_T is not a list 96 | model_fn = model_wrapper( 97 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject), 98 | ns, 99 | model_type=MODEL_TYPES[self.model.parameterization], 100 | guidance_type="classifier-free", 101 | condition=inv_emb, 102 | unconditional_condition=inv_emb, 103 | guidance_scale=unconditional_guidance_scale, 104 | ) 105 | 106 | 107 | dpm_solver = DPM_Solver(model_fn, ns) 108 | data, _ = dpm_solver.sample_lower(x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=DPMencode) 109 | 110 | for step in range(order, steps + 1): 111 | data = dpm_solver.sample_one_step(data, step, steps, order=order, DPMencode=DPMencode) 112 | 113 | return data['x'].to(device), None 114 | 115 | else: 116 | # x_T is a list 117 | 118 | model_fn_decode = model_wrapper( 119 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject), 120 | ns, 121 | model_type=MODEL_TYPES[self.model.parameterization], 122 | guidance_type="classifier-free", 123 | condition=inv_emb, 124 | unconditional_condition=inv_emb, 125 | guidance_scale=unconditional_guidance_scale, 126 | ) 127 | model_fn_gen = model_wrapper( 128 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject), 129 | ns, 130 | model_type=MODEL_TYPES[self.model.parameterization], 131 | guidance_type="classifier-free", 132 | condition=conditioning, 133 | unconditional_condition=unconditional_conditioning, 134 | guidance_scale=unconditional_guidance_scale, 135 | ) 136 | 137 | 138 | model_fn_decode_edit = model_wrapper( 139 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject), 140 | ns, 141 | model_type=MODEL_TYPES[self.model.parameterization], 142 | guidance_type="classifier-free", 143 | condition=inv_emb_edit, 144 | unconditional_condition=inv_emb_edit, 145 | guidance_scale=unconditional_guidance_scale, 146 | ) 147 | model_fn_gen_edit = model_wrapper( 148 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject), 149 | ns, 150 | model_type=MODEL_TYPES[self.model.parameterization], 151 | guidance_type="classifier-free", 152 | condition=conditioning_edit, 153 | unconditional_condition=unconditional_conditioning_edit, 154 | guidance_scale=unconditional_guidance_scale, 155 | ) 156 | 157 | orig_controller = ptp.AttentionStore() 158 | ref_controller = ptp.AttentionStore() 159 | gen_controller = ptp.AttentionStore() 160 | Inject_controller = ptp.AttentionStore() 161 | 162 | dpm_solver_decode = DPM_Solver(model_fn_decode, ns) 163 | dpm_solver_gen = DPM_Solver(model_fn_gen, ns) 164 | 165 | 166 | dpm_solver_decode_edit = DPM_Solver(model_fn_decode_edit, ns) 167 | 168 | dpm_solver_gen_edit = DPM_Solver(model_fn_gen_edit, ns) 169 | 170 | # decoded background 171 | 172 | ptp_utils_ori.register_attention_control(self.model, orig_controller) 173 | 174 | 175 | orig, orig_controller = dpm_solver_decode_edit.sample_lower(x[0], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=orig_controller) 176 | # decoded reference 177 | ptp_utils_ori.register_attention_control(self.model, ref_controller) 178 | ref, ref_controller = dpm_solver_decode_edit.sample_lower(x[3], dpm_solver_decode_edit, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=ref_controller) 179 | 180 | 181 | # generation 182 | Inject_controller = [orig_controller, ref_controller] 183 | ptp_utils_ori.register_attention_control(self.model, gen_controller, inject_bg=False) 184 | 185 | gen, _ = dpm_solver_decode_edit.sample_lower(x[1], dpm_solver_gen_edit, steps, order, t_start, t_end, device, 186 | DPMencode=DPMencode, controller=Inject_controller, inject=True) 187 | 188 | del orig_controller, ref_controller, gen_controller, Inject_controller 189 | 190 | orig_controller = ptp.AttentionStore() 191 | ref_controller = ptp.AttentionStore() 192 | gen_controller = ptp.AttentionStore() 193 | 194 | for step in range(order, 21): 195 | # decoded background 196 | ptp_utils_ori.register_attention_control(self.model, orig_controller) 197 | orig = dpm_solver_decode.sample_one_step(orig, step, steps, order=order, DPMencode=DPMencode) 198 | ptp_utils_ori.register_attention_control(self.model, ref_controller) 199 | ref = dpm_solver_decode_edit.sample_one_step(ref, step, steps, order=order, DPMencode=DPMencode) 200 | 201 | 202 | if step >= int(0.2*(steps) + 1 - order) and step <= int(0.5*(steps) + 1 - order): 203 | inject = True 204 | controller = [orig_controller, ref_controller] 205 | else: 206 | inject = False 207 | controller = [orig_controller, None] 208 | 209 | if step < int(0.5 * (steps) + 1 - order) and step > int(0.* (steps) + 1 - order) : 210 | inject_bg = True 211 | else: 212 | inject_bg = False 213 | 214 | 215 | ptp_utils_ori.register_attention_control(self.model, gen_controller, inject_bg=inject_bg) 216 | gen = dpm_solver_gen_edit.sample_one_step(gen, step, steps, order=order, DPMencode=DPMencode, controller=controller, inject=inject) 217 | 218 | if step < int(1.0*(steps) + 1 - order): 219 | blended = orig['x'].clone() 220 | blended[:, :, param[0] : param[1], param[2] : param[3]] \ 221 | = gen['x'][:, :, param[0] : param[1], param[2] : param[3]].clone() 222 | gen['x'] = blended.clone() 223 | 224 | 225 | return gen['x'].to(device), None 226 | 227 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 87 | if cbs != batch_size: 88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 89 | else: 90 | if conditioning.shape[0] != batch_size: 91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 92 | 93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 94 | # sampling 95 | C, H, W = shape 96 | size = (batch_size, C, H, W) 97 | print(f'Data shape for PLMS sampling is {size}') 98 | 99 | samples, intermediates = self.plms_sampling(conditioning, size, 100 | callback=callback, 101 | img_callback=img_callback, 102 | quantize_denoised=quantize_x0, 103 | mask=mask, x0=x0, 104 | ddim_use_original_steps=False, 105 | noise_dropout=noise_dropout, 106 | temperature=temperature, 107 | score_corrector=score_corrector, 108 | corrector_kwargs=corrector_kwargs, 109 | x_T=x_T, 110 | log_every_t=log_every_t, 111 | unconditional_guidance_scale=unconditional_guidance_scale, 112 | unconditional_conditioning=unconditional_conditioning, 113 | dynamic_threshold=dynamic_threshold, 114 | ) 115 | return samples, intermediates 116 | 117 | @torch.no_grad() 118 | def plms_sampling(self, cond, shape, 119 | x_T=None, ddim_use_original_steps=False, 120 | callback=None, timesteps=None, quantize_denoised=False, 121 | mask=None, x0=None, img_callback=None, log_every_t=100, 122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 123 | unconditional_guidance_scale=1., unconditional_conditioning=None, 124 | dynamic_threshold=None): 125 | device = self.model.betas.device 126 | b = shape[0] 127 | if x_T is None: 128 | img = torch.randn(shape, device=device) 129 | else: 130 | img = x_T 131 | 132 | if timesteps is None: 133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 134 | elif timesteps is not None and not ddim_use_original_steps: 135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 136 | timesteps = self.ddim_timesteps[:subset_end] 137 | 138 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 141 | print(f"Running PLMS Sampling with {total_steps} timesteps") 142 | 143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 144 | old_eps = [] 145 | 146 | for i, step in enumerate(iterator): 147 | index = total_steps - i - 1 148 | ts = torch.full((b,), step, device=device, dtype=torch.long) 149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 | 151 | if mask is not None: 152 | assert x0 is not None 153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 | img = img_orig * mask + (1. - mask) * img 155 | 156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 157 | quantize_denoised=quantize_denoised, temperature=temperature, 158 | noise_dropout=noise_dropout, score_corrector=score_corrector, 159 | corrector_kwargs=corrector_kwargs, 160 | unconditional_guidance_scale=unconditional_guidance_scale, 161 | unconditional_conditioning=unconditional_conditioning, 162 | old_eps=old_eps, t_next=ts_next, 163 | dynamic_threshold=dynamic_threshold) 164 | img, pred_x0, e_t = outs 165 | old_eps.append(e_t) 166 | if len(old_eps) >= 4: 167 | old_eps.pop(0) 168 | if callback: callback(i) 169 | if img_callback: img_callback(pred_x0, i) 170 | 171 | if index % log_every_t == 0 or index == total_steps - 1: 172 | intermediates['x_inter'].append(img) 173 | intermediates['pred_x0'].append(pred_x0) 174 | 175 | return img, intermediates 176 | 177 | @torch.no_grad() 178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 181 | dynamic_threshold=None): 182 | b, *_, device = *x.shape, x.device 183 | 184 | def get_model_output(x, t): 185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 186 | e_t = self.model.apply_model(x, t, c) 187 | else: 188 | x_in = torch.cat([x] * 2) 189 | t_in = torch.cat([t] * 2) 190 | c_in = torch.cat([unconditional_conditioning, c]) 191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 | 194 | if score_corrector is not None: 195 | assert self.model.parameterization == "eps" 196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 | 198 | return e_t 199 | 200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 204 | 205 | def get_x_prev_and_pred_x0(e_t, index): 206 | # select parameters corresponding to the currently considered timestep 207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 211 | 212 | # current prediction for x_0 213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 214 | if quantize_denoised: 215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 216 | if dynamic_threshold is not None: 217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 218 | # direction pointing to x_t 219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 221 | if noise_dropout > 0.: 222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 224 | return x_prev, pred_x0 225 | 226 | e_t = get_model_output(x, t) 227 | if len(old_eps) == 0: 228 | # Pseudo Improved Euler (2nd order) 229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 230 | e_t_next = get_model_output(x_prev, t_next) 231 | e_t_prime = (e_t + e_t_next) / 2 232 | elif len(old_eps) == 1: 233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 235 | elif len(old_eps) == 2: 236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 238 | elif len(old_eps) >= 3: 239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 241 | 242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 243 | 244 | return x_prev, pred_x0, e_t 245 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | from typing import Optional, Any 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from ldm.modules.diffusionmodules.util import checkpoint 11 | from PIL import Image 12 | 13 | try: 14 | import xformers 15 | import xformers.ops 16 | XFORMERS_IS_AVAILBLE = False 17 | except: 18 | XFORMERS_IS_AVAILBLE = False 19 | 20 | 21 | def exists(val): 22 | return val is not None 23 | 24 | 25 | def uniq(arr): 26 | return{el: True for el in arr}.keys() 27 | 28 | 29 | def default(val, d): 30 | if exists(val): 31 | return val 32 | return d() if isfunction(d) else d 33 | 34 | 35 | def max_neg_value(t): 36 | return -torch.finfo(t.dtype).max 37 | 38 | 39 | def init_(tensor): 40 | dim = tensor.shape[-1] 41 | std = 1 / math.sqrt(dim) 42 | tensor.uniform_(-std, std) 43 | return tensor 44 | 45 | 46 | # feedforward 47 | class GEGLU(nn.Module): 48 | def __init__(self, dim_in, dim_out): 49 | super().__init__() 50 | self.proj = nn.Linear(dim_in, dim_out * 2) 51 | 52 | def forward(self, x): 53 | x, gate = self.proj(x).chunk(2, dim=-1) 54 | return x * F.gelu(gate) 55 | 56 | 57 | class FeedForward(nn.Module): 58 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 59 | super().__init__() 60 | inner_dim = int(dim * mult) 61 | dim_out = default(dim_out, dim) 62 | project_in = nn.Sequential( 63 | nn.Linear(dim, inner_dim), 64 | nn.GELU() 65 | ) if not glu else GEGLU(dim, inner_dim) 66 | 67 | self.net = nn.Sequential( 68 | project_in, 69 | nn.Dropout(dropout), 70 | nn.Linear(inner_dim, dim_out) 71 | ) 72 | 73 | def forward(self, x): 74 | return self.net(x) 75 | 76 | 77 | def zero_module(module): 78 | """ 79 | Zero out the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().zero_() 83 | return module 84 | 85 | 86 | def Normalize(in_channels): 87 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 88 | 89 | 90 | class SpatialSelfAttention(nn.Module): 91 | def __init__(self, in_channels): 92 | super().__init__() 93 | self.in_channels = in_channels 94 | 95 | self.norm = Normalize(in_channels) 96 | self.q = torch.nn.Conv2d(in_channels, 97 | in_channels, 98 | kernel_size=1, 99 | stride=1, 100 | padding=0) 101 | self.k = torch.nn.Conv2d(in_channels, 102 | in_channels, 103 | kernel_size=1, 104 | stride=1, 105 | padding=0) 106 | self.v = torch.nn.Conv2d(in_channels, 107 | in_channels, 108 | kernel_size=1, 109 | stride=1, 110 | padding=0) 111 | self.proj_out = torch.nn.Conv2d(in_channels, 112 | in_channels, 113 | kernel_size=1, 114 | stride=1, 115 | padding=0) 116 | 117 | def forward(self, x): 118 | h_ = x 119 | h_ = self.norm(h_) 120 | q = self.q(h_) 121 | k = self.k(h_) 122 | v = self.v(h_) 123 | 124 | # compute attention 125 | b,c,h,w = q.shape 126 | q = rearrange(q, 'b c h w -> b (h w) c') 127 | k = rearrange(k, 'b c h w -> b c (h w)') 128 | w_ = torch.einsum('bij,bjk->bik', q, k) 129 | 130 | w_ = w_ * (int(c)**(-0.5)) 131 | w_ = torch.nn.functional.softmax(w_, dim=2) 132 | 133 | # attend to values 134 | v = rearrange(v, 'b c h w -> b c (h w)') 135 | w_ = rearrange(w_, 'b i j -> b j i') 136 | h_ = torch.einsum('bij,bjk->bik', v, w_) 137 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 138 | h_ = self.proj_out(h_) 139 | 140 | return x+h_ 141 | 142 | 143 | class CrossAttention(nn.Module): 144 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 145 | super().__init__() 146 | inner_dim = dim_head * heads 147 | context_dim = default(context_dim, query_dim) 148 | 149 | self.scale = dim_head ** -0.5 150 | self.heads = heads 151 | 152 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 153 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 154 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 155 | 156 | self.to_out = nn.Sequential( 157 | nn.Linear(inner_dim, query_dim), 158 | nn.Dropout(dropout) 159 | ) 160 | 161 | def forward(self, x, context=None, mask=None, encode=False, controller_for_inject=None, inject=False, layernum=None, main_height=None, main_width=None): 162 | h = self.heads 163 | 164 | q = self.to_q(x) 165 | context = default(context, x) 166 | k = self.to_k(context) 167 | v = self.to_v(context) 168 | 169 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 170 | 171 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 172 | del q, k 173 | 174 | if exists(mask): 175 | mask = rearrange(mask, 'b ... -> b (...)') 176 | max_neg_value = -torch.finfo(sim.dtype).max 177 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 178 | sim.masked_fill_(~mask, max_neg_value) 179 | 180 | # a = ((sim.mean(0).mean(1).resize(64,64)/torch.max(sim.max(), abs(sim.min())) + 1)*127.5).cpu().numpy().astype(np.uint8) 181 | # image = Image.fromarray(a) 182 | # image.resize((512,512)).save('2.jpg') 183 | 184 | # u, s, vh = np.linalg.svd(sim.mean(0).cpu().numpy().astype(np.float32) - np.mean(sim.mean(0).cpu().numpy().astype(np.float32), axis=1, keepdims=True)) 185 | # images = [] 186 | # for i in range(3): 187 | # image = u[:,i].reshape(64, 64) 188 | # image = image - image.min() 189 | # image = 255 * image / image.max() 190 | # image = np.expand_dims(image, axis=2).astype(np.uint8) 191 | # images.append(image) 192 | 193 | # final = np.dstack(images) 194 | # final = Image.fromarray(final).resize((256, 256)) 195 | # final = np.array(final) 196 | # import ptp_scripts.ptp_utils as ptp_utils 197 | # ptp_utils.view_images(final) 198 | 199 | # attention, what we cannot get enough of 200 | sim = sim.softmax(dim=-1) 201 | # sim = sim.sigmoid() 202 | 203 | out = einsum('b i j, b j d -> b i d', sim, v) 204 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 205 | return self.to_out(out) 206 | 207 | 208 | class MemoryEfficientCrossAttention(nn.Module): 209 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 210 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 211 | super().__init__() 212 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " 213 | f"{heads} heads.") 214 | inner_dim = dim_head * heads 215 | context_dim = default(context_dim, query_dim) 216 | 217 | self.heads = heads 218 | self.dim_head = dim_head 219 | 220 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 221 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 222 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 223 | 224 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 225 | self.attention_op: Optional[Any] = None 226 | 227 | def forward(self, x, context=None, mask=None): 228 | q = self.to_q(x) 229 | context = default(context, x) 230 | k = self.to_k(context) 231 | v = self.to_v(context) 232 | 233 | b, _, _ = q.shape 234 | q, k, v = map( 235 | lambda t: t.unsqueeze(3) 236 | .reshape(b, t.shape[1], self.heads, self.dim_head) 237 | .permute(0, 2, 1, 3) 238 | .reshape(b * self.heads, t.shape[1], self.dim_head) 239 | .contiguous(), 240 | (q, k, v), 241 | ) 242 | 243 | # actually compute the attention, what we cannot get enough of 244 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 245 | 246 | if exists(mask): 247 | raise NotImplementedError 248 | out = ( 249 | out.unsqueeze(0) 250 | .reshape(b, self.heads, out.shape[1], self.dim_head) 251 | .permute(0, 2, 1, 3) 252 | .reshape(b, out.shape[1], self.heads * self.dim_head) 253 | ) 254 | return self.to_out(out) 255 | 256 | 257 | class BasicTransformerBlock(nn.Module): 258 | ATTENTION_MODES = { 259 | "softmax": CrossAttention, # vanilla attention 260 | "softmax-xformers": MemoryEfficientCrossAttention 261 | } 262 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 263 | disable_self_attn=False): 264 | super().__init__() 265 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" 266 | assert attn_mode in self.ATTENTION_MODES 267 | attn_cls = self.ATTENTION_MODES[attn_mode] 268 | self.disable_self_attn = disable_self_attn 269 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 270 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 271 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 272 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 273 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 274 | self.norm1 = nn.LayerNorm(dim) 275 | self.norm2 = nn.LayerNorm(dim) 276 | self.norm3 = nn.LayerNorm(dim) 277 | self.checkpoint = checkpoint 278 | 279 | def forward(self, x, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0, h=None, w=None): 280 | return checkpoint(self._forward, (x, context, encode, encode_uncon, decode_uncon, controller, inject, layernum, h, w), self.parameters(), self.checkpoint) 281 | 282 | def _forward(self, x, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0, h=None, w=None): 283 | 284 | if encode_uncon == True and decode_uncon == True: 285 | # pass 286 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x 287 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x # 如果要加层数,要记得改 register_attention_control 288 | 289 | elif encode_uncon == True and decode_uncon == False: 290 | if encode: 291 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x 292 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x # 如果要加层数,要记得改 register_attention_control 293 | else: 294 | x = self.attn1(self.norm1(x), context=context 295 | if self.disable_self_attn else None, controller_for_inject=controller, inject=inject, layernum=layernum) + x 296 | x = self.attn1(self.norm1(x), context=context 297 | if self.disable_self_attn else None, controller_for_inject=controller, inject=inject, layernum=layernum+1) + x 298 | x = self.attn2(self.norm2(x), context=context) + x 299 | 300 | elif encode_uncon == False and decode_uncon == False: 301 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, encode=encode, 302 | controller_for_inject=controller, inject=inject, layernum=layernum, main_height=h, main_width=w) + x 303 | x = self.attn2(self.norm2(x), context=context, encode=encode) + x 304 | # pass 305 | 306 | x = self.ff(self.norm3(x)) + x 307 | return x 308 | 309 | 310 | class SpatialTransformer(nn.Module): 311 | """ 312 | Transformer block for image-like data. 313 | First, project the input (aka embedding) 314 | and reshape to b, t, d. 315 | Then apply standard transformer action. 316 | Finally, reshape to image 317 | NEW: use_linear for more efficiency instead of the 1x1 convs 318 | """ 319 | def __init__(self, in_channels, n_heads, d_head, 320 | depth=1, dropout=0., context_dim=None, 321 | disable_self_attn=False, use_linear=False, 322 | use_checkpoint=True): 323 | super().__init__() 324 | if exists(context_dim) and not isinstance(context_dim, list): 325 | context_dim = [context_dim] 326 | self.in_channels = in_channels 327 | inner_dim = n_heads * d_head 328 | self.norm = Normalize(in_channels) 329 | if not use_linear: 330 | self.proj_in = nn.Conv2d(in_channels, 331 | inner_dim, 332 | kernel_size=1, 333 | stride=1, 334 | padding=0) 335 | else: 336 | self.proj_in = nn.Linear(in_channels, inner_dim) 337 | 338 | self.transformer_blocks = nn.ModuleList( 339 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 340 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 341 | for d in range(depth)] 342 | ) 343 | if not use_linear: 344 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 345 | in_channels, 346 | kernel_size=1, 347 | stride=1, 348 | padding=0)) 349 | else: 350 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 351 | self.use_linear = use_linear 352 | 353 | def forward(self, x, context=None, encode=False, encode_uncon=True, decode_uncon=True, controller=None, inject=False, layernum=0): 354 | # note: if no context is given, cross-attention defaults to self-attention 355 | if not isinstance(context, list): 356 | context = [context] 357 | b, c, h, w = x.shape 358 | x_in = x 359 | x = self.norm(x) 360 | if not self.use_linear: 361 | x = self.proj_in(x) 362 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 363 | if self.use_linear: 364 | x = self.proj_in(x) 365 | for i, block in enumerate(self.transformer_blocks): 366 | x = block(x, context=context[i], encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, 367 | controller=controller, inject=inject, layernum=layernum, h=h, w=w) 368 | if self.use_linear: 369 | x = self.proj_out(x) 370 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 371 | if not self.use_linear: 372 | x = self.proj_out(x) 373 | 374 | layernum = layernum + 1 # 和register_recr对应起来 375 | 376 | return x + x_in, layernum 377 | 378 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | if ddim_timesteps[-1] == 1000: 66 | ddim_timesteps = ddim_timesteps - 1 67 | alphas = alphacums[ddim_timesteps] 68 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 69 | alphas_next = np.asarray(alphacums[ddim_timesteps[1:]].tolist() + [alphacums[-1].tolist()]) 70 | 71 | # according the the formula provided in https://arxiv.org/abs/2010.02502 72 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 73 | if verbose: 74 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 75 | print(f'For the chosen value of eta, which is {eta}, ' 76 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 77 | return sigmas, alphas, alphas_prev, alphas_next 78 | 79 | 80 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 81 | """ 82 | Create a beta schedule that discretizes the given alpha_t_bar function, 83 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 84 | :param num_diffusion_timesteps: the number of betas to produce. 85 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 86 | produces the cumulative product of (1-beta) up to that 87 | part of the diffusion process. 88 | :param max_beta: the maximum beta to use; use values lower than 1 to 89 | prevent singularities. 90 | """ 91 | betas = [] 92 | for i in range(num_diffusion_timesteps): 93 | t1 = i / num_diffusion_timesteps 94 | t2 = (i + 1) / num_diffusion_timesteps 95 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 96 | return np.array(betas) 97 | 98 | 99 | def extract_into_tensor(a, t, x_shape): 100 | b, *_ = t.shape 101 | out = a.gather(-1, t) 102 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 103 | 104 | 105 | def checkpoint(func, inputs, params, flag): 106 | """ 107 | Evaluate a function without caching intermediate activations, allowing for 108 | reduced memory at the expense of extra compute in the backward pass. 109 | :param func: the function to evaluate. 110 | :param inputs: the argument sequence to pass to `func`. 111 | :param params: a sequence of parameters `func` depends on but does not 112 | explicitly take as arguments. 113 | :param flag: if False, disable gradient checkpointing. 114 | """ 115 | if flag: 116 | args = tuple(inputs) + tuple(params) 117 | return CheckpointFunction.apply(func, len(inputs), *args) 118 | else: 119 | return func(*inputs) 120 | 121 | 122 | class CheckpointFunction(torch.autograd.Function): 123 | @staticmethod 124 | def forward(ctx, run_function, length, *args): 125 | ctx.run_function = run_function 126 | ctx.input_tensors = list(args[:length]) 127 | ctx.input_params = list(args[length:]) 128 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 129 | "dtype": torch.get_autocast_gpu_dtype(), 130 | "cache_enabled": torch.is_autocast_cache_enabled()} 131 | with torch.no_grad(): 132 | output_tensors = ctx.run_function(*ctx.input_tensors) 133 | return output_tensors 134 | 135 | @staticmethod 136 | def backward(ctx, *output_grads): 137 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 138 | with torch.enable_grad(), \ 139 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 140 | # Fixes a bug where the first op in run_function modifies the 141 | # Tensor storage in place, which is not allowed for detach()'d 142 | # Tensors. 143 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 144 | output_tensors = ctx.run_function(*shallow_copies) 145 | input_grads = torch.autograd.grad( 146 | output_tensors, 147 | ctx.input_tensors + ctx.input_params, 148 | output_grads, 149 | allow_unused=True, 150 | ) 151 | del ctx.input_tensors 152 | del ctx.input_params 153 | del output_tensors 154 | return (None, None) + input_grads 155 | 156 | 157 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 158 | """ 159 | Create sinusoidal timestep embeddings. 160 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 161 | These may be fractional. 162 | :param dim: the dimension of the output. 163 | :param max_period: controls the minimum frequency of the embeddings. 164 | :return: an [N x dim] Tensor of positional embeddings. 165 | """ 166 | if not repeat_only: 167 | half = dim // 2 168 | freqs = torch.exp( 169 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 170 | ).to(device=timesteps.device) 171 | args = timesteps[:, None].float() * freqs[None] 172 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 173 | if dim % 2: 174 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 175 | else: 176 | embedding = repeat(timesteps, 'b -> b d', d=dim) 177 | return embedding 178 | 179 | 180 | def zero_module(module): 181 | """ 182 | Zero out the parameters of a module and return it. 183 | """ 184 | for p in module.parameters(): 185 | p.detach().zero_() 186 | return module 187 | 188 | 189 | def scale_module(module, scale): 190 | """ 191 | Scale the parameters of a module and return it. 192 | """ 193 | for p in module.parameters(): 194 | p.detach().mul_(scale) 195 | return module 196 | 197 | 198 | def mean_flat(tensor): 199 | """ 200 | Take the mean over all non-batch dimensions. 201 | """ 202 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 203 | 204 | 205 | def normalization(channels): 206 | """ 207 | Make a standard normalization layer. 208 | :param channels: number of input channels. 209 | :return: an nn.Module for normalization. 210 | """ 211 | return GroupNorm32(32, channels) 212 | 213 | 214 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 215 | class SiLU(nn.Module): 216 | def forward(self, x): 217 | return x * torch.sigmoid(x) 218 | 219 | 220 | class GroupNorm32(nn.GroupNorm): 221 | def forward(self, x): 222 | return super().forward(x.float()).type(x.dtype) 223 | 224 | def conv_nd(dims, *args, **kwargs): 225 | """ 226 | Create a 1D, 2D, or 3D convolution module. 227 | """ 228 | if dims == 1: 229 | return nn.Conv1d(*args, **kwargs) 230 | elif dims == 2: 231 | return nn.Conv2d(*args, **kwargs) 232 | elif dims == 3: 233 | return nn.Conv3d(*args, **kwargs) 234 | raise ValueError(f"unsupported dimensions: {dims}") 235 | 236 | 237 | def linear(*args, **kwargs): 238 | """ 239 | Create a linear module. 240 | """ 241 | return nn.Linear(*args, **kwargs) 242 | 243 | 244 | def avg_pool_nd(dims, *args, **kwargs): 245 | """ 246 | Create a 1D, 2D, or 3D average pooling module. 247 | """ 248 | if dims == 1: 249 | return nn.AvgPool1d(*args, **kwargs) 250 | elif dims == 2: 251 | return nn.AvgPool2d(*args, **kwargs) 252 | elif dims == 3: 253 | return nn.AvgPool3d(*args, **kwargs) 254 | raise ValueError(f"unsupported dimensions: {dims}") 255 | 256 | 257 | class HybridConditioner(nn.Module): 258 | 259 | def __init__(self, c_concat_config, c_crossattn_config): 260 | super().__init__() 261 | self.concat_conditioner = instantiate_from_config(c_concat_config) 262 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 263 | 264 | def forward(self, c_concat, c_crossattn): 265 | c_concat = self.concat_conditioner(c_concat) 266 | c_crossattn = self.crossattn_conditioner(c_crossattn) 267 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 268 | 269 | 270 | def noise_like(shape, device, repeat=False): 271 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 272 | noise = lambda: torch.randn(shape, device=device) 273 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from ldm.util import default, count_params 9 | import einops 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 135 | """ 136 | Uses the OpenCLIP transformer encoder for text 137 | """ 138 | LAYERS = [ 139 | #"pooled", 140 | "last", 141 | "penultimate" 142 | ] 143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 144 | freeze=True, layer="last"): 145 | super().__init__() 146 | assert layer in self.LAYERS 147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 148 | del model.visual 149 | self.model = model 150 | 151 | self.device = device 152 | self.max_length = max_length 153 | if freeze: 154 | self.freeze() 155 | self.layer = layer 156 | if self.layer == "last": 157 | self.layer_idx = 0 158 | elif self.layer == "penultimate": 159 | self.layer_idx = 1 160 | else: 161 | raise NotImplementedError() 162 | 163 | def freeze(self): 164 | self.model = self.model.eval() 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | # def forward(self, text, inv=False): 169 | # tokens = open_clip.tokenize(text) 170 | # # print(tokens) 171 | # if inv: 172 | 173 | # # tokens[0] = torch.zeros(77) 174 | # # tokens[0] = torch.zeros(77)+7788 175 | # print(tokens[0]) 176 | # z = self.encode_with_transformer(tokens.to(self.device), inv=True) 177 | # # print(z.shape) 178 | # else: 179 | # z = self.encode_with_transformer(tokens.to(self.device),inv=True) 180 | # # z = self.encode_with_transformer(tokens.to(self.device),inv) 181 | # # print(z.shape) 182 | # return z 183 | 184 | # def encode_with_transformer(self, text, inv=False): 185 | # x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 186 | 187 | # if inv == True: 188 | # # x = einops.repeat(x[:,0], 'i j -> i c j', c=77) 189 | # x = x + self.model.positional_embedding 190 | 191 | # x = x.permute(1, 0, 2) # NLD -> LND 192 | # x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 193 | # x = x.permute(1, 0, 2) # LND -> NLD 194 | # x = self.model.ln_final(x) 195 | # return x 196 | def forward(self, text, inv=False): 197 | tokens = open_clip.tokenize(text) 198 | if inv: 199 | if tokens[0][1] == 49407: 200 | # tokens[0] = torch.zeros(77) 201 | # print(tokens[0]) 202 | z = self.encode_with_transformer(tokens.to(self.device), inv=False) 203 | else: 204 | # tokens[0] = torch.zeros(77) 205 | # print(tokens[0]) 206 | z = self.encode_with_transformer(tokens.to(self.device), inv=False) 207 | else: 208 | # tokens[0] = torch.zeros(77) + 788 209 | z = self.encode_with_transformer(tokens.to(self.device)) 210 | return z 211 | 212 | def encode_with_transformer(self, text, inv=False): 213 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 214 | if inv == False: 215 | # x = einops.repeat(x[:,0], 'i j -> i c j', c=77) 216 | x = x + self.model.positional_embedding 217 | x = x.permute(1, 0, 2) # NLD -> LND 218 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 219 | x = x.permute(1, 0, 2) # LND -> NLD 220 | x = self.model.ln_final(x) 221 | return x 222 | 223 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 224 | for i, r in enumerate(self.model.transformer.resblocks): 225 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 226 | break 227 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 228 | x = checkpoint(r, x, attn_mask) 229 | else: 230 | x = r(x, attn_mask=attn_mask) 231 | return x 232 | 233 | def encode(self, text, inv=False, device=None): 234 | if device is not None: 235 | self.device = device 236 | return self(text, inv) 237 | 238 | 239 | class FrozenCLIPT5Encoder(AbstractEncoder): 240 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 241 | clip_max_length=77, t5_max_length=77): 242 | super().__init__() 243 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 244 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 245 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 246 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 247 | 248 | def encode(self, text): 249 | return self(text) 250 | 251 | def forward(self, text): 252 | clip_z = self.clip_encoder.encode(text) 253 | t5_z = self.t5_encoder.encode(text) 254 | return [clip_z, t5_z] 255 | 256 | 257 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Slice(nn.Module): 10 | def __init__(self, start_index=1): 11 | super(Slice, self).__init__() 12 | self.start_index = start_index 13 | 14 | def forward(self, x): 15 | return x[:, self.start_index :] 16 | 17 | 18 | class AddReadout(nn.Module): 19 | def __init__(self, start_index=1): 20 | super(AddReadout, self).__init__() 21 | self.start_index = start_index 22 | 23 | def forward(self, x): 24 | if self.start_index == 2: 25 | readout = (x[:, 0] + x[:, 1]) / 2 26 | else: 27 | readout = x[:, 0] 28 | return x[:, self.start_index :] + readout.unsqueeze(1) 29 | 30 | 31 | class ProjectReadout(nn.Module): 32 | def __init__(self, in_features, start_index=1): 33 | super(ProjectReadout, self).__init__() 34 | self.start_index = start_index 35 | 36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 37 | 38 | def forward(self, x): 39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 40 | features = torch.cat((x[:, self.start_index :], readout), -1) 41 | 42 | return self.project(features) 43 | 44 | 45 | class Transpose(nn.Module): 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0 = dim0 49 | self.dim1 = dim1 50 | 51 | def forward(self, x): 52 | x = x.transpose(self.dim0, self.dim1) 53 | return x 54 | 55 | 56 | def forward_vit(pretrained, x): 57 | b, c, h, w = x.shape 58 | 59 | glob = pretrained.model.forward_flex(x) 60 | 61 | layer_1 = pretrained.activations["1"] 62 | layer_2 = pretrained.activations["2"] 63 | layer_3 = pretrained.activations["3"] 64 | layer_4 = pretrained.activations["4"] 65 | 66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 70 | 71 | unflatten = nn.Sequential( 72 | nn.Unflatten( 73 | 2, 74 | torch.Size( 75 | [ 76 | h // pretrained.model.patch_size[1], 77 | w // pretrained.model.patch_size[0], 78 | ] 79 | ), 80 | ) 81 | ) 82 | 83 | if layer_1.ndim == 3: 84 | layer_1 = unflatten(layer_1) 85 | if layer_2.ndim == 3: 86 | layer_2 = unflatten(layer_2) 87 | if layer_3.ndim == 3: 88 | layer_3 = unflatten(layer_3) 89 | if layer_4.ndim == 3: 90 | layer_4 = unflatten(layer_4) 91 | 92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 96 | 97 | return layer_1, layer_2, layer_3, layer_4 98 | 99 | 100 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 101 | posemb_tok, posemb_grid = ( 102 | posemb[:, : self.start_index], 103 | posemb[0, self.start_index :], 104 | ) 105 | 106 | gs_old = int(math.sqrt(len(posemb_grid))) 107 | 108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 111 | 112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 113 | 114 | return posemb 115 | 116 | 117 | def forward_flex(self, x): 118 | b, c, h, w = x.shape 119 | 120 | pos_embed = self._resize_pos_embed( 121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 122 | ) 123 | 124 | B = x.shape[0] 125 | 126 | if hasattr(self.patch_embed, "backbone"): 127 | x = self.patch_embed.backbone(x) 128 | if isinstance(x, (list, tuple)): 129 | x = x[-1] # last feature if backbone outputs list/tuple of features 130 | 131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 132 | 133 | if getattr(self, "dist_token", None) is not None: 134 | cls_tokens = self.cls_token.expand( 135 | B, -1, -1 136 | ) # stole cls_tokens impl from Phil Wang, thanks 137 | dist_token = self.dist_token.expand(B, -1, -1) 138 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 139 | else: 140 | cls_tokens = self.cls_token.expand( 141 | B, -1, -1 142 | ) # stole cls_tokens impl from Phil Wang, thanks 143 | x = torch.cat((cls_tokens, x), dim=1) 144 | 145 | x = x + pos_embed 146 | x = self.pos_drop(x) 147 | 148 | for blk in self.blocks: 149 | x = blk(x) 150 | 151 | x = self.norm(x) 152 | 153 | return x 154 | 155 | 156 | activations = {} 157 | 158 | 159 | def get_activation(name): 160 | def hook(model, input, output): 161 | activations[name] = output 162 | 163 | return hook 164 | 165 | 166 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 167 | if use_readout == "ignore": 168 | readout_oper = [Slice(start_index)] * len(features) 169 | elif use_readout == "add": 170 | readout_oper = [AddReadout(start_index)] * len(features) 171 | elif use_readout == "project": 172 | readout_oper = [ 173 | ProjectReadout(vit_features, start_index) for out_feat in features 174 | ] 175 | else: 176 | assert ( 177 | False 178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 179 | 180 | return readout_oper 181 | 182 | 183 | def _make_vit_b16_backbone( 184 | model, 185 | features=[96, 192, 384, 768], 186 | size=[384, 384], 187 | hooks=[2, 5, 8, 11], 188 | vit_features=768, 189 | use_readout="ignore", 190 | start_index=1, 191 | ): 192 | pretrained = nn.Module() 193 | 194 | pretrained.model = model 195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 199 | 200 | pretrained.activations = activations 201 | 202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 203 | 204 | # 32, 48, 136, 384 205 | pretrained.act_postprocess1 = nn.Sequential( 206 | readout_oper[0], 207 | Transpose(1, 2), 208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 209 | nn.Conv2d( 210 | in_channels=vit_features, 211 | out_channels=features[0], 212 | kernel_size=1, 213 | stride=1, 214 | padding=0, 215 | ), 216 | nn.ConvTranspose2d( 217 | in_channels=features[0], 218 | out_channels=features[0], 219 | kernel_size=4, 220 | stride=4, 221 | padding=0, 222 | bias=True, 223 | dilation=1, 224 | groups=1, 225 | ), 226 | ) 227 | 228 | pretrained.act_postprocess2 = nn.Sequential( 229 | readout_oper[1], 230 | Transpose(1, 2), 231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 232 | nn.Conv2d( 233 | in_channels=vit_features, 234 | out_channels=features[1], 235 | kernel_size=1, 236 | stride=1, 237 | padding=0, 238 | ), 239 | nn.ConvTranspose2d( 240 | in_channels=features[1], 241 | out_channels=features[1], 242 | kernel_size=2, 243 | stride=2, 244 | padding=0, 245 | bias=True, 246 | dilation=1, 247 | groups=1, 248 | ), 249 | ) 250 | 251 | pretrained.act_postprocess3 = nn.Sequential( 252 | readout_oper[2], 253 | Transpose(1, 2), 254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 255 | nn.Conv2d( 256 | in_channels=vit_features, 257 | out_channels=features[2], 258 | kernel_size=1, 259 | stride=1, 260 | padding=0, 261 | ), 262 | ) 263 | 264 | pretrained.act_postprocess4 = nn.Sequential( 265 | readout_oper[3], 266 | Transpose(1, 2), 267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 268 | nn.Conv2d( 269 | in_channels=vit_features, 270 | out_channels=features[3], 271 | kernel_size=1, 272 | stride=1, 273 | padding=0, 274 | ), 275 | nn.Conv2d( 276 | in_channels=features[3], 277 | out_channels=features[3], 278 | kernel_size=3, 279 | stride=2, 280 | padding=1, 281 | ), 282 | ) 283 | 284 | pretrained.model.start_index = start_index 285 | pretrained.model.patch_size = [16, 16] 286 | 287 | # We inject this function into the VisionTransformer instances so that 288 | # we can use it with interpolated position embeddings without modifying the library source. 289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 290 | pretrained.model._resize_pos_embed = types.MethodType( 291 | _resize_pos_embed, pretrained.model 292 | ) 293 | 294 | return pretrained 295 | 296 | 297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): 298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 299 | 300 | hooks = [5, 11, 17, 23] if hooks == None else hooks 301 | return _make_vit_b16_backbone( 302 | model, 303 | features=[256, 512, 1024, 1024], 304 | hooks=hooks, 305 | vit_features=1024, 306 | use_readout=use_readout, 307 | ) 308 | 309 | 310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): 311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 312 | 313 | hooks = [2, 5, 8, 11] if hooks == None else hooks 314 | return _make_vit_b16_backbone( 315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 316 | ) 317 | 318 | 319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): 320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 321 | 322 | hooks = [2, 5, 8, 11] if hooks == None else hooks 323 | return _make_vit_b16_backbone( 324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 325 | ) 326 | 327 | 328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): 329 | model = timm.create_model( 330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 331 | ) 332 | 333 | hooks = [2, 5, 8, 11] if hooks == None else hooks 334 | return _make_vit_b16_backbone( 335 | model, 336 | features=[96, 192, 384, 768], 337 | hooks=hooks, 338 | use_readout=use_readout, 339 | start_index=2, 340 | ) 341 | 342 | 343 | def _make_vit_b_rn50_backbone( 344 | model, 345 | features=[256, 512, 768, 768], 346 | size=[384, 384], 347 | hooks=[0, 1, 8, 11], 348 | vit_features=768, 349 | use_vit_only=False, 350 | use_readout="ignore", 351 | start_index=1, 352 | ): 353 | pretrained = nn.Module() 354 | 355 | pretrained.model = model 356 | 357 | if use_vit_only == True: 358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 360 | else: 361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 362 | get_activation("1") 363 | ) 364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 365 | get_activation("2") 366 | ) 367 | 368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 370 | 371 | pretrained.activations = activations 372 | 373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 374 | 375 | if use_vit_only == True: 376 | pretrained.act_postprocess1 = nn.Sequential( 377 | readout_oper[0], 378 | Transpose(1, 2), 379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 380 | nn.Conv2d( 381 | in_channels=vit_features, 382 | out_channels=features[0], 383 | kernel_size=1, 384 | stride=1, 385 | padding=0, 386 | ), 387 | nn.ConvTranspose2d( 388 | in_channels=features[0], 389 | out_channels=features[0], 390 | kernel_size=4, 391 | stride=4, 392 | padding=0, 393 | bias=True, 394 | dilation=1, 395 | groups=1, 396 | ), 397 | ) 398 | 399 | pretrained.act_postprocess2 = nn.Sequential( 400 | readout_oper[1], 401 | Transpose(1, 2), 402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 403 | nn.Conv2d( 404 | in_channels=vit_features, 405 | out_channels=features[1], 406 | kernel_size=1, 407 | stride=1, 408 | padding=0, 409 | ), 410 | nn.ConvTranspose2d( 411 | in_channels=features[1], 412 | out_channels=features[1], 413 | kernel_size=2, 414 | stride=2, 415 | padding=0, 416 | bias=True, 417 | dilation=1, 418 | groups=1, 419 | ), 420 | ) 421 | else: 422 | pretrained.act_postprocess1 = nn.Sequential( 423 | nn.Identity(), nn.Identity(), nn.Identity() 424 | ) 425 | pretrained.act_postprocess2 = nn.Sequential( 426 | nn.Identity(), nn.Identity(), nn.Identity() 427 | ) 428 | 429 | pretrained.act_postprocess3 = nn.Sequential( 430 | readout_oper[2], 431 | Transpose(1, 2), 432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 433 | nn.Conv2d( 434 | in_channels=vit_features, 435 | out_channels=features[2], 436 | kernel_size=1, 437 | stride=1, 438 | padding=0, 439 | ), 440 | ) 441 | 442 | pretrained.act_postprocess4 = nn.Sequential( 443 | readout_oper[3], 444 | Transpose(1, 2), 445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 446 | nn.Conv2d( 447 | in_channels=vit_features, 448 | out_channels=features[3], 449 | kernel_size=1, 450 | stride=1, 451 | padding=0, 452 | ), 453 | nn.Conv2d( 454 | in_channels=features[3], 455 | out_channels=features[3], 456 | kernel_size=3, 457 | stride=2, 458 | padding=1, 459 | ), 460 | ) 461 | 462 | pretrained.model.start_index = start_index 463 | pretrained.model.patch_size = [16, 16] 464 | 465 | # We inject this function into the VisionTransformer instances so that 466 | # we can use it with interpolated position embeddings without modifying the library source. 467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 468 | 469 | # We inject this function into the VisionTransformer instances so that 470 | # we can use it with interpolated position embeddings without modifying the library source. 471 | pretrained.model._resize_pos_embed = types.MethodType( 472 | _resize_pos_embed, pretrained.model 473 | ) 474 | 475 | return pretrained 476 | 477 | 478 | def _make_pretrained_vitb_rn50_384( 479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False 480 | ): 481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 482 | 483 | hooks = [0, 1, 8, 11] if hooks == None else hooks 484 | return _make_vit_b_rn50_backbone( 485 | model, 486 | features=[256, 512, 768, 768], 487 | size=[384, 384], 488 | hooks=hooks, 489 | use_vit_only=use_vit_only, 490 | use_readout=use_readout, 491 | ) 492 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /ptp_scripts/ptp_scripts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Union, Tuple, List, Callable, Dict 16 | import torch 17 | from diffusers import StableDiffusionPipeline 18 | import torch.nn.functional as nnf 19 | import numpy as np 20 | import abc 21 | import sys 22 | sys.path.append('..') 23 | import matplotlib.pyplot as plt 24 | from PIL import Image 25 | 26 | LOW_RESOURCE = False 27 | 28 | class AttentionControl(abc.ABC): 29 | 30 | def step_callback(self, x_t): 31 | return x_t 32 | 33 | def between_steps(self): 34 | return 35 | 36 | @property 37 | def num_uncond_att_layers(self): 38 | return self.num_att_layers if LOW_RESOURCE else 0 39 | 40 | @abc.abstractmethod 41 | def forward (self, attn, is_cross: bool, place_in_unet: str): 42 | raise NotImplementedError 43 | 44 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 45 | if self.cur_att_layer >= self.num_uncond_att_layers: 46 | if LOW_RESOURCE: 47 | attn = self.forward(attn, is_cross, place_in_unet) 48 | else: 49 | h = attn.shape[0] 50 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 51 | 52 | self.cur_att_layer += 1 53 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 54 | self.cur_att_layer = 0 55 | self.cur_step += 1 56 | self.between_steps() 57 | return attn 58 | 59 | def reset(self): 60 | self.cur_step = 0 61 | self.cur_att_layer = 0 62 | 63 | def __init__(self): 64 | self.cur_step = 0 65 | self.num_att_layers = -1 66 | self.cur_att_layer = 0 67 | 68 | 69 | class AttentionStore(AttentionControl): 70 | 71 | @staticmethod 72 | def get_empty_store(): 73 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 74 | "down_self": [], "mid_self": [], "up_self": []} 75 | 76 | def forward(self, attn, is_cross: bool, place_in_unet: str): 77 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 78 | # if attn.shape[1] <= 16 ** 2: # and attn.shape[1] > 16 ** 2: # avoid memory overhead 79 | self.step_store[key].append(attn) 80 | return attn 81 | 82 | def between_steps(self): 83 | for key in self.step_store: 84 | self.attention_store[key] = self.step_store[key] 85 | 86 | self.step_store = self.get_empty_store() 87 | 88 | def get_average_attention(self): 89 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 90 | return average_attention 91 | 92 | 93 | def reset(self): 94 | super(AttentionStore, self).reset() 95 | self.step_store = self.get_empty_store() 96 | self.attention_store = self.get_empty_store() 97 | 98 | def __init__(self): 99 | super(AttentionStore, self).__init__() 100 | self.step_store = self.get_empty_store() 101 | self.attention_store = self.get_empty_store() -------------------------------------------------------------------------------- /ptp_scripts/ptp_utils_ori.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image, ImageDraw, ImageFont 18 | import cv2 19 | from typing import Optional, Union, Tuple, List, Callable, Dict 20 | from IPython.display import display 21 | from tqdm.notebook import tqdm 22 | import matplotlib.pyplot as plt 23 | from torch import nn, einsum 24 | from einops import rearrange, repeat 25 | from inspect import isfunction 26 | 27 | 28 | def exists(val): 29 | return val is not None 30 | 31 | 32 | def default(val, d): 33 | if exists(val): 34 | return val 35 | return d() if isfunction(d) else d 36 | 37 | 38 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 39 | h, w, c = image.shape 40 | offset = int(h * .2) 41 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 42 | font = cv2.FONT_HERSHEY_SIMPLEX 43 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 44 | img[:h] = image 45 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 46 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 47 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 48 | return img 49 | 50 | 51 | def view_images(images, num_rows=1, offset_ratio=0.02, name='image', timestamp=0, layernum=None): 52 | if type(images) is list: 53 | num_empty = len(images) % num_rows 54 | elif images.ndim == 4: 55 | num_empty = images.shape[0] % num_rows 56 | else: 57 | images = [images] 58 | num_empty = 0 59 | 60 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 61 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 62 | num_items = len(images) 63 | 64 | h, w, c = images[0].shape 65 | offset = int(h * offset_ratio) 66 | num_cols = num_items // num_rows 67 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 68 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 69 | for i in range(num_rows): 70 | for j in range(num_cols): 71 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 72 | i * num_cols + j] 73 | 74 | pil_img = Image.fromarray(image_) 75 | display(pil_img) 76 | 77 | 78 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 79 | if low_resource: 80 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 81 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 82 | else: 83 | latents_input = torch.cat([latents] * 2) 84 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 85 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 86 | 87 | # classifier-free guidance during inference 88 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 89 | # noise_pred = guidance_scale * noise_prediction_text 90 | 91 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 92 | latents = controller.step_callback(latents) 93 | return latents 94 | 95 | 96 | def latent2image(vae, latents): 97 | latents = 1 / 0.18215 * latents 98 | image = vae.decode(latents)['sample'] 99 | image = (image / 2 + 0.5).clamp(0, 1) 100 | image = image.cpu().permute(0, 2, 3, 1).numpy() 101 | image = (image * 255).astype(np.uint8) 102 | return image 103 | 104 | 105 | def init_latent(latent, model, height, width, generator, batch_size): 106 | if latent is None: 107 | latent = torch.randn( 108 | (1, model.unet.in_channels, height // 8, width // 8), 109 | generator=generator, 110 | ) 111 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 112 | return latent, latents 113 | 114 | 115 | @torch.no_grad() 116 | def text2image_ldm( 117 | model, 118 | prompt: List[str], 119 | controller, 120 | num_inference_steps: int = 50, 121 | guidance_scale: Optional[float] = 7., 122 | generator: Optional[torch.Generator] = None, 123 | latent: Optional[torch.FloatTensor] = None, 124 | ): 125 | register_attention_control(model, controller) 126 | height = width = 256 127 | batch_size = len(prompt) 128 | 129 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 130 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 131 | 132 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 133 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 134 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 135 | context = torch.cat([uncond_embeddings, text_embeddings]) 136 | 137 | model.scheduler.set_timesteps(num_inference_steps) 138 | for t in tqdm(model.scheduler.timesteps): 139 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 140 | 141 | image = latent2image(model.vqvae, latents) 142 | 143 | return image, latent 144 | 145 | 146 | @torch.no_grad() 147 | def text2image_ldm_stable( 148 | model, 149 | prompt: List[str], 150 | controller, 151 | num_inference_steps: int = 50, 152 | guidance_scale: float = 7.5, 153 | generator: Optional[torch.Generator] = None, 154 | latent: Optional[torch.FloatTensor] = None, 155 | low_resource: bool = False, 156 | ): 157 | register_attention_control(model, controller) 158 | height = width = 512 159 | batch_size = len(prompt) 160 | 161 | text_input = model.tokenizer( 162 | prompt, 163 | padding="max_length", 164 | max_length=model.tokenizer.model_max_length, 165 | truncation=True, 166 | return_tensors="pt", 167 | ) 168 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 169 | max_length = text_input.input_ids.shape[-1] 170 | uncond_input = model.tokenizer( 171 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 172 | ) 173 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 174 | 175 | context = [uncond_embeddings, text_embeddings] 176 | if not low_resource: 177 | context = torch.cat(context) 178 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 179 | 180 | model.scheduler.set_timesteps(num_inference_steps) 181 | for t in tqdm(model.scheduler.timesteps): 182 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 183 | 184 | image = latent2image(model.vae, latents) 185 | 186 | return image, latent 187 | 188 | 189 | def register_attention_control(model, controller, inject_bg=False, pseudo_cross=False): 190 | def ca_forward(self, place_in_unet): 191 | def forward(x, context=None, mask=None, encode=False, controller_for_inject=None, inject=False, layernum=None, main_height=None, main_width=None): 192 | 193 | is_cross = context is not None 194 | h = self.heads 195 | 196 | q = self.to_q(x) 197 | context = default(context, x) 198 | k = self.to_k(context) 199 | v = self.to_v(context) 200 | 201 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 202 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 203 | 204 | if exists(mask): 205 | mask = rearrange(mask, 'b ... -> b (...)') 206 | max_neg_value = -torch.finfo(sim.dtype).max 207 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 208 | sim.masked_fill_(~mask, max_neg_value) 209 | 210 | sim_2 = sim.clone() 211 | if encode == False: 212 | sim = controller(sim, is_cross, place_in_unet) 213 | 214 | if inject or inject_bg and is_cross == False: 215 | 216 | if layernum >= -1 : 217 | if place_in_unet == 'down': 218 | 219 | if inject_bg: 220 | sim[h:] = sim[h:] 221 | 222 | if inject: 223 | sim[h:] = sim[h:] 224 | 225 | 226 | elif place_in_unet == 'up': 227 | 228 | if inject_bg: 229 | sim[h:] = controller_for_inject[0].attention_store['up_self'][layernum] * 0.5 \ 230 | + 0.5* sim[h:] 231 | 232 | if inject: 233 | sim[h:] = controller_for_inject[1].attention_store['up_self'][layernum] * 0.5 \ 234 | + 0.5 * sim[h:] 235 | 236 | 237 | elif place_in_unet == 'mid': 238 | 239 | if inject_bg: 240 | 241 | sim[h:] = sim[h:] 242 | 243 | if inject: 244 | 245 | sim[h:] = sim[h:] 246 | 247 | sim = sim.softmax(dim=-1) 248 | sim_2 = sim_2.softmax(dim=-1) 249 | 250 | 251 | out = einsum('b i j, b j d -> b i d', sim, v) 252 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 253 | 254 | del sim, v, q, k, context 255 | 256 | return self.to_out(out) 257 | 258 | return forward 259 | 260 | def register_recr(net_, count, place_in_unet): 261 | if 'CrossAttention' in net_.__class__.__name__: 262 | if net_.to_k.in_features != 1024: 263 | net_.forward = ca_forward(net_, place_in_unet) 264 | return count + 1 265 | else: 266 | return count 267 | elif hasattr(net_, 'children'): 268 | for net__ in net_.children(): 269 | count = register_recr(net__, count, place_in_unet) 270 | return count 271 | 272 | cross_att_count = 0 273 | # sub_nets = model.unet.named_children() 274 | sub_nets = model.model.diffusion_model.named_children() 275 | 276 | for net in sub_nets: 277 | if "input" in net[0]: 278 | cross_att_count += register_recr(net[1], 0, "down") 279 | elif "output" in net[0]: 280 | cross_att_count += register_recr(net[1], 0, "up") 281 | elif "middle" in net[0]: 282 | cross_att_count += register_recr(net[1], 0, "mid") 283 | controller.num_att_layers = cross_att_count 284 | 285 | 286 | 287 | 288 | def register_recr(net_, count, place_in_unet): 289 | if 'CrossAttention' in net_.__class__.__name__: 290 | if net_.to_k.in_features != 1024: 291 | net_.forward = ca_forward(net_, place_in_unet) 292 | return count + 1 293 | else: 294 | return count 295 | elif hasattr(net_, 'children'): 296 | for net__ in net_.children(): 297 | count = register_recr(net__, count, place_in_unet) 298 | return count 299 | 300 | cross_att_count = 0 301 | # sub_nets = model.unet.named_children() 302 | sub_nets = model.model.diffusion_model.named_children() 303 | 304 | for net in sub_nets: 305 | if "input" in net[0]: 306 | cross_att_count += register_recr(net[1], 0, "down") 307 | elif "output" in net[0]: 308 | cross_att_count += register_recr(net[1], 0, "up") 309 | elif "middle" in net[0]: 310 | cross_att_count += register_recr(net[1], 0, "mid") 311 | controller.num_att_layers = cross_att_count 312 | 313 | 314 | def get_word_inds(text: str, word_place: int, tokenizer): 315 | split_text = text.split(" ") 316 | if type(word_place) is str: 317 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 318 | elif type(word_place) is int: 319 | word_place = [word_place] 320 | out = [] 321 | if len(word_place) > 0: 322 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 323 | cur_len, ptr = 0, 0 324 | 325 | for i in range(len(words_encode)): 326 | cur_len += len(words_encode[i]) 327 | if ptr in word_place: 328 | out.append(i + 1) 329 | if cur_len >= len(split_text[ptr]): 330 | ptr += 1 331 | cur_len = 0 332 | return np.array(out) 333 | 334 | 335 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None): 336 | if type(bounds) is float: 337 | bounds = 0, bounds 338 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 339 | if word_inds is None: 340 | word_inds = torch.arange(alpha.shape[2]) 341 | alpha[: start, prompt_ind, word_inds] = 0 342 | alpha[start: end, prompt_ind, word_inds] = 1 343 | alpha[end:, prompt_ind, word_inds] = 0 344 | return alpha 345 | 346 | 347 | def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 348 | tokenizer, max_num_words=77): 349 | if type(cross_replace_steps) is not dict: 350 | cross_replace_steps = {"default_": cross_replace_steps} 351 | if "default_" not in cross_replace_steps: 352 | cross_replace_steps["default_"] = (0., 1.) 353 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 354 | 355 | for i in range(len(prompts) - 1): 356 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) 357 | 358 | for key, item in cross_replace_steps.items(): 359 | if key != "default_": 360 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 361 | for i, ind in enumerate(inds): 362 | if len(ind) > 0: 363 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 364 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words 365 | return alpha_time_words 366 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='stable-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------