├── LICENSE ├── assets └── thumbnail.png ├── inference_single.sh ├── readme.md ├── requirements.txt ├── sample_data ├── cat │ └── cat_1.png ├── dog │ └── dog_1.png ├── horse │ └── horse_1.png ├── tree │ └── tree_1.png └── zebra │ └── zebra_1.png └── src ├── edit_once.py ├── inversion.py └── utils ├── __pycache__ ├── base_pipeline.cpython-310.pyc ├── base_pipeline.cpython-38.pyc ├── base_pipeline.cpython-39.pyc ├── cross_attention.cpython-310.pyc ├── cross_attention.cpython-38.pyc ├── cross_attention.cpython-39.pyc ├── ddim_inv.cpython-310.pyc ├── edit_directions.cpython-310.pyc ├── edit_pipeline.cpython-310.pyc ├── edit_pipeline_hdir.cpython-310.pyc ├── edit_pipeline_hdir2.cpython-310.pyc ├── edit_pipeline_hdir_noc.cpython-310.pyc ├── edit_pipeline_hdir_reg.cpython-310.pyc ├── edit_pipeline_hdir_textedit.cpython-310.pyc ├── edit_pipeline_nocross.cpython-310.pyc ├── pic_pipeline.cpython-310.pyc ├── pic_pipeline.cpython-38.pyc ├── pic_pipeline.cpython-39.pyc ├── scheduler.cpython-310.pyc ├── scheduler.cpython-38.pyc ├── scheduler.cpython-39.pyc ├── sdedit.cpython-310.pyc └── sdedit.cpython-38.pyc ├── base_pipeline.py ├── cross_attention.py ├── ddim_inv.py ├── pic_pipeline.py └── scheduler.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 junsung_kr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/assets/thumbnail.png -------------------------------------------------------------------------------- /inference_single.sh: -------------------------------------------------------------------------------- 1 | device_num="0" 2 | num_ddim_steps=50 3 | tau=25 4 | beta=0.8 5 | gamma=0.2 6 | category=cat 7 | input_image=cat_1 8 | task="cat2cat wearing glasses" 9 | 10 | CUDA_VISIBLE_DEVICES=${device_num} python src/edit_once.py \ 11 | --input_image "sample_data/${category}/${input_image}.png" \ 12 | --task_name "${task}" \ 13 | --results_folder "output" \ 14 | --num_ddim_steps "${num_ddim_steps}" \ 15 | --negative_guidance_scale 5.0 \ 16 | --tau "${tau}" \ 17 | --beta "${beta}" \ 18 | --gamma "${gamma}" \ 19 | --use_float_16 20 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Diffusion-based Image-to-Image Translation by Noise Correction via Prompt Interpolation (ECCV 2024 Poster) 2 | 3 | This is the official code of the paper "Diffusion-based Image-to-Image Translation by Noise Correction via Prompt Interpolation" in ECCV 2024. 4 | 5 | ![thumbnail](assets/thumbnail.png) 6 | 7 | ## News 8 | 9 | :star: [2024. July] Our paper is accepted in ECCV 2024! \ 10 | :star: [2024. Sep] We've uploaded our video & poster for ECCV 2024! You can check them through this [link](https://eccv.ecva.net/virtual/2024/poster/2134). Also, the official code of our paper has been released! We are still updating the code for better performance. 11 | 12 | 13 | ## Getting Started 14 | 15 | ### Installing 16 | 17 | ``` 18 | git clone https://github.com/JS-Lee525/PIC.git 19 | ``` 20 | 21 | ``` 22 | conda create -n [your_env] python=3.9 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Structures 27 | 28 | . 29 | ├── assets 30 | ├── thumbnail.png 31 | ├── sample_data 32 | ├── cat 33 | ├── cat_1.png 34 | ├── dog 35 | ├── dog_1.png 36 | ├── horse 37 | ├── horse_1.png 38 | ├── tree 39 | ├── tree_1.png 40 | ├── zebra 41 | ├── zebra_1.png 42 | ├── src 43 | ├── ... 44 | ├── inference_single.sh # Command File for editing images 45 | ├── requirements.txt 46 | ├── LICENSE 47 | └── readme.md 48 | 49 | ### Execution 50 | 51 | ``` 52 | sh inference_single.sh 53 | ``` 54 | 55 | You can follow the details of this sh file. 56 | - device_num: your GPU number 57 | - num_ddim_steps: diffusion steps for the inference (default = 50) 58 | - tau: steps of editing images in reverse process (default = 25) 59 | - beta: hyperparameter used in initalization of prompt interpolation (default = 0.3 for word swap / 0.8 for adding phrases) 60 | - gamma: hyperparameter of controlling the corretion term (default = 0.2, gamma * negative_guidance_scale is used in the code.) 61 | - task: translation tasks in the format of "{source phrase}2{target phrase}". For example, if the task is set to 'cat2dog' and source prompt is created as 'a cat is lying on the grass', the target prompt will be set 'a dog is lying on the grass'. Note that the source phrase should be included in the source prompt. You can check the source prompt when you execute this command. 62 | 63 | 64 | ## Citation 65 | ``` 66 | @article{lee2024diffusion, 67 | title={Diffusion-Based Image-to-Image Translation by Noise Correction via Prompt Interpolation}, 68 | author={Lee, Junsung and Kang, Minsoo and Han, Bohyung}, 69 | journal={arXiv preprint arXiv:2409.08077}, 70 | year={2024} 71 | } 72 | ``` 73 | ## License 74 | 75 | This project is licensed under the MIT License. 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | altair==5.2.0 5 | annotated-types==0.6.0 6 | antlr4-python3-runtime==4.9.3 7 | asttokens==2.4.1 8 | async-timeout==4.0.3 9 | attrs==23.2.0 10 | bleach==6.1.0 11 | blinker==1.7.0 12 | blis==0.7.11 13 | braceexpand==0.1.7 14 | cachetools==5.3.3 15 | catalogue==2.0.10 16 | certifi==2024.2.2 17 | cfgv==3.4.0 18 | charset-normalizer==3.3.2 19 | click==8.1.7 20 | clip==0.2.0 21 | cloudpathlib==0.16.0 22 | confection==0.1.4 23 | contexttimer==0.3.3 24 | contourpy==1.2.0 25 | cycler==0.12.1 26 | cymem==2.0.8 27 | decorator==5.1.1 28 | decord==0.6.0 29 | diffusers==0.12.1 30 | distlib==0.3.8 31 | einops==0.7.0 32 | exceptiongroup==1.2.0 33 | executing==2.0.1 34 | fairscale==0.4.4 35 | filelock==3.13.1 36 | fonttools==4.49.0 37 | frozenlist==1.4.1 38 | fsspec==2024.2.0 39 | ftfy==6.1.3 40 | gitdb==4.0.11 41 | GitPython==3.1.42 42 | huggingface-hub==0.21.4 43 | identify==2.5.35 44 | idna==3.6 45 | imageio==2.34.0 46 | importlib_metadata==7.0.2 47 | iopath==0.1.10 48 | jedi==0.19.1 49 | Jinja2==3.1.3 50 | jsonschema==4.21.1 51 | jsonschema-specifications==2023.12.1 52 | kaggle==1.6.6 53 | kiwisolver==1.4.5 54 | langcodes==3.3.0 55 | lazy_loader==0.3 56 | lightning-utilities==0.10.1 57 | markdown-it-py==3.0.0 58 | MarkupSafe==2.1.5 59 | matplotlib==3.8.3 60 | matplotlib-inline==0.1.6 61 | mdurl==0.1.2 62 | mpmath==1.3.0 63 | multidict==6.0.5 64 | murmurhash==1.0.10 65 | networkx==3.2.1 66 | nodeenv==1.8.0 67 | numpy==1.26.4 68 | nvidia-cublas-cu12==12.1.3.1 69 | nvidia-cuda-cupti-cu12==12.1.105 70 | nvidia-cuda-nvrtc-cu12==12.1.105 71 | nvidia-cuda-runtime-cu12==12.1.105 72 | nvidia-cudnn-cu12==8.9.2.26 73 | nvidia-cufft-cu12==11.0.2.54 74 | nvidia-curand-cu12==10.3.2.106 75 | nvidia-cusolver-cu12==11.4.5.107 76 | nvidia-cusparse-cu12==12.1.0.106 77 | nvidia-nccl-cu12==2.19.3 78 | nvidia-nvjitlink-cu12==12.4.99 79 | nvidia-nvtx-cu12==12.1.105 80 | omegaconf==2.3.0 81 | opencv-python==4.8.0.74 82 | opencv-python-headless==4.5.5.64 83 | opendatasets==0.1.22 84 | packaging==23.2 85 | pandas==2.2.1 86 | parso==0.8.3 87 | pexpect==4.9.0 88 | pillow==10.2.0 89 | platformdirs==4.2.0 90 | plotly==5.19.0 91 | portalocker==2.8.2 92 | pre-commit==3.6.2 93 | preshed==3.0.9 94 | prompt-toolkit==3.0.43 95 | protobuf==4.25.3 96 | psutil==5.9.8 97 | ptyprocess==0.7.0 98 | pure-eval==0.2.2 99 | pyarrow==15.0.1 100 | pycocoevalcap==1.2 101 | pycocotools==2.0.7 102 | pydantic==2.6.3 103 | pydantic_core==2.16.3 104 | pydeck==0.8.1b0 105 | Pygments==2.17.2 106 | pyparsing==3.1.2 107 | python-dateutil==2.9.0.post0 108 | python-magic==0.4.27 109 | python-slugify==8.0.4 110 | pytorch-lightning==2.2.1 111 | pytz==2024.1 112 | PyYAML==6.0.1 113 | referencing==0.33.0 114 | regex==2023.12.25 115 | requests==2.31.0 116 | rich==13.7.1 117 | rpds-py==0.18.0 118 | safetensors==0.4.2 119 | salesforce-lavis==1.0.2 120 | scikit-image==0.22.0 121 | scipy==1.12.0 122 | sentencepiece==0.2.0 123 | six==1.16.0 124 | smart-open==6.4.0 125 | smmap==5.0.1 126 | spacy==3.7.4 127 | spacy-legacy==3.0.12 128 | spacy-loggers==1.0.5 129 | srsly==2.4.8 130 | stack-data==0.6.3 131 | streamlit==1.32.0 132 | sympy==1.12 133 | tenacity==8.2.3 134 | text-unidecode==1.3 135 | thinc==8.2.3 136 | tifffile==2024.2.12 137 | timm==0.4.12 138 | tokenizers==0.13.3 139 | toml==0.10.2 140 | toolz==0.12.1 141 | torch==2.2.1 142 | torchmetrics==1.3.1 143 | torchvision==0.17.1 144 | tornado==6.4 145 | tqdm==4.66.2 146 | traitlets==5.14.1 147 | transformers==4.26.1 148 | triton==2.2.0 149 | typer==0.9.0 150 | typing_extensions==4.10.0 151 | tzdata==2024.1 152 | urllib3==2.2.1 153 | virtualenv==20.25.1 154 | wasabi==1.1.2 155 | watchdog==4.0.0 156 | wcwidth==0.2.13 157 | weasel==0.3.4 158 | webdataset==0.2.86 159 | webencodings==0.5.1 160 | yarl==1.9.4 161 | zipp==3.17.0 162 | -------------------------------------------------------------------------------- /sample_data/cat/cat_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/sample_data/cat/cat_1.png -------------------------------------------------------------------------------- /sample_data/dog/dog_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/sample_data/dog/dog_1.png -------------------------------------------------------------------------------- /sample_data/horse/horse_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/sample_data/horse/horse_1.png -------------------------------------------------------------------------------- /sample_data/tree/tree_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/sample_data/tree/tree_1.png -------------------------------------------------------------------------------- /sample_data/zebra/zebra_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/sample_data/zebra/zebra_1.png -------------------------------------------------------------------------------- /src/edit_once.py: -------------------------------------------------------------------------------- 1 | import os, pdb 2 | 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import requests 7 | import glob 8 | from PIL import Image 9 | import argparse 10 | import os 11 | import copy 12 | import os, sys 13 | 14 | import torch 15 | from PIL import Image 16 | import glob 17 | from tqdm.autonotebook import tqdm 18 | 19 | 20 | import cv2 21 | import matplotlib.pyplot as plt 22 | 23 | from lavis.models import load_model_and_preprocess 24 | from diffusers import DDIMScheduler 25 | from utils.pic_pipeline import PicPipeline 26 | from pytorch_lightning import seed_everything 27 | from utils.scheduler import DDIMInverseScheduler 28 | 29 | if torch.cuda.is_available(): 30 | device = "cuda" 31 | else: 32 | device = "cpu" 33 | 34 | import torch 35 | from PIL import Image, ImageDraw, ImageFont 36 | import glob 37 | from tqdm.autonotebook import tqdm 38 | 39 | if __name__=="__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png') 42 | parser.add_argument('--task_name', type=str, default='cat2dog') 43 | parser.add_argument('--results_folder', type=str, default='output/test_cat') 44 | parser.add_argument('--num_ddim_steps', type=int, default=50) 45 | parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') 46 | parser.add_argument('--negative_guidance_scale', default=5.0, type=float) 47 | parser.add_argument('--use_float_16', action='store_true') 48 | parser.add_argument('--tau', type=int, default=25) 49 | parser.add_argument('--beta', type=float, default=0.0) 50 | parser.add_argument('--gamma', type=float, default=0.0) 51 | parser.add_argument('--use_wordswap', action='store_true') 52 | args = parser.parse_args() 53 | 54 | seed_everything(42) 55 | os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True) 56 | os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True) 57 | 58 | if args.use_float_16: 59 | torch_dtype = torch.float16 60 | else: 61 | torch_dtype = torch.float32 62 | 63 | model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device(device)) 64 | 65 | pipe = PicPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to(device) 66 | pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) 67 | 68 | if args.use_float_16: 69 | torch_dtype = torch.float16 70 | else: 71 | torch_dtype = torch.float32 72 | 73 | if os.path.isdir(args.input_image): 74 | l_img_paths = sorted(glob.glob(os.path.join(args.input_image, "*.png"))) 75 | else: 76 | l_img_paths = [args.input_image] 77 | 78 | for img_path in l_img_paths: 79 | bname = os.path.basename(img_path).split(".")[0] 80 | img_num = int(img_path.split('/')[-1].split('.')[0].split('_')[-1]) 81 | img = Image.open(img_path).convert('RGB').resize((512,512), Image.Resampling.LANCZOS) 82 | # generate the caption 83 | _image = vis_processors["eval"](img).unsqueeze(0).to(device) 84 | prompt_str = model_blip.generate({"image": _image})[0] 85 | 86 | x_img, x_rec = pipe( 87 | prompt_str, 88 | guidance_scale_for=1, 89 | guidance_scale_rev=args.negative_guidance_scale, 90 | num_inversion_steps=args.num_ddim_steps, 91 | img=img, 92 | torch_dtype=torch_dtype, 93 | tau=args.tau, 94 | task_name=args.task_name, 95 | use_wordswap=args.use_wordswap, 96 | beta=args.beta, 97 | gamma=args.gamma, 98 | ) 99 | 100 | bname = os.path.basename(img_path).split(".")[0] 101 | x_img[0].save(os.path.join(args.results_folder, f"edit/{bname}.png")) 102 | x_rec[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png")) 103 | -------------------------------------------------------------------------------- /src/inversion.py: -------------------------------------------------------------------------------- 1 | import os, pdb 2 | from glob import glob 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import requests 7 | from PIL import Image 8 | 9 | from lavis.models import load_model_and_preprocess 10 | from pytorch_lightning import seed_everything 11 | from utils.ddim_inv import DDIMInversion 12 | from utils.scheduler import DDIMInverseScheduler 13 | 14 | if torch.cuda.is_available(): 15 | device = "cuda" 16 | else: 17 | device = "cpu" 18 | 19 | if __name__=="__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png') 22 | parser.add_argument('--results_folder', type=str, default='output/test_cat') 23 | parser.add_argument('--num_ddim_steps', type=int, default=50) 24 | parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') 25 | parser.add_argument('--use_float_16', action='store_true') 26 | args = parser.parse_args() 27 | 28 | # make the output folders 29 | os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True) 30 | os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True) 31 | 32 | if args.use_float_16: 33 | torch_dtype = torch.float16 34 | else: 35 | torch_dtype = torch.float32 36 | 37 | seed_everything(42) 38 | # load the BLIP model 39 | model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device(device)) 40 | # make the DDIM inversion pipeline 41 | pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to(device) 42 | pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) 43 | 44 | 45 | # if the input is a folder, collect all the images as a list 46 | if os.path.isdir(args.input_image): 47 | l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png"))) 48 | else: 49 | l_img_paths = [args.input_image] 50 | 51 | 52 | for img_path in l_img_paths: 53 | bname = os.path.basename(img_path).split(".")[0] 54 | img = Image.open(img_path).convert('RGB').resize((512,512), Image.Resampling.LANCZOS) 55 | # generate the caption 56 | _image = vis_processors["eval"](img).unsqueeze(0).to(device) 57 | prompt_str = model_blip.generate({"image": _image})[0] 58 | x_inv, x_inv_image, x_dec_img = pipe( 59 | prompt_str, 60 | guidance_scale=1, 61 | num_inversion_steps=args.num_ddim_steps, 62 | img=img, 63 | torch_dtype=torch_dtype 64 | ) 65 | # save the inversion 66 | torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt")) 67 | # save the prompt string 68 | with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f: 69 | f.write(prompt_str) 70 | -------------------------------------------------------------------------------- /src/utils/__pycache__/base_pipeline.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/base_pipeline.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/base_pipeline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/base_pipeline.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/base_pipeline.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/base_pipeline.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/cross_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/cross_attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/cross_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/cross_attention.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/cross_attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/cross_attention.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/ddim_inv.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/ddim_inv.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_directions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_directions.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline_hdir.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline_hdir.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline_hdir2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline_hdir2.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline_hdir_noc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline_hdir_noc.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline_hdir_reg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline_hdir_reg.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline_hdir_textedit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline_hdir_textedit.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/edit_pipeline_nocross.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/edit_pipeline_nocross.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/pic_pipeline.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/pic_pipeline.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/pic_pipeline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/pic_pipeline.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/pic_pipeline.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/pic_pipeline.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/sdedit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/sdedit.cpython-310.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/sdedit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jslee525/PIC/c7cb95a6ea16c4c0ca43c4fc46c1df4a2d426da4/src/utils/__pycache__/sdedit.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/base_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import inspect 4 | from packaging import version 5 | from typing import Any, Callable, Dict, List, Optional, Union 6 | 7 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 8 | from diffusers import DiffusionPipeline 9 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 10 | from diffusers.schedulers import KarrasDiffusionSchedulers 11 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring 12 | from diffusers import StableDiffusionPipeline 13 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 14 | 15 | 16 | 17 | class BasePipeline(DiffusionPipeline): 18 | _optional_components = ["safety_checker", "feature_extractor"] 19 | def __init__( 20 | self, 21 | vae: AutoencoderKL, 22 | text_encoder: CLIPTextModel, 23 | tokenizer: CLIPTokenizer, 24 | unet: UNet2DConditionModel, 25 | scheduler: KarrasDiffusionSchedulers, 26 | safety_checker: StableDiffusionSafetyChecker, 27 | feature_extractor: CLIPFeatureExtractor, 28 | requires_safety_checker: bool = True, 29 | ): 30 | super().__init__() 31 | 32 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 33 | deprecation_message = ( 34 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 35 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 36 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 37 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 38 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 39 | " file" 40 | ) 41 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 42 | new_config = dict(scheduler.config) 43 | new_config["steps_offset"] = 1 44 | scheduler._internal_dict = FrozenDict(new_config) 45 | 46 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 47 | deprecation_message = ( 48 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 49 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 50 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 51 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 52 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 53 | ) 54 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 55 | new_config = dict(scheduler.config) 56 | new_config["clip_sample"] = False 57 | scheduler._internal_dict = FrozenDict(new_config) 58 | 59 | if safety_checker is None and requires_safety_checker: 60 | logger.warning( 61 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 62 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 63 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 64 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 65 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 66 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 67 | ) 68 | 69 | if safety_checker is not None and feature_extractor is None: 70 | raise ValueError( 71 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 72 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 73 | ) 74 | 75 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 76 | version.parse(unet.config._diffusers_version).base_version 77 | ) < version.parse("0.9.0.dev0") 78 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 79 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 80 | deprecation_message = ( 81 | "The configuration file of the unet has set the default `sample_size` to smaller than" 82 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 83 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 84 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 85 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 86 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 87 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 88 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 89 | " the `unet/config.json` file" 90 | ) 91 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 92 | new_config = dict(unet.config) 93 | new_config["sample_size"] = 64 94 | unet._internal_dict = FrozenDict(new_config) 95 | 96 | self.register_modules( 97 | vae=vae, 98 | text_encoder=text_encoder, 99 | tokenizer=tokenizer, 100 | unet=unet, 101 | scheduler=scheduler, 102 | safety_checker=safety_checker, 103 | feature_extractor=feature_extractor, 104 | ) 105 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 106 | self.register_to_config(requires_safety_checker=requires_safety_checker) 107 | 108 | @property 109 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device 110 | def _execution_device(self): 111 | r""" 112 | Returns the device on which the pipeline's models will be executed. After calling 113 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 114 | hooks. 115 | """ 116 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 117 | return self.device 118 | for module in self.unet.modules(): 119 | if ( 120 | hasattr(module, "_hf_hook") 121 | and hasattr(module._hf_hook, "execution_device") 122 | and module._hf_hook.execution_device is not None 123 | ): 124 | return torch.device(module._hf_hook.execution_device) 125 | return self.device 126 | 127 | 128 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 129 | def _encode_prompt( 130 | self, 131 | prompt, 132 | device, 133 | num_images_per_prompt, 134 | do_classifier_free_guidance, 135 | negative_prompt=None, 136 | prompt_embeds: Optional[torch.FloatTensor] = None, 137 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 138 | ): 139 | r""" 140 | Encodes the prompt into text encoder hidden states. 141 | 142 | Args: 143 | prompt (`str` or `List[str]`, *optional*): 144 | prompt to be encoded 145 | device: (`torch.device`): 146 | torch device 147 | num_images_per_prompt (`int`): 148 | number of images that should be generated per prompt 149 | do_classifier_free_guidance (`bool`): 150 | whether to use classifier free guidance or not 151 | negative_ prompt (`str` or `List[str]`, *optional*): 152 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 153 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 154 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 155 | prompt_embeds (`torch.FloatTensor`, *optional*): 156 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 157 | provided, text embeddings will be generated from `prompt` input argument. 158 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 159 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 160 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 161 | argument. 162 | """ 163 | if prompt is not None and isinstance(prompt, str): 164 | batch_size = 1 165 | elif prompt is not None and isinstance(prompt, list): 166 | batch_size = len(prompt) 167 | else: 168 | batch_size = prompt_embeds.shape[0] 169 | 170 | if prompt_embeds is None: 171 | text_inputs = self.tokenizer( 172 | prompt, 173 | padding="max_length", 174 | max_length=self.tokenizer.model_max_length, 175 | truncation=True, 176 | return_tensors="pt", 177 | ) 178 | text_input_ids = text_inputs.input_ids 179 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 180 | 181 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 182 | text_input_ids, untruncated_ids 183 | ): 184 | removed_text = self.tokenizer.batch_decode( 185 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 186 | ) 187 | logger.warning( 188 | "The following part of your input was truncated because CLIP can only handle sequences up to" 189 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 190 | ) 191 | 192 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 193 | attention_mask = text_inputs.attention_mask.to(device) 194 | else: 195 | attention_mask = None 196 | 197 | prompt_embeds = self.text_encoder( 198 | text_input_ids.to(device), 199 | attention_mask=attention_mask, 200 | ) 201 | prompt_embeds = prompt_embeds[0] 202 | 203 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 204 | 205 | bs_embed, seq_len, _ = prompt_embeds.shape 206 | # duplicate text embeddings for each generation per prompt, using mps friendly method 207 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 208 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 209 | 210 | # get unconditional embeddings for classifier free guidance 211 | if do_classifier_free_guidance and negative_prompt_embeds is None: 212 | uncond_tokens: List[str] 213 | if negative_prompt is None: 214 | uncond_tokens = [""] * batch_size 215 | elif type(prompt) is not type(negative_prompt): 216 | raise TypeError( 217 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 218 | f" {type(prompt)}." 219 | ) 220 | elif isinstance(negative_prompt, str): 221 | uncond_tokens = [negative_prompt] 222 | elif batch_size != len(negative_prompt): 223 | raise ValueError( 224 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 225 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 226 | " the batch size of `prompt`." 227 | ) 228 | else: 229 | uncond_tokens = negative_prompt 230 | 231 | max_length = prompt_embeds.shape[1] 232 | uncond_input = self.tokenizer( 233 | uncond_tokens, 234 | padding="max_length", 235 | max_length=max_length, 236 | truncation=True, 237 | return_tensors="pt", 238 | ) 239 | 240 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 241 | attention_mask = uncond_input.attention_mask.to(device) 242 | else: 243 | attention_mask = None 244 | 245 | negative_prompt_embeds = self.text_encoder( 246 | uncond_input.input_ids.to(device), 247 | attention_mask=attention_mask, 248 | ) 249 | negative_prompt_embeds = negative_prompt_embeds[0] 250 | 251 | if do_classifier_free_guidance: 252 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 253 | seq_len = negative_prompt_embeds.shape[1] 254 | 255 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 256 | 257 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 258 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 259 | 260 | # For classifier free guidance, we need to do two forward passes. 261 | # Here we concatenate the unconditional and text embeddings into a single batch 262 | # to avoid doing two forward passes 263 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 264 | 265 | return prompt_embeds 266 | 267 | 268 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 269 | def decode_latents(self, latents): 270 | latents = 1 / 0.18215 * latents 271 | image = self.vae.decode(latents).sample 272 | image = (image / 2 + 0.5).clamp(0, 1) 273 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 274 | image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() 275 | return image 276 | 277 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 278 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 279 | if isinstance(generator, list) and len(generator) != batch_size: 280 | raise ValueError( 281 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 282 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 283 | ) 284 | 285 | if latents is None: 286 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 287 | else: 288 | latents = latents.to(device) 289 | 290 | # scale the initial noise by the standard deviation required by the scheduler 291 | latents = latents * self.scheduler.init_noise_sigma 292 | return latents 293 | 294 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 295 | def prepare_extra_step_kwargs(self, generator, eta): 296 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 297 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 298 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 299 | # and should be between [0, 1] 300 | 301 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 302 | extra_step_kwargs = {} 303 | if accepts_eta: 304 | extra_step_kwargs["eta"] = eta 305 | 306 | # check if the scheduler accepts generator 307 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 308 | if accepts_generator: 309 | extra_step_kwargs["generator"] = generator 310 | return extra_step_kwargs 311 | 312 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 313 | def run_safety_checker(self, image, device, dtype): 314 | if self.safety_checker is not None: 315 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 316 | image, has_nsfw_concept = self.safety_checker( 317 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 318 | ) 319 | else: 320 | has_nsfw_concept = None 321 | return image, has_nsfw_concept 322 | 323 | -------------------------------------------------------------------------------- /src/utils/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffusers.models.attention import CrossAttention 4 | 5 | class MyConv2D(nn.Module): 6 | def __init__(self, c): 7 | super(). __init__() 8 | self.c = c 9 | 10 | def forward(self, input): 11 | self.h = self.c(input) 12 | 13 | return self.h 14 | 15 | class MyCrossAttnProcessor: 16 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 17 | batch_size, sequence_length, _ = hidden_states.shape 18 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 19 | 20 | query = attn.to_q(hidden_states) 21 | 22 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 23 | key = attn.to_k(encoder_hidden_states) 24 | value = attn.to_v(encoder_hidden_states) 25 | 26 | query = attn.head_to_batch_dim(query) 27 | key = attn.head_to_batch_dim(key) 28 | value = attn.head_to_batch_dim(value) 29 | 30 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 31 | # new bookkeeping to save the attn probs 32 | attn.attn_probs = attention_probs 33 | 34 | hidden_states = torch.bmm(attention_probs, value) 35 | hidden_states = attn.batch_to_head_dim(hidden_states) 36 | 37 | # linear proj 38 | hidden_states = attn.to_out[0](hidden_states) 39 | # dropout 40 | hidden_states = attn.to_out[1](hidden_states) 41 | 42 | return hidden_states 43 | 44 | 45 | """ 46 | A function that prepares a U-Net model for training by enabling gradient computation 47 | for a specified set of parameters and setting the forward pass to be performed by a 48 | custom cross attention processor. 49 | 50 | Parameters: 51 | unet: A U-Net model. 52 | 53 | Returns: 54 | unet: The prepared U-Net model. 55 | """ 56 | def prep_unet(unet): 57 | 58 | # set the gradients for XA maps to be true 59 | for name, params in unet.named_parameters(): 60 | if 'attn2' in name: 61 | params.requires_grad = False 62 | else: 63 | params.requires_grad = False 64 | # replace the fwd function 65 | for name, module in unet.named_modules(): 66 | module_name = type(module).__name__ 67 | if module_name == "CrossAttention": 68 | module.set_processor(MyCrossAttnProcessor()) 69 | if name == 'mid_block.resnets.1.conv2': 70 | c = unet.mid_block.resnets[1].conv2 71 | unet.mid_block.resnets[1].conv2 = MyConv2D(c) 72 | return unet -------------------------------------------------------------------------------- /src/utils/ddim_inv.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from random import randrange 6 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 7 | from diffusers import DDIMScheduler 8 | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput 9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 10 | sys.path.insert(0, "src/utils") 11 | from base_pipeline import BasePipeline 12 | from cross_attention import prep_unet 13 | 14 | 15 | if torch.cuda.is_available(): 16 | device = "cuda" 17 | else: 18 | device = "cpu" 19 | 20 | 21 | class DDIMInversion(BasePipeline): 22 | 23 | def auto_corr_loss(self, x, random_shift=True): 24 | B,C,H,W = x.shape 25 | assert B==1 26 | x = x.squeeze(0) 27 | # x must be shape [C,H,W] now 28 | reg_loss = 0.0 29 | for ch_idx in range(x.shape[0]): 30 | noise = x[ch_idx][None, None,:,:] 31 | while True: 32 | if random_shift: roll_amount = randrange(noise.shape[2]//2) 33 | else: roll_amount = 1 34 | reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2 35 | reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2 36 | if noise.shape[2] <= 8: 37 | break 38 | noise = F.avg_pool2d(noise, kernel_size=2) 39 | return reg_loss 40 | 41 | def kl_divergence(self, x): 42 | _mu = x.mean() 43 | _var = x.var() 44 | return _var + _mu**2 - 1 - torch.log(_var+1e-7) 45 | 46 | 47 | def __call__( 48 | self, 49 | prompt: Union[str, List[str]] = None, 50 | num_inversion_steps: int = 50, 51 | guidance_scale: float = 7.5, 52 | negative_prompt: Optional[Union[str, List[str]]] = None, 53 | num_images_per_prompt: Optional[int] = 1, 54 | eta: float = 0.0, 55 | output_type: Optional[str] = "pil", 56 | return_dict: bool = True, 57 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 58 | img=None, # the input image as a PIL image 59 | torch_dtype=torch.float32, 60 | 61 | # inversion regularization parameters 62 | lambda_ac: float = 20.0, 63 | lambda_kl: float = 20.0, 64 | num_reg_steps: int = 5, 65 | num_ac_rolls: int = 5, 66 | ): 67 | 68 | # 0. modify the unet to be useful :D 69 | self.unet = prep_unet(self.unet) 70 | 71 | # set the scheduler to be the Inverse DDIM scheduler 72 | # self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config) 73 | 74 | device = self._execution_device 75 | do_classifier_free_guidance = guidance_scale > 1.0 76 | self.scheduler.set_timesteps(num_inversion_steps, device=device) 77 | timesteps = self.scheduler.timesteps 78 | 79 | # Encode the input image with the first stage model 80 | x0 = np.array(img)/255 81 | x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).to(device) 82 | x0 = (x0 - 0.5) * 2. 83 | with torch.no_grad(): 84 | x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype) 85 | latents = x0_enc = 0.18215 * x0_enc 86 | 87 | # Decode and return the image 88 | with torch.no_grad(): 89 | x0_dec = self.decode_latents(x0_enc.detach()) 90 | image_x0_dec = self.numpy_to_pil(x0_dec) 91 | 92 | with torch.no_grad(): 93 | prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device) 94 | extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta) 95 | 96 | # Do the inversion 97 | num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0? 98 | with self.progress_bar(total=num_inversion_steps) as progress_bar: 99 | for i, t in enumerate(timesteps.flip(0)[1:-1]): 100 | # expand the latents if we are doing classifier free guidance 101 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 102 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 103 | 104 | # predict the noise residual 105 | with torch.no_grad(): 106 | noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample 107 | 108 | # perform guidance 109 | if do_classifier_free_guidance: 110 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 111 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 112 | 113 | # regularization of the noise prediction 114 | e_t = noise_pred 115 | for _outer in range(num_reg_steps): 116 | if lambda_ac>0: 117 | for _inner in range(num_ac_rolls): 118 | _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) 119 | l_ac = self.auto_corr_loss(_var) 120 | l_ac.backward() 121 | _grad = _var.grad.detach()/num_ac_rolls 122 | e_t = e_t - lambda_ac*_grad 123 | if lambda_kl>0: 124 | _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) 125 | l_kld = self.kl_divergence(_var) 126 | l_kld.backward() 127 | _grad = _var.grad.detach() 128 | e_t = e_t - lambda_kl*_grad 129 | e_t = e_t.detach() 130 | noise_pred = e_t 131 | 132 | # compute the previous noisy sample x_t -> x_t-1 133 | latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample 134 | 135 | # call the callback, if provided 136 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 137 | progress_bar.update() 138 | 139 | 140 | x_inv = latents.detach().clone() 141 | # reconstruct the image 142 | 143 | # 8. Post-processing 144 | image = self.decode_latents(latents.detach()) 145 | image = self.numpy_to_pil(image) 146 | return x_inv, image, image_x0_dec 147 | -------------------------------------------------------------------------------- /src/utils/pic_pipeline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision.transforms as T 6 | from torch.optim.optimizer import Optimizer 7 | import math 8 | 9 | from random import randrange 10 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 11 | from diffusers import DDIMScheduler 12 | import time 13 | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput 14 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 15 | sys.path.insert(0, "src/utils") 16 | from base_pipeline import BasePipeline 17 | from cross_attention import prep_unet 18 | from tqdm.autonotebook import tqdm 19 | import clip 20 | import torchvision.models as models 21 | from pytorch_lightning import seed_everything 22 | # from vgg import Vgg16 23 | 24 | from PIL import Image 25 | from utils.scheduler import DDIMInverseScheduler 26 | 27 | if torch.cuda.is_available(): 28 | device = "cuda" 29 | else: 30 | device = "cpu" 31 | 32 | class PicPipeline(BasePipeline): 33 | 34 | def __call__( 35 | self, 36 | prompt: Union[str, List[str]] = None, 37 | num_inversion_steps: int = 50, 38 | guidance_scale_for: float = 7.5, 39 | guidance_scale_rev: float = 7.5, 40 | task_name = None, 41 | negative_prompt: Optional[Union[str, List[str]]] = None, 42 | num_images_per_prompt: Optional[int] = 1, 43 | eta: float = 0.0, 44 | output_type: Optional[str] = "pil", 45 | return_dict: bool = True, 46 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 47 | prompt_embeds: Optional[torch.FloatTensor] = None, 48 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 49 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 50 | height: Optional[int] = None, 51 | width: Optional[int] = None, 52 | img=None, # the input image as a PIL image 53 | torch_dtype=torch.float32, 54 | 55 | use_wordswap: bool = False, 56 | use_lowvram: bool = False, 57 | is_synthetic: bool = False, 58 | beta: float = 0.0, 59 | tau: int = 50, 60 | gamma: float = 0.0, 61 | ): 62 | print(f"Use Low VRAM? (Computational speed can be slowed down.): {use_lowvram}") 63 | 64 | device = self._execution_device 65 | seed_everything(42) 66 | 67 | self.unet = prep_unet(self.unet) 68 | 69 | do_classifier_free_guidance = guidance_scale_for > 1.0 70 | 71 | self.scheduler = DDIMInverseScheduler.from_config(self.scheduler.config) 72 | self.scheduler.set_timesteps(num_inversion_steps, device=device) 73 | timesteps = self.scheduler.timesteps 74 | 75 | negative_prompt = prompt 76 | 77 | x0 = np.array(img)/255 78 | x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).to(device) 79 | x0 = (x0 - 0.5) * 2. 80 | 81 | with torch.no_grad(): 82 | x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype) 83 | 84 | latents = x0_enc = 0.18215 * x0_enc 85 | 86 | with torch.no_grad(): 87 | x0_dec = self.decode_latents(x0_enc.detach()) 88 | 89 | with torch.no_grad(): 90 | prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=None, negative_prompt_embeds=negative_prompt_embeds,) 91 | 92 | 93 | prompt_change, idx_list, is_added = text_swap(task_name, prompt, device) 94 | 95 | print(idx_list) 96 | print(f'Source Prompt: {prompt}') 97 | print(f'Target Prompt: {prompt_change}') 98 | 99 | with torch.no_grad(): 100 | prompt_to = self._encode_prompt(prompt_change, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) 101 | 102 | extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta) 103 | 104 | latent_save = {} 105 | eps_save = {} 106 | 107 | num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order 108 | 109 | with self.progress_bar(total=num_inversion_steps) as progress_bar: 110 | for i, t in enumerate(timesteps.flip(0)[1:-1]): 111 | 112 | latent_save[t.item()] = latents.detach().clone() 113 | 114 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 115 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 116 | 117 | with torch.no_grad(): 118 | noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample 119 | 120 | if do_classifier_free_guidance: 121 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 122 | noise_pred = noise_pred_uncond + guidance_scale_for * (noise_pred_text - noise_pred_uncond) 123 | 124 | latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample 125 | 126 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 127 | progress_bar.update() 128 | 129 | x_inv = latents.detach().clone() 130 | num_inference_steps = num_inversion_steps 131 | self.scheduler = DDIMScheduler.from_config(self.scheduler.config) 132 | height = height or self.unet.config.sample_size * self.vae_scale_factor 133 | width = width or self.unet.config.sample_size * self.vae_scale_factor 134 | 135 | if prompt is not None and isinstance(prompt, str): 136 | batch_size = 1 137 | elif prompt is not None and isinstance(prompt, list): 138 | batch_size = len(prompt) 139 | else: 140 | batch_size = prompt_embeds.shape[0] 141 | 142 | device = self._execution_device 143 | do_classifier_free_guidance = guidance_scale_rev > 1.0 144 | 145 | if is_synthetic: 146 | x_in = torch.randn((1,4,64,64)).to(dtype=self.unet.dtype, device=self._execution_device) 147 | prompt = input('[Synthetic Mode] Prompt: ') 148 | negative_prompt = prompt 149 | if use_wordswap: 150 | prompt_change, idx_list = text_swap(task_name, prompt, device) 151 | # prompt_change = input(f'prompt: {prompt} -> ') 152 | 153 | print(f'{prompt} -> {prompt_change}, {idx_list[0]}') 154 | else: 155 | prompt_change = input(f'prompt: {prompt} -> ') 156 | 157 | else: 158 | x_in = x_inv.to(dtype=self.unet.dtype, device=self._execution_device) 159 | 160 | del latents, x_inv 161 | 162 | self.scheduler.set_timesteps(num_inference_steps, device=device) 163 | timesteps = self.scheduler.timesteps 164 | 165 | num_channels_latents = self.unet.in_channels 166 | 167 | latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,) 168 | latents_init = latents.clone() 169 | 170 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 171 | 172 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 173 | 174 | with torch.no_grad(): 175 | prompt_to = self._encode_prompt(prompt_change, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=None, negative_prompt_embeds=negative_prompt_embeds,) 176 | 177 | with torch.no_grad(): 178 | prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=None, negative_prompt_embeds=negative_prompt_embeds,) 179 | 180 | prompt_embeds_edit = prompt_embeds.clone() 181 | 182 | prompt_embeds_edit[1:2] = prompt_to[1:2].clone() 183 | 184 | with torch.no_grad(): 185 | with self.progress_bar(total=num_inference_steps) as progress_bar: 186 | for i, t in enumerate(timesteps): 187 | 188 | 189 | alpha = (1-beta) * ((i+1)/tau) + beta 190 | 191 | noise_pred = pred_noise(self, latents, t, prompt_embeds[0:1], False, guidance_scale_rev, cross_attention_kwargs, use_lowvram) 192 | 193 | 194 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 195 | 196 | if i < tau: 197 | eps_save[t.item()] = noise_pred.detach().clone() 198 | 199 | 200 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 201 | progress_bar.update() 202 | 203 | image_rec = self.numpy_to_pil(self.decode_latents(latents.detach())) 204 | latents = latents_init.clone() 205 | 206 | with self.progress_bar(total=num_inference_steps) as progress_bar: 207 | for i, t in enumerate(timesteps): 208 | 209 | alpha = (1-beta) * ((i+1)/tau) + beta 210 | with torch.no_grad(): 211 | 212 | if i < tau: 213 | 214 | noise_src = eps_save[t.item()].clone() 215 | 216 | prompt_embeds_star = interpolate_text(alpha, prompt_embeds, prompt_embeds_edit, idx_list).clone() 217 | 218 | noise_pred_star = pred_noise(self, latents, t, prompt_embeds_star, do_classifier_free_guidance, guidance_scale_rev, cross_attention_kwargs, use_lowvram, batched=True) 219 | noise_pred_star_pred, noise_pred_star = noise_pred_star.chunk(2) 220 | 221 | text_gui = (noise_pred_star - noise_pred_star_pred) * guidance_scale_rev 222 | noise_delta = gamma * text_gui 223 | noise_pred = noise_src + noise_delta 224 | 225 | else: 226 | noise_pred = pred_noise(self, latents, t, prompt_embeds_edit, do_classifier_free_guidance, guidance_scale_rev, cross_attention_kwargs, use_lowvram) 227 | 228 | 229 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 230 | 231 | with torch.no_grad(): 232 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 233 | progress_bar.update() 234 | 235 | with torch.no_grad(): 236 | image = self.decode_latents(latents.detach()) 237 | 238 | x_img = self.numpy_to_pil(image) 239 | 240 | return x_img, image_rec 241 | 242 | 243 | def interpolate_text(scale, src_embeds, tgt_embeds, idx_list, is_added=False): 244 | 245 | assert src_embeds.shape == tgt_embeds.shape 246 | 247 | _src_embeds = src_embeds.clone() 248 | _tgt_embeds = tgt_embeds.clone() 249 | temp = tgt_embeds[1].clone() 250 | 251 | if not is_added: 252 | temp = _src_embeds[1].clone() * (1-scale) + _tgt_embeds[1].clone() * scale 253 | 254 | else: 255 | cnt = 0 256 | for idx, v in enumerate(idx_list): 257 | if v == "*": 258 | cnt += 1 259 | else: 260 | temp[idx] = scale * _tgt_embeds[1][idx] + (1-scale) * _src_embeds[1][v] 261 | 262 | temp = torch.stack([_tgt_embeds[0].clone(), temp], dim = 0) 263 | return temp.clone() 264 | 265 | 266 | def pred_noise(self, latents, t, prompt_emb, do_classifier_free_guidance, guidance_scale_rev, cross_attention_kwargs, use_lowvram, batched=False): 267 | with torch.no_grad(): 268 | if use_lowvram: 269 | latent_uncond_input = self.scheduler.scale_model_input(latents, t) 270 | noise_pred_uncond = self.unet(latent_uncond_input,t,encoder_hidden_states=prompt_emb[0:1], cross_attention_kwargs=cross_attention_kwargs,).sample 271 | 272 | latent_cond_input = self.scheduler.scale_model_input(latents, t) 273 | noise_pred_cond = self.unet(latent_cond_input,t,encoder_hidden_states=prompt_emb[1:2], cross_attention_kwargs=cross_attention_kwargs,).sample 274 | 275 | if batched: 276 | return torch.cat([noise_pred_uncond, noise_pred_cond], dim=0) 277 | noise_pred = noise_pred_uncond + guidance_scale_rev * (noise_pred_cond - noise_pred_uncond) 278 | 279 | else: 280 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 281 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 282 | noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_emb, cross_attention_kwargs=cross_attention_kwargs,).sample 283 | 284 | if batched: 285 | return noise_pred 286 | latents = latent_model_input.detach().chunk(2)[0] 287 | 288 | # perform guidance 289 | if do_classifier_free_guidance: 290 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 291 | noise_pred = noise_pred_uncond + guidance_scale_rev * (noise_pred_text - noise_pred_uncond) 292 | 293 | return noise_pred 294 | 295 | def text_swap(task, prompt, device): 296 | word_f, word_t = task.split('2') 297 | 298 | with torch.no_grad(): 299 | from torch.nn import CosineSimilarity 300 | from transformers import CLIPTokenizer, CLIPModel, CLIPTextModel 301 | cossim = CosineSimilarity(dim=0, eps=1e-6) 302 | 303 | def dist(v1, v2): 304 | return cossim(v1, v2) 305 | 306 | tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32') 307 | text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-base-patch32').to(device) 308 | 309 | prompts = list(prompt.split(' ')) 310 | 311 | text_inputs = tokenizer( 312 | prompts, 313 | padding="max_length", 314 | return_tensors="pt", 315 | ).to(device) 316 | 317 | text_f = tokenizer( 318 | word_f, 319 | padding="max_length", 320 | return_tensors="pt", 321 | ).to(device) 322 | 323 | wordf_embeddings = torch.flatten(text_encoder(text_f.input_ids.to(device))['last_hidden_state'],1,-1) 324 | text_embeddings = torch.flatten(text_encoder(text_inputs.input_ids.to(device))['last_hidden_state'],1,-1) 325 | 326 | temp = [] 327 | for i in range(len(prompts)): temp.append(dist(text_embeddings[i], wordf_embeddings[0]).item()) 328 | idx_start = temp.index(max(temp)) 329 | 330 | prompt_tgt = prompt.replace(prompts[idx_start], word_t, 1) 331 | 332 | idx_list = [i for i in range(77)] 333 | 334 | is_added = False if len(word_t.split(' ')) == len(word_f.split(' ')) else True 335 | 336 | cnt = 0 337 | 338 | if is_added: 339 | idx_list = idx_list[:idx_start+1] + ["*"] * len(word_t.split(' ')) + idx_list[idx_start+1:] 340 | idx_list = idx_list[:len(prompt_tgt.split(' '))+1] 341 | 342 | else: 343 | idx_list = idx_list[:idx_start+1] + ["*"] + idx_list[idx_start+2:] 344 | idx_list = idx_list[:len(prompt_tgt.split(' '))+1] 345 | 346 | return prompt_tgt, idx_list, is_added 347 | -------------------------------------------------------------------------------- /src/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion 16 | # and https://github.com/hojonathanho/diffusion 17 | import os, sys, pdb 18 | import math 19 | from dataclasses import dataclass 20 | from typing import List, Optional, Tuple, Union 21 | 22 | import numpy as np 23 | import torch 24 | 25 | from diffusers.configuration_utils import ConfigMixin, register_to_config 26 | from diffusers.utils import BaseOutput, randn_tensor 27 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 28 | 29 | 30 | @dataclass 31 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM 32 | class DDIMSchedulerOutput(BaseOutput): 33 | """ 34 | Output class for the scheduler's step function output. 35 | 36 | Args: 37 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 38 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 39 | denoising loop. 40 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 41 | The predicted denoised sample (x_{0}) based on the model output from the current timestep. 42 | `pred_original_sample` can be used to preview progress or for guidance. 43 | """ 44 | 45 | prev_sample: torch.FloatTensor 46 | pred_original_sample: Optional[torch.FloatTensor] = None 47 | 48 | 49 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: 50 | """ 51 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 52 | (1-beta) over time from t = [0,1]. 53 | 54 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 55 | to that part of the diffusion process. 56 | 57 | 58 | Args: 59 | num_diffusion_timesteps (`int`): the number of betas to produce. 60 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 61 | prevent singularities. 62 | 63 | Returns: 64 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 65 | """ 66 | 67 | def alpha_bar(time_step): 68 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 69 | 70 | betas = [] 71 | for i in range(num_diffusion_timesteps): 72 | t1 = i / num_diffusion_timesteps 73 | t2 = (i + 1) / num_diffusion_timesteps 74 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 75 | return torch.tensor(betas) 76 | 77 | 78 | class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): 79 | """ 80 | Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising 81 | diffusion probabilistic models (DDPMs) with non-Markovian guidance. 82 | 83 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 84 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 85 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 86 | [`~SchedulerMixin.from_pretrained`] functions. 87 | 88 | For more details, see the original paper: https://arxiv.org/abs/2010.02502 89 | 90 | Args: 91 | num_train_timesteps (`int`): number of diffusion steps used to train the model. 92 | beta_start (`float`): the starting `beta` value of inference. 93 | beta_end (`float`): the final `beta` value. 94 | beta_schedule (`str`): 95 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 96 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`. 97 | trained_betas (`np.ndarray`, optional): 98 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. 99 | clip_sample (`bool`, default `True`): 100 | option to clip predicted sample between -1 and 1 for numerical stability. 101 | set_alpha_to_one (`bool`, default `True`): 102 | each diffusion step uses the value of alphas product at that step and at the previous one. For the final 103 | step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, 104 | otherwise it uses the value of alpha at step 0. 105 | steps_offset (`int`, default `0`): 106 | an offset added to the inference steps. You can use a combination of `offset=1` and 107 | `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in 108 | stable diffusion. 109 | prediction_type (`str`, default `epsilon`, optional): 110 | prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion 111 | process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 112 | https://imagen.research.google/video/paper.pdf) 113 | """ 114 | 115 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 116 | order = 1 117 | 118 | @register_to_config 119 | def __init__( 120 | self, 121 | num_train_timesteps: int = 1000, 122 | beta_start: float = 0.0001, 123 | beta_end: float = 0.02, 124 | beta_schedule: str = "linear", 125 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 126 | clip_sample: bool = True, 127 | set_alpha_to_one: bool = True, 128 | steps_offset: int = 0, 129 | prediction_type: str = "epsilon", 130 | ): 131 | if trained_betas is not None: 132 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 133 | elif beta_schedule == "linear": 134 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 135 | elif beta_schedule == "scaled_linear": 136 | # this schedule is very specific to the latent diffusion model. 137 | self.betas = ( 138 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 139 | ) 140 | elif beta_schedule == "squaredcos_cap_v2": 141 | # Glide cosine schedule 142 | self.betas = betas_for_alpha_bar(num_train_timesteps) 143 | else: 144 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 145 | 146 | self.alphas = 1.0 - self.betas 147 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 148 | 149 | # At every step in ddim, we are looking into the previous alphas_cumprod 150 | # For the final step, there is no previous alphas_cumprod because we are already at 0 151 | # `set_alpha_to_one` decides whether we set this parameter simply to one or 152 | # whether we use the final alpha of the "non-previous" one. 153 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] 154 | 155 | # standard deviation of the initial noise distribution 156 | self.init_noise_sigma = 1.0 157 | 158 | # setable values 159 | self.num_inference_steps = None 160 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) 161 | 162 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: 163 | """ 164 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 165 | current timestep. 166 | 167 | Args: 168 | sample (`torch.FloatTensor`): input sample 169 | timestep (`int`, optional): current timestep 170 | 171 | Returns: 172 | `torch.FloatTensor`: scaled input sample 173 | """ 174 | return sample 175 | 176 | def _get_variance(self, timestep, prev_timestep): 177 | alpha_prod_t = self.alphas_cumprod[timestep] 178 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 179 | beta_prod_t = 1 - alpha_prod_t 180 | beta_prod_t_prev = 1 - alpha_prod_t_prev 181 | 182 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) 183 | 184 | return variance 185 | 186 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 187 | """ 188 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 189 | 190 | Args: 191 | num_inference_steps (`int`): 192 | the number of diffusion steps used when generating samples with a pre-trained model. 193 | """ 194 | 195 | if num_inference_steps > self.config.num_train_timesteps: 196 | raise ValueError( 197 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 198 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 199 | f" maximal {self.config.num_train_timesteps} timesteps." 200 | ) 201 | 202 | self.num_inference_steps = num_inference_steps 203 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 204 | # creates integer timesteps by multiplying by ratio 205 | # casting to int to avoid issues when num_inference_step is power of 3 206 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 207 | self.timesteps = torch.from_numpy(timesteps).to(device) 208 | self.timesteps += self.config.steps_offset 209 | 210 | def step( 211 | self, 212 | model_output: torch.FloatTensor, 213 | timestep: int, 214 | sample: torch.FloatTensor, 215 | eta: float = 0.0, 216 | use_clipped_model_output: bool = False, 217 | generator=None, 218 | variance_noise: Optional[torch.FloatTensor] = None, 219 | return_dict: bool = True, 220 | reverse=False 221 | ) -> Union[DDIMSchedulerOutput, Tuple]: 222 | 223 | 224 | e_t = model_output 225 | 226 | x = sample 227 | prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps 228 | # print(timestep, prev_timestep) 229 | a_t = alpha_prod_t = self.alphas_cumprod[timestep-1] 230 | a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod 231 | beta_prod_t = 1 - alpha_prod_t 232 | 233 | pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt() 234 | # direction pointing to x_t 235 | dir_xt = (1. - a_prev).sqrt() * e_t 236 | x = a_prev.sqrt()*pred_x0 + dir_xt 237 | if not return_dict: 238 | return (x,) 239 | return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0) 240 | 241 | 242 | 243 | 244 | 245 | def add_noise( 246 | self, 247 | original_samples: torch.FloatTensor, 248 | noise: torch.FloatTensor, 249 | timesteps: torch.IntTensor, 250 | ) -> torch.FloatTensor: 251 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 252 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) 253 | timesteps = timesteps.to(original_samples.device) 254 | 255 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 256 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 257 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 258 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 259 | 260 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 261 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 262 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 263 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 264 | 265 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 266 | return noisy_samples 267 | 268 | def get_velocity( 269 | self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor 270 | ) -> torch.FloatTensor: 271 | # Make sure alphas_cumprod and timestep have same device and dtype as sample 272 | self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) 273 | timesteps = timesteps.to(sample.device) 274 | 275 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 276 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 277 | while len(sqrt_alpha_prod.shape) < len(sample.shape): 278 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 279 | 280 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 281 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 282 | while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): 283 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 284 | 285 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 286 | return velocity 287 | 288 | def __len__(self): 289 | return self.config.num_train_timesteps 290 | --------------------------------------------------------------------------------