├── GPH Benchmark demo data ├── background_data │ ├── 229.png │ ├── 330.png │ ├── 45.png │ └── 68.png ├── composite_data │ ├── 229.png │ ├── 330.png │ ├── 45.png │ └── 68.png ├── foreground_data │ ├── 229.png │ ├── 330.png │ ├── 45.png │ └── 68.png ├── harmonized_data │ ├── 229.png │ ├── 330.png │ ├── 45.png │ └── 68.png └── mask_data │ ├── 229.png │ ├── 330.png │ ├── 45.png │ └── 68.png ├── LICENSE ├── README.md ├── app_util.py ├── attention_control ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── diffuser_utils.cpython-310.pyc │ ├── masactrl.cpython-310.pyc │ ├── masactrl.cpython-38.pyc │ ├── masactrl_utils.cpython-310.pyc │ ├── masactrl_utils.cpython-38.pyc │ └── share_attention.cpython-38.pyc ├── masactrl_utils.py └── share_attention.py ├── compute_metrics.py ├── configs └── stable-diffusion │ ├── v2-inference-v.yaml │ ├── v2-inference.yaml │ ├── v2-inpainting-inference.yaml │ ├── v2-midas-inference.yaml │ └── x4-upscaling.yaml ├── demo_outputs ├── harmonized │ └── kangaroo_starry_15S_12L_total.jpg └── reconstruct │ └── composite_recon.jpg ├── github_source ├── fig1.png ├── fig14.png ├── fig15.png ├── fig17.png ├── fig18.png ├── fig1_new.png.png ├── fig4.png └── tf-gph_demo.gif ├── gradio ├── background │ ├── bg03.png │ ├── bg36.png │ ├── bg52.png │ ├── bg58.png │ └── bg62.png ├── foreground │ ├── fg10_63d22a7f1f5b66e8e5ac28f7.jpg │ ├── fg50_63d22c871f5b66e8e5ac95e1.jpg │ ├── fg88_63d9d508b82cf5cb1db01976.jpg │ ├── fg90_63d9d4a0b82cf5cb1db00800.jpg │ └── fg92_63d9d6c9b82cf5cb1db05fda.jpg └── seg_foreground │ ├── fg10_mask.jpg │ ├── fg50_mask.png │ ├── fg88_mask.png │ ├── fg90_mask.png │ └── fg92_mask.png ├── inputs └── demo_input │ ├── kangaroo.jpg │ ├── kangaroo_starry.jpg │ └── starry_night.jpg ├── ldm ├── __pycache__ │ ├── util.cpython-310.pyc │ └── util.cpython-38.pyc ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-38.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ddim.cpython-310.pyc │ │ ├── ddim.cpython-38.pyc │ │ ├── ddpm.cpython-310.pyc │ │ ├── ddpm.cpython-38.pyc │ │ ├── plms.cpython-38.pyc │ │ └── sampling_util.cpython-38.pyc │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dpm_solver.cpython-38.pyc │ │ │ ├── dpm_solver_pytorch.cpython-38.pyc │ │ │ ├── sampler.cpython-310.pyc │ │ │ └── sampler.cpython-38.pyc │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-38.pyc │ │ └── ema.cpython-38.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ ├── util.cpython-310.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── openaimodel_new.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── modules.cpython-38.pyc │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── scripts ├── __pycache__ │ ├── dpm_solver_pytorch.cpython-310.pyc │ ├── dpm_solver_pytorch.cpython-38.pyc │ └── txt2img.cpython-38.pyc ├── dpm_solver_pytorch.py ├── dpm_solver_pytorch_new.py └── header.html ├── setup.py ├── tfgph_app.py ├── tfgph_demo.ipynb ├── tfgph_env.yaml └── tfgph_main.py /GPH Benchmark demo data/background_data/229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/background_data/229.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/background_data/330.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/background_data/330.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/background_data/45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/background_data/45.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/background_data/68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/background_data/68.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/composite_data/229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/composite_data/229.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/composite_data/330.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/composite_data/330.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/composite_data/45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/composite_data/45.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/composite_data/68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/composite_data/68.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/foreground_data/229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/foreground_data/229.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/foreground_data/330.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/foreground_data/330.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/foreground_data/45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/foreground_data/45.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/foreground_data/68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/foreground_data/68.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/harmonized_data/229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/harmonized_data/229.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/harmonized_data/330.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/harmonized_data/330.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/harmonized_data/45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/harmonized_data/45.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/harmonized_data/68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/harmonized_data/68.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/mask_data/229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/mask_data/229.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/mask_data/330.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/mask_data/330.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/mask_data/45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/mask_data/45.png -------------------------------------------------------------------------------- /GPH Benchmark demo data/mask_data/68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/GPH Benchmark demo data/mask_data/68.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 蕭登方(Hsiao, Teng-Fang) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2404.12900-b31b1b.svg)](https://arxiv.org/abs/2404.12900) 2 | # TF-GPH (AAAI'25) 3 | **Training-and-Prompt-Free General Painterly Harmonization via Zero-Shot Disentenglement on Style and Content References** 4 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/fig1_new.png.png) 5 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/tf-gph_demo.gif) 6 | ## Setup 7 | Our codebase is built on [Stable-Diffusion](https://github.com/Stability-AI/stablediffusion) 8 | and has shared dependencies and model architecture. A VRAM of 23 GB is recommended (RTX 3090 for example), though this may vary depending on the input samples (minimum 20 GB). 9 | 10 | This github repo is based on [TF-ICON](https://github.com/Shilin-LU/TF-ICON) and [MasaCtrl](https://github.com/TencentARC/MasaCtrl/tree/main) 11 | ### Creating a Conda Environment 12 | 13 | ``` 14 | git clone https://github.com/BlueDyee/TF-GPH.git 15 | cd TF-GPH 16 | conda env create -f tfgph_env.yaml 17 | conda activate tfgph 18 | ``` 19 | 20 | ### Downloading Stable-Diffusion Weights 21 | 22 | Download the StableDiffusion weights from the [Stability AI at Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.ckpt) 23 | (download the `sd-v2-1_512-ema-pruned.ckpt` file, This will occupy around 5GB storage) 24 | For example 25 | 26 | ``` 27 | wget -O v2-1_512-ema-pruned.ckpt https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt?download=true 28 | ``` 29 | ## Run 30 | We provide three methods to run our repo **web app (gradio)/ipynb/py** 31 | ### app 32 | Running the TF-GPH webui 33 | ``` 34 | python tfgph_app.py 35 | ``` 36 | 37 | 38 | ### ipynb 39 | ``` 40 | Runall 41 | ``` 42 | 43 | ### py 44 | Using default parameters 45 | ``` 46 | python tfgph_main.py 47 | ``` 48 | Customize parameters 49 | 50 | (Due to the conflict between mathematical correctness and code conciseness, the effect of share_step in the code is different from that in the paper.) 51 | 52 | (In the code, share_step views the inverted latent z_T as 0th step, so share_step 15 means to normally denoise for 15 steps, then denoise with shared attention in the remaining steps.) 53 | 54 | ``` 55 | python tfgph_main.py --ref1 "./inputs/demo_input/kangaroo.jpg" \ 56 | --ref2 "./inputs/demo_input/starry_night.jpg" \ 57 | --comp "./inputs/demo_input/kangaroo_starry.jpg" \ 58 | --share_step 15 \ 59 | --share_layer 12 \ 60 | ``` 61 | ## Evaluation of GPH Benchmark 62 | Your data directory should be looked like: 63 | ``` 64 | GPH Benchmark demo data 65 | ├── background_data 66 | │ ├── x.png 67 | │ └── xx.png 68 | ├── composite_data 69 | │ ├── x.png 70 | │ └── xx.png 71 | ├── foreground_data 72 | │ ├── x.png 73 | │ └── xx.png 74 | ├── harmonized_data **(Your Generation result)** 75 | │ ├── x.png 76 | │ └── xx.png 77 | ├── mask_data 78 | │ ├── x.png 79 | │ └── xx.png 80 | ``` 81 | ``` 82 | python comput_metrics.py -r "GPH Benchmark demo data" 83 | ``` 84 | 85 | ### More Results 86 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/fig4.png) 87 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/fig17.png) 88 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/fig18.png) 89 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/fig14.png) 90 | ![image](https://github.com/BlueDyee/TF-GPH/blob/main/github_source/fig15.png) 91 | 92 | 93 | -------------------------------------------------------------------------------- /app_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import PIL 4 | import torch 5 | import cv2 6 | import time 7 | import shutil 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from omegaconf import OmegaConf 12 | from PIL import Image 13 | from itertools import islice 14 | from einops import rearrange, repeat 15 | from torch import autocast 16 | from pytorch_lightning import seed_everything 17 | import gradio as gr 18 | 19 | from ldm.util import instantiate_from_config, load_model_from_config, load_img, load_model_and_get_prompt_embedding 20 | from ldm.models.diffusion.ddim import DDIMSampler 21 | from ldm.models.diffusion.dpm_solver import DPMSolverSampler 22 | from attention_control.masactrl_utils import regiter_attention_editor_ldm 23 | from attention_control.share_attention import ShareSelfAttentionControl 24 | from torchvision.utils import save_image 25 | 26 | def pil_load_img(image, SCALE, pad=False, seg=False, target_size=None): 27 | w, h = image.size 28 | w_,h_=w,h 29 | print(f"loaded input image of size ({w}, {h})") 30 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 31 | w = h = 512 32 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 33 | 34 | image = np.array(image).astype(np.float32) / 255.0 35 | image = image[None].transpose(0, 3, 1, 2) 36 | image = torch.from_numpy(image) 37 | print(f"resize input image of size ({w_}, {h_}) to {w}, {h}") 38 | 39 | return 2. * image - 1., w, h 40 | 41 | def tfgph_load(opt, device): 42 | # Load Model 43 | config = OmegaConf.load(opt["config"]) 44 | model = load_model_from_config(config, opt["ckpt"]) 45 | 46 | model = model.to(device) 47 | sampler = DPMSolverSampler(model) 48 | 49 | print("##----Model LOAD Success---##") 50 | return model,sampler 51 | 52 | def tfgph_inverse(ref_img,opt,model,sampler,device,ref1_path=None,ref2_path=None,comp_path=None): 53 | # Read Image 54 | ref_image, target_width, target_height = pil_load_img(ref_img, 1) 55 | ref_image = repeat(ref_image.to(device), '1 ... -> b ...', b=1) 56 | print("##----Image LOAD Success---##") 57 | 58 | # Reconstruct 59 | uncond_scale=2.5 60 | precision_scope = autocast 61 | with precision_scope("cuda"): 62 | c, uc, inv_emb = load_model_and_get_prompt_embedding(model, uncond_scale, device, opt["prompt"], inv=True) 63 | 64 | T1 = time.time() 65 | ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref_image)) 66 | shape = ref_latent.shape[1:] 67 | z_ref, _ = sampler.sample(steps=opt["total_steps"], 68 | inv_emb=inv_emb, 69 | unconditional_conditioning=uc, 70 | conditioning=c, 71 | batch_size=1, 72 | shape=shape, 73 | verbose=False, 74 | unconditional_guidance_scale=uncond_scale, 75 | eta=0, 76 | order=opt["order"], 77 | x_T=ref_latent, 78 | width=512, 79 | height=512, 80 | DPMencode=True, 81 | ) 82 | return z_ref 83 | def tfgph_harmonize(z_ref1, z_ref2, z_comp, opt,model,sampler,device,ref1_path=None,ref2_path=None,comp_path=None): 84 | precision_scope = autocast 85 | uncond_scale=2.5 86 | c, uc, inv_emb = load_model_and_get_prompt_embedding(model, uncond_scale, device, opt["prompt"], inv=True) 87 | sim_scales = torch.tensor([opt["scale_alpha"],opt["scale_beta"]]).to(device) 88 | shape=z_ref1.shape[1:] 89 | with precision_scope("cuda"): 90 | 91 | # hijack the attention module (sclaed share) 92 | editor = ShareSelfAttentionControl(opt["share_step"], opt["share_layer"],scales=sim_scales,total_steps=opt["total_steps"]) 93 | regiter_attention_editor_ldm(model, editor) 94 | latents_harmonized = sampler.sample(steps=opt["total_steps"], 95 | inv_emb=torch.cat([inv_emb,inv_emb,inv_emb]), 96 | conditioning=torch.cat([c,c,c]), 97 | shape=shape, 98 | verbose=False, 99 | unconditional_guidance_scale=uncond_scale, 100 | unconditional_conditioning=torch.cat([uc,uc,uc]), 101 | eta=0, 102 | order=opt["order"], 103 | x_T=torch.cat([z_ref1,z_ref2,z_comp]), 104 | width=512, 105 | height=512, 106 | ) 107 | x_harmonized = model.decode_first_stage(latents_harmonized) 108 | x_harmonized = torch.clamp((x_harmonized + 1.0) / 2.0, min=0.0, max=1.0) 109 | 110 | 111 | x_sample=x_harmonized[-1] 112 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 113 | img = Image.fromarray(x_sample.astype(np.uint8)) 114 | 115 | return img 116 | -------------------------------------------------------------------------------- /attention_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__init__.py -------------------------------------------------------------------------------- /attention_control/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/diffuser_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/diffuser_utils.cpython-310.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/masactrl.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/masactrl.cpython-310.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/masactrl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/masactrl.cpython-38.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/masactrl_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/masactrl_utils.cpython-310.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/masactrl_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/masactrl_utils.cpython-38.pyc -------------------------------------------------------------------------------- /attention_control/__pycache__/share_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/attention_control/__pycache__/share_attention.cpython-38.pyc -------------------------------------------------------------------------------- /attention_control/masactrl_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from typing import Optional, Union, Tuple, List, Callable, Dict 9 | 10 | from torchvision.utils import save_image 11 | from einops import rearrange, repeat 12 | 13 | 14 | class AttentionBase: 15 | def __init__(self): 16 | self.cur_step = 0 17 | self.num_att_layers = -1 18 | self.cur_att_layer = 0 19 | print("Create hijack attention") 20 | def after_step(self): 21 | pass 22 | 23 | def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): 24 | out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) 25 | self.cur_att_layer += 1 26 | if self.cur_att_layer == self.num_att_layers: 27 | self.cur_att_layer = 0 28 | self.cur_step += 1 29 | # after step 30 | self.after_step() 31 | return out 32 | 33 | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): 34 | out = torch.einsum('b i j, b j d -> b i d', attn, v) 35 | # print("Normal out1:",out.shape) 36 | # print("Normal atten:",attn.shape) 37 | out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) 38 | # print("q",q.shape) 39 | # print("k",k.shape) 40 | # print("v",v.shape) 41 | # print("sim",sim.shape) 42 | # print("attn",attn.shape) 43 | # print("num_heads",num_heads) 44 | # print("Normal out2:",out.shape) 45 | return out 46 | 47 | def reset(self): 48 | self.cur_step = 0 49 | self.cur_att_layer = 0 50 | 51 | 52 | class AttentionStore(AttentionBase): 53 | def __init__(self, res=[32], min_step=0, max_step=1000): 54 | super().__init__() 55 | self.res = res 56 | self.min_step = min_step 57 | self.max_step = max_step 58 | self.valid_steps = 0 59 | 60 | self.self_attns = [] # store the all attns 61 | self.cross_attns = [] 62 | 63 | self.self_attns_step = [] # store the attns in each step 64 | self.cross_attns_step = [] 65 | 66 | def after_step(self): 67 | if self.cur_step > self.min_step and self.cur_step < self.max_step: 68 | self.valid_steps += 1 69 | if len(self.self_attns) == 0: 70 | self.self_attns = self.self_attns_step 71 | self.cross_attns = self.cross_attns_step 72 | else: 73 | for i in range(len(self.self_attns)): 74 | self.self_attns[i] += self.self_attns_step[i] 75 | self.cross_attns[i] += self.cross_attns_step[i] 76 | self.self_attns_step.clear() 77 | self.cross_attns_step.clear() 78 | 79 | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): 80 | if attn.shape[1] <= 64 ** 2: # avoid OOM 81 | if is_cross: 82 | self.cross_attns_step.append(attn) 83 | else: 84 | self.self_attns_step.append(attn) 85 | return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) 86 | 87 | 88 | def regiter_attention_editor_diffusers(model, editor: AttentionBase): 89 | """ 90 | Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] 91 | """ 92 | def ca_forward(self, place_in_unet,**kwarg): 93 | def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None,**kwarg): 94 | """ 95 | The attention is similar to the original implementation of LDM CrossAttention class 96 | except adding some modifications on the attention 97 | """ 98 | if encoder_hidden_states is not None: 99 | context = encoder_hidden_states 100 | if attention_mask is not None: 101 | mask = attention_mask 102 | 103 | to_out = self.to_out 104 | if isinstance(to_out, nn.modules.container.ModuleList): 105 | to_out = self.to_out[0] 106 | else: 107 | to_out = self.to_out 108 | 109 | h = self.heads 110 | q = self.to_q(x) 111 | is_cross = context is not None 112 | context = context if is_cross else x 113 | k = self.to_k(context) 114 | v = self.to_v(context) 115 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 116 | 117 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale 118 | 119 | if mask is not None: 120 | mask = rearrange(mask, 'b ... -> b (...)') 121 | max_neg_value = -torch.finfo(sim.dtype).max 122 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 123 | mask = mask[:, None, :].repeat(h, 1, 1) 124 | sim.masked_fill_(~mask, max_neg_value) 125 | 126 | attn = sim.softmax(dim=-1) 127 | # the only difference 128 | out = editor( 129 | q, k, v, sim, attn, is_cross, place_in_unet, 130 | self.heads, scale=self.scale) 131 | 132 | return to_out(out) 133 | 134 | return forward 135 | # !!! 136 | def register_editor(net, count, place_in_unet): 137 | for name, subnet in net.named_children(): 138 | if net.__class__.__name__ == 'Attention': # spatial Transformer layer 139 | net.forward = ca_forward(net, place_in_unet) 140 | return count + 1 141 | elif hasattr(net, 'children'): 142 | count = register_editor(subnet, count, place_in_unet) 143 | return count 144 | 145 | cross_att_count = 0 146 | for net_name, net in model.unet.named_children(): 147 | if "down" in net_name: 148 | cross_att_count += register_editor(net, 0, "down") 149 | elif "mid" in net_name: 150 | cross_att_count += register_editor(net, 0, "mid") 151 | elif "up" in net_name: 152 | cross_att_count += register_editor(net, 0, "up") 153 | editor.num_att_layers = cross_att_count 154 | 155 | 156 | def regiter_attention_editor_ldm(model, editor: AttentionBase): 157 | """ 158 | Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt] 159 | """ 160 | def ca_forward(self, place_in_unet,**kwarg): 161 | def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None,**kwarg): 162 | """ 163 | The attention is similar to the original implementation of LDM CrossAttention class 164 | except adding some modifications on the attention 165 | """ 166 | if encoder_hidden_states is not None: 167 | context = encoder_hidden_states 168 | if attention_mask is not None: 169 | mask = attention_mask 170 | 171 | to_out = self.to_out 172 | if isinstance(to_out, nn.modules.container.ModuleList): 173 | to_out = self.to_out[0] 174 | else: 175 | to_out = self.to_out 176 | 177 | h = self.heads 178 | q = self.to_q(x) 179 | is_cross = context is not None 180 | context = context if is_cross else x 181 | k = self.to_k(context) 182 | v = self.to_v(context) 183 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 184 | 185 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale 186 | 187 | if mask is not None: 188 | mask = rearrange(mask, 'b ... -> b (...)') 189 | max_neg_value = -torch.finfo(sim.dtype).max 190 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 191 | mask = mask[:, None, :].repeat(h, 1, 1) 192 | sim.masked_fill_(~mask, max_neg_value) 193 | 194 | attn = sim.softmax(dim=-1) 195 | # the only difference 196 | out = editor( 197 | q, k, v, sim, attn, is_cross, place_in_unet, 198 | self.heads, scale=self.scale) 199 | 200 | return to_out(out) 201 | 202 | return forward 203 | 204 | def register_editor(net, count, place_in_unet): 205 | for name, subnet in net.named_children(): 206 | if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer 207 | net.forward = ca_forward(net, place_in_unet) 208 | return count + 1 209 | elif hasattr(net, 'children'): 210 | count = register_editor(subnet, count, place_in_unet) 211 | return count 212 | 213 | cross_att_count = 0 214 | for net_name, net in model.model.diffusion_model.named_children(): 215 | if "input" in net_name: 216 | cross_att_count += register_editor(net, 0, "input") 217 | elif "middle" in net_name: 218 | cross_att_count += register_editor(net, 0, "middle") 219 | elif "output" in net_name: 220 | cross_att_count += register_editor(net, 0, "output") 221 | print("Editor -->", cross_att_count) 222 | editor.num_att_layers = cross_att_count 223 | -------------------------------------------------------------------------------- /attention_control/share_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from einops import rearrange 8 | 9 | from .masactrl_utils import AttentionBase 10 | 11 | from torchvision.utils import save_image 12 | 13 | # From diffuser/example/interpolate diffusion model 14 | def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): 15 | """helper function to spherically interpolate two arrays v1 v2""" 16 | 17 | 18 | dot = torch.sum(v0 * v1 / (torch.norm(v0) * torch.norm(v1))) 19 | if torch.abs(dot) > DOT_THRESHOLD: 20 | v2 = (1 - t) * v0 + t * v1 21 | else: 22 | theta_0 = torch.arccos(dot) 23 | sin_theta_0 = torch.sin(theta_0) 24 | theta_t = theta_0 * t 25 | sin_theta_t = torch.sin(theta_t) 26 | s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 27 | s1 = sin_theta_t / sin_theta_0 28 | v2 = s0 * v0 + s1 * v1 29 | 30 | return v2 31 | 32 | class ShareSelfAttentionControl(AttentionBase): 33 | MODEL_TYPE = { 34 | "SD": 16, 35 | "SDXL": 70 36 | } 37 | 38 | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=20,scales=None, model_type="SD"): 39 | """ 40 | Mutual self-attention control for Stable-Diffusion model 41 | Args: 42 | start_step: the step to start mutual self-attention control 43 | start_layer: the layer to start mutual self-attention control 44 | layer_idx: list of the layers to apply mutual self-attention control 45 | step_idx: list the steps to apply mutual self-attention control 46 | total_steps: the total number of steps 47 | model_type: the model type, SD or SDXL 48 | """ 49 | super().__init__() 50 | self.total_steps = total_steps 51 | self.total_layers = self.MODEL_TYPE.get(model_type, 16) 52 | self.h_layer=self.total_layers//2 53 | self.start_step = start_step 54 | self.start_layer = start_layer 55 | self.scales=scales 56 | self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) 57 | self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) 58 | print("ShareSelfAttentionControl at denoising steps: ", self.step_idx) 59 | print("ShareSelfAttentionControl at U-Net layers: ", self.layer_idx) 60 | if self.scales!=None: 61 | print("Similarity Rescales=",self.scales) 62 | 63 | def share_attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,scales=None, **kwargs): 64 | """ 65 | Performing attention for a batch of queries, keys, and values 66 | """ 67 | b = q.shape[0] // num_heads 68 | n, d = q.shape[1],q.shape[2] 69 | q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) 70 | k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) 71 | v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) 72 | sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") 73 | 74 | # print("q",q.shape) 75 | # print("k",k.shape) 76 | # print("v",v.shape) 77 | # print("sim",sim.shape) 78 | # Attention Mask 79 | mask = torch.zeros_like(input=sim,dtype=torch.bool) 80 | # mask cur to attend all but not to it self 81 | # X X X 82 | # X X X 83 | # O O X 84 | mask[:,-n:,:-n]=True 85 | # mask ref to attend only self 86 | # O X X 87 | # X O X 88 | # X X X 89 | for ref_idx in range(b-1): 90 | mask[:,ref_idx*n:(ref_idx+1)*n,ref_idx*n:(ref_idx+1)*n]=True 91 | 92 | max_neg_value = -torch.finfo(sim.dtype).max 93 | masked_sim=sim.masked_fill_(~mask, max_neg_value) 94 | 95 | if scales!=None: 96 | assert len(scales)==(b-1), "length of scales should equal to batch size-1 (-1 because self-value-ignorance)" 97 | for cur_scale,ref_idx in zip(scales,range(b-1)): 98 | masked_sim[:,-n:,ref_idx*n:(ref_idx+1)*n]*=cur_scale 99 | 100 | attn = (masked_sim).softmax(dim=-1) 101 | out = torch.einsum("h i j, h j d -> h i d", attn, v) 102 | out = rearrange(out, "h (b n) d -> b n (h d)", b=b) 103 | return out 104 | 105 | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): 106 | """ 107 | Attention forward function 108 | """ 109 | # !!! is_cross: using original attention forward 110 | # self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx : 111 | # performing multual attention after some step and after some layer 112 | if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: 113 | return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) 114 | # Maintaining two image via batch manner 115 | # Normal: latents -> [uncond,cond] 116 | # Mutual: latents -> [ref-uncond,cur-uncond,ref-cond,ref-uncond] 117 | qu, qc = q.chunk(2) 118 | ku, kc = k.chunk(2) 119 | vu, vc = v.chunk(2) 120 | attnu, attnc = attn.chunk(2) 121 | 122 | # print("q",q.shape) 123 | # print("k",k.shape) 124 | # print("v",v.shape) 125 | # print("heads",num_heads) 126 | # qu->[q of ref-uncond, q of cur-uncond] 127 | # ku->[k of ref-uncond] 128 | # vu->[v of ref-uncond] 129 | # out-> [ref-latent, cur-latent] 130 | out_u = self.share_attn_batch(qu, ku, vu, sim, attnu, is_cross, place_in_unet, num_heads,scales=self.scales, **kwargs) 131 | out_c = self.share_attn_batch(qc, kc, vc, sim, attnc, is_cross, place_in_unet, num_heads,scales=self.scales, **kwargs) 132 | out = torch.cat([out_u, out_c], dim=0) 133 | # !!! Debug 134 | 135 | return out 136 | 137 | class ShareSelfAttentionControlSlerp(AttentionBase): 138 | MODEL_TYPE = { 139 | "SD": 16, 140 | "SDXL": 70 141 | } 142 | 143 | def __init__(self, start_step=4, start_layer=10, layer_idx=None, 144 | step_idx=None, total_steps=50,scales=None,slerp_ratio=0.1, model_type="SD"): 145 | """ 146 | Mutual self-attention control for Stable-Diffusion model 147 | Args: 148 | start_step: the step to start mutual self-attention control 149 | start_layer: the layer to start mutual self-attention control 150 | layer_idx: list of the layers to apply mutual self-attention control 151 | step_idx: list the steps to apply mutual self-attention control 152 | total_steps: the total number of steps 153 | model_type: the model type, SD or SDXL 154 | """ 155 | super().__init__() 156 | self.total_steps = total_steps 157 | self.total_layers = self.MODEL_TYPE.get(model_type, 16) 158 | self.start_step = start_step 159 | self.start_layer = start_layer 160 | self.scales=scales 161 | self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) 162 | self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) 163 | self.slerp_ratio=slerp_ratio 164 | print("ShareSelfCtrl at denoising steps: ", self.step_idx) 165 | print("ShareSelfCtrl at U-Net layers: ", self.layer_idx) 166 | if self.scales!=None: 167 | print("Rescales=",self.scales) 168 | 169 | def share_attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,scales=None, **kwargs): 170 | """ 171 | Performing attention for a batch of queries, keys, and values 172 | """ 173 | b = q.shape[0] // num_heads 174 | n, d = q.shape[1],q.shape[2] 175 | q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) 176 | k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) 177 | v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) 178 | sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") 179 | 180 | # print("q",q.shape) 181 | # print("k",k.shape) 182 | # print("v",v.shape) 183 | # print("sim",sim.shape) 184 | # Attention Mask 185 | mask = torch.zeros_like(input=sim,dtype=torch.bool) 186 | # mask cur to attend all but not to it self 187 | # X X X 188 | # X X X 189 | # O O X 190 | mask[:,-n:,:-n]=True 191 | # mask ref to attend only self 192 | # O X X 193 | # X O X 194 | # X X X 195 | for ref_idx in range(b-1): 196 | mask[:,ref_idx*n:(ref_idx+1)*n,ref_idx*n:(ref_idx+1)*n]=True 197 | 198 | max_neg_value = -torch.finfo(sim.dtype).max 199 | masked_sim=sim.masked_fill_(~mask, max_neg_value) 200 | 201 | if scales!=None: 202 | assert len(scales)==(b-1), "length of scales should equal to batch size-1 (-1 because self-value-ignorance)" 203 | for cur_scale,ref_idx in zip(scales,range(b-1)): 204 | masked_sim[:,-n:,ref_idx*n:(ref_idx+1)*n]*=cur_scale 205 | 206 | attn = (masked_sim).softmax(dim=-1) 207 | out = torch.einsum("h i j, h j d -> h i d", attn, v) 208 | out = rearrange(out, "h (b n) d -> b n (h d)", b=b) 209 | # print("out shape",out.shape) 210 | # (3,64,1280) 211 | if out.shape[1]<=64: 212 | # print("H-space") 213 | # !!! 214 | tmp_h=slerp(self.slerp_ratio,out[-1],out[-2]/torch.norm(out[-2])*torch.norm(out[-1])) 215 | out[-1]=tmp_h 216 | #out[0]=out[1] 217 | 218 | 219 | return out 220 | 221 | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): 222 | """ 223 | Attention forward function 224 | """ 225 | # !!! is_cross: using original attention forward 226 | # self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx : 227 | # performing multual attention after some step and after some layer 228 | # print(f"Cur layer {self.cur_att_layer}, q.shape{q.shape}, attn.shape{attn.shape}") 229 | if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: 230 | return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) 231 | # Maintaining two image via batch manner 232 | # Normal: latents -> [uncond,cond] 233 | # Mutual: latents -> [ref-uncond,cur-uncond,ref-cond,ref-uncond] 234 | qu, qc = q.chunk(2) 235 | ku, kc = k.chunk(2) 236 | vu, vc = v.chunk(2) 237 | attnu, attnc = attn.chunk(2) 238 | 239 | # print("q",q.shape) 240 | # print("k",k.shape) 241 | # print("v",v.shape) 242 | # print("heads",num_heads) 243 | # qu->[q of ref-uncond, q of cur-uncond] 244 | # ku->[k of ref-uncond] 245 | # vu->[v of ref-uncond] 246 | # out-> [ref-latent, cur-latent] 247 | out_u = self.share_attn_batch(qu, ku, vu, sim, attnu, is_cross, place_in_unet, num_heads,scales=self.scales, **kwargs) 248 | out_c = self.share_attn_batch(qc, kc, vc, sim, attnc, is_cross, place_in_unet, num_heads,scales=self.scales, **kwargs) 249 | out = torch.cat([out_u, out_c], dim=0) 250 | # !!! Debug 251 | 252 | return out 253 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from collections import defaultdict 5 | from numbers import Number 6 | from typing import Any, Optional, Tuple, Union 7 | 8 | import clip 9 | import lpips 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F_torch 13 | import torchvision.transforms.functional as F_vision 14 | from PIL import Image 15 | from tqdm import tqdm 16 | import cv2 17 | 18 | 19 | def parse_argument(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "-r", "--imgs-root", type=str, default="./GPH Benchmark demo data", help="Path to ground truth images" 23 | ) 24 | parser.add_argument( 25 | "-cm", 26 | "--clip-model-type", 27 | type=str, 28 | default="ViT-B/32", 29 | choices=["ViT-B/32", "ViT-L/14"], 30 | help="CLIP model type", 31 | ) 32 | parser.add_argument( 33 | "-b", 34 | "--box_meter", 35 | type=bool, 36 | default=False, 37 | help="If enable box evaluation", 38 | ) 39 | args = parser.parse_args() 40 | return args 41 | 42 | 43 | def to_rgb_tensor(img: np.ndarray, size: Union[Tuple[int, int]] = None): 44 | tensor = F_vision.to_tensor(img) 45 | if size is not None: 46 | tensor = F_vision.resize(tensor, size, antialias=True) 47 | return tensor.unsqueeze(0) 48 | 49 | 50 | class AverageMeter(object): 51 | def __init__(self): 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | 60 | def update(self, val: Number, n: int = 1): 61 | self.val = val 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | 66 | 67 | class MetricMeter(object): 68 | def __init__(self, delimiter: str = "\t"): 69 | self.meters = defaultdict(AverageMeter) 70 | self.delimiter = delimiter 71 | 72 | def update(self, input_dict: Optional[dict]): 73 | if input_dict is None: 74 | return 75 | 76 | if not isinstance(input_dict, dict): 77 | raise TypeError("Input to MetricMeter.update() must be a dictionary") 78 | 79 | for k, v in input_dict.items(): 80 | if v is not None: 81 | self.meters[k].update(v) 82 | 83 | def __str__(self): 84 | output_str = [] 85 | for name, meter in self.meters.items(): 86 | output_str.append(f"{name} {meter.avg:.4f}") 87 | return self.delimiter.join(output_str) 88 | 89 | def get_log_dict(self): 90 | log_dict = {} 91 | for name, meter in self.meters.items(): 92 | log_dict[name] = meter.val 93 | log_dict[f"avg_{name}"] = meter.avg 94 | return log_dict 95 | 96 | 97 | @torch.no_grad() 98 | def compute_lpips( 99 | source_img: np.ndarray, 100 | style_img: np.ndarray, 101 | stylized_img: np.ndarray, 102 | composite_mask: np.ndarray, 103 | lpips_model: torch.nn.Module, 104 | ): 105 | 106 | non_x, non_y = np.nonzero(composite_mask) 107 | left, right = non_y.min(), non_y.max() 108 | top, bottom = non_x.min(), non_x.max() 109 | #fg_img = (stylized_img * composite_mask[:, :, None])[top:bottom, left:right] 110 | crop_stylized_img = (stylized_img * composite_mask[:, :, None])[top:bottom, left:right] 111 | crop_source= (source_img * composite_mask[:, :, None])[top:bottom, left:right] 112 | bg_stylized_img = stylized_img * (1 - composite_mask[:, :, None]) 113 | bg_style_img = style_img * (1 - composite_mask[:, :, None]) 114 | 115 | """ 116 | pil_image = Image.fromarray(bg_stylized_img) 117 | pil_image.save("./test/bg_img.png") 118 | 119 | pil_image = Image.fromarray(bg_style_img) 120 | pil_image.save("./test/bg_style_img.png") 121 | exit() 122 | """ 123 | 124 | 125 | fg_score = lpips_model( 126 | to_rgb_tensor(crop_stylized_img,size=source_img.shape[:2]), to_rgb_tensor(crop_source,size=source_img.shape[:2]) 127 | ).item() 128 | bg_score = lpips_model(to_rgb_tensor(bg_stylized_img), to_rgb_tensor(bg_style_img)).item() 129 | 130 | return fg_score, bg_score 131 | 132 | 133 | @torch.no_grad() 134 | def compute_clip_score( 135 | source_img: np.ndarray, 136 | style_img: np.ndarray, 137 | stylized_img: np.ndarray, 138 | composite_mask: np.ndarray, 139 | clip_model: Any, 140 | device: Union[torch.device, str], 141 | prompt: Union[str] = None, 142 | ): 143 | non_x, non_y = np.nonzero(composite_mask) 144 | left, right = non_y.min(), non_y.max() 145 | top, bottom = non_x.min(), non_x.max() 146 | #fg_img = (stylized_img * composite_mask[:, :, None])[top:bottom, left:right] 147 | crop_stylized_img = (stylized_img * composite_mask[:, :, None])[top:bottom, left:right] 148 | crop_source_img= (source_img * composite_mask[:, :, None])[top:bottom, left:right] 149 | #bg_stylized_img = stylized_img * (1 - composite_mask[:, :, None]) 150 | #bg_style_img = style_img * (1 - composite_mask[:, :, None]) 151 | 152 | preprocess_crop_stylized_img = clip_model[1](Image.fromarray(crop_stylized_img)).unsqueeze(0).to(device) 153 | preprocess_crop_fg_img = clip_model[1](Image.fromarray(crop_source_img)).unsqueeze(0).to(device) 154 | 155 | crop_stylized_features = F_torch.normalize(clip_model[0].encode_image(preprocess_crop_stylized_img), dim=-1) 156 | crop_fg_features = F_torch.normalize(clip_model[0].encode_image(preprocess_crop_fg_img), dim=-1) 157 | 158 | img_score = (crop_stylized_features .squeeze() @ crop_fg_features.squeeze()).item() * 100 159 | 160 | preprocess_stylized_img = clip_model[1](Image.fromarray(stylized_img)).unsqueeze(0).to(device) 161 | preprocess_style_img = clip_model[1](Image.fromarray(style_img)).unsqueeze(0).to(device) 162 | #stylized_features = F_torch.normalize(clip_model[0].encode_image(preprocess_stylized_img), dim=-1) 163 | style_features = F_torch.normalize(clip_model[0].encode_image(preprocess_style_img), dim=-1) 164 | 165 | style_score = (crop_stylized_features .squeeze() @ style_features.squeeze()).item() * 100 166 | 167 | preprocess_source_img = clip_model[1](Image.fromarray(source_img)).unsqueeze(0).to(device) 168 | source_features = F_torch.normalize(clip_model[0].encode_image(preprocess_source_img), dim=-1) 169 | dfg=crop_stylized_features-crop_fg_features 170 | dbg=style_features-source_features 171 | dir_score=(dfg .squeeze() @ dbg.squeeze()).item() * 100 172 | 173 | text_score = None 174 | if prompt is not None: 175 | tokenized_prompt = clip.tokenize([prompt]).to(device) 176 | text_features = F_torch.normalize(clip_model[0].encode_text(tokenized_prompt), dim=-1) 177 | #text_score = (stylized_features.squeeze() @ text_features.squeeze()).item() * 100 178 | 179 | 180 | return img_score, style_score, dir_score, text_score 181 | 182 | 183 | def main(args): 184 | device = "cuda" if torch.cuda.is_available() else "cpu" 185 | lpips_fn = lpips.LPIPS(net="alex") 186 | clip_model = clip.load(args.clip_model_type, device=device) 187 | 188 | meter = MetricMeter() 189 | if args.box_meter: 190 | box_meter = MetricMeter() 191 | else: 192 | box_meter=False 193 | idxs=[file[:-4] for file in os.listdir(os.path.join(args.imgs_root, "harmonized_data"))] 194 | for i in tqdm(idxs): 195 | source_img_path = os.path.join(args.imgs_root, "composite_data", f"{i}.png") 196 | composite_mask_path = os.path.join(args.imgs_root, "mask_data", f"{i}.png") 197 | style_img_path = os.path.join(args.imgs_root, "background_data", f"{i}.png") 198 | stylized_img_path = os.path.join(args.imgs_root, "harmonized_data", f"{i}.png") 199 | 200 | source_img = np.array(Image.open(source_img_path).convert("RGB")) 201 | composite_mask = np.array(Image.open(composite_mask_path).convert("L"))//255 202 | style_img = np.array(Image.open(style_img_path).convert("RGB")) 203 | stylized_img = np.array(Image.open(stylized_img_path).convert("RGB")) 204 | 205 | fg_lpips, bg_lpips = compute_lpips( 206 | source_img, style_img, stylized_img, composite_mask, lpips_fn 207 | ) 208 | img_clip, style_clip, dir_clip, text_clip = compute_clip_score( 209 | source_img, 210 | style_img, 211 | stylized_img, 212 | composite_mask, 213 | clip_model, 214 | device, 215 | #prompt=f"a photo in {style_text} style", 216 | ) 217 | 218 | meter.update( 219 | { 220 | "bg_lpips": bg_lpips, 221 | "fg_lpips": fg_lpips, 222 | "img_clip": img_clip, 223 | "style_clip": style_clip, 224 | "dir_clip": dir_clip, 225 | #"text_clip": text_clip, 226 | } 227 | ) 228 | 229 | print("Segment meter:",meter) 230 | 231 | 232 | if __name__ == "__main__": 233 | args = parse_argument() 234 | main(args) 235 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference-v.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "jpg" 12 | cond_stage_key: "txt" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False # we set this to false because this is an inference only config 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | use_fp16: True 26 | image_size: 32 # unused 27 | in_channels: 4 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: [ 4, 2, 1 ] 31 | num_res_blocks: 2 32 | channel_mult: [ 1, 2, 4, 4 ] 33 | num_head_channels: 64 # need to fix for flash-attn 34 | use_spatial_transformer: True 35 | use_linear_in_transformer: True 36 | transformer_depth: 1 37 | context_dim: 1024 38 | legacy: False 39 | 40 | first_stage_config: 41 | target: ldm.models.autoencoder.AutoencoderKL 42 | params: 43 | embed_dim: 4 44 | monitor: val/rec_loss 45 | ddconfig: 46 | #attn_type: "vanilla-xformers" 47 | double_z: true 48 | z_channels: 4 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: [] 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | 64 | cond_stage_config: 65 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 66 | params: 67 | freeze: True 68 | layer: "penultimate" 69 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False # we set this to false because this is an inference only config 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | use_checkpoint: True 24 | use_fp16: True 25 | image_size: 32 # unused 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | unet_config: 22 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 23 | params: 24 | use_checkpoint: True 25 | image_size: 32 # unused 26 | in_channels: 9 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: 2 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_head_channels: 64 # need to fix for flash-attn 33 | use_spatial_transformer: True 34 | use_linear_in_transformer: True 35 | transformer_depth: 1 36 | context_dim: 1024 37 | legacy: False 38 | 39 | first_stage_config: 40 | target: ldm.models.autoencoder.AutoencoderKL 41 | params: 42 | embed_dim: 4 43 | monitor: val/rec_loss 44 | ddconfig: 45 | #attn_type: "vanilla-xformers" 46 | double_z: true 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 4 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 65 | params: 66 | freeze: True 67 | layer: "penultimate" 68 | 69 | 70 | data: 71 | target: ldm.data.laion.WebDataModuleFromConfig 72 | params: 73 | tar_base: null # for concat as in LAION-A 74 | p_unsafe_threshold: 0.1 75 | filter_word_list: "data/filters.yaml" 76 | max_pwatermark: 0.45 77 | batch_size: 8 78 | num_workers: 6 79 | multinode: True 80 | min_size: 512 81 | train: 82 | shards: 83 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" 84 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" 85 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" 86 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" 87 | - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" 88 | shuffle: 10000 89 | image_key: jpg 90 | image_transforms: 91 | - target: torchvision.transforms.Resize 92 | params: 93 | size: 512 94 | interpolation: 3 95 | - target: torchvision.transforms.RandomCrop 96 | params: 97 | size: 512 98 | postprocess: 99 | target: ldm.data.laion.AddMask 100 | params: 101 | mode: "512train-large" 102 | p_drop: 0.25 103 | # NOTE use enough shards to avoid empty validation loops in workers 104 | validation: 105 | shards: 106 | - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " 107 | shuffle: 0 108 | image_key: jpg 109 | image_transforms: 110 | - target: torchvision.transforms.Resize 111 | params: 112 | size: 512 113 | interpolation: 3 114 | - target: torchvision.transforms.CenterCrop 115 | params: 116 | size: 512 117 | postprocess: 118 | target: ldm.data.laion.AddMask 119 | params: 120 | mode: "512train-large" 121 | p_drop: 0.25 122 | 123 | lightning: 124 | find_unused_parameters: True 125 | modelcheckpoint: 126 | params: 127 | every_n_train_steps: 5000 128 | 129 | callbacks: 130 | metrics_over_trainsteps_checkpoint: 131 | params: 132 | every_n_train_steps: 10000 133 | 134 | image_logger: 135 | target: main.ImageLogger 136 | params: 137 | enable_autocast: False 138 | disabled: False 139 | batch_frequency: 1000 140 | max_images: 4 141 | increase_log_steps: False 142 | log_first_step: False 143 | log_images_kwargs: 144 | use_ema_scope: False 145 | inpaint: False 146 | plot_progressive_rows: False 147 | plot_diffusion_rows: False 148 | N: 4 149 | unconditional_guidance_scale: 5.0 150 | unconditional_guidance_label: [""] 151 | ddim_steps: 50 # todo check these out for depth2img, 152 | ddim_eta: 0.0 # todo check these out for depth2img, 153 | 154 | trainer: 155 | benchmark: True 156 | val_check_interval: 5000000 157 | num_sanity_val_steps: 0 158 | accumulate_grad_batches: 1 159 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v2-midas-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-07 3 | target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: hybrid 16 | scale_factor: 0.18215 17 | monitor: val/loss_simple_ema 18 | finetune_keys: null 19 | use_ema: False 20 | 21 | depth_stage_config: 22 | target: ldm.modules.midas.api.MiDaSInference 23 | params: 24 | model_type: "dpt_hybrid" 25 | 26 | unet_config: 27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 5 32 | out_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | #attn_type: "vanilla-xformers" 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [ ] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 70 | params: 71 | freeze: True 72 | layer: "penultimate" 73 | 74 | 75 | -------------------------------------------------------------------------------- /configs/stable-diffusion/x4-upscaling.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion 4 | params: 5 | parameterization: "v" 6 | low_scale_key: "lr" 7 | linear_start: 0.0001 8 | linear_end: 0.02 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: "jpg" 13 | cond_stage_key: "txt" 14 | image_size: 128 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: "hybrid-adm" 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.08333 20 | use_ema: False 21 | 22 | low_scale_config: 23 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation 24 | params: 25 | noise_schedule_config: # image space 26 | linear_start: 0.0001 27 | linear_end: 0.02 28 | max_noise_level: 350 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | use_checkpoint: True 34 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) 35 | image_size: 128 36 | in_channels: 7 37 | out_channels: 4 38 | model_channels: 256 39 | attention_resolutions: [ 2,4,8] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 2, 4] 42 | disable_self_attentions: [True, True, True, False] 43 | disable_middle_self_attn: False 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 1024 48 | legacy: False 49 | use_linear_in_transformer: True 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | ddconfig: 56 | # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) 57 | double_z: True 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 64 | num_res_blocks: 2 65 | attn_resolutions: [ ] 66 | dropout: 0.0 67 | 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 73 | params: 74 | freeze: True 75 | layer: "penultimate" 76 | 77 | -------------------------------------------------------------------------------- /demo_outputs/harmonized/kangaroo_starry_15S_12L_total.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/demo_outputs/harmonized/kangaroo_starry_15S_12L_total.jpg -------------------------------------------------------------------------------- /demo_outputs/reconstruct/composite_recon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/demo_outputs/reconstruct/composite_recon.jpg -------------------------------------------------------------------------------- /github_source/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig1.png -------------------------------------------------------------------------------- /github_source/fig14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig14.png -------------------------------------------------------------------------------- /github_source/fig15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig15.png -------------------------------------------------------------------------------- /github_source/fig17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig17.png -------------------------------------------------------------------------------- /github_source/fig18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig18.png -------------------------------------------------------------------------------- /github_source/fig1_new.png.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig1_new.png.png -------------------------------------------------------------------------------- /github_source/fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/fig4.png -------------------------------------------------------------------------------- /github_source/tf-gph_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/github_source/tf-gph_demo.gif -------------------------------------------------------------------------------- /gradio/background/bg03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/background/bg03.png -------------------------------------------------------------------------------- /gradio/background/bg36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/background/bg36.png -------------------------------------------------------------------------------- /gradio/background/bg52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/background/bg52.png -------------------------------------------------------------------------------- /gradio/background/bg58.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/background/bg58.png -------------------------------------------------------------------------------- /gradio/background/bg62.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/background/bg62.png -------------------------------------------------------------------------------- /gradio/foreground/fg10_63d22a7f1f5b66e8e5ac28f7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/foreground/fg10_63d22a7f1f5b66e8e5ac28f7.jpg -------------------------------------------------------------------------------- /gradio/foreground/fg50_63d22c871f5b66e8e5ac95e1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/foreground/fg50_63d22c871f5b66e8e5ac95e1.jpg -------------------------------------------------------------------------------- /gradio/foreground/fg88_63d9d508b82cf5cb1db01976.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/foreground/fg88_63d9d508b82cf5cb1db01976.jpg -------------------------------------------------------------------------------- /gradio/foreground/fg90_63d9d4a0b82cf5cb1db00800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/foreground/fg90_63d9d4a0b82cf5cb1db00800.jpg -------------------------------------------------------------------------------- /gradio/foreground/fg92_63d9d6c9b82cf5cb1db05fda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/foreground/fg92_63d9d6c9b82cf5cb1db05fda.jpg -------------------------------------------------------------------------------- /gradio/seg_foreground/fg10_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/seg_foreground/fg10_mask.jpg -------------------------------------------------------------------------------- /gradio/seg_foreground/fg50_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/seg_foreground/fg50_mask.png -------------------------------------------------------------------------------- /gradio/seg_foreground/fg88_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/seg_foreground/fg88_mask.png -------------------------------------------------------------------------------- /gradio/seg_foreground/fg90_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/seg_foreground/fg90_mask.png -------------------------------------------------------------------------------- /gradio/seg_foreground/fg92_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/gradio/seg_foreground/fg92_mask.png -------------------------------------------------------------------------------- /inputs/demo_input/kangaroo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/inputs/demo_input/kangaroo.jpg -------------------------------------------------------------------------------- /inputs/demo_input/kangaroo_starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/inputs/demo_input/kangaroo_starry.jpg -------------------------------------------------------------------------------- /inputs/demo_input/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/inputs/demo_input/starry_night.jpg -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/plms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver_pytorch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver_pytorch.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | from tqdm import tqdm 4 | # from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | from scripts.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | MODEL_TYPES = { 9 | "eps": "noise", 10 | "v": "v" 11 | } 12 | 13 | 14 | class DPMSolverSampler(object): 15 | def __init__(self, model, **kwargs): 16 | super().__init__() 17 | self.model = model 18 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 19 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 20 | 21 | def register_buffer(self, name, attr): 22 | if type(attr) == torch.Tensor: 23 | if attr.device != self.model.device: 24 | attr = attr.to(self.model.device) 25 | setattr(self, name, attr) 26 | 27 | @torch.no_grad() 28 | def sample(self, 29 | steps, 30 | shape, 31 | batch_size=1, 32 | conditioning=None, 33 | inv_emb=None, 34 | callback=None, 35 | normals_sequence=None, 36 | img_callback=None, 37 | quantize_x0=False, 38 | eta=0., 39 | mask=None, 40 | x0=None, 41 | temperature=1., 42 | noise_dropout=0., 43 | score_corrector=None, 44 | corrector_kwargs=None, 45 | verbose=True, 46 | x_T=None, 47 | log_every_t=100, 48 | unconditional_guidance_scale=1., 49 | unconditional_conditioning=None, 50 | t_start=None, 51 | t_end=None, 52 | DPMencode=False, 53 | order=3, 54 | width=None, 55 | height=None, 56 | ref=False, 57 | top=None, 58 | left=None, 59 | bottom=None, 60 | right=None, 61 | segmentation_map=None, 62 | param=None, 63 | target_height=None, 64 | target_width=None, 65 | center_row_rm=None, 66 | center_col_rm=None, 67 | tau_a=0.4, 68 | tau_b=0.8, 69 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 70 | **kwargs 71 | ): 72 | if conditioning is not None: 73 | if isinstance(conditioning, dict): 74 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 75 | if cbs != batch_size: 76 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 77 | else: 78 | if conditioning.shape[0] != batch_size: 79 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 80 | 81 | # sampling 82 | C, H, W = shape 83 | size = (batch_size, C, H, W) 84 | 85 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {steps}') 86 | 87 | device = self.model.betas.device 88 | if x_T is None: 89 | x = torch.randn(size, device=device) 90 | else: 91 | x = x_T 92 | 93 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 94 | 95 | if DPMencode: 96 | # x_T is not a list 97 | model_fn = model_wrapper( 98 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject), 99 | ns, 100 | model_type=MODEL_TYPES[self.model.parameterization], 101 | guidance_type="classifier-free", 102 | condition=inv_emb, 103 | unconditional_condition=inv_emb, 104 | guidance_scale=unconditional_guidance_scale, 105 | ) 106 | 107 | dpm_solver = DPM_Solver(model_fn, ns) 108 | 109 | 110 | data, _ = self.low_order_sample(x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=DPMencode) 111 | 112 | for step in tqdm(range(order, steps + 1), desc='DPM++ inversion (adding noise) Z0-->ZT'): 113 | data = dpm_solver.sample_one_step(data, step, steps, order=order, DPMencode=DPMencode) 114 | 115 | return data['x'].to(device), None 116 | 117 | else: 118 | # x_T is not a list 119 | model_fn = model_wrapper( 120 | lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject), 121 | ns, 122 | model_type=MODEL_TYPES[self.model.parameterization], 123 | guidance_type="classifier-free", 124 | condition=inv_emb, 125 | unconditional_condition=inv_emb, 126 | guidance_scale=unconditional_guidance_scale, 127 | ) 128 | 129 | dpm_solver = DPM_Solver(model_fn, ns) 130 | 131 | x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=order, 132 | skip_type='time_uniform', method='multistep') 133 | return x_sample 134 | 135 | 136 | def low_order_sample(self, x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=False, controller=None, inject=False, ref_init=None): 137 | 138 | t_0 = 1. / dpm_solver.noise_schedule.total_N if t_end is None else t_end 139 | t_T = dpm_solver.noise_schedule.T if t_start is None else t_start 140 | 141 | total_controller = [] 142 | assert steps >= order 143 | timesteps = dpm_solver.get_time_steps(skip_type="time_uniform", t_T=t_T, t_0=t_0, N=steps, device=device, DPMencode=DPMencode) 144 | assert timesteps.shape[0] - 1 == steps 145 | with torch.no_grad(): 146 | vec_t = timesteps[0].expand((x.shape[0])) 147 | model_prev_list = [dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode, 148 | controller=[controller[0][0], controller[1][0], controller[2][0]] if isinstance(controller, list) else controller, 149 | inject=inject, ref_init=ref_init)] 150 | 151 | total_controller.append(controller) 152 | t_prev_list = [vec_t] 153 | # Init the first `order` values by lower order multistep DPM-Solver. 154 | for init_order in range(1, order): 155 | vec_t = timesteps[init_order].expand(x.shape[0]) 156 | x = dpm_solver.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, 157 | solver_type='dpmsolver', DPMencode=DPMencode) 158 | model_prev_list.append(dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode, 159 | controller=[controller[0][init_order], controller[1][init_order], controller[2][init_order]] if isinstance(controller, list) else controller, 160 | inject=inject, ref_init=ref_init)) 161 | total_controller.append(controller) 162 | t_prev_list.append(vec_t) 163 | 164 | return {'x': x, 'model_prev_list': model_prev_list, 't_prev_list': t_prev_list, 'timesteps':timesteps}, total_controller 165 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 87 | if cbs != batch_size: 88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 89 | else: 90 | if conditioning.shape[0] != batch_size: 91 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 92 | 93 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 94 | # sampling 95 | C, H, W = shape 96 | size = (batch_size, C, H, W) 97 | print(f'Data shape for PLMS sampling is {size}') 98 | 99 | samples, intermediates = self.plms_sampling(conditioning, size, 100 | callback=callback, 101 | img_callback=img_callback, 102 | quantize_denoised=quantize_x0, 103 | mask=mask, x0=x0, 104 | ddim_use_original_steps=False, 105 | noise_dropout=noise_dropout, 106 | temperature=temperature, 107 | score_corrector=score_corrector, 108 | corrector_kwargs=corrector_kwargs, 109 | x_T=x_T, 110 | log_every_t=log_every_t, 111 | unconditional_guidance_scale=unconditional_guidance_scale, 112 | unconditional_conditioning=unconditional_conditioning, 113 | dynamic_threshold=dynamic_threshold, 114 | ) 115 | return samples, intermediates 116 | 117 | @torch.no_grad() 118 | def plms_sampling(self, cond, shape, 119 | x_T=None, ddim_use_original_steps=False, 120 | callback=None, timesteps=None, quantize_denoised=False, 121 | mask=None, x0=None, img_callback=None, log_every_t=100, 122 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 123 | unconditional_guidance_scale=1., unconditional_conditioning=None, 124 | dynamic_threshold=None): 125 | device = self.model.betas.device 126 | b = shape[0] 127 | if x_T is None: 128 | img = torch.randn(shape, device=device) 129 | else: 130 | img = x_T 131 | 132 | if timesteps is None: 133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 134 | elif timesteps is not None and not ddim_use_original_steps: 135 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 136 | timesteps = self.ddim_timesteps[:subset_end] 137 | 138 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 139 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 140 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 141 | print(f"Running PLMS Sampling with {total_steps} timesteps") 142 | 143 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 144 | old_eps = [] 145 | 146 | for i, step in enumerate(iterator): 147 | index = total_steps - i - 1 148 | ts = torch.full((b,), step, device=device, dtype=torch.long) 149 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 | 151 | if mask is not None: 152 | assert x0 is not None 153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 | img = img_orig * mask + (1. - mask) * img 155 | 156 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 157 | quantize_denoised=quantize_denoised, temperature=temperature, 158 | noise_dropout=noise_dropout, score_corrector=score_corrector, 159 | corrector_kwargs=corrector_kwargs, 160 | unconditional_guidance_scale=unconditional_guidance_scale, 161 | unconditional_conditioning=unconditional_conditioning, 162 | old_eps=old_eps, t_next=ts_next, 163 | dynamic_threshold=dynamic_threshold) 164 | img, pred_x0, e_t = outs 165 | old_eps.append(e_t) 166 | if len(old_eps) >= 4: 167 | old_eps.pop(0) 168 | if callback: callback(i) 169 | if img_callback: img_callback(pred_x0, i) 170 | 171 | if index % log_every_t == 0 or index == total_steps - 1: 172 | intermediates['x_inter'].append(img) 173 | intermediates['pred_x0'].append(pred_x0) 174 | 175 | return img, intermediates 176 | 177 | @torch.no_grad() 178 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 179 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 180 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 181 | dynamic_threshold=None): 182 | b, *_, device = *x.shape, x.device 183 | 184 | def get_model_output(x, t): 185 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 186 | e_t = self.model.apply_model(x, t, c) 187 | else: 188 | x_in = torch.cat([x] * 2) 189 | t_in = torch.cat([t] * 2) 190 | c_in = torch.cat([unconditional_conditioning, c]) 191 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 | 194 | if score_corrector is not None: 195 | assert self.model.parameterization == "eps" 196 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 | 198 | return e_t 199 | 200 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 201 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 202 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 203 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 204 | 205 | def get_x_prev_and_pred_x0(e_t, index): 206 | # select parameters corresponding to the currently considered timestep 207 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 208 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 209 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 210 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 211 | 212 | # current prediction for x_0 213 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 214 | if quantize_denoised: 215 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 216 | if dynamic_threshold is not None: 217 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 218 | # direction pointing to x_t 219 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 220 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 221 | if noise_dropout > 0.: 222 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 223 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 224 | return x_prev, pred_x0 225 | 226 | e_t = get_model_output(x, t) 227 | if len(old_eps) == 0: 228 | # Pseudo Improved Euler (2nd order) 229 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 230 | e_t_next = get_model_output(x_prev, t_next) 231 | e_t_prime = (e_t + e_t_next) / 2 232 | elif len(old_eps) == 1: 233 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 234 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 235 | elif len(old_eps) == 2: 236 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 237 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 238 | elif len(old_eps) >= 3: 239 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 240 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 241 | 242 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 243 | 244 | return x_prev, pred_x0, e_t 245 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | if ddim_timesteps[-1] == 1000: 66 | ddim_timesteps = ddim_timesteps - 1 67 | alphas = alphacums[ddim_timesteps] 68 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 69 | alphas_next = np.asarray(alphacums[ddim_timesteps[1:]].tolist() + [alphacums[-1].tolist()]) 70 | 71 | # according the the formula provided in https://arxiv.org/abs/2010.02502 72 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 73 | if verbose: 74 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 75 | print(f'For the chosen value of eta, which is {eta}, ' 76 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 77 | return sigmas, alphas, alphas_prev, alphas_next 78 | 79 | 80 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 81 | """ 82 | Create a beta schedule that discretizes the given alpha_t_bar function, 83 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 84 | :param num_diffusion_timesteps: the number of betas to produce. 85 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 86 | produces the cumulative product of (1-beta) up to that 87 | part of the diffusion process. 88 | :param max_beta: the maximum beta to use; use values lower than 1 to 89 | prevent singularities. 90 | """ 91 | betas = [] 92 | for i in range(num_diffusion_timesteps): 93 | t1 = i / num_diffusion_timesteps 94 | t2 = (i + 1) / num_diffusion_timesteps 95 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 96 | return np.array(betas) 97 | 98 | 99 | def extract_into_tensor(a, t, x_shape): 100 | b, *_ = t.shape 101 | out = a.gather(-1, t) 102 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 103 | 104 | 105 | def checkpoint(func, inputs, params, flag): 106 | """ 107 | Evaluate a function without caching intermediate activations, allowing for 108 | reduced memory at the expense of extra compute in the backward pass. 109 | :param func: the function to evaluate. 110 | :param inputs: the argument sequence to pass to `func`. 111 | :param params: a sequence of parameters `func` depends on but does not 112 | explicitly take as arguments. 113 | :param flag: if False, disable gradient checkpointing. 114 | """ 115 | if flag: 116 | args = tuple(inputs) + tuple(params) 117 | return CheckpointFunction.apply(func, len(inputs), *args) 118 | else: 119 | return func(*inputs) 120 | 121 | 122 | class CheckpointFunction(torch.autograd.Function): 123 | @staticmethod 124 | def forward(ctx, run_function, length, *args): 125 | ctx.run_function = run_function 126 | ctx.input_tensors = list(args[:length]) 127 | ctx.input_params = list(args[length:]) 128 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 129 | "dtype": torch.get_autocast_gpu_dtype(), 130 | "cache_enabled": torch.is_autocast_cache_enabled()} 131 | with torch.no_grad(): 132 | output_tensors = ctx.run_function(*ctx.input_tensors) 133 | return output_tensors 134 | 135 | @staticmethod 136 | def backward(ctx, *output_grads): 137 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 138 | with torch.enable_grad(), \ 139 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 140 | # Fixes a bug where the first op in run_function modifies the 141 | # Tensor storage in place, which is not allowed for detach()'d 142 | # Tensors. 143 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 144 | output_tensors = ctx.run_function(*shallow_copies) 145 | input_grads = torch.autograd.grad( 146 | output_tensors, 147 | ctx.input_tensors + ctx.input_params, 148 | output_grads, 149 | allow_unused=True, 150 | ) 151 | del ctx.input_tensors 152 | del ctx.input_params 153 | del output_tensors 154 | return (None, None) + input_grads 155 | 156 | 157 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 158 | """ 159 | Create sinusoidal timestep embeddings. 160 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 161 | These may be fractional. 162 | :param dim: the dimension of the output. 163 | :param max_period: controls the minimum frequency of the embeddings. 164 | :return: an [N x dim] Tensor of positional embeddings. 165 | """ 166 | if not repeat_only: 167 | half = dim // 2 168 | freqs = torch.exp( 169 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 170 | ).to(device=timesteps.device) 171 | args = timesteps[:, None].float() * freqs[None] 172 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 173 | if dim % 2: 174 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 175 | else: 176 | embedding = repeat(timesteps, 'b -> b d', d=dim) 177 | return embedding 178 | 179 | 180 | def zero_module(module): 181 | """ 182 | Zero out the parameters of a module and return it. 183 | """ 184 | for p in module.parameters(): 185 | p.detach().zero_() 186 | return module 187 | 188 | 189 | def scale_module(module, scale): 190 | """ 191 | Scale the parameters of a module and return it. 192 | """ 193 | for p in module.parameters(): 194 | p.detach().mul_(scale) 195 | return module 196 | 197 | 198 | def mean_flat(tensor): 199 | """ 200 | Take the mean over all non-batch dimensions. 201 | """ 202 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 203 | 204 | 205 | def normalization(channels): 206 | """ 207 | Make a standard normalization layer. 208 | :param channels: number of input channels. 209 | :return: an nn.Module for normalization. 210 | """ 211 | return GroupNorm32(32, channels) 212 | 213 | 214 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 215 | class SiLU(nn.Module): 216 | def forward(self, x): 217 | return x * torch.sigmoid(x) 218 | 219 | 220 | class GroupNorm32(nn.GroupNorm): 221 | def forward(self, x): 222 | return super().forward(x.float()).type(x.dtype) 223 | 224 | def conv_nd(dims, *args, **kwargs): 225 | """ 226 | Create a 1D, 2D, or 3D convolution module. 227 | """ 228 | if dims == 1: 229 | return nn.Conv1d(*args, **kwargs) 230 | elif dims == 2: 231 | return nn.Conv2d(*args, **kwargs) 232 | elif dims == 3: 233 | return nn.Conv3d(*args, **kwargs) 234 | raise ValueError(f"unsupported dimensions: {dims}") 235 | 236 | 237 | def linear(*args, **kwargs): 238 | """ 239 | Create a linear module. 240 | """ 241 | return nn.Linear(*args, **kwargs) 242 | 243 | 244 | def avg_pool_nd(dims, *args, **kwargs): 245 | """ 246 | Create a 1D, 2D, or 3D average pooling module. 247 | """ 248 | if dims == 1: 249 | return nn.AvgPool1d(*args, **kwargs) 250 | elif dims == 2: 251 | return nn.AvgPool2d(*args, **kwargs) 252 | elif dims == 3: 253 | return nn.AvgPool3d(*args, **kwargs) 254 | raise ValueError(f"unsupported dimensions: {dims}") 255 | 256 | 257 | class HybridConditioner(nn.Module): 258 | 259 | def __init__(self, c_concat_config, c_crossattn_config): 260 | super().__init__() 261 | self.concat_conditioner = instantiate_from_config(c_concat_config) 262 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 263 | 264 | def forward(self, c_concat, c_crossattn): 265 | c_concat = self.concat_conditioner(c_concat) 266 | c_crossattn = self.crossattn_conditioner(c_crossattn) 267 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 268 | 269 | 270 | def noise_like(shape, device, repeat=False): 271 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 272 | noise = lambda: torch.randn(shape, device=device) 273 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from ldm.util import default, count_params 9 | import einops 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 135 | """ 136 | Uses the OpenCLIP transformer encoder for text 137 | """ 138 | LAYERS = [ 139 | #"pooled", 140 | "last", 141 | "penultimate" 142 | ] 143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 144 | freeze=True, layer="last"): 145 | super().__init__() 146 | assert layer in self.LAYERS 147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 148 | del model.visual 149 | self.model = model 150 | 151 | self.device = device 152 | self.max_length = max_length 153 | if freeze: 154 | self.freeze() 155 | self.layer = layer 156 | if self.layer == "last": 157 | self.layer_idx = 0 158 | elif self.layer == "penultimate": 159 | self.layer_idx = 1 160 | else: 161 | raise NotImplementedError() 162 | 163 | def freeze(self): 164 | self.model = self.model.eval() 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | def forward(self, text, inv=False): 169 | tokens = open_clip.tokenize(text) 170 | if inv: 171 | # print("encoders/modules: tmp mark off [startoftext]") 172 | tokens[0] = torch.zeros(77) + 7788 173 | z = self.encode_with_transformer(tokens.to(self.device), inv) 174 | else: 175 | z = self.encode_with_transformer(tokens.to(self.device)) 176 | return z 177 | 178 | def encode_with_transformer(self, text, inv=False): 179 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 180 | if inv == False: 181 | # x = einops.repeat(x[:,0], 'i j -> i c j', c=77) 182 | # print("encoders/modules: tmp mark off pass positional_embedding") 183 | x = x + self.model.positional_embedding 184 | 185 | 186 | x = x.permute(1, 0, 2) # NLD -> LND 187 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 188 | x = x.permute(1, 0, 2) # LND -> NLD 189 | x = self.model.ln_final(x) 190 | return x 191 | 192 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 193 | for i, r in enumerate(self.model.transformer.resblocks): 194 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 195 | break 196 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 197 | x = checkpoint(r, x, attn_mask) 198 | else: 199 | x = r(x, attn_mask=attn_mask) 200 | return x 201 | 202 | def encode(self, text, inv=False, device=None): 203 | if device is not None: 204 | self.device = device 205 | return self(text, inv) 206 | 207 | 208 | class FrozenCLIPT5Encoder(AbstractEncoder): 209 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 210 | clip_max_length=77, t5_max_length=77): 211 | super().__init__() 212 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 213 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 214 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 215 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 216 | 217 | def encode(self, text): 218 | return self(text) 219 | 220 | def forward(self, text): 221 | clip_z = self.clip_encoder.encode(text) 222 | t5_z = self.t5_encoder.encode(text) 223 | return [clip_z, t5_z] 224 | 225 | 226 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | import PIL 9 | from PIL import Image, ImageDraw, ImageFont 10 | from itertools import islice 11 | import cv2 12 | 13 | def load_model_from_config(config, ckpt, verbose=False): 14 | print(f"Loading model from {ckpt}") 15 | pl_sd = torch.load(ckpt, map_location="cpu") 16 | if "global_step" in pl_sd: 17 | print(f"Global Step: {pl_sd['global_step']}") 18 | sd = pl_sd["state_dict"] 19 | model = instantiate_from_config(config.model) 20 | m, u = model.load_state_dict(sd, strict=False) 21 | if len(m) > 0 and verbose: 22 | print("missing keys:") 23 | print(m) 24 | if len(u) > 0 and verbose: 25 | print("unexpected keys:") 26 | print(u) 27 | 28 | model.cuda() 29 | model.eval() 30 | return model 31 | 32 | 33 | def chunk(it, size): 34 | it = iter(it) 35 | return iter(lambda: tuple(islice(it, size)), ()) 36 | 37 | 38 | def load_img(path, SCALE, pad=False, seg=False, target_size=None): 39 | image = Image.open(path).convert("RGB") 40 | w, h = image.size 41 | w_,h_=w,h 42 | print(f"loaded input image of size ({w}, {h}) from {path}") 43 | w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 44 | w = h = 512 45 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 46 | 47 | image = np.array(image).astype(np.float32) / 255.0 48 | image = image[None].transpose(0, 3, 1, 2) 49 | image = torch.from_numpy(image) 50 | print(f"resize input image of size ({w_}, {h_}) to {w}, {h}") 51 | if pad or seg: 52 | return 2. * image - 1., new_w, new_h, padded_segmentation_map 53 | 54 | return 2. * image - 1., w, h 55 | 56 | 57 | def load_model_and_get_prompt_embedding(model, cfg_scale, device, prompts, inv=False): 58 | 59 | if inv: 60 | inv_emb = model.get_learned_conditioning(prompts, inv) 61 | c = uc = inv_emb 62 | else: 63 | inv_emb = None 64 | 65 | if cfg_scale != 1.0: 66 | uc = model.get_learned_conditioning(1 * [""]) 67 | else: 68 | uc = None 69 | c = model.get_learned_conditioning(prompts) 70 | 71 | return c, uc, inv_emb 72 | 73 | 74 | def log_txt_as_img(wh, xc, size=10): 75 | # wh a tuple of (width, height) 76 | # xc a list of captions to plot 77 | b = len(xc) 78 | txts = list() 79 | for bi in range(b): 80 | txt = Image.new("RGB", wh, color="white") 81 | draw = ImageDraw.Draw(txt) 82 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 83 | nc = int(40 * (wh[0] / 256)) 84 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 85 | 86 | try: 87 | draw.text((0, 0), lines, fill="black", font=font) 88 | except UnicodeEncodeError: 89 | print("Cant encode string for logging. Skipping.") 90 | 91 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 92 | txts.append(txt) 93 | txts = np.stack(txts) 94 | txts = torch.tensor(txts) 95 | return txts 96 | 97 | 98 | def ismap(x): 99 | if not isinstance(x, torch.Tensor): 100 | return False 101 | return (len(x.shape) == 4) and (x.shape[1] > 3) 102 | 103 | 104 | def isimage(x): 105 | if not isinstance(x,torch.Tensor): 106 | return False 107 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 108 | 109 | 110 | def exists(x): 111 | return x is not None 112 | 113 | 114 | def default(val, d): 115 | if exists(val): 116 | return val 117 | return d() if isfunction(d) else d 118 | 119 | 120 | def mean_flat(tensor): 121 | """ 122 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 123 | Take the mean over all non-batch dimensions. 124 | """ 125 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 126 | 127 | 128 | def count_params(model, verbose=False): 129 | total_params = sum(p.numel() for p in model.parameters()) 130 | if verbose: 131 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 132 | return total_params 133 | 134 | 135 | def instantiate_from_config(config): 136 | if not "target" in config: 137 | if config == '__is_first_stage__': 138 | return None 139 | elif config == "__is_unconditional__": 140 | return None 141 | raise KeyError("Expected key `target` to instantiate.") 142 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 143 | 144 | 145 | def get_obj_from_str(string, reload=False): 146 | module, cls = string.rsplit(".", 1) 147 | if reload: 148 | module_imp = importlib.import_module(module) 149 | importlib.reload(module_imp) 150 | return getattr(importlib.import_module(module, package=None), cls) 151 | 152 | 153 | class AdamWwithEMAandWings(optim.Optimizer): 154 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 155 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 156 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 157 | ema_power=1., param_names=()): 158 | """AdamW that saves EMA versions of the parameters.""" 159 | if not 0.0 <= lr: 160 | raise ValueError("Invalid learning rate: {}".format(lr)) 161 | if not 0.0 <= eps: 162 | raise ValueError("Invalid epsilon value: {}".format(eps)) 163 | if not 0.0 <= betas[0] < 1.0: 164 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 165 | if not 0.0 <= betas[1] < 1.0: 166 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 167 | if not 0.0 <= weight_decay: 168 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 169 | if not 0.0 <= ema_decay <= 1.0: 170 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 171 | defaults = dict(lr=lr, betas=betas, eps=eps, 172 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 173 | ema_power=ema_power, param_names=param_names) 174 | super().__init__(params, defaults) 175 | 176 | def __setstate__(self, state): 177 | super().__setstate__(state) 178 | for group in self.param_groups: 179 | group.setdefault('amsgrad', False) 180 | 181 | @torch.no_grad() 182 | def step(self, closure=None): 183 | """Performs a single optimization step. 184 | Args: 185 | closure (callable, optional): A closure that reevaluates the model 186 | and returns the loss. 187 | """ 188 | loss = None 189 | if closure is not None: 190 | with torch.enable_grad(): 191 | loss = closure() 192 | 193 | for group in self.param_groups: 194 | params_with_grad = [] 195 | grads = [] 196 | exp_avgs = [] 197 | exp_avg_sqs = [] 198 | ema_params_with_grad = [] 199 | state_sums = [] 200 | max_exp_avg_sqs = [] 201 | state_steps = [] 202 | amsgrad = group['amsgrad'] 203 | beta1, beta2 = group['betas'] 204 | ema_decay = group['ema_decay'] 205 | ema_power = group['ema_power'] 206 | 207 | for p in group['params']: 208 | if p.grad is None: 209 | continue 210 | params_with_grad.append(p) 211 | if p.grad.is_sparse: 212 | raise RuntimeError('AdamW does not support sparse gradients') 213 | grads.append(p.grad) 214 | 215 | state = self.state[p] 216 | 217 | # State initialization 218 | if len(state) == 0: 219 | state['step'] = 0 220 | # Exponential moving average of gradient values 221 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 222 | # Exponential moving average of squared gradient values 223 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 224 | if amsgrad: 225 | # Maintains max of all exp. moving avg. of sq. grad. values 226 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 227 | # Exponential moving average of parameter values 228 | state['param_exp_avg'] = p.detach().float().clone() 229 | 230 | exp_avgs.append(state['exp_avg']) 231 | exp_avg_sqs.append(state['exp_avg_sq']) 232 | ema_params_with_grad.append(state['param_exp_avg']) 233 | 234 | if amsgrad: 235 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 236 | 237 | # update the steps for each param group update 238 | state['step'] += 1 239 | # record the step after step update 240 | state_steps.append(state['step']) 241 | 242 | optim._functional.adamw(params_with_grad, 243 | grads, 244 | exp_avgs, 245 | exp_avg_sqs, 246 | max_exp_avg_sqs, 247 | state_steps, 248 | amsgrad=amsgrad, 249 | beta1=beta1, 250 | beta2=beta2, 251 | lr=group['lr'], 252 | weight_decay=group['weight_decay'], 253 | eps=group['eps'], 254 | maximize=False) 255 | 256 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 257 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 258 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 259 | 260 | return loss -------------------------------------------------------------------------------- /scripts/__pycache__/dpm_solver_pytorch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/scripts/__pycache__/dpm_solver_pytorch.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/dpm_solver_pytorch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/scripts/__pycache__/dpm_solver_pytorch.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/txt2img.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueDyee/TF-GPH/81fbbec68fc0d35db887837ff04344d3106b932c/scripts/__pycache__/txt2img.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/header.html: -------------------------------------------------------------------------------- 1 |
2 |
9 |

10 | TF-ICON 🦄️ 11 |

12 |
13 |
14 |

15 | TF-ICON, upload a background image and click twice to specify a mask region what you want to replace with a foreground image. 16 |

17 |

18 | Paper is available in arXiv. If you like this demo, please help to ⭐ the Github Repo 😊. 19 |

20 |
21 |
-------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='stable-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /tfgph_app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from app_util import tfgph_load, tfgph_inverse, tfgph_harmonize 3 | import time 4 | import torch 5 | import PIL 6 | 7 | opt = { 8 | "seed": 4753, 9 | "ckpt": "v2-1_512-ema-pruned.ckpt", 10 | "config": "./configs/stable-diffusion/v2-inference.yaml", 11 | "scale": 5, 12 | "n_samples": 1, 13 | "f": 16, 14 | "C": 4, 15 | "total_steps":25, 16 | "ddim_eta": 0.0, 17 | "outdir": "./outputs/", 18 | "order":2, 19 | "prompt":"" 20 | } 21 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 22 | 23 | def app_load_model(x=None,progress=gr.Progress()): 24 | progress((50,100),desc="Loading model (this might take 0~30 seconds)... ") 25 | #img = PIL.Image.open("./inputs/demo_input/kangaroo.jpg") 26 | model,sampler= tfgph_load(opt,device) 27 | opt["model"]=model 28 | opt["sampler"]=sampler 29 | return "Model Loaded 💪" 30 | def app_inv_img(img,idx,progress=gr.Progress()): 31 | progress((50,100),desc="Prepocessing...") 32 | #img = PIL.Image.open("./inputs/demo_input/kangaroo.jpg") 33 | # Resize the image with the LANCZOS resampling filter 34 | resized_img = img.resize((512, 512), resample=PIL.Image.LANCZOS) 35 | z_ref=tfgph_inverse(resized_img,opt,opt["model"],opt["sampler"],device) 36 | opt[f"z_ref{idx}"]=z_ref 37 | return img 38 | def app_har(share_step,share_layer,alpha,beta): 39 | 40 | opt["share_step"]=int(share_step) 41 | opt["share_layer"]=int(share_layer) 42 | opt["scale_alpha"]=alpha 43 | opt["scale_beta"]=beta 44 | 45 | out=tfgph_harmonize(opt["z_ref1"],opt["z_ref2"],opt["z_ref3"],opt,opt["model"],opt["sampler"],device) 46 | return out 47 | 48 | with gr.Blocks(theme=gr.themes.Soft()) as demo: 49 | TITLE="""# Welcome to our TF-GPH web-demo🥰, 50 | the TF-GPH method is aiming to support artistic image generation🖼️ (e.g. Style Transfer, Painterly Harmonization), 51 | using user-specific image composition as instruction. 52 | """ 53 | gr.Markdown(TITLE) 54 | #model, sampler, z_ref1, z_ref2, z_comp=gr.State(), gr.State(), gr.State(), gr.State(), gr.State() 55 | load_btn=gr.Button("Load TF-GPH model (❗ Insure you have already load the model before you start)") 56 | load_label = gr.Label(value="Model not yet loadded ❌",label="TF-GPH model status") 57 | load_btn.click(app_load_model,inputs=None,outputs=[load_label]) 58 | 59 | gallery_val=["./inputs/demo_input/kangaroo.jpg","./inputs/demo_input/starry_night.jpg","./inputs/demo_input/kangaroo_starry.jpg"] 60 | gallery = gr.Gallery(value=gallery_val,label="Generated images", show_label=False, elem_id="gallery", columns=[6], rows=[1], object_fit="contain", height="auto") 61 | with gr.Row(): 62 | layout_h=300 63 | layout_w=300 64 | 65 | 66 | ref1=gr.Image(scale=1,height=layout_h,width=layout_w,image_mode="RGB",label="reference 1 (foreground object)",type="pil") 67 | ref2=gr.Image(scale=1,height=layout_h,width=layout_w,image_mode="RGB",label="reference 2 (background image)",type="pil") 68 | ref3=gr.Image(scale=1,height=layout_h,width=layout_w,image_mode="RGB",label="composite (composite image)",type="pil") 69 | 70 | ref1.upload(app_inv_img, [ref1,gr.State(value=1)],[ref1]) 71 | ref2.upload(app_inv_img, [ref2,gr.State(value=2)],[ref2]) 72 | ref3.upload(app_inv_img, [ref3,gr.State(value=3)],[ref3]) 73 | with gr.Row(): 74 | with gr.Column(): 75 | share_step=gr.Slider(value=20,minimum=0,maximum=opt["total_steps"],label="share_step (lower->stronger)") 76 | share_layer=gr.Slider(value=8,minimum=0,maximum=16,label="share_layer (lower->stronger)") 77 | alpha=gr.Slider(value=0.9,minimum=0,maximum=10,label="ref1 weight (higher->stronger)") 78 | beta=gr.Slider(value=1.1,minimum=0,maximum=10,label="ref2 weight (higher->stronger)") 79 | har_btn=gr.Button("Run !") 80 | 81 | 82 | opt["share_step"]=share_step 83 | opt["share_layer"]=share_layer 84 | opt["scale_alpha"]=alpha 85 | opt["scale_beta"]=beta 86 | 87 | out=gr.Image(scale=1,image_mode="RGB",label="Harmonized image") 88 | har_btn.click(app_har,inputs=[share_step,share_layer,alpha,beta],outputs=[out]) 89 | if __name__ == "__main__": 90 | demo.launch(debug=True, enable_queue=True) 91 | #demo.launch(server_port=8002, debug=True, enable_queue=True) -------------------------------------------------------------------------------- /tfgph_env.yaml: -------------------------------------------------------------------------------- 1 | name: tfgph 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - albumentations==1.3.0 14 | - opencv-python==4.6.0.66 15 | - imageio==2.9.0 16 | - gradio==3.44.1 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - diffusers==0.12.1 31 | - ipykernel 32 | - matplotlib 33 | - -e . 34 | -------------------------------------------------------------------------------- /tfgph_main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import PIL 5 | import torch 6 | import cv2 7 | import time 8 | 9 | import numpy as np 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | from itertools import islice 13 | from einops import rearrange, repeat 14 | from torch import autocast 15 | from pytorch_lightning import seed_everything 16 | 17 | from ldm.util import instantiate_from_config, load_model_from_config, load_img, load_model_and_get_prompt_embedding 18 | from ldm.models.diffusion.ddim import DDIMSampler 19 | from ldm.models.diffusion.dpm_solver import DPMSolverSampler 20 | from attention_control.masactrl_utils import regiter_attention_editor_ldm 21 | from attention_control.share_attention import ShareSelfAttentionControl 22 | from torchvision.utils import save_image 23 | 24 | def tfphd_main(opt): 25 | # Load Model 26 | config = OmegaConf.load(opt.config) 27 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 28 | print(f"Running on {device}") 29 | model = load_model_from_config(config, opt.ckpt) 30 | 31 | model = model.to(device) 32 | sampler = DPMSolverSampler(model) 33 | 34 | print("##----Model LOAD Success---##") 35 | 36 | # Read Image 37 | ref_1_image, target_width, target_height = load_img(opt.ref1,1) 38 | ref_1_image = repeat(ref_1_image.to(device), '1 ... -> b ...', b=1) 39 | 40 | ref_2_image, width, height= load_img(opt.ref2, 1) 41 | ref_2_image = repeat(ref_2_image.to(device), '1 ... -> b ...', b=1) 42 | 43 | composite_image, width, height= load_img(opt.comp, 1) 44 | composite_image = repeat(composite_image.to(device), '1 ... -> b ...', b=1) 45 | 46 | print("##----Image LOAD Success---##") 47 | 48 | # Reconstruct 49 | uncond_scale=2.5 50 | precision_scope = autocast 51 | with precision_scope("cuda"): 52 | c, uc, inv_emb = load_model_and_get_prompt_embedding(model, uncond_scale, device, opt.prompt, inv=True) 53 | print("Condition shape", c.shape,uc.shape,inv_emb.shape) 54 | 55 | T1 = time.time() 56 | ref_1_latent = model.get_first_stage_encoding(model.encode_first_stage(ref_1_image)) 57 | ref_2_latent = model.get_first_stage_encoding(model.encode_first_stage(ref_2_image)) 58 | composite_latent = model.get_first_stage_encoding(model.encode_first_stage(composite_image)) 59 | shape = ref_1_latent.shape[1:] 60 | z_ref_1_enc, _ = sampler.sample(steps=opt.total_steps, 61 | inv_emb=inv_emb, 62 | unconditional_conditioning=uc, 63 | conditioning=c, 64 | batch_size=1, 65 | shape=shape, 66 | verbose=False, 67 | unconditional_guidance_scale=uncond_scale, 68 | eta=0, 69 | order=opt.order, 70 | x_T=ref_1_latent, 71 | width=width, 72 | height=height, 73 | DPMencode=True, 74 | ) 75 | 76 | z_ref_2_enc, _ = sampler.sample(steps=opt.total_steps, 77 | inv_emb=inv_emb, 78 | unconditional_conditioning=uc, 79 | conditioning=c, 80 | batch_size=1, 81 | shape=shape, 82 | verbose=False, 83 | unconditional_guidance_scale=uncond_scale, 84 | eta=0, 85 | order=opt.order, 86 | x_T=ref_2_latent, 87 | DPMencode=True, 88 | width=width, 89 | height=height, 90 | ref=True, 91 | ) 92 | z_composite_enc, _ = sampler.sample(steps=opt.total_steps, 93 | inv_emb=inv_emb, 94 | unconditional_conditioning=uc, 95 | conditioning=c, 96 | batch_size=1, 97 | shape=shape, 98 | verbose=False, 99 | unconditional_guidance_scale=uncond_scale, 100 | eta=0, 101 | order=opt.order, 102 | x_T=composite_latent, 103 | DPMencode=True, 104 | width=width, 105 | height=height, 106 | ref=True, 107 | ) 108 | 109 | samples = sampler.sample(steps=opt.total_steps, 110 | inv_emb=torch.cat([inv_emb,inv_emb,inv_emb]), 111 | conditioning=torch.cat([c,c,c]), 112 | shape=shape, 113 | verbose=False, 114 | unconditional_guidance_scale=uncond_scale, 115 | unconditional_conditioning=torch.cat([uc,uc,uc]), 116 | eta=0, 117 | order=opt.order, 118 | x_T=torch.cat([z_ref_1_enc,z_ref_2_enc,z_composite_enc]), 119 | width=width, 120 | height=height, 121 | ) 122 | 123 | x_reconstruct = model.decode_first_stage(samples) 124 | x_reconstruct = torch.clamp((x_reconstruct + 1.0) / 2.0, min=0.0, max=1.0) 125 | 126 | T2 = time.time() 127 | print('Running Time: %s s' % ((T2 - T1))) 128 | names=["ref_1_recon.png","ref_2_recon.png","composite_recon.png"] 129 | exp_count=len(os.listdir(opt.outdir)) 130 | exp_path=os.path.join(opt.outdir,f"exp_{exp_count:04d}") 131 | os.mkdir(exp_path) 132 | reconstruct_path=os.path.join(exp_path,"reconstruct") 133 | os.mkdir(reconstruct_path) 134 | for x_sample,sample_name in zip(x_reconstruct,names): 135 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 136 | img = Image.fromarray(x_sample.astype(np.uint8)) 137 | img.save(os.path.join(reconstruct_path, sample_name)) 138 | print(f"Reconstruct result saving at {reconstruct_path}") 139 | 140 | print("##----Reconstuct Success---##") 141 | 142 | 143 | sim_scales = torch.tensor([opt.scale_alpha,opt.scale_beta]).to(device) 144 | 145 | with precision_scope("cuda"): 146 | # hijack the attention module (sclaed share) 147 | editor = ShareSelfAttentionControl(opt.share_step, opt.share_layer,scales=sim_scales,total_steps=opt.total_steps) 148 | regiter_attention_editor_ldm(model, editor) 149 | latents_harmonized = sampler.sample(steps=opt.total_steps, 150 | inv_emb=torch.cat([inv_emb,inv_emb,inv_emb]), 151 | conditioning=torch.cat([c,c,c]), 152 | shape=shape, 153 | verbose=False, 154 | unconditional_guidance_scale=uncond_scale, 155 | unconditional_conditioning=torch.cat([uc,uc,uc]), 156 | eta=0, 157 | order=opt.order, 158 | x_T=torch.cat([z_ref_1_enc,z_ref_2_enc,z_composite_enc]), 159 | width=width, 160 | height=height, 161 | ) 162 | x_harmonized = model.decode_first_stage(latents_harmonized) 163 | x_harmonized = torch.clamp((x_harmonized + 1.0) / 2.0, min=0.0, max=1.0) 164 | names=["ref_1_harmonized.png","ref_2_harmonized.png","composite_harmonized.png"] 165 | 166 | harmonized_path=os.path.join(exp_path, "harmonized") 167 | os.mkdir(harmonized_path) 168 | for x_sample,sample_name in zip(x_harmonized,names): 169 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 170 | img = Image.fromarray(x_sample.astype(np.uint8)) 171 | img.save(os.path.join(harmonized_path, sample_name)) 172 | 173 | composite_name=opt.comp.split("/")[-1][:-4] 174 | file_name=os.path.join(exp_path, f"{composite_name}_{opt.share_step}S_{opt.share_layer}L_total.png") 175 | save_image(torch.cat([x_reconstruct,x_harmonized]), file_name,nrow=3) 176 | print(f"Harmonized Result saved in {harmonized_path}") 177 | print("Success !!! Enjoy your results ~") 178 | 179 | 180 | 181 | if __name__ == "__main__": 182 | # TFPHD Hyper Parameter 183 | parser = argparse.ArgumentParser(description='Hyper-parameter of TFPHD') 184 | parser.add_argument('--ref1', type=str, default="./inputs/demo_input/kangaroo.jpg", help='The path of FIRST reference image (Foreground)') 185 | parser.add_argument('--ref2', type=str, default="./inputs/demo_input/starry_night.jpg", help='The path of SECOND reference image (Background)') 186 | parser.add_argument('--comp', type=str, default="./inputs/demo_input/kangaroo_starry.jpg", help='The path of COMPOSITE reference image (Copy-and-Paste)') 187 | parser.add_argument('--order', type=int, default=2, help='order=2-->DPM++, order=1-->DDIM ') 188 | parser.add_argument('--total_steps', type=int, default=20, help='Total Steps of DPM++ or DDIM ') 189 | parser.add_argument('--share_step', type=int, default=15, help='Which STEP to start share attention module') 190 | parser.add_argument('--share_layer', type=int, default=12, help='Which LAYER to start share attention module') 191 | parser.add_argument('--scale_alpha', type=float, default=0.8, help='Strength of rescale REF1') 192 | parser.add_argument('--scale_beta', type=float, default=1.2, help='Strength of rescale REF2') 193 | # TF-ICON Hyper Parameter 194 | parser.add_argument('--outdir', type=str, default="./outputs", help='Directory for saving result') 195 | parser.add_argument('--seed', type=float, default=7414, help='Radom seed of diffusion model') 196 | parser.add_argument('--ckpt', type=str, default="v2-1_512-ema-pruned.ckpt", help='ckpt of stable diffusion model') 197 | parser.add_argument('--config', type=str, default="./configs/stable-diffusion/v2-inference.yaml", help='config of stable diffusion model') 198 | parser.add_argument('--prompt', type=str, default="", help='prompt for CFG (TFPHD is prompt-free but you can add prompt if you want)') 199 | 200 | opt = parser.parse_args() 201 | 202 | tfphd_main(opt) --------------------------------------------------------------------------------