├── LICENSE
├── README.md
├── assets
├── example-01.png
├── example-02.png
├── example-03.png
├── example-04.png
├── example-05.png
├── example-06.png
├── seed.jpg
└── teaser.png
├── ckpt
└── download.sh
├── configs
└── stable-diffusion
│ ├── v2-inference copy.yaml
│ ├── v2-inference-v.yaml
│ ├── v2-inference.yaml
│ ├── v2-inpainting-inference.yaml
│ ├── v2-midas-inference.yaml
│ └── x4-upscaling.yaml
├── environment.yaml
├── examples
├── a photo of a black sloth plushie
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a pixel cartoon
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a polygonal illustration duck toy
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a purple plushie toy
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a red clock
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a robot monster toy
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a sand statue plushie
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a sculpture golden retriever
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── a photo of a silver backpack
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
└── a photo of a wooden sculpture Chow Chow
│ ├── scene.png
│ ├── selected_mask.png
│ ├── sub_mask.png
│ └── subject.png
├── ldm
├── data
│ ├── __init__.py
│ └── util.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ ├── plms.py
│ │ └── sampling_util.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── openaimodel.cpython-38.pyc
│ │ │ └── util.cpython-38.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── upscaling.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── modules.cpython-38.pyc
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ └── midas
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blocks.py
│ │ ├── dpt_depth.py
│ │ ├── midas_net.py
│ │ ├── midas_net_custom.py
│ │ ├── transforms.py
│ │ └── vit.py
│ │ └── utils.py
└── util.py
├── ptp_scripts
├── ptp_scripts.py
└── ptp_utils_ori.py
├── run.py
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 PENGZHI LI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tuning-Free Image Customization with Image and Text Guidance (ECCV 2024)
2 |
3 | [](https://arxiv.org/abs/2403.12658)
4 | [](https://zrealli.github.io/TIGIC)
5 |
6 |
7 |
8 | This repository contains the official implementation of the following paper:
9 | > **Tuning-Free Image Customization with Image and Text Guidance**
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | ## :open_book: Overview
23 | Current image customization methods with diffusion models have limitations, such as unintended changes, reliance on either reference images or text, and time-consuming fine-tuning. We introduce a tuning-free framework **TIGIC** for text-image-guided image customization, allowing precise, region-specific edits within seconds. Our approach preserves the reference image's semantic features while modifying details based on text descriptions, using an innovative attention blending strategy in the UNet decoder. This method is applicable to image synthesis, design, and creative photography.
24 |
25 |
26 |
27 | ## :hammer: Quick Start
28 |
29 | ```
30 | git clone https://github.com/zrealli/TIGIC.git
31 | cd TIGIC
32 | ```
33 | ### 1. Prepare Environment
34 | To set up our environment, please follow these instructions:
35 | ```
36 | conda env create -f environment.yaml
37 | conda activate TIGIC
38 | ```
39 | Please note that our project requires 24GB of memory to run.
40 |
41 | ### 2. Download Checkpoints
42 | Next, download the [Stable Diffusion weights](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt) and put it to ./ckpt.
43 | ```
44 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt
45 | ```
46 |
47 | ### 4. Inference with TIGIC
48 | After downloading the base model, to execute user inference, use the following command:
49 |
50 |
51 | ```
52 | python run.py --ckpt ./ckpt/v2-1_512-ema-pruned.ckpt --dpm_steps 20 --outdir ./outputs --root ./examples --scale 5 --seed 42
53 | ```
54 | The generation results can be found in the ./outputs directory. Here we show some examples below.
55 |
56 |
57 | ## :framed_picture: Generation Results
58 | From left to right, the images are the **background, subject, collage, and the result** generated by TIGIC.
59 |
60 | *'a photo of a pixel cartoon'*
61 |
62 |
63 |
64 |
65 |
66 |
67 | *'a photo of a black sloth plushie'*
68 |
69 |
70 |
71 |
72 |
73 |
74 | *'a photo of a polygonal illustration duck toy'*
75 |
76 |
77 |
78 |
79 |
80 |
81 | *'a photo of a purple plushie toy'*
82 |
83 |
84 |
85 |
86 |
87 |
88 | *'a photo of a sculpture golden retriever'*
89 |
90 |
91 |
92 |
93 |
94 |
95 | *'a photo of a red clock'*
96 |
97 |
98 |
99 |
100 |
101 |
102 | :fountain_pen: Please note that if the generated results contain significant artifacts, adjust the random seed `--seed` to obtain the desired outcome. We demonstrate that different random seeds can produce varying image quality in the demo.
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 | *'a photo of a black sloth plushie'*
113 |
114 |
115 |
116 |
117 |
118 |
119 | Please refer to our [project page](https://zrealli.github.io/TIGIC) for more visual comparisons.
120 |
121 | ### 🤗 Gradio Demo
122 | We will soon release a Gradio demo version, which will include an integrated automatic foreground subject segmentation module.
123 |
124 | ## :four_leaf_clover: Acknowledgments
125 | This project is distributed under the MIT License. Our work builds upon the foundation laid by others. We thank the contributions of the following projects that our code is based on [TF-ICON](https://github.com/Shilin-LU/TF-ICON) and [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt).
126 |
127 | ## :fountain_pen: Citation
128 |
129 | If you find our repo useful for your research, please consider citing our paper:
130 |
131 | ```bibtex
132 | @inproceedings{li2024tuning,
133 | title={Tuning-Free Image Customization with Image and Text Guidance},
134 | author={Li, Pengzhi and Nie, Qiang and Chen, Ying and Jiang, Xi and Wu, Kai and Lin, Yuhuan and Liu, Yong and Peng, Jinlong and Wang, Chengjie and Zheng, Feng},
135 | booktitle={European Conference on Computer Vision},
136 | year={2024}
137 | }
138 |
139 | ```
140 |
141 |
--------------------------------------------------------------------------------
/assets/example-01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-01.png
--------------------------------------------------------------------------------
/assets/example-02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-02.png
--------------------------------------------------------------------------------
/assets/example-03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-03.png
--------------------------------------------------------------------------------
/assets/example-04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-04.png
--------------------------------------------------------------------------------
/assets/example-05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-05.png
--------------------------------------------------------------------------------
/assets/example-06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/example-06.png
--------------------------------------------------------------------------------
/assets/seed.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/seed.jpg
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/assets/teaser.png
--------------------------------------------------------------------------------
/ckpt/download.sh:
--------------------------------------------------------------------------------
1 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-inference copy.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False # we set this to false because this is an inference only config
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | use_checkpoint: True
24 | use_fp16: True
25 | image_size: 32 # unused
26 | in_channels: 4
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions: [ 4, 2, 1 ]
30 | num_res_blocks: 2
31 | channel_mult: [ 1, 2, 4, 4 ]
32 | num_head_channels: 64 # need to fix for flash-attn
33 | use_spatial_transformer: True
34 | use_linear_in_transformer: True
35 | transformer_depth: 1
36 | context_dim: 1024
37 | legacy: False
38 |
39 | first_stage_config:
40 | target: ldm.models.autoencoder.AutoencoderKL
41 | params:
42 | embed_dim: 4
43 | monitor: val/rec_loss
44 | ddconfig:
45 | #attn_type: "vanilla-xformers"
46 | double_z: true
47 | z_channels: 4
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 2
55 | - 4
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65 | params:
66 | freeze: True
67 | layer: "penultimate"
68 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-inference-v.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | parameterization: "v"
6 | linear_start: 0.00085
7 | linear_end: 0.0120
8 | num_timesteps_cond: 1
9 | log_every_t: 200
10 | timesteps: 1000
11 | first_stage_key: "jpg"
12 | cond_stage_key: "txt"
13 | image_size: 64
14 | channels: 4
15 | cond_stage_trainable: false
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple_ema
18 | scale_factor: 0.18215
19 | use_ema: False # we set this to false because this is an inference only config
20 |
21 | unet_config:
22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23 | params:
24 | use_checkpoint: True
25 | use_fp16: True
26 | image_size: 32 # unused
27 | in_channels: 4
28 | out_channels: 4
29 | model_channels: 320
30 | attention_resolutions: [ 4, 2, 1 ]
31 | num_res_blocks: 2
32 | channel_mult: [ 1, 2, 4, 4 ]
33 | num_head_channels: 64 # need to fix for flash-attn
34 | use_spatial_transformer: True
35 | use_linear_in_transformer: True
36 | transformer_depth: 1
37 | context_dim: 1024
38 | legacy: False
39 |
40 | first_stage_config:
41 | target: ldm.models.autoencoder.AutoencoderKL
42 | params:
43 | embed_dim: 4
44 | monitor: val/rec_loss
45 | ddconfig:
46 | #attn_type: "vanilla-xformers"
47 | double_z: true
48 | z_channels: 4
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | - 4
58 | num_res_blocks: 2
59 | attn_resolutions: []
60 | dropout: 0.0
61 | lossconfig:
62 | target: torch.nn.Identity
63 |
64 | cond_stage_config:
65 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
66 | params:
67 | freeze: True
68 | layer: "penultimate"
69 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False # we set this to false because this is an inference only config
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | use_checkpoint: True
24 | use_fp16: True
25 | image_size: 32 # unused
26 | in_channels: 4
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions: [ 4, 2, 1 ]
30 | num_res_blocks: 2
31 | channel_mult: [ 1, 2, 4, 4 ]
32 | num_head_channels: 64 # need to fix for flash-attn
33 | use_spatial_transformer: True
34 | use_linear_in_transformer: True
35 | transformer_depth: 1
36 | context_dim: 1024
37 | legacy: False
38 |
39 | first_stage_config:
40 | target: ldm.models.autoencoder.AutoencoderKL
41 | params:
42 | embed_dim: 4
43 | monitor: val/rec_loss
44 | ddconfig:
45 | #attn_type: "vanilla-xformers"
46 | double_z: true
47 | z_channels: 4
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 2
55 | - 4
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65 | params:
66 | freeze: True
67 | layer: "penultimate"
68 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-inpainting-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: hybrid
16 | scale_factor: 0.18215
17 | monitor: val/loss_simple_ema
18 | finetune_keys: null
19 | use_ema: False
20 |
21 | unet_config:
22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23 | params:
24 | use_checkpoint: True
25 | image_size: 32 # unused
26 | in_channels: 9
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions: [ 4, 2, 1 ]
30 | num_res_blocks: 2
31 | channel_mult: [ 1, 2, 4, 4 ]
32 | num_head_channels: 64 # need to fix for flash-attn
33 | use_spatial_transformer: True
34 | use_linear_in_transformer: True
35 | transformer_depth: 1
36 | context_dim: 1024
37 | legacy: False
38 |
39 | first_stage_config:
40 | target: ldm.models.autoencoder.AutoencoderKL
41 | params:
42 | embed_dim: 4
43 | monitor: val/rec_loss
44 | ddconfig:
45 | #attn_type: "vanilla-xformers"
46 | double_z: true
47 | z_channels: 4
48 | resolution: 256
49 | in_channels: 3
50 | out_ch: 3
51 | ch: 128
52 | ch_mult:
53 | - 1
54 | - 2
55 | - 4
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65 | params:
66 | freeze: True
67 | layer: "penultimate"
68 |
69 |
70 | data:
71 | target: ldm.data.laion.WebDataModuleFromConfig
72 | params:
73 | tar_base: null # for concat as in LAION-A
74 | p_unsafe_threshold: 0.1
75 | filter_word_list: "data/filters.yaml"
76 | max_pwatermark: 0.45
77 | batch_size: 8
78 | num_workers: 6
79 | multinode: True
80 | min_size: 512
81 | train:
82 | shards:
83 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
84 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
85 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
86 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
87 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
88 | shuffle: 10000
89 | image_key: jpg
90 | image_transforms:
91 | - target: torchvision.transforms.Resize
92 | params:
93 | size: 512
94 | interpolation: 3
95 | - target: torchvision.transforms.RandomCrop
96 | params:
97 | size: 512
98 | postprocess:
99 | target: ldm.data.laion.AddMask
100 | params:
101 | mode: "512train-large"
102 | p_drop: 0.25
103 | # NOTE use enough shards to avoid empty validation loops in workers
104 | validation:
105 | shards:
106 | - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
107 | shuffle: 0
108 | image_key: jpg
109 | image_transforms:
110 | - target: torchvision.transforms.Resize
111 | params:
112 | size: 512
113 | interpolation: 3
114 | - target: torchvision.transforms.CenterCrop
115 | params:
116 | size: 512
117 | postprocess:
118 | target: ldm.data.laion.AddMask
119 | params:
120 | mode: "512train-large"
121 | p_drop: 0.25
122 |
123 | lightning:
124 | find_unused_parameters: True
125 | modelcheckpoint:
126 | params:
127 | every_n_train_steps: 5000
128 |
129 | callbacks:
130 | metrics_over_trainsteps_checkpoint:
131 | params:
132 | every_n_train_steps: 10000
133 |
134 | image_logger:
135 | target: main.ImageLogger
136 | params:
137 | enable_autocast: False
138 | disabled: False
139 | batch_frequency: 1000
140 | max_images: 4
141 | increase_log_steps: False
142 | log_first_step: False
143 | log_images_kwargs:
144 | use_ema_scope: False
145 | inpaint: False
146 | plot_progressive_rows: False
147 | plot_diffusion_rows: False
148 | N: 4
149 | unconditional_guidance_scale: 5.0
150 | unconditional_guidance_label: [""]
151 | ddim_steps: 50 # todo check these out for depth2img,
152 | ddim_eta: 0.0 # todo check these out for depth2img,
153 |
154 | trainer:
155 | benchmark: True
156 | val_check_interval: 5000000
157 | num_sanity_val_steps: 0
158 | accumulate_grad_batches: 1
159 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/v2-midas-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-07
3 | target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: hybrid
16 | scale_factor: 0.18215
17 | monitor: val/loss_simple_ema
18 | finetune_keys: null
19 | use_ema: False
20 |
21 | depth_stage_config:
22 | target: ldm.modules.midas.api.MiDaSInference
23 | params:
24 | model_type: "dpt_hybrid"
25 |
26 | unet_config:
27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
28 | params:
29 | use_checkpoint: True
30 | image_size: 32 # unused
31 | in_channels: 5
32 | out_channels: 4
33 | model_channels: 320
34 | attention_resolutions: [ 4, 2, 1 ]
35 | num_res_blocks: 2
36 | channel_mult: [ 1, 2, 4, 4 ]
37 | num_head_channels: 64 # need to fix for flash-attn
38 | use_spatial_transformer: True
39 | use_linear_in_transformer: True
40 | transformer_depth: 1
41 | context_dim: 1024
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | #attn_type: "vanilla-xformers"
51 | double_z: true
52 | z_channels: 4
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult:
58 | - 1
59 | - 2
60 | - 4
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: [ ]
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 |
68 | cond_stage_config:
69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
70 | params:
71 | freeze: True
72 | layer: "penultimate"
73 |
74 |
75 |
--------------------------------------------------------------------------------
/configs/stable-diffusion/x4-upscaling.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
4 | params:
5 | parameterization: "v"
6 | low_scale_key: "lr"
7 | linear_start: 0.0001
8 | linear_end: 0.02
9 | num_timesteps_cond: 1
10 | log_every_t: 200
11 | timesteps: 1000
12 | first_stage_key: "jpg"
13 | cond_stage_key: "txt"
14 | image_size: 128
15 | channels: 4
16 | cond_stage_trainable: false
17 | conditioning_key: "hybrid-adm"
18 | monitor: val/loss_simple_ema
19 | scale_factor: 0.08333
20 | use_ema: False
21 |
22 | low_scale_config:
23 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
24 | params:
25 | noise_schedule_config: # image space
26 | linear_start: 0.0001
27 | linear_end: 0.02
28 | max_noise_level: 350
29 |
30 | unet_config:
31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32 | params:
33 | use_checkpoint: True
34 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
35 | image_size: 128
36 | in_channels: 7
37 | out_channels: 4
38 | model_channels: 256
39 | attention_resolutions: [ 2,4,8]
40 | num_res_blocks: 2
41 | channel_mult: [ 1, 2, 2, 4]
42 | disable_self_attentions: [True, True, True, False]
43 | disable_middle_self_attn: False
44 | num_heads: 8
45 | use_spatial_transformer: True
46 | transformer_depth: 1
47 | context_dim: 1024
48 | legacy: False
49 | use_linear_in_transformer: True
50 |
51 | first_stage_config:
52 | target: ldm.models.autoencoder.AutoencoderKL
53 | params:
54 | embed_dim: 4
55 | ddconfig:
56 | # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
57 | double_z: True
58 | z_channels: 4
59 | resolution: 256
60 | in_channels: 3
61 | out_ch: 3
62 | ch: 128
63 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
64 | num_res_blocks: 2
65 | attn_resolutions: [ ]
66 | dropout: 0.0
67 |
68 | lossconfig:
69 | target: torch.nn.Identity
70 |
71 | cond_stage_config:
72 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
73 | params:
74 | freeze: True
75 | layer: "penultimate"
76 |
77 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: TIGIC
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - numpy=1.23.1
12 | - pip:
13 | - albumentations==1.3.0
14 | - opencv-python==4.6.0.66
15 | - imageio==2.9.0
16 | - imageio-ffmpeg==0.4.2
17 | - pytorch-lightning==1.4.2
18 | - omegaconf==2.1.1
19 | - test-tube>=0.7.5
20 | - streamlit==1.12.1
21 | - einops==0.3.0
22 | - transformers==4.19.2
23 | - webdataset==0.2.5
24 | - kornia==0.6
25 | - open_clip_torch==2.0.2
26 | - invisible-watermark>=0.1.5
27 | - streamlit-drawable-canvas==0.8.0
28 | - torchmetrics==0.6.0
29 | - diffusers==0.12.1
30 | - ipykernel
31 | - matplotlib
32 | - -e .
33 |
34 |
35 |
--------------------------------------------------------------------------------
/examples/a photo of a black sloth plushie/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a black sloth plushie/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a black sloth plushie/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a black sloth plushie/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a black sloth plushie/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a pixel cartoon/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a pixel cartoon/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a pixel cartoon/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a pixel cartoon/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a pixel cartoon/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a polygonal illustration duck toy/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a polygonal illustration duck toy/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a polygonal illustration duck toy/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a polygonal illustration duck toy/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a polygonal illustration duck toy/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a purple plushie toy/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a purple plushie toy/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a purple plushie toy/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a purple plushie toy/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a purple plushie toy/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a red clock/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a red clock/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a red clock/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a red clock/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a red clock/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a robot monster toy/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a robot monster toy/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a robot monster toy/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a robot monster toy/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a robot monster toy/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a sand statue plushie/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a sand statue plushie/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a sand statue plushie/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a sand statue plushie/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sand statue plushie/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a sculpture golden retriever/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a sculpture golden retriever/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a sculpture golden retriever/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a sculpture golden retriever/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a sculpture golden retriever/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a silver backpack/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a silver backpack/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a silver backpack/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a silver backpack/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a silver backpack/subject.png
--------------------------------------------------------------------------------
/examples/a photo of a wooden sculpture Chow Chow/scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/scene.png
--------------------------------------------------------------------------------
/examples/a photo of a wooden sculpture Chow Chow/selected_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/selected_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a wooden sculpture Chow Chow/sub_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/sub_mask.png
--------------------------------------------------------------------------------
/examples/a photo of a wooden sculpture Chow Chow/subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/examples/a photo of a wooden sculpture Chow Chow/subject.png
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ldm.modules.midas.api import load_midas_transform
4 |
5 |
6 | class AddMiDaS(object):
7 | def __init__(self, model_type):
8 | super().__init__()
9 | self.transform = load_midas_transform(model_type)
10 |
11 | def pt2np(self, x):
12 | x = ((x + 1.0) * .5).detach().cpu().numpy()
13 | return x
14 |
15 | def np2pt(self, x):
16 | x = torch.from_numpy(x) * 2 - 1.
17 | return x
18 |
19 | def __call__(self, sample):
20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point
21 | x = self.pt2np(sample['jpg'])
22 | x = self.transform({"image": x})["image"]
23 | sample['midas_in'] = x
24 | return sample
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8 |
9 | from ldm.util import instantiate_from_config
10 | from ldm.modules.ema import LitEma
11 |
12 |
13 | class AutoencoderKL(pl.LightningModule):
14 | def __init__(self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | ckpt_path=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | ema_decay=None,
24 | learn_logvar=False
25 | ):
26 | super().__init__()
27 | self.learn_logvar = learn_logvar
28 | self.image_key = image_key
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 | self.loss = instantiate_from_config(lossconfig)
32 | assert ddconfig["double_z"]
33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35 | self.embed_dim = embed_dim
36 | if colorize_nlabels is not None:
37 | assert type(colorize_nlabels)==int
38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39 | if monitor is not None:
40 | self.monitor = monitor
41 |
42 | self.use_ema = ema_decay is not None
43 | if self.use_ema:
44 | self.ema_decay = ema_decay
45 | assert 0. < ema_decay < 1.
46 | self.model_ema = LitEma(self, decay=ema_decay)
47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48 |
49 | if ckpt_path is not None:
50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51 |
52 | def init_from_ckpt(self, path, ignore_keys=list()):
53 | sd = torch.load(path, map_location="cpu")["state_dict"]
54 | keys = list(sd.keys())
55 | for k in keys:
56 | for ik in ignore_keys:
57 | if k.startswith(ik):
58 | print("Deleting key {} from state_dict.".format(k))
59 | del sd[k]
60 | self.load_state_dict(sd, strict=False)
61 | print(f"Restored from {path}")
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def on_train_batch_end(self, *args, **kwargs):
79 | if self.use_ema:
80 | self.model_ema(self)
81 |
82 | def encode(self, x):
83 | h = self.encoder(x)
84 | moments = self.quant_conv(h)
85 | posterior = DiagonalGaussianDistribution(moments)
86 | return posterior
87 |
88 | def decode(self, z):
89 | z = self.post_quant_conv(z)
90 | dec = self.decoder(z)
91 | return dec
92 |
93 | def forward(self, input, sample_posterior=True):
94 | posterior = self.encode(input)
95 | if sample_posterior:
96 | z = posterior.sample()
97 | else:
98 | z = posterior.mode()
99 | dec = self.decode(z)
100 | return dec, posterior
101 |
102 | def get_input(self, batch, k):
103 | x = batch[k]
104 | if len(x.shape) == 3:
105 | x = x[..., None]
106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107 | return x
108 |
109 | def training_step(self, batch, batch_idx, optimizer_idx):
110 | inputs = self.get_input(batch, self.image_key)
111 | reconstructions, posterior = self(inputs)
112 |
113 | if optimizer_idx == 0:
114 | # train encoder+decoder+logvar
115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116 | last_layer=self.get_last_layer(), split="train")
117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119 | return aeloss
120 |
121 | if optimizer_idx == 1:
122 | # train the discriminator
123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124 | last_layer=self.get_last_layer(), split="train")
125 |
126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return discloss
129 |
130 | def validation_step(self, batch, batch_idx):
131 | log_dict = self._validation_step(batch, batch_idx)
132 | with self.ema_scope():
133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134 | return log_dict
135 |
136 | def _validation_step(self, batch, batch_idx, postfix=""):
137 | inputs = self.get_input(batch, self.image_key)
138 | reconstructions, posterior = self(inputs)
139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140 | last_layer=self.get_last_layer(), split="val"+postfix)
141 |
142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143 | last_layer=self.get_last_layer(), split="val"+postfix)
144 |
145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146 | self.log_dict(log_dict_ae)
147 | self.log_dict(log_dict_disc)
148 | return self.log_dict
149 |
150 | def configure_optimizers(self):
151 | lr = self.learning_rate
152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154 | if self.learn_logvar:
155 | print(f"{self.__class__.__name__}: Learning logvar")
156 | ae_params_list.append(self.loss.logvar)
157 | opt_ae = torch.optim.Adam(ae_params_list,
158 | lr=lr, betas=(0.5, 0.9))
159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160 | lr=lr, betas=(0.5, 0.9))
161 | return [opt_ae, opt_disc], []
162 |
163 | def get_last_layer(self):
164 | return self.decoder.conv_out.weight
165 |
166 | @torch.no_grad()
167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168 | log = dict()
169 | x = self.get_input(batch, self.image_key)
170 | x = x.to(self.device)
171 | if not only_inputs:
172 | xrec, posterior = self(x)
173 | if x.shape[1] > 3:
174 | # colorize with random projection
175 | assert xrec.shape[1] > 3
176 | x = self.to_rgb(x)
177 | xrec = self.to_rgb(xrec)
178 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179 | log["reconstructions"] = xrec
180 | if log_ema or self.use_ema:
181 | with self.ema_scope():
182 | xrec_ema, posterior_ema = self(x)
183 | if x.shape[1] > 3:
184 | # colorize with random projection
185 | assert xrec_ema.shape[1] > 3
186 | xrec_ema = self.to_rgb(xrec_ema)
187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188 | log["reconstructions_ema"] = xrec_ema
189 | log["inputs"] = x
190 | return log
191 |
192 | def to_rgb(self, x):
193 | assert self.image_key == "segmentation"
194 | if not hasattr(self, "colorize"):
195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196 | x = F.conv2d(x, weight=self.colorize)
197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198 | return x
199 |
200 |
201 | class IdentityFirstStage(torch.nn.Module):
202 | def __init__(self, *args, vq_interface=False, **kwargs):
203 | self.vq_interface = vq_interface
204 | super().__init__()
205 |
206 | def encode(self, x, *args, **kwargs):
207 | return x
208 |
209 | def decode(self, x, *args, **kwargs):
210 | return x
211 |
212 | def quantize(self, x, *args, **kwargs):
213 | if self.vq_interface:
214 | return x, None, [None, None, None]
215 | return x
216 |
217 | def forward(self, x, *args, **kwargs):
218 | return x
219 |
220 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 | import ptp_scripts.ptp_scripts as ptp
4 | import sys
5 | sys.path.append('..')
6 | import ptp_scripts.ptp_utils_ori as ptp_utils_ori
7 |
8 | from ldm.models.diffusion.dpm_solver.dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
9 |
10 | from tqdm import tqdm
11 |
12 | MODEL_TYPES = {
13 | "eps": "noise",
14 | "v": "v"
15 | }
16 |
17 |
18 | class DPMSolverSampler(object):
19 | def __init__(self, model, **kwargs):
20 | super().__init__()
21 | self.model = model
22 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
23 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
24 |
25 | def register_buffer(self, name, attr):
26 | if type(attr) == torch.Tensor:
27 | if attr.device != self.model.device:
28 | attr = attr.to(self.model.device)
29 | setattr(self, name, attr)
30 |
31 | @torch.no_grad()
32 | def sample(self,
33 | steps,
34 | batch_size,
35 | shape,
36 | conditioning=None,
37 | conditioning_edit=None,
38 | inv_emb=None,
39 | inv_emb_edit=None,
40 | callback=None,
41 | normals_sequence=None,
42 | img_callback=None,
43 | quantize_x0=False,
44 | eta=0.,
45 | mask=None,
46 | x0=None,
47 | temperature=1.,
48 | noise_dropout=0.,
49 | score_corrector=None,
50 | corrector_kwargs=None,
51 | verbose=True,
52 | x_T=None,
53 | log_every_t=100,
54 | unconditional_guidance_scale=1.,
55 | unconditional_conditioning=None,
56 | unconditional_conditioning_edit=None,
57 | t_start=None,
58 | t_end=None,
59 | DPMencode=False,
60 | order=2,
61 | width=None,
62 | height=None,
63 | ref=False,
64 | param=None,
65 | tau_a=0.5,
66 | tau_b=0.8,
67 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
68 | **kwargs
69 | ):
70 | if conditioning is not None:
71 | if isinstance(conditioning, dict):
72 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
73 | if cbs != batch_size:
74 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
75 | else:
76 | if conditioning.shape[0] != batch_size:
77 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
78 |
79 | # sampling
80 | C, H, W = shape
81 | size = (batch_size, C, H, W)
82 |
83 |
84 | device = self.model.betas.device
85 | if x_T is None:
86 | x = torch.randn(size, device=device)
87 | else:
88 | x = x_T
89 |
90 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
91 |
92 |
93 |
94 | if DPMencode:
95 | # x_T is not a list
96 | model_fn = model_wrapper(
97 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject),
98 | ns,
99 | model_type=MODEL_TYPES[self.model.parameterization],
100 | guidance_type="classifier-free",
101 | condition=inv_emb,
102 | unconditional_condition=inv_emb,
103 | guidance_scale=unconditional_guidance_scale,
104 | )
105 |
106 |
107 | dpm_solver = DPM_Solver(model_fn, ns)
108 | data, _ = dpm_solver.sample_lower(x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=DPMencode)
109 |
110 | for step in range(order, steps + 1):
111 | data = dpm_solver.sample_one_step(data, step, steps, order=order, DPMencode=DPMencode)
112 |
113 | return data['x'].to(device), None
114 |
115 | else:
116 | # x_T is a list
117 |
118 | model_fn_decode = model_wrapper(
119 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
120 | ns,
121 | model_type=MODEL_TYPES[self.model.parameterization],
122 | guidance_type="classifier-free",
123 | condition=inv_emb,
124 | unconditional_condition=inv_emb,
125 | guidance_scale=unconditional_guidance_scale,
126 | )
127 | model_fn_gen = model_wrapper(
128 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
129 | ns,
130 | model_type=MODEL_TYPES[self.model.parameterization],
131 | guidance_type="classifier-free",
132 | condition=conditioning,
133 | unconditional_condition=unconditional_conditioning,
134 | guidance_scale=unconditional_guidance_scale,
135 | )
136 |
137 |
138 | model_fn_decode_edit = model_wrapper(
139 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
140 | ns,
141 | model_type=MODEL_TYPES[self.model.parameterization],
142 | guidance_type="classifier-free",
143 | condition=inv_emb_edit,
144 | unconditional_condition=inv_emb_edit,
145 | guidance_scale=unconditional_guidance_scale,
146 | )
147 | model_fn_gen_edit = model_wrapper(
148 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
149 | ns,
150 | model_type=MODEL_TYPES[self.model.parameterization],
151 | guidance_type="classifier-free",
152 | condition=conditioning_edit,
153 | unconditional_condition=unconditional_conditioning_edit,
154 | guidance_scale=unconditional_guidance_scale,
155 | )
156 |
157 | orig_controller = ptp.AttentionStore()
158 | ref_controller = ptp.AttentionStore()
159 | gen_controller = ptp.AttentionStore()
160 | Inject_controller = ptp.AttentionStore()
161 |
162 | dpm_solver_decode = DPM_Solver(model_fn_decode, ns)
163 | dpm_solver_gen = DPM_Solver(model_fn_gen, ns)
164 |
165 |
166 | dpm_solver_decode_edit = DPM_Solver(model_fn_decode_edit, ns)
167 |
168 | dpm_solver_gen_edit = DPM_Solver(model_fn_gen_edit, ns)
169 |
170 | # decoded background
171 |
172 | ptp_utils_ori.register_attention_control(self.model, orig_controller)
173 |
174 |
175 | orig, orig_controller = dpm_solver_decode_edit.sample_lower(x[0], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=orig_controller)
176 | # decoded reference
177 | ptp_utils_ori.register_attention_control(self.model, ref_controller)
178 | ref, ref_controller = dpm_solver_decode_edit.sample_lower(x[3], dpm_solver_decode_edit, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=ref_controller)
179 |
180 |
181 | # generation
182 | Inject_controller = [orig_controller, ref_controller]
183 | ptp_utils_ori.register_attention_control(self.model, gen_controller, inject_bg=False)
184 |
185 | gen, _ = dpm_solver_decode_edit.sample_lower(x[1], dpm_solver_gen_edit, steps, order, t_start, t_end, device,
186 | DPMencode=DPMencode, controller=Inject_controller, inject=True)
187 |
188 | del orig_controller, ref_controller, gen_controller, Inject_controller
189 |
190 | orig_controller = ptp.AttentionStore()
191 | ref_controller = ptp.AttentionStore()
192 | gen_controller = ptp.AttentionStore()
193 |
194 | for step in range(order, 21):
195 | # decoded background
196 | ptp_utils_ori.register_attention_control(self.model, orig_controller)
197 | orig = dpm_solver_decode.sample_one_step(orig, step, steps, order=order, DPMencode=DPMencode)
198 | ptp_utils_ori.register_attention_control(self.model, ref_controller)
199 | ref = dpm_solver_decode_edit.sample_one_step(ref, step, steps, order=order, DPMencode=DPMencode)
200 |
201 |
202 | if step >= int(0.2*(steps) + 1 - order) and step <= int(0.5*(steps) + 1 - order):
203 | inject = True
204 | controller = [orig_controller, ref_controller]
205 | else:
206 | inject = False
207 | controller = [orig_controller, None]
208 |
209 | if step < int(0.5 * (steps) + 1 - order) and step > int(0.* (steps) + 1 - order) :
210 | inject_bg = True
211 | else:
212 | inject_bg = False
213 |
214 |
215 | ptp_utils_ori.register_attention_control(self.model, gen_controller, inject_bg=inject_bg)
216 | gen = dpm_solver_gen_edit.sample_one_step(gen, step, steps, order=order, DPMencode=DPMencode, controller=controller, inject=inject)
217 |
218 | if step < int(1.0*(steps) + 1 - order):
219 | blended = orig['x'].clone()
220 | blended[:, :, param[0] : param[1], param[2] : param[3]] \
221 | = gen['x'][:, :, param[0] : param[1], param[2] : param[3]].clone()
222 | gen['x'] = blended.clone()
223 |
224 |
225 | return gen['x'].to(device), None
226 |
227 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class PLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | dynamic_threshold=None,
82 | **kwargs
83 | ):
84 | if conditioning is not None:
85 | if isinstance(conditioning, dict):
86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87 | if cbs != batch_size:
88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89 | else:
90 | if conditioning.shape[0] != batch_size:
91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92 |
93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94 | # sampling
95 | C, H, W = shape
96 | size = (batch_size, C, H, W)
97 | print(f'Data shape for PLMS sampling is {size}')
98 |
99 | samples, intermediates = self.plms_sampling(conditioning, size,
100 | callback=callback,
101 | img_callback=img_callback,
102 | quantize_denoised=quantize_x0,
103 | mask=mask, x0=x0,
104 | ddim_use_original_steps=False,
105 | noise_dropout=noise_dropout,
106 | temperature=temperature,
107 | score_corrector=score_corrector,
108 | corrector_kwargs=corrector_kwargs,
109 | x_T=x_T,
110 | log_every_t=log_every_t,
111 | unconditional_guidance_scale=unconditional_guidance_scale,
112 | unconditional_conditioning=unconditional_conditioning,
113 | dynamic_threshold=dynamic_threshold,
114 | )
115 | return samples, intermediates
116 |
117 | @torch.no_grad()
118 | def plms_sampling(self, cond, shape,
119 | x_T=None, ddim_use_original_steps=False,
120 | callback=None, timesteps=None, quantize_denoised=False,
121 | mask=None, x0=None, img_callback=None, log_every_t=100,
122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123 | unconditional_guidance_scale=1., unconditional_conditioning=None,
124 | dynamic_threshold=None):
125 | device = self.model.betas.device
126 | b = shape[0]
127 | if x_T is None:
128 | img = torch.randn(shape, device=device)
129 | else:
130 | img = x_T
131 |
132 | if timesteps is None:
133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134 | elif timesteps is not None and not ddim_use_original_steps:
135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136 | timesteps = self.ddim_timesteps[:subset_end]
137 |
138 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141 | print(f"Running PLMS Sampling with {total_steps} timesteps")
142 |
143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144 | old_eps = []
145 |
146 | for i, step in enumerate(iterator):
147 | index = total_steps - i - 1
148 | ts = torch.full((b,), step, device=device, dtype=torch.long)
149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150 |
151 | if mask is not None:
152 | assert x0 is not None
153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154 | img = img_orig * mask + (1. - mask) * img
155 |
156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157 | quantize_denoised=quantize_denoised, temperature=temperature,
158 | noise_dropout=noise_dropout, score_corrector=score_corrector,
159 | corrector_kwargs=corrector_kwargs,
160 | unconditional_guidance_scale=unconditional_guidance_scale,
161 | unconditional_conditioning=unconditional_conditioning,
162 | old_eps=old_eps, t_next=ts_next,
163 | dynamic_threshold=dynamic_threshold)
164 | img, pred_x0, e_t = outs
165 | old_eps.append(e_t)
166 | if len(old_eps) >= 4:
167 | old_eps.pop(0)
168 | if callback: callback(i)
169 | if img_callback: img_callback(pred_x0, i)
170 |
171 | if index % log_every_t == 0 or index == total_steps - 1:
172 | intermediates['x_inter'].append(img)
173 | intermediates['pred_x0'].append(pred_x0)
174 |
175 | return img, intermediates
176 |
177 | @torch.no_grad()
178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181 | dynamic_threshold=None):
182 | b, *_, device = *x.shape, x.device
183 |
184 | def get_model_output(x, t):
185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186 | e_t = self.model.apply_model(x, t, c)
187 | else:
188 | x_in = torch.cat([x] * 2)
189 | t_in = torch.cat([t] * 2)
190 | c_in = torch.cat([unconditional_conditioning, c])
191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193 |
194 | if score_corrector is not None:
195 | assert self.model.parameterization == "eps"
196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197 |
198 | return e_t
199 |
200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204 |
205 | def get_x_prev_and_pred_x0(e_t, index):
206 | # select parameters corresponding to the currently considered timestep
207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211 |
212 | # current prediction for x_0
213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214 | if quantize_denoised:
215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216 | if dynamic_threshold is not None:
217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218 | # direction pointing to x_t
219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221 | if noise_dropout > 0.:
222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224 | return x_prev, pred_x0
225 |
226 | e_t = get_model_output(x, t)
227 | if len(old_eps) == 0:
228 | # Pseudo Improved Euler (2nd order)
229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230 | e_t_next = get_model_output(x_prev, t_next)
231 | e_t_prime = (e_t + e_t_next) / 2
232 | elif len(old_eps) == 1:
233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
235 | elif len(old_eps) == 2:
236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238 | elif len(old_eps) >= 3:
239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241 |
242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243 |
244 | return x_prev, pred_x0, e_t
245 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 | from typing import Optional, Any
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from ldm.modules.diffusionmodules.util import checkpoint
11 | from PIL import Image
12 |
13 | try:
14 | import xformers
15 | import xformers.ops
16 | XFORMERS_IS_AVAILBLE = False
17 | except:
18 | XFORMERS_IS_AVAILBLE = False
19 |
20 |
21 | def exists(val):
22 | return val is not None
23 |
24 |
25 | def uniq(arr):
26 | return{el: True for el in arr}.keys()
27 |
28 |
29 | def default(val, d):
30 | if exists(val):
31 | return val
32 | return d() if isfunction(d) else d
33 |
34 |
35 | def max_neg_value(t):
36 | return -torch.finfo(t.dtype).max
37 |
38 |
39 | def init_(tensor):
40 | dim = tensor.shape[-1]
41 | std = 1 / math.sqrt(dim)
42 | tensor.uniform_(-std, std)
43 | return tensor
44 |
45 |
46 | # feedforward
47 | class GEGLU(nn.Module):
48 | def __init__(self, dim_in, dim_out):
49 | super().__init__()
50 | self.proj = nn.Linear(dim_in, dim_out * 2)
51 |
52 | def forward(self, x):
53 | x, gate = self.proj(x).chunk(2, dim=-1)
54 | return x * F.gelu(gate)
55 |
56 |
57 | class FeedForward(nn.Module):
58 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
59 | super().__init__()
60 | inner_dim = int(dim * mult)
61 | dim_out = default(dim_out, dim)
62 | project_in = nn.Sequential(
63 | nn.Linear(dim, inner_dim),
64 | nn.GELU()
65 | ) if not glu else GEGLU(dim, inner_dim)
66 |
67 | self.net = nn.Sequential(
68 | project_in,
69 | nn.Dropout(dropout),
70 | nn.Linear(inner_dim, dim_out)
71 | )
72 |
73 | def forward(self, x):
74 | return self.net(x)
75 |
76 |
77 | def zero_module(module):
78 | """
79 | Zero out the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().zero_()
83 | return module
84 |
85 |
86 | def Normalize(in_channels):
87 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
88 |
89 |
90 | class SpatialSelfAttention(nn.Module):
91 | def __init__(self, in_channels):
92 | super().__init__()
93 | self.in_channels = in_channels
94 |
95 | self.norm = Normalize(in_channels)
96 | self.q = torch.nn.Conv2d(in_channels,
97 | in_channels,
98 | kernel_size=1,
99 | stride=1,
100 | padding=0)
101 | self.k = torch.nn.Conv2d(in_channels,
102 | in_channels,
103 | kernel_size=1,
104 | stride=1,
105 | padding=0)
106 | self.v = torch.nn.Conv2d(in_channels,
107 | in_channels,
108 | kernel_size=1,
109 | stride=1,
110 | padding=0)
111 | self.proj_out = torch.nn.Conv2d(in_channels,
112 | in_channels,
113 | kernel_size=1,
114 | stride=1,
115 | padding=0)
116 |
117 | def forward(self, x):
118 | h_ = x
119 | h_ = self.norm(h_)
120 | q = self.q(h_)
121 | k = self.k(h_)
122 | v = self.v(h_)
123 |
124 | # compute attention
125 | b,c,h,w = q.shape
126 | q = rearrange(q, 'b c h w -> b (h w) c')
127 | k = rearrange(k, 'b c h w -> b c (h w)')
128 | w_ = torch.einsum('bij,bjk->bik', q, k)
129 |
130 | w_ = w_ * (int(c)**(-0.5))
131 | w_ = torch.nn.functional.softmax(w_, dim=2)
132 |
133 | # attend to values
134 | v = rearrange(v, 'b c h w -> b c (h w)')
135 | w_ = rearrange(w_, 'b i j -> b j i')
136 | h_ = torch.einsum('bij,bjk->bik', v, w_)
137 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
138 | h_ = self.proj_out(h_)
139 |
140 | return x+h_
141 |
142 |
143 | class CrossAttention(nn.Module):
144 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
145 | super().__init__()
146 | inner_dim = dim_head * heads
147 | context_dim = default(context_dim, query_dim)
148 |
149 | self.scale = dim_head ** -0.5
150 | self.heads = heads
151 |
152 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
153 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
154 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
155 |
156 | self.to_out = nn.Sequential(
157 | nn.Linear(inner_dim, query_dim),
158 | nn.Dropout(dropout)
159 | )
160 |
161 | def forward(self, x, context=None, mask=None, encode=False, controller_for_inject=None, inject=False, layernum=None, main_height=None, main_width=None):
162 | h = self.heads
163 |
164 | q = self.to_q(x)
165 | context = default(context, x)
166 | k = self.to_k(context)
167 | v = self.to_v(context)
168 |
169 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
170 |
171 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
172 | del q, k
173 |
174 | if exists(mask):
175 | mask = rearrange(mask, 'b ... -> b (...)')
176 | max_neg_value = -torch.finfo(sim.dtype).max
177 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
178 | sim.masked_fill_(~mask, max_neg_value)
179 |
180 | # a = ((sim.mean(0).mean(1).resize(64,64)/torch.max(sim.max(), abs(sim.min())) + 1)*127.5).cpu().numpy().astype(np.uint8)
181 | # image = Image.fromarray(a)
182 | # image.resize((512,512)).save('2.jpg')
183 |
184 | # u, s, vh = np.linalg.svd(sim.mean(0).cpu().numpy().astype(np.float32) - np.mean(sim.mean(0).cpu().numpy().astype(np.float32), axis=1, keepdims=True))
185 | # images = []
186 | # for i in range(3):
187 | # image = u[:,i].reshape(64, 64)
188 | # image = image - image.min()
189 | # image = 255 * image / image.max()
190 | # image = np.expand_dims(image, axis=2).astype(np.uint8)
191 | # images.append(image)
192 |
193 | # final = np.dstack(images)
194 | # final = Image.fromarray(final).resize((256, 256))
195 | # final = np.array(final)
196 | # import ptp_scripts.ptp_utils as ptp_utils
197 | # ptp_utils.view_images(final)
198 |
199 | # attention, what we cannot get enough of
200 | sim = sim.softmax(dim=-1)
201 | # sim = sim.sigmoid()
202 |
203 | out = einsum('b i j, b j d -> b i d', sim, v)
204 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
205 | return self.to_out(out)
206 |
207 |
208 | class MemoryEfficientCrossAttention(nn.Module):
209 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
210 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
211 | super().__init__()
212 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
213 | f"{heads} heads.")
214 | inner_dim = dim_head * heads
215 | context_dim = default(context_dim, query_dim)
216 |
217 | self.heads = heads
218 | self.dim_head = dim_head
219 |
220 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
221 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
222 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
223 |
224 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
225 | self.attention_op: Optional[Any] = None
226 |
227 | def forward(self, x, context=None, mask=None):
228 | q = self.to_q(x)
229 | context = default(context, x)
230 | k = self.to_k(context)
231 | v = self.to_v(context)
232 |
233 | b, _, _ = q.shape
234 | q, k, v = map(
235 | lambda t: t.unsqueeze(3)
236 | .reshape(b, t.shape[1], self.heads, self.dim_head)
237 | .permute(0, 2, 1, 3)
238 | .reshape(b * self.heads, t.shape[1], self.dim_head)
239 | .contiguous(),
240 | (q, k, v),
241 | )
242 |
243 | # actually compute the attention, what we cannot get enough of
244 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
245 |
246 | if exists(mask):
247 | raise NotImplementedError
248 | out = (
249 | out.unsqueeze(0)
250 | .reshape(b, self.heads, out.shape[1], self.dim_head)
251 | .permute(0, 2, 1, 3)
252 | .reshape(b, out.shape[1], self.heads * self.dim_head)
253 | )
254 | return self.to_out(out)
255 |
256 |
257 | class BasicTransformerBlock(nn.Module):
258 | ATTENTION_MODES = {
259 | "softmax": CrossAttention, # vanilla attention
260 | "softmax-xformers": MemoryEfficientCrossAttention
261 | }
262 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
263 | disable_self_attn=False):
264 | super().__init__()
265 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
266 | assert attn_mode in self.ATTENTION_MODES
267 | attn_cls = self.ATTENTION_MODES[attn_mode]
268 | self.disable_self_attn = disable_self_attn
269 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
270 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
271 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
272 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
273 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
274 | self.norm1 = nn.LayerNorm(dim)
275 | self.norm2 = nn.LayerNorm(dim)
276 | self.norm3 = nn.LayerNorm(dim)
277 | self.checkpoint = checkpoint
278 |
279 | def forward(self, x, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0, h=None, w=None):
280 | return checkpoint(self._forward, (x, context, encode, encode_uncon, decode_uncon, controller, inject, layernum, h, w), self.parameters(), self.checkpoint)
281 |
282 | def _forward(self, x, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0, h=None, w=None):
283 |
284 | if encode_uncon == True and decode_uncon == True:
285 | # pass
286 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x
287 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x # 如果要加层数,要记得改 register_attention_control
288 |
289 | elif encode_uncon == True and decode_uncon == False:
290 | if encode:
291 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x
292 | x = self.attn1(self.norm1(x), context=None, encode=encode) + x # 如果要加层数,要记得改 register_attention_control
293 | else:
294 | x = self.attn1(self.norm1(x), context=context
295 | if self.disable_self_attn else None, controller_for_inject=controller, inject=inject, layernum=layernum) + x
296 | x = self.attn1(self.norm1(x), context=context
297 | if self.disable_self_attn else None, controller_for_inject=controller, inject=inject, layernum=layernum+1) + x
298 | x = self.attn2(self.norm2(x), context=context) + x
299 |
300 | elif encode_uncon == False and decode_uncon == False:
301 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, encode=encode,
302 | controller_for_inject=controller, inject=inject, layernum=layernum, main_height=h, main_width=w) + x
303 | x = self.attn2(self.norm2(x), context=context, encode=encode) + x
304 | # pass
305 |
306 | x = self.ff(self.norm3(x)) + x
307 | return x
308 |
309 |
310 | class SpatialTransformer(nn.Module):
311 | """
312 | Transformer block for image-like data.
313 | First, project the input (aka embedding)
314 | and reshape to b, t, d.
315 | Then apply standard transformer action.
316 | Finally, reshape to image
317 | NEW: use_linear for more efficiency instead of the 1x1 convs
318 | """
319 | def __init__(self, in_channels, n_heads, d_head,
320 | depth=1, dropout=0., context_dim=None,
321 | disable_self_attn=False, use_linear=False,
322 | use_checkpoint=True):
323 | super().__init__()
324 | if exists(context_dim) and not isinstance(context_dim, list):
325 | context_dim = [context_dim]
326 | self.in_channels = in_channels
327 | inner_dim = n_heads * d_head
328 | self.norm = Normalize(in_channels)
329 | if not use_linear:
330 | self.proj_in = nn.Conv2d(in_channels,
331 | inner_dim,
332 | kernel_size=1,
333 | stride=1,
334 | padding=0)
335 | else:
336 | self.proj_in = nn.Linear(in_channels, inner_dim)
337 |
338 | self.transformer_blocks = nn.ModuleList(
339 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
340 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
341 | for d in range(depth)]
342 | )
343 | if not use_linear:
344 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
345 | in_channels,
346 | kernel_size=1,
347 | stride=1,
348 | padding=0))
349 | else:
350 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
351 | self.use_linear = use_linear
352 |
353 | def forward(self, x, context=None, encode=False, encode_uncon=True, decode_uncon=True, controller=None, inject=False, layernum=0):
354 | # note: if no context is given, cross-attention defaults to self-attention
355 | if not isinstance(context, list):
356 | context = [context]
357 | b, c, h, w = x.shape
358 | x_in = x
359 | x = self.norm(x)
360 | if not self.use_linear:
361 | x = self.proj_in(x)
362 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
363 | if self.use_linear:
364 | x = self.proj_in(x)
365 | for i, block in enumerate(self.transformer_blocks):
366 | x = block(x, context=context[i], encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon,
367 | controller=controller, inject=inject, layernum=layernum, h=h, w=w)
368 | if self.use_linear:
369 | x = self.proj_out(x)
370 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
371 | if not self.use_linear:
372 | x = self.proj_out(x)
373 |
374 | layernum = layernum + 1 # 和register_recr对应起来
375 |
376 | return x + x_in, layernum
377 |
378 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/upscaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7 | from ldm.util import default
8 |
9 |
10 | class AbstractLowScaleModel(nn.Module):
11 | # for concatenating a downsampled image to the latent representation
12 | def __init__(self, noise_schedule_config=None):
13 | super(AbstractLowScaleModel, self).__init__()
14 | if noise_schedule_config is not None:
15 | self.register_schedule(**noise_schedule_config)
16 |
17 | def register_schedule(self, beta_schedule="linear", timesteps=1000,
18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20 | cosine_s=cosine_s)
21 | alphas = 1. - betas
22 | alphas_cumprod = np.cumprod(alphas, axis=0)
23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24 |
25 | timesteps, = betas.shape
26 | self.num_timesteps = int(timesteps)
27 | self.linear_start = linear_start
28 | self.linear_end = linear_end
29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30 |
31 | to_torch = partial(torch.tensor, dtype=torch.float32)
32 |
33 | self.register_buffer('betas', to_torch(betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43 |
44 | def q_sample(self, x_start, t, noise=None):
45 | noise = default(noise, lambda: torch.randn_like(x_start))
46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48 |
49 | def forward(self, x):
50 | return x, None
51 |
52 | def decode(self, x):
53 | return x
54 |
55 |
56 | class SimpleImageConcat(AbstractLowScaleModel):
57 | # no noise level conditioning
58 | def __init__(self):
59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60 | self.max_noise_level = 0
61 |
62 | def forward(self, x):
63 | # fix to constant noise level
64 | return x, torch.zeros(x.shape[0], device=x.device).long()
65 |
66 |
67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69 | super().__init__(noise_schedule_config=noise_schedule_config)
70 | self.max_noise_level = max_noise_level
71 |
72 | def forward(self, x, noise_level=None):
73 | if noise_level is None:
74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75 | else:
76 | assert isinstance(noise_level, torch.Tensor)
77 | z = self.q_sample(x, noise_level)
78 | return z, noise_level
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | if ddim_timesteps[-1] == 1000:
66 | ddim_timesteps = ddim_timesteps - 1
67 | alphas = alphacums[ddim_timesteps]
68 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
69 | alphas_next = np.asarray(alphacums[ddim_timesteps[1:]].tolist() + [alphacums[-1].tolist()])
70 |
71 | # according the the formula provided in https://arxiv.org/abs/2010.02502
72 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
73 | if verbose:
74 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
75 | print(f'For the chosen value of eta, which is {eta}, '
76 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
77 | return sigmas, alphas, alphas_prev, alphas_next
78 |
79 |
80 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
81 | """
82 | Create a beta schedule that discretizes the given alpha_t_bar function,
83 | which defines the cumulative product of (1-beta) over time from t = [0,1].
84 | :param num_diffusion_timesteps: the number of betas to produce.
85 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
86 | produces the cumulative product of (1-beta) up to that
87 | part of the diffusion process.
88 | :param max_beta: the maximum beta to use; use values lower than 1 to
89 | prevent singularities.
90 | """
91 | betas = []
92 | for i in range(num_diffusion_timesteps):
93 | t1 = i / num_diffusion_timesteps
94 | t2 = (i + 1) / num_diffusion_timesteps
95 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
96 | return np.array(betas)
97 |
98 |
99 | def extract_into_tensor(a, t, x_shape):
100 | b, *_ = t.shape
101 | out = a.gather(-1, t)
102 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
103 |
104 |
105 | def checkpoint(func, inputs, params, flag):
106 | """
107 | Evaluate a function without caching intermediate activations, allowing for
108 | reduced memory at the expense of extra compute in the backward pass.
109 | :param func: the function to evaluate.
110 | :param inputs: the argument sequence to pass to `func`.
111 | :param params: a sequence of parameters `func` depends on but does not
112 | explicitly take as arguments.
113 | :param flag: if False, disable gradient checkpointing.
114 | """
115 | if flag:
116 | args = tuple(inputs) + tuple(params)
117 | return CheckpointFunction.apply(func, len(inputs), *args)
118 | else:
119 | return func(*inputs)
120 |
121 |
122 | class CheckpointFunction(torch.autograd.Function):
123 | @staticmethod
124 | def forward(ctx, run_function, length, *args):
125 | ctx.run_function = run_function
126 | ctx.input_tensors = list(args[:length])
127 | ctx.input_params = list(args[length:])
128 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
129 | "dtype": torch.get_autocast_gpu_dtype(),
130 | "cache_enabled": torch.is_autocast_cache_enabled()}
131 | with torch.no_grad():
132 | output_tensors = ctx.run_function(*ctx.input_tensors)
133 | return output_tensors
134 |
135 | @staticmethod
136 | def backward(ctx, *output_grads):
137 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
138 | with torch.enable_grad(), \
139 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
140 | # Fixes a bug where the first op in run_function modifies the
141 | # Tensor storage in place, which is not allowed for detach()'d
142 | # Tensors.
143 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
144 | output_tensors = ctx.run_function(*shallow_copies)
145 | input_grads = torch.autograd.grad(
146 | output_tensors,
147 | ctx.input_tensors + ctx.input_params,
148 | output_grads,
149 | allow_unused=True,
150 | )
151 | del ctx.input_tensors
152 | del ctx.input_params
153 | del output_tensors
154 | return (None, None) + input_grads
155 |
156 |
157 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
158 | """
159 | Create sinusoidal timestep embeddings.
160 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
161 | These may be fractional.
162 | :param dim: the dimension of the output.
163 | :param max_period: controls the minimum frequency of the embeddings.
164 | :return: an [N x dim] Tensor of positional embeddings.
165 | """
166 | if not repeat_only:
167 | half = dim // 2
168 | freqs = torch.exp(
169 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
170 | ).to(device=timesteps.device)
171 | args = timesteps[:, None].float() * freqs[None]
172 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
173 | if dim % 2:
174 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
175 | else:
176 | embedding = repeat(timesteps, 'b -> b d', d=dim)
177 | return embedding
178 |
179 |
180 | def zero_module(module):
181 | """
182 | Zero out the parameters of a module and return it.
183 | """
184 | for p in module.parameters():
185 | p.detach().zero_()
186 | return module
187 |
188 |
189 | def scale_module(module, scale):
190 | """
191 | Scale the parameters of a module and return it.
192 | """
193 | for p in module.parameters():
194 | p.detach().mul_(scale)
195 | return module
196 |
197 |
198 | def mean_flat(tensor):
199 | """
200 | Take the mean over all non-batch dimensions.
201 | """
202 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
203 |
204 |
205 | def normalization(channels):
206 | """
207 | Make a standard normalization layer.
208 | :param channels: number of input channels.
209 | :return: an nn.Module for normalization.
210 | """
211 | return GroupNorm32(32, channels)
212 |
213 |
214 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
215 | class SiLU(nn.Module):
216 | def forward(self, x):
217 | return x * torch.sigmoid(x)
218 |
219 |
220 | class GroupNorm32(nn.GroupNorm):
221 | def forward(self, x):
222 | return super().forward(x.float()).type(x.dtype)
223 |
224 | def conv_nd(dims, *args, **kwargs):
225 | """
226 | Create a 1D, 2D, or 3D convolution module.
227 | """
228 | if dims == 1:
229 | return nn.Conv1d(*args, **kwargs)
230 | elif dims == 2:
231 | return nn.Conv2d(*args, **kwargs)
232 | elif dims == 3:
233 | return nn.Conv3d(*args, **kwargs)
234 | raise ValueError(f"unsupported dimensions: {dims}")
235 |
236 |
237 | def linear(*args, **kwargs):
238 | """
239 | Create a linear module.
240 | """
241 | return nn.Linear(*args, **kwargs)
242 |
243 |
244 | def avg_pool_nd(dims, *args, **kwargs):
245 | """
246 | Create a 1D, 2D, or 3D average pooling module.
247 | """
248 | if dims == 1:
249 | return nn.AvgPool1d(*args, **kwargs)
250 | elif dims == 2:
251 | return nn.AvgPool2d(*args, **kwargs)
252 | elif dims == 3:
253 | return nn.AvgPool3d(*args, **kwargs)
254 | raise ValueError(f"unsupported dimensions: {dims}")
255 |
256 |
257 | class HybridConditioner(nn.Module):
258 |
259 | def __init__(self, c_concat_config, c_crossattn_config):
260 | super().__init__()
261 | self.concat_conditioner = instantiate_from_config(c_concat_config)
262 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
263 |
264 | def forward(self, c_concat, c_crossattn):
265 | c_concat = self.concat_conditioner(c_concat)
266 | c_crossattn = self.crossattn_conditioner(c_crossattn)
267 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
268 |
269 |
270 | def noise_like(shape, device, repeat=False):
271 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
272 | noise = lambda: torch.randn(shape, device=device)
273 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 |
5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6 |
7 | import open_clip
8 | from ldm.util import default, count_params
9 | import einops
10 |
11 | class AbstractEncoder(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def encode(self, *args, **kwargs):
16 | raise NotImplementedError
17 |
18 |
19 | class IdentityEncoder(AbstractEncoder):
20 |
21 | def encode(self, x):
22 | return x
23 |
24 |
25 | class ClassEmbedder(nn.Module):
26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27 | super().__init__()
28 | self.key = key
29 | self.embedding = nn.Embedding(n_classes, embed_dim)
30 | self.n_classes = n_classes
31 | self.ucg_rate = ucg_rate
32 |
33 | def forward(self, batch, key=None, disable_dropout=False):
34 | if key is None:
35 | key = self.key
36 | # this is for use in crossattn
37 | c = batch[key][:, None]
38 | if self.ucg_rate > 0. and not disable_dropout:
39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
41 | c = c.long()
42 | c = self.embedding(c)
43 | return c
44 |
45 | def get_unconditional_conditioning(self, bs, device="cuda"):
46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47 | uc = torch.ones((bs,), device=device) * uc_class
48 | uc = {self.key: uc}
49 | return uc
50 |
51 |
52 | def disabled_train(self, mode=True):
53 | """Overwrite model.train with this function to make sure train/eval mode
54 | does not change anymore."""
55 | return self
56 |
57 |
58 | class FrozenT5Embedder(AbstractEncoder):
59 | """Uses the T5 transformer encoder for text"""
60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
61 | super().__init__()
62 | self.tokenizer = T5Tokenizer.from_pretrained(version)
63 | self.transformer = T5EncoderModel.from_pretrained(version)
64 | self.device = device
65 | self.max_length = max_length # TODO: typical value?
66 | if freeze:
67 | self.freeze()
68 |
69 | def freeze(self):
70 | self.transformer = self.transformer.eval()
71 | #self.train = disabled_train
72 | for param in self.parameters():
73 | param.requires_grad = False
74 |
75 | def forward(self, text):
76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
78 | tokens = batch_encoding["input_ids"].to(self.device)
79 | outputs = self.transformer(input_ids=tokens)
80 |
81 | z = outputs.last_hidden_state
82 | return z
83 |
84 | def encode(self, text):
85 | return self(text)
86 |
87 |
88 | class FrozenCLIPEmbedder(AbstractEncoder):
89 | """Uses the CLIP transformer encoder for text (from huggingface)"""
90 | LAYERS = [
91 | "last",
92 | "pooled",
93 | "hidden"
94 | ]
95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
97 | super().__init__()
98 | assert layer in self.LAYERS
99 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
100 | self.transformer = CLIPTextModel.from_pretrained(version)
101 | self.device = device
102 | self.max_length = max_length
103 | if freeze:
104 | self.freeze()
105 | self.layer = layer
106 | self.layer_idx = layer_idx
107 | if layer == "hidden":
108 | assert layer_idx is not None
109 | assert 0 <= abs(layer_idx) <= 12
110 |
111 | def freeze(self):
112 | self.transformer = self.transformer.eval()
113 | #self.train = disabled_train
114 | for param in self.parameters():
115 | param.requires_grad = False
116 |
117 | def forward(self, text):
118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120 | tokens = batch_encoding["input_ids"].to(self.device)
121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122 | if self.layer == "last":
123 | z = outputs.last_hidden_state
124 | elif self.layer == "pooled":
125 | z = outputs.pooler_output[:, None, :]
126 | else:
127 | z = outputs.hidden_states[self.layer_idx]
128 | return z
129 |
130 | def encode(self, text):
131 | return self(text)
132 |
133 |
134 | class FrozenOpenCLIPEmbedder(AbstractEncoder):
135 | """
136 | Uses the OpenCLIP transformer encoder for text
137 | """
138 | LAYERS = [
139 | #"pooled",
140 | "last",
141 | "penultimate"
142 | ]
143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
144 | freeze=True, layer="last"):
145 | super().__init__()
146 | assert layer in self.LAYERS
147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
148 | del model.visual
149 | self.model = model
150 |
151 | self.device = device
152 | self.max_length = max_length
153 | if freeze:
154 | self.freeze()
155 | self.layer = layer
156 | if self.layer == "last":
157 | self.layer_idx = 0
158 | elif self.layer == "penultimate":
159 | self.layer_idx = 1
160 | else:
161 | raise NotImplementedError()
162 |
163 | def freeze(self):
164 | self.model = self.model.eval()
165 | for param in self.parameters():
166 | param.requires_grad = False
167 |
168 | # def forward(self, text, inv=False):
169 | # tokens = open_clip.tokenize(text)
170 | # # print(tokens)
171 | # if inv:
172 |
173 | # # tokens[0] = torch.zeros(77)
174 | # # tokens[0] = torch.zeros(77)+7788
175 | # print(tokens[0])
176 | # z = self.encode_with_transformer(tokens.to(self.device), inv=True)
177 | # # print(z.shape)
178 | # else:
179 | # z = self.encode_with_transformer(tokens.to(self.device),inv=True)
180 | # # z = self.encode_with_transformer(tokens.to(self.device),inv)
181 | # # print(z.shape)
182 | # return z
183 |
184 | # def encode_with_transformer(self, text, inv=False):
185 | # x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
186 |
187 | # if inv == True:
188 | # # x = einops.repeat(x[:,0], 'i j -> i c j', c=77)
189 | # x = x + self.model.positional_embedding
190 |
191 | # x = x.permute(1, 0, 2) # NLD -> LND
192 | # x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
193 | # x = x.permute(1, 0, 2) # LND -> NLD
194 | # x = self.model.ln_final(x)
195 | # return x
196 | def forward(self, text, inv=False):
197 | tokens = open_clip.tokenize(text)
198 | if inv:
199 | if tokens[0][1] == 49407:
200 | # tokens[0] = torch.zeros(77)
201 | # print(tokens[0])
202 | z = self.encode_with_transformer(tokens.to(self.device), inv=False)
203 | else:
204 | # tokens[0] = torch.zeros(77)
205 | # print(tokens[0])
206 | z = self.encode_with_transformer(tokens.to(self.device), inv=False)
207 | else:
208 | # tokens[0] = torch.zeros(77) + 788
209 | z = self.encode_with_transformer(tokens.to(self.device))
210 | return z
211 |
212 | def encode_with_transformer(self, text, inv=False):
213 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
214 | if inv == False:
215 | # x = einops.repeat(x[:,0], 'i j -> i c j', c=77)
216 | x = x + self.model.positional_embedding
217 | x = x.permute(1, 0, 2) # NLD -> LND
218 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
219 | x = x.permute(1, 0, 2) # LND -> NLD
220 | x = self.model.ln_final(x)
221 | return x
222 |
223 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
224 | for i, r in enumerate(self.model.transformer.resblocks):
225 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
226 | break
227 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
228 | x = checkpoint(r, x, attn_mask)
229 | else:
230 | x = r(x, attn_mask=attn_mask)
231 | return x
232 |
233 | def encode(self, text, inv=False, device=None):
234 | if device is not None:
235 | self.device = device
236 | return self(text, inv)
237 |
238 |
239 | class FrozenCLIPT5Encoder(AbstractEncoder):
240 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
241 | clip_max_length=77, t5_max_length=77):
242 | super().__init__()
243 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
244 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
245 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
246 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
247 |
248 | def encode(self, text):
249 | return self(text)
250 |
251 | def forward(self, text):
252 | clip_z = self.clip_encoder.encode(text)
253 | t5_z = self.t5_encoder.encode(text)
254 | return [clip_z, t5_z]
255 |
256 |
257 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zrealli/TIGIC/c7c1fd39e0c077c07ba5b917ad408dcd507833a3/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import types
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | class Slice(nn.Module):
10 | def __init__(self, start_index=1):
11 | super(Slice, self).__init__()
12 | self.start_index = start_index
13 |
14 | def forward(self, x):
15 | return x[:, self.start_index :]
16 |
17 |
18 | class AddReadout(nn.Module):
19 | def __init__(self, start_index=1):
20 | super(AddReadout, self).__init__()
21 | self.start_index = start_index
22 |
23 | def forward(self, x):
24 | if self.start_index == 2:
25 | readout = (x[:, 0] + x[:, 1]) / 2
26 | else:
27 | readout = x[:, 0]
28 | return x[:, self.start_index :] + readout.unsqueeze(1)
29 |
30 |
31 | class ProjectReadout(nn.Module):
32 | def __init__(self, in_features, start_index=1):
33 | super(ProjectReadout, self).__init__()
34 | self.start_index = start_index
35 |
36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37 |
38 | def forward(self, x):
39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40 | features = torch.cat((x[:, self.start_index :], readout), -1)
41 |
42 | return self.project(features)
43 |
44 |
45 | class Transpose(nn.Module):
46 | def __init__(self, dim0, dim1):
47 | super(Transpose, self).__init__()
48 | self.dim0 = dim0
49 | self.dim1 = dim1
50 |
51 | def forward(self, x):
52 | x = x.transpose(self.dim0, self.dim1)
53 | return x
54 |
55 |
56 | def forward_vit(pretrained, x):
57 | b, c, h, w = x.shape
58 |
59 | glob = pretrained.model.forward_flex(x)
60 |
61 | layer_1 = pretrained.activations["1"]
62 | layer_2 = pretrained.activations["2"]
63 | layer_3 = pretrained.activations["3"]
64 | layer_4 = pretrained.activations["4"]
65 |
66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70 |
71 | unflatten = nn.Sequential(
72 | nn.Unflatten(
73 | 2,
74 | torch.Size(
75 | [
76 | h // pretrained.model.patch_size[1],
77 | w // pretrained.model.patch_size[0],
78 | ]
79 | ),
80 | )
81 | )
82 |
83 | if layer_1.ndim == 3:
84 | layer_1 = unflatten(layer_1)
85 | if layer_2.ndim == 3:
86 | layer_2 = unflatten(layer_2)
87 | if layer_3.ndim == 3:
88 | layer_3 = unflatten(layer_3)
89 | if layer_4.ndim == 3:
90 | layer_4 = unflatten(layer_4)
91 |
92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96 |
97 | return layer_1, layer_2, layer_3, layer_4
98 |
99 |
100 | def _resize_pos_embed(self, posemb, gs_h, gs_w):
101 | posemb_tok, posemb_grid = (
102 | posemb[:, : self.start_index],
103 | posemb[0, self.start_index :],
104 | )
105 |
106 | gs_old = int(math.sqrt(len(posemb_grid)))
107 |
108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111 |
112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113 |
114 | return posemb
115 |
116 |
117 | def forward_flex(self, x):
118 | b, c, h, w = x.shape
119 |
120 | pos_embed = self._resize_pos_embed(
121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122 | )
123 |
124 | B = x.shape[0]
125 |
126 | if hasattr(self.patch_embed, "backbone"):
127 | x = self.patch_embed.backbone(x)
128 | if isinstance(x, (list, tuple)):
129 | x = x[-1] # last feature if backbone outputs list/tuple of features
130 |
131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132 |
133 | if getattr(self, "dist_token", None) is not None:
134 | cls_tokens = self.cls_token.expand(
135 | B, -1, -1
136 | ) # stole cls_tokens impl from Phil Wang, thanks
137 | dist_token = self.dist_token.expand(B, -1, -1)
138 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
139 | else:
140 | cls_tokens = self.cls_token.expand(
141 | B, -1, -1
142 | ) # stole cls_tokens impl from Phil Wang, thanks
143 | x = torch.cat((cls_tokens, x), dim=1)
144 |
145 | x = x + pos_embed
146 | x = self.pos_drop(x)
147 |
148 | for blk in self.blocks:
149 | x = blk(x)
150 |
151 | x = self.norm(x)
152 |
153 | return x
154 |
155 |
156 | activations = {}
157 |
158 |
159 | def get_activation(name):
160 | def hook(model, input, output):
161 | activations[name] = output
162 |
163 | return hook
164 |
165 |
166 | def get_readout_oper(vit_features, features, use_readout, start_index=1):
167 | if use_readout == "ignore":
168 | readout_oper = [Slice(start_index)] * len(features)
169 | elif use_readout == "add":
170 | readout_oper = [AddReadout(start_index)] * len(features)
171 | elif use_readout == "project":
172 | readout_oper = [
173 | ProjectReadout(vit_features, start_index) for out_feat in features
174 | ]
175 | else:
176 | assert (
177 | False
178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179 |
180 | return readout_oper
181 |
182 |
183 | def _make_vit_b16_backbone(
184 | model,
185 | features=[96, 192, 384, 768],
186 | size=[384, 384],
187 | hooks=[2, 5, 8, 11],
188 | vit_features=768,
189 | use_readout="ignore",
190 | start_index=1,
191 | ):
192 | pretrained = nn.Module()
193 |
194 | pretrained.model = model
195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199 |
200 | pretrained.activations = activations
201 |
202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203 |
204 | # 32, 48, 136, 384
205 | pretrained.act_postprocess1 = nn.Sequential(
206 | readout_oper[0],
207 | Transpose(1, 2),
208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209 | nn.Conv2d(
210 | in_channels=vit_features,
211 | out_channels=features[0],
212 | kernel_size=1,
213 | stride=1,
214 | padding=0,
215 | ),
216 | nn.ConvTranspose2d(
217 | in_channels=features[0],
218 | out_channels=features[0],
219 | kernel_size=4,
220 | stride=4,
221 | padding=0,
222 | bias=True,
223 | dilation=1,
224 | groups=1,
225 | ),
226 | )
227 |
228 | pretrained.act_postprocess2 = nn.Sequential(
229 | readout_oper[1],
230 | Transpose(1, 2),
231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232 | nn.Conv2d(
233 | in_channels=vit_features,
234 | out_channels=features[1],
235 | kernel_size=1,
236 | stride=1,
237 | padding=0,
238 | ),
239 | nn.ConvTranspose2d(
240 | in_channels=features[1],
241 | out_channels=features[1],
242 | kernel_size=2,
243 | stride=2,
244 | padding=0,
245 | bias=True,
246 | dilation=1,
247 | groups=1,
248 | ),
249 | )
250 |
251 | pretrained.act_postprocess3 = nn.Sequential(
252 | readout_oper[2],
253 | Transpose(1, 2),
254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255 | nn.Conv2d(
256 | in_channels=vit_features,
257 | out_channels=features[2],
258 | kernel_size=1,
259 | stride=1,
260 | padding=0,
261 | ),
262 | )
263 |
264 | pretrained.act_postprocess4 = nn.Sequential(
265 | readout_oper[3],
266 | Transpose(1, 2),
267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268 | nn.Conv2d(
269 | in_channels=vit_features,
270 | out_channels=features[3],
271 | kernel_size=1,
272 | stride=1,
273 | padding=0,
274 | ),
275 | nn.Conv2d(
276 | in_channels=features[3],
277 | out_channels=features[3],
278 | kernel_size=3,
279 | stride=2,
280 | padding=1,
281 | ),
282 | )
283 |
284 | pretrained.model.start_index = start_index
285 | pretrained.model.patch_size = [16, 16]
286 |
287 | # We inject this function into the VisionTransformer instances so that
288 | # we can use it with interpolated position embeddings without modifying the library source.
289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290 | pretrained.model._resize_pos_embed = types.MethodType(
291 | _resize_pos_embed, pretrained.model
292 | )
293 |
294 | return pretrained
295 |
296 |
297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299 |
300 | hooks = [5, 11, 17, 23] if hooks == None else hooks
301 | return _make_vit_b16_backbone(
302 | model,
303 | features=[256, 512, 1024, 1024],
304 | hooks=hooks,
305 | vit_features=1024,
306 | use_readout=use_readout,
307 | )
308 |
309 |
310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312 |
313 | hooks = [2, 5, 8, 11] if hooks == None else hooks
314 | return _make_vit_b16_backbone(
315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316 | )
317 |
318 |
319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321 |
322 | hooks = [2, 5, 8, 11] if hooks == None else hooks
323 | return _make_vit_b16_backbone(
324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325 | )
326 |
327 |
328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329 | model = timm.create_model(
330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331 | )
332 |
333 | hooks = [2, 5, 8, 11] if hooks == None else hooks
334 | return _make_vit_b16_backbone(
335 | model,
336 | features=[96, 192, 384, 768],
337 | hooks=hooks,
338 | use_readout=use_readout,
339 | start_index=2,
340 | )
341 |
342 |
343 | def _make_vit_b_rn50_backbone(
344 | model,
345 | features=[256, 512, 768, 768],
346 | size=[384, 384],
347 | hooks=[0, 1, 8, 11],
348 | vit_features=768,
349 | use_vit_only=False,
350 | use_readout="ignore",
351 | start_index=1,
352 | ):
353 | pretrained = nn.Module()
354 |
355 | pretrained.model = model
356 |
357 | if use_vit_only == True:
358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360 | else:
361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362 | get_activation("1")
363 | )
364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365 | get_activation("2")
366 | )
367 |
368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370 |
371 | pretrained.activations = activations
372 |
373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374 |
375 | if use_vit_only == True:
376 | pretrained.act_postprocess1 = nn.Sequential(
377 | readout_oper[0],
378 | Transpose(1, 2),
379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380 | nn.Conv2d(
381 | in_channels=vit_features,
382 | out_channels=features[0],
383 | kernel_size=1,
384 | stride=1,
385 | padding=0,
386 | ),
387 | nn.ConvTranspose2d(
388 | in_channels=features[0],
389 | out_channels=features[0],
390 | kernel_size=4,
391 | stride=4,
392 | padding=0,
393 | bias=True,
394 | dilation=1,
395 | groups=1,
396 | ),
397 | )
398 |
399 | pretrained.act_postprocess2 = nn.Sequential(
400 | readout_oper[1],
401 | Transpose(1, 2),
402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403 | nn.Conv2d(
404 | in_channels=vit_features,
405 | out_channels=features[1],
406 | kernel_size=1,
407 | stride=1,
408 | padding=0,
409 | ),
410 | nn.ConvTranspose2d(
411 | in_channels=features[1],
412 | out_channels=features[1],
413 | kernel_size=2,
414 | stride=2,
415 | padding=0,
416 | bias=True,
417 | dilation=1,
418 | groups=1,
419 | ),
420 | )
421 | else:
422 | pretrained.act_postprocess1 = nn.Sequential(
423 | nn.Identity(), nn.Identity(), nn.Identity()
424 | )
425 | pretrained.act_postprocess2 = nn.Sequential(
426 | nn.Identity(), nn.Identity(), nn.Identity()
427 | )
428 |
429 | pretrained.act_postprocess3 = nn.Sequential(
430 | readout_oper[2],
431 | Transpose(1, 2),
432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433 | nn.Conv2d(
434 | in_channels=vit_features,
435 | out_channels=features[2],
436 | kernel_size=1,
437 | stride=1,
438 | padding=0,
439 | ),
440 | )
441 |
442 | pretrained.act_postprocess4 = nn.Sequential(
443 | readout_oper[3],
444 | Transpose(1, 2),
445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446 | nn.Conv2d(
447 | in_channels=vit_features,
448 | out_channels=features[3],
449 | kernel_size=1,
450 | stride=1,
451 | padding=0,
452 | ),
453 | nn.Conv2d(
454 | in_channels=features[3],
455 | out_channels=features[3],
456 | kernel_size=3,
457 | stride=2,
458 | padding=1,
459 | ),
460 | )
461 |
462 | pretrained.model.start_index = start_index
463 | pretrained.model.patch_size = [16, 16]
464 |
465 | # We inject this function into the VisionTransformer instances so that
466 | # we can use it with interpolated position embeddings without modifying the library source.
467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468 |
469 | # We inject this function into the VisionTransformer instances so that
470 | # we can use it with interpolated position embeddings without modifying the library source.
471 | pretrained.model._resize_pos_embed = types.MethodType(
472 | _resize_pos_embed, pretrained.model
473 | )
474 |
475 | return pretrained
476 |
477 |
478 | def _make_pretrained_vitb_rn50_384(
479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480 | ):
481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482 |
483 | hooks = [0, 1, 8, 11] if hooks == None else hooks
484 | return _make_vit_b_rn50_backbone(
485 | model,
486 | features=[256, 512, 768, 768],
487 | size=[384, 384],
488 | hooks=hooks,
489 | use_vit_only=use_vit_only,
490 | use_readout=use_readout,
491 | )
492 |
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 |
11 | def log_txt_as_img(wh, xc, size=10):
12 | # wh a tuple of (width, height)
13 | # xc a list of captions to plot
14 | b = len(xc)
15 | txts = list()
16 | for bi in range(b):
17 | txt = Image.new("RGB", wh, color="white")
18 | draw = ImageDraw.Draw(txt)
19 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
20 | nc = int(40 * (wh[0] / 256))
21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22 |
23 | try:
24 | draw.text((0, 0), lines, fill="black", font=font)
25 | except UnicodeEncodeError:
26 | print("Cant encode string for logging. Skipping.")
27 |
28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29 | txts.append(txt)
30 | txts = np.stack(txts)
31 | txts = torch.tensor(txts)
32 | return txts
33 |
34 |
35 | def ismap(x):
36 | if not isinstance(x, torch.Tensor):
37 | return False
38 | return (len(x.shape) == 4) and (x.shape[1] > 3)
39 |
40 |
41 | def isimage(x):
42 | if not isinstance(x,torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45 |
46 |
47 | def exists(x):
48 | return x is not None
49 |
50 |
51 | def default(val, d):
52 | if exists(val):
53 | return val
54 | return d() if isfunction(d) else d
55 |
56 |
57 | def mean_flat(tensor):
58 | """
59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60 | Take the mean over all non-batch dimensions.
61 | """
62 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
63 |
64 |
65 | def count_params(model, verbose=False):
66 | total_params = sum(p.numel() for p in model.parameters())
67 | if verbose:
68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69 | return total_params
70 |
71 |
72 | def instantiate_from_config(config):
73 | if not "target" in config:
74 | if config == '__is_first_stage__':
75 | return None
76 | elif config == "__is_unconditional__":
77 | return None
78 | raise KeyError("Expected key `target` to instantiate.")
79 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
80 |
81 |
82 | def get_obj_from_str(string, reload=False):
83 | module, cls = string.rsplit(".", 1)
84 | if reload:
85 | module_imp = importlib.import_module(module)
86 | importlib.reload(module_imp)
87 | return getattr(importlib.import_module(module, package=None), cls)
88 |
89 |
90 | class AdamWwithEMAandWings(optim.Optimizer):
91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94 | ema_power=1., param_names=()):
95 | """AdamW that saves EMA versions of the parameters."""
96 | if not 0.0 <= lr:
97 | raise ValueError("Invalid learning rate: {}".format(lr))
98 | if not 0.0 <= eps:
99 | raise ValueError("Invalid epsilon value: {}".format(eps))
100 | if not 0.0 <= betas[0] < 1.0:
101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102 | if not 0.0 <= betas[1] < 1.0:
103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104 | if not 0.0 <= weight_decay:
105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106 | if not 0.0 <= ema_decay <= 1.0:
107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108 | defaults = dict(lr=lr, betas=betas, eps=eps,
109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110 | ema_power=ema_power, param_names=param_names)
111 | super().__init__(params, defaults)
112 |
113 | def __setstate__(self, state):
114 | super().__setstate__(state)
115 | for group in self.param_groups:
116 | group.setdefault('amsgrad', False)
117 |
118 | @torch.no_grad()
119 | def step(self, closure=None):
120 | """Performs a single optimization step.
121 | Args:
122 | closure (callable, optional): A closure that reevaluates the model
123 | and returns the loss.
124 | """
125 | loss = None
126 | if closure is not None:
127 | with torch.enable_grad():
128 | loss = closure()
129 |
130 | for group in self.param_groups:
131 | params_with_grad = []
132 | grads = []
133 | exp_avgs = []
134 | exp_avg_sqs = []
135 | ema_params_with_grad = []
136 | state_sums = []
137 | max_exp_avg_sqs = []
138 | state_steps = []
139 | amsgrad = group['amsgrad']
140 | beta1, beta2 = group['betas']
141 | ema_decay = group['ema_decay']
142 | ema_power = group['ema_power']
143 |
144 | for p in group['params']:
145 | if p.grad is None:
146 | continue
147 | params_with_grad.append(p)
148 | if p.grad.is_sparse:
149 | raise RuntimeError('AdamW does not support sparse gradients')
150 | grads.append(p.grad)
151 |
152 | state = self.state[p]
153 |
154 | # State initialization
155 | if len(state) == 0:
156 | state['step'] = 0
157 | # Exponential moving average of gradient values
158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159 | # Exponential moving average of squared gradient values
160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161 | if amsgrad:
162 | # Maintains max of all exp. moving avg. of sq. grad. values
163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164 | # Exponential moving average of parameter values
165 | state['param_exp_avg'] = p.detach().float().clone()
166 |
167 | exp_avgs.append(state['exp_avg'])
168 | exp_avg_sqs.append(state['exp_avg_sq'])
169 | ema_params_with_grad.append(state['param_exp_avg'])
170 |
171 | if amsgrad:
172 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173 |
174 | # update the steps for each param group update
175 | state['step'] += 1
176 | # record the step after step update
177 | state_steps.append(state['step'])
178 |
179 | optim._functional.adamw(params_with_grad,
180 | grads,
181 | exp_avgs,
182 | exp_avg_sqs,
183 | max_exp_avg_sqs,
184 | state_steps,
185 | amsgrad=amsgrad,
186 | beta1=beta1,
187 | beta2=beta2,
188 | lr=group['lr'],
189 | weight_decay=group['weight_decay'],
190 | eps=group['eps'],
191 | maximize=False)
192 |
193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196 |
197 | return loss
--------------------------------------------------------------------------------
/ptp_scripts/ptp_scripts.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional, Union, Tuple, List, Callable, Dict
16 | import torch
17 | from diffusers import StableDiffusionPipeline
18 | import torch.nn.functional as nnf
19 | import numpy as np
20 | import abc
21 | import sys
22 | sys.path.append('..')
23 | import matplotlib.pyplot as plt
24 | from PIL import Image
25 |
26 | LOW_RESOURCE = False
27 |
28 | class AttentionControl(abc.ABC):
29 |
30 | def step_callback(self, x_t):
31 | return x_t
32 |
33 | def between_steps(self):
34 | return
35 |
36 | @property
37 | def num_uncond_att_layers(self):
38 | return self.num_att_layers if LOW_RESOURCE else 0
39 |
40 | @abc.abstractmethod
41 | def forward (self, attn, is_cross: bool, place_in_unet: str):
42 | raise NotImplementedError
43 |
44 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
45 | if self.cur_att_layer >= self.num_uncond_att_layers:
46 | if LOW_RESOURCE:
47 | attn = self.forward(attn, is_cross, place_in_unet)
48 | else:
49 | h = attn.shape[0]
50 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
51 |
52 | self.cur_att_layer += 1
53 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
54 | self.cur_att_layer = 0
55 | self.cur_step += 1
56 | self.between_steps()
57 | return attn
58 |
59 | def reset(self):
60 | self.cur_step = 0
61 | self.cur_att_layer = 0
62 |
63 | def __init__(self):
64 | self.cur_step = 0
65 | self.num_att_layers = -1
66 | self.cur_att_layer = 0
67 |
68 |
69 | class AttentionStore(AttentionControl):
70 |
71 | @staticmethod
72 | def get_empty_store():
73 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
74 | "down_self": [], "mid_self": [], "up_self": []}
75 |
76 | def forward(self, attn, is_cross: bool, place_in_unet: str):
77 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
78 | # if attn.shape[1] <= 16 ** 2: # and attn.shape[1] > 16 ** 2: # avoid memory overhead
79 | self.step_store[key].append(attn)
80 | return attn
81 |
82 | def between_steps(self):
83 | for key in self.step_store:
84 | self.attention_store[key] = self.step_store[key]
85 |
86 | self.step_store = self.get_empty_store()
87 |
88 | def get_average_attention(self):
89 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
90 | return average_attention
91 |
92 |
93 | def reset(self):
94 | super(AttentionStore, self).reset()
95 | self.step_store = self.get_empty_store()
96 | self.attention_store = self.get_empty_store()
97 |
98 | def __init__(self):
99 | super(AttentionStore, self).__init__()
100 | self.step_store = self.get_empty_store()
101 | self.attention_store = self.get_empty_store()
--------------------------------------------------------------------------------
/ptp_scripts/ptp_utils_ori.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | from PIL import Image, ImageDraw, ImageFont
18 | import cv2
19 | from typing import Optional, Union, Tuple, List, Callable, Dict
20 | from IPython.display import display
21 | from tqdm.notebook import tqdm
22 | import matplotlib.pyplot as plt
23 | from torch import nn, einsum
24 | from einops import rearrange, repeat
25 | from inspect import isfunction
26 |
27 |
28 | def exists(val):
29 | return val is not None
30 |
31 |
32 | def default(val, d):
33 | if exists(val):
34 | return val
35 | return d() if isfunction(d) else d
36 |
37 |
38 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
39 | h, w, c = image.shape
40 | offset = int(h * .2)
41 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
42 | font = cv2.FONT_HERSHEY_SIMPLEX
43 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
44 | img[:h] = image
45 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
46 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
47 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
48 | return img
49 |
50 |
51 | def view_images(images, num_rows=1, offset_ratio=0.02, name='image', timestamp=0, layernum=None):
52 | if type(images) is list:
53 | num_empty = len(images) % num_rows
54 | elif images.ndim == 4:
55 | num_empty = images.shape[0] % num_rows
56 | else:
57 | images = [images]
58 | num_empty = 0
59 |
60 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
61 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
62 | num_items = len(images)
63 |
64 | h, w, c = images[0].shape
65 | offset = int(h * offset_ratio)
66 | num_cols = num_items // num_rows
67 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
68 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
69 | for i in range(num_rows):
70 | for j in range(num_cols):
71 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
72 | i * num_cols + j]
73 |
74 | pil_img = Image.fromarray(image_)
75 | display(pil_img)
76 |
77 |
78 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
79 | if low_resource:
80 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
81 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
82 | else:
83 | latents_input = torch.cat([latents] * 2)
84 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
85 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
86 |
87 | # classifier-free guidance during inference
88 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
89 | # noise_pred = guidance_scale * noise_prediction_text
90 |
91 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
92 | latents = controller.step_callback(latents)
93 | return latents
94 |
95 |
96 | def latent2image(vae, latents):
97 | latents = 1 / 0.18215 * latents
98 | image = vae.decode(latents)['sample']
99 | image = (image / 2 + 0.5).clamp(0, 1)
100 | image = image.cpu().permute(0, 2, 3, 1).numpy()
101 | image = (image * 255).astype(np.uint8)
102 | return image
103 |
104 |
105 | def init_latent(latent, model, height, width, generator, batch_size):
106 | if latent is None:
107 | latent = torch.randn(
108 | (1, model.unet.in_channels, height // 8, width // 8),
109 | generator=generator,
110 | )
111 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
112 | return latent, latents
113 |
114 |
115 | @torch.no_grad()
116 | def text2image_ldm(
117 | model,
118 | prompt: List[str],
119 | controller,
120 | num_inference_steps: int = 50,
121 | guidance_scale: Optional[float] = 7.,
122 | generator: Optional[torch.Generator] = None,
123 | latent: Optional[torch.FloatTensor] = None,
124 | ):
125 | register_attention_control(model, controller)
126 | height = width = 256
127 | batch_size = len(prompt)
128 |
129 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
130 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
131 |
132 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
133 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
134 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
135 | context = torch.cat([uncond_embeddings, text_embeddings])
136 |
137 | model.scheduler.set_timesteps(num_inference_steps)
138 | for t in tqdm(model.scheduler.timesteps):
139 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
140 |
141 | image = latent2image(model.vqvae, latents)
142 |
143 | return image, latent
144 |
145 |
146 | @torch.no_grad()
147 | def text2image_ldm_stable(
148 | model,
149 | prompt: List[str],
150 | controller,
151 | num_inference_steps: int = 50,
152 | guidance_scale: float = 7.5,
153 | generator: Optional[torch.Generator] = None,
154 | latent: Optional[torch.FloatTensor] = None,
155 | low_resource: bool = False,
156 | ):
157 | register_attention_control(model, controller)
158 | height = width = 512
159 | batch_size = len(prompt)
160 |
161 | text_input = model.tokenizer(
162 | prompt,
163 | padding="max_length",
164 | max_length=model.tokenizer.model_max_length,
165 | truncation=True,
166 | return_tensors="pt",
167 | )
168 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
169 | max_length = text_input.input_ids.shape[-1]
170 | uncond_input = model.tokenizer(
171 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
172 | )
173 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
174 |
175 | context = [uncond_embeddings, text_embeddings]
176 | if not low_resource:
177 | context = torch.cat(context)
178 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
179 |
180 | model.scheduler.set_timesteps(num_inference_steps)
181 | for t in tqdm(model.scheduler.timesteps):
182 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
183 |
184 | image = latent2image(model.vae, latents)
185 |
186 | return image, latent
187 |
188 |
189 | def register_attention_control(model, controller, inject_bg=False, pseudo_cross=False):
190 | def ca_forward(self, place_in_unet):
191 | def forward(x, context=None, mask=None, encode=False, controller_for_inject=None, inject=False, layernum=None, main_height=None, main_width=None):
192 |
193 | is_cross = context is not None
194 | h = self.heads
195 |
196 | q = self.to_q(x)
197 | context = default(context, x)
198 | k = self.to_k(context)
199 | v = self.to_v(context)
200 |
201 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
202 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
203 |
204 | if exists(mask):
205 | mask = rearrange(mask, 'b ... -> b (...)')
206 | max_neg_value = -torch.finfo(sim.dtype).max
207 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
208 | sim.masked_fill_(~mask, max_neg_value)
209 |
210 | sim_2 = sim.clone()
211 | if encode == False:
212 | sim = controller(sim, is_cross, place_in_unet)
213 |
214 | if inject or inject_bg and is_cross == False:
215 |
216 | if layernum >= -1 :
217 | if place_in_unet == 'down':
218 |
219 | if inject_bg:
220 | sim[h:] = sim[h:]
221 |
222 | if inject:
223 | sim[h:] = sim[h:]
224 |
225 |
226 | elif place_in_unet == 'up':
227 |
228 | if inject_bg:
229 | sim[h:] = controller_for_inject[0].attention_store['up_self'][layernum] * 0.5 \
230 | + 0.5* sim[h:]
231 |
232 | if inject:
233 | sim[h:] = controller_for_inject[1].attention_store['up_self'][layernum] * 0.5 \
234 | + 0.5 * sim[h:]
235 |
236 |
237 | elif place_in_unet == 'mid':
238 |
239 | if inject_bg:
240 |
241 | sim[h:] = sim[h:]
242 |
243 | if inject:
244 |
245 | sim[h:] = sim[h:]
246 |
247 | sim = sim.softmax(dim=-1)
248 | sim_2 = sim_2.softmax(dim=-1)
249 |
250 |
251 | out = einsum('b i j, b j d -> b i d', sim, v)
252 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
253 |
254 | del sim, v, q, k, context
255 |
256 | return self.to_out(out)
257 |
258 | return forward
259 |
260 | def register_recr(net_, count, place_in_unet):
261 | if 'CrossAttention' in net_.__class__.__name__:
262 | if net_.to_k.in_features != 1024:
263 | net_.forward = ca_forward(net_, place_in_unet)
264 | return count + 1
265 | else:
266 | return count
267 | elif hasattr(net_, 'children'):
268 | for net__ in net_.children():
269 | count = register_recr(net__, count, place_in_unet)
270 | return count
271 |
272 | cross_att_count = 0
273 | # sub_nets = model.unet.named_children()
274 | sub_nets = model.model.diffusion_model.named_children()
275 |
276 | for net in sub_nets:
277 | if "input" in net[0]:
278 | cross_att_count += register_recr(net[1], 0, "down")
279 | elif "output" in net[0]:
280 | cross_att_count += register_recr(net[1], 0, "up")
281 | elif "middle" in net[0]:
282 | cross_att_count += register_recr(net[1], 0, "mid")
283 | controller.num_att_layers = cross_att_count
284 |
285 |
286 |
287 |
288 | def register_recr(net_, count, place_in_unet):
289 | if 'CrossAttention' in net_.__class__.__name__:
290 | if net_.to_k.in_features != 1024:
291 | net_.forward = ca_forward(net_, place_in_unet)
292 | return count + 1
293 | else:
294 | return count
295 | elif hasattr(net_, 'children'):
296 | for net__ in net_.children():
297 | count = register_recr(net__, count, place_in_unet)
298 | return count
299 |
300 | cross_att_count = 0
301 | # sub_nets = model.unet.named_children()
302 | sub_nets = model.model.diffusion_model.named_children()
303 |
304 | for net in sub_nets:
305 | if "input" in net[0]:
306 | cross_att_count += register_recr(net[1], 0, "down")
307 | elif "output" in net[0]:
308 | cross_att_count += register_recr(net[1], 0, "up")
309 | elif "middle" in net[0]:
310 | cross_att_count += register_recr(net[1], 0, "mid")
311 | controller.num_att_layers = cross_att_count
312 |
313 |
314 | def get_word_inds(text: str, word_place: int, tokenizer):
315 | split_text = text.split(" ")
316 | if type(word_place) is str:
317 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
318 | elif type(word_place) is int:
319 | word_place = [word_place]
320 | out = []
321 | if len(word_place) > 0:
322 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
323 | cur_len, ptr = 0, 0
324 |
325 | for i in range(len(words_encode)):
326 | cur_len += len(words_encode[i])
327 | if ptr in word_place:
328 | out.append(i + 1)
329 | if cur_len >= len(split_text[ptr]):
330 | ptr += 1
331 | cur_len = 0
332 | return np.array(out)
333 |
334 |
335 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
336 | if type(bounds) is float:
337 | bounds = 0, bounds
338 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
339 | if word_inds is None:
340 | word_inds = torch.arange(alpha.shape[2])
341 | alpha[: start, prompt_ind, word_inds] = 0
342 | alpha[start: end, prompt_ind, word_inds] = 1
343 | alpha[end:, prompt_ind, word_inds] = 0
344 | return alpha
345 |
346 |
347 | def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
348 | tokenizer, max_num_words=77):
349 | if type(cross_replace_steps) is not dict:
350 | cross_replace_steps = {"default_": cross_replace_steps}
351 | if "default_" not in cross_replace_steps:
352 | cross_replace_steps["default_"] = (0., 1.)
353 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
354 |
355 | for i in range(len(prompts) - 1):
356 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i)
357 |
358 | for key, item in cross_replace_steps.items():
359 | if key != "default_":
360 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
361 | for i, ind in enumerate(inds):
362 | if len(ind) > 0:
363 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
364 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
365 | return alpha_time_words
366 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='stable-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------