├── LICENSE ├── README.md ├── ckpt ├── inpaint_ckpt.sh └── runwayml_sd_v1_5.sh ├── diffedit_v1.sh ├── diffedit_v2.sh ├── diffedit_v3.sh ├── examples ├── horse.png └── oranges.png ├── get_edit_v3.py ├── get_mask_v1.py ├── get_mask_v2.py ├── inpaint.py ├── modules ├── __pycache__ │ ├── diffedit_v1.cpython-39.pyc │ └── diffedit_v3.cpython-39.pyc ├── diffedit_v1.py ├── diffedit_v2.py └── diffedit_v3.py └── png ├── demo.png ├── diffedit_v1.png ├── diffedit_v2.png └── diffedit_v3.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 wang-will 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edit Image base on the Mask generated from the stable diffusion itself 2 | ## Introduction 3 | This is an unofficial implementation of the paper [DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427) based on [Stable Diffusion](https://arxiv.org/abs/2112.10752) 4 | * All the weights and apis are token from [Hugging Face Diffusers](https://huggingface.co/docs/diffusers/index) 5 | * weights of the Stable Diffusion img2imgpipeline are from [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), you can get by this command: 6 | ```shell 7 | cd ckpt 8 | bash runwayml_sd_v1_5.sh 9 | ``` 10 | * weights of the Stable Diffusion inpaintingpipeline are from [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting), you can get by this command: 11 | ```shell 12 | cd ckpt 13 | bash inpaint_ckpt.sh 14 | ``` 15 | * Scheduler is [DDIMScheduler](https://huggingface.co/docs/diffusers/api/schedulers/ddim) 16 | * Example images are from [Google TEDBench](https://github.com/imagic-editing/imagic-editing.github.io/tree/main/tedbench) 17 | * Hugging Face Diffusers has provided api for this paper, [see this](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/diffedit) 18 | ## Environment 19 | ``` 20 | python == 3.9.12 21 | torch == 1.13.1 22 | pillow == 9.4.0 23 | scikit-image == 0.19.2 24 | diffusers == 0.15.0 25 | xformers == 0.0.16 26 | accelerate == 0.17.1 27 | ``` 28 | with the following codes, you can use the GPU with 6GB memory to run the codes, edit image with the size of 512x512 29 | ```python 30 | pipe.enable_xformers_memory_efficient_attention() 31 | pipe.enable_attention_slicing() 32 | pipe.vae.enable_tiling() 33 | pipe.enable_model_cpu_offload() 34 | ``` 35 | ## Method 36 | ### v1 37 | Mask is computed from the images generated by img2imgpipeline, and edit operation is implemented by inpaintingpipeline base on the mask image, as shown in the figure below 38 | ![](./png/diffedit_v1.png) 39 | you can run v1 method by this command: 40 | ```shell 41 | bash diffedit_v1.sh 42 | ``` 43 | if you want to use other image or hyperparameters, please edit the diffedit_v1.sh 44 | ### v2 45 | Mask is computed from the noise residual(**you can also use the noise latents, just change the '--not_residual_guide' in .sh files**) in latent space by img2imgpipeline, then the mask is resized to the size of the image, edit operation is implemented by inpaintingpipeline base on the resized mask image, as shown in the figure below 46 | ![](./png/diffedit_v2.png) 47 | you can run v2 method by this command: 48 | ```shell 49 | bash diffedit_v2.sh 50 | ``` 51 | if you want to use other image or hyperparameters, please edit the diffedit_v2.sh 52 | ### v3 53 | Mask is computed from the noise residual(**you can also use the noise latents, just change the '--not_residual_guide' in .sh files**) in latent space by img2imgpipeline,edit operation is implemented by img2imgpipeline base on the mask, as shown in the figure below 54 | ![](./png/diffedit_v3.png) 55 | you can run v3 method by this command: 56 | ```shell 57 | bash diffedit_v3.sh 58 | ``` 59 | if you want to use other image or hyperparameters, please edit the diffedit_v3.sh 60 | ## Result 61 | ![](./png/demo.png) 62 | ## Reference 63 | The following repository also provides the code implementation of this: 64 | * [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion) 65 | * [aayushmnit/diffusion_playground](https://github.com/aayushmnit/diffusion_playground/blob/main/notebooks/4_DiffEdit_v4.ipynb) 66 | * [johnrobinsn/diffusion_experiments](https://github.com/johnrobinsn/diffusion_experiments/blob/main/DiffEdit.ipynb) 67 | * [daspartho/DiffEdit](https://github.com/daspartho/DiffEdit) 68 | ## What's more 69 | if you have any questions, feel free to contact me with wangruilin.will@foxmail.com 70 | 71 | 72 | -------------------------------------------------------------------------------- /ckpt/inpaint_ckpt.sh: -------------------------------------------------------------------------------- 1 | mkdir inpainting-v1-2 2 | cd inpainting-v1-2 3 | mkdir feature_extractor 4 | wget -P ./feature_extractor https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/feature_extractor/preprocessor_config.json 5 | 6 | mkdir scheduler 7 | wget -P ./scheduler https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/scheduler/scheduler_config.json 8 | 9 | mkdir text_encoder 10 | wget -P ./text_encoder https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/text_encoder/config.json 11 | wget -P ./text_encoder https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/text_encoder/pytorch_model.bin 12 | 13 | mkdir tokenizer 14 | wget -P ./tokenizer https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/tokenizer/merges.txt 15 | wget -P ./tokenizer https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/tokenizer/special_tokens_map.json 16 | wget -P ./tokenizer https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/tokenizer/tokenizer_config.json 17 | wget -P ./tokenizer https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/tokenizer/vocab.json 18 | 19 | mkdir unet 20 | wget -P ./unet https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/unet/config.json 21 | wget -P ./unet https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/unet/diffusion_pytorch_model.bin 22 | 23 | mkdir vae 24 | wget -P ./vae https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/vae/config.json 25 | wget -P ./vae https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/vae/diffusion_pytorch_model.bin 26 | 27 | wget https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/model_index.json 28 | -------------------------------------------------------------------------------- /ckpt/runwayml_sd_v1_5.sh: -------------------------------------------------------------------------------- 1 | mkdir runwayml_sd_v1_5 2 | cd runwayml_sd_v1_5 3 | mkdir feature_extractor 4 | wget -P ./feature_extractor https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/feature_extractor/preprocessor_config.json 5 | 6 | mkdir scheduler 7 | wget -P ./scheduler https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/scheduler/scheduler_config.json 8 | 9 | mkdir text_encoder 10 | wget -P ./text_encoder https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/config.json 11 | wget -P ./text_encoder https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/pytorch_model.bin 12 | 13 | mkdir tokenizer 14 | wget -P ./tokenizer https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/tokenizer/merges.txt 15 | wget -P ./tokenizer https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/tokenizer/special_tokens_map.json 16 | wget -P ./tokenizer https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/tokenizer/tokenizer_config.json 17 | wget -P ./tokenizer https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/tokenizer/vocab.json 18 | 19 | mkdir unet 20 | wget -P ./unet https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/config.json 21 | wget -P ./unet https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin 22 | 23 | mkdir vae 24 | wget -P ./vae https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/config.json 25 | wget -P ./vae https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/diffusion_pytorch_model.bin 26 | 27 | wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/model_index.json -------------------------------------------------------------------------------- /diffedit_v1.sh: -------------------------------------------------------------------------------- 1 | # horse demo 2 | python get_mask_v1.py \ 3 | --ckpt_dir ./ckpt/runwayml_sd_v1_5 \ 4 | --image_dir ./examples/white_horse.png \ 5 | --avg_times 10 \ 6 | --reference "a white horse" \ 7 | --query "a zebra" \ 8 | --output_dir ./v1_output \ 9 | --strength 0.8 \ 10 | --steps 50 \ 11 | --scale 7.5 \ 12 | 13 | python inpaint.py \ 14 | --ckpt_dir ./ckpt/inpainting-v1-2 \ 15 | --image_dir ./examples/white_horse.png \ 16 | --mask_dir ./v1_output/mask_v1.png \ 17 | --query "a zebra" \ 18 | --output_dir ./v1_output \ 19 | --steps 50 \ 20 | --scale 7.5 \ 21 | # orange demo 22 | python get_mask_v1.py \ 23 | --ckpt_dir ./ckpt/runwayml_sd_v1_5 \ 24 | --image_dir ./examples/oranges.png \ 25 | --avg_times 10 \ 26 | --reference "A basket of oranges" \ 27 | --query "A basket of apples" \ 28 | --output_dir ./v1_output \ 29 | --strength 0.8 \ 30 | --steps 50 \ 31 | --scale 7.5 \ 32 | 33 | python inpaint.py \ 34 | --ckpt_dir ./ckpt/inpainting-v1-2 \ 35 | --image_dir ./examples/oranges.png \ 36 | --mask_dir ./v1_output/mask_v1.png \ 37 | --query "A basket of apples" \ 38 | --output_dir ./v1_output \ 39 | --steps 50 \ 40 | --scale 7.5 \ -------------------------------------------------------------------------------- /diffedit_v2.sh: -------------------------------------------------------------------------------- 1 | # horse demo 2 | python get_mask_v2.py \ 3 | --ckpt_dir ./ckpt/runwayml_sd_v1_5 \ 4 | --image_dir ./examples/white_horse.png \ 5 | --avg_times 10 \ 6 | --reference "a white horse" \ 7 | --query "a zebra" \ 8 | --output_dir ./v2_output \ 9 | --strength 0.6 \ 10 | --steps 25 \ 11 | --scale 7.5 \ 12 | --not_residual_guide \ 13 | 14 | python inpaint.py \ 15 | --ckpt_dir ./ckpt/inpainting-v1-2 \ 16 | --image_dir ./examples/white_horse.png \ 17 | --mask_dir ./v2_output/mask_v2.png \ 18 | --query "a zebra" \ 19 | --output_dir ./v2_output \ 20 | --steps 25 \ 21 | --scale 7.5 \ 22 | 23 | # orange demo 24 | python get_mask_v2.py \ 25 | --ckpt_dir ./ckpt/runwayml_sd_v1_5 \ 26 | --image_dir ./examples/oranges.png \ 27 | --avg_times 10 \ 28 | --reference "A basket of oranges" \ 29 | --query "A basket of apples" \ 30 | --output_dir ./v2_output \ 31 | --strength 0.6 \ 32 | --steps 25 \ 33 | --scale 7.5 \ 34 | --not_residual_guide \ 35 | 36 | python inpaint.py \ 37 | --ckpt_dir ./ckpt/inpainting-v1-2 \ 38 | --image_dir ./examples/oranges.png \ 39 | --mask_dir ./v2_output/mask_v2.png \ 40 | --query "A basket of apples" \ 41 | --output_dir ./v2_output \ 42 | --steps 25 \ 43 | --scale 7.5 \ -------------------------------------------------------------------------------- /diffedit_v3.sh: -------------------------------------------------------------------------------- 1 | # horse demo 2 | python get_edit_v3.py \ 3 | --ckpt_dir ./ckpt/runwayml_sd_v1_5 \ 4 | --image_dir ./examples/white_horse.png \ 5 | --avg_times 10 \ 6 | --reference "a white horse" \ 7 | --query "a zebra" \ 8 | --output_dir ./v3_output \ 9 | --strength 0.6 \ 10 | --steps 25 \ 11 | --scale 7.5 \ 12 | --not_residual_guide \ 13 | # oragen demo 14 | python get_edit_v3.py \ 15 | --ckpt_dir ./ckpt/runwayml_sd_v1_5 \ 16 | --image_dir ./examples/oranges.png \ 17 | --avg_times 10 \ 18 | --reference "A basket of oranges" \ 19 | --query "A basket of apples" \ 20 | --output_dir ./v3_output \ 21 | --strength 0.6 \ 22 | --steps 25 \ 23 | --scale 7.5 \ 24 | --not_residual_guide \ -------------------------------------------------------------------------------- /examples/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/examples/horse.png -------------------------------------------------------------------------------- /examples/oranges.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/examples/oranges.png -------------------------------------------------------------------------------- /get_edit_v3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from diffusers import DDIMScheduler 6 | from modules.diffedit_v3 import DiffEdit_v3 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--ckpt_dir', 12 | default=None, 13 | type=str, 14 | help='path of the weight of stable diffusioninpaint pipeline ' 15 | ) 16 | parser.add_argument( 17 | '--image_dir', 18 | default=None, 19 | type=str, 20 | help='path of the image to be edited' 21 | ) 22 | parser.add_argument( 23 | '--avg_times', 24 | default=10, 25 | type=int, 26 | help='times of caculate the difference to attain mask' 27 | ) 28 | parser.add_argument( 29 | '--reference', 30 | default=None, 31 | type=str, 32 | help='reference prompt' 33 | ) 34 | parser.add_argument( 35 | '--query', 36 | default=None, 37 | type=str, 38 | help='edit prompt' 39 | ) 40 | parser.add_argument( 41 | '--output_dir', 42 | default=None, 43 | type=str, 44 | help='path of the result to be saved' 45 | ) 46 | parser.add_argument( 47 | '--strength', 48 | default=0.8, 49 | type=float, 50 | help='hyperparamemter of pipeline' 51 | ) 52 | parser.add_argument( 53 | '--steps', 54 | default=50, 55 | type=int, 56 | help='hyperparamemter of pipeline' 57 | ) 58 | parser.add_argument( 59 | '--seed', 60 | default=2625, 61 | type=int, 62 | help='random seed' 63 | ) 64 | parser.add_argument( 65 | '--scale', 66 | default=7.5, 67 | type=float, 68 | help='hyperparamemter of pipeline' 69 | ) 70 | parser.add_argument( 71 | '--not_residual_guide', 72 | default=False, 73 | action="store_true", 74 | help="whether to use the noise residual to compute the mask or not" 75 | ) 76 | args = parser.parse_args() 77 | return args 78 | 79 | if __name__ == "__main__": 80 | args = get_args() 81 | 82 | DDIM = DDIMScheduler.from_pretrained( 83 | pretrained_model_name_or_path=args.ckpt_dir, 84 | subfolder="scheduler" 85 | ) 86 | pipe = DiffEdit_v3.from_pretrained( 87 | pretrained_model_name_or_path=args.ckpt_dir, 88 | safety_checker=None, 89 | torch_dtype=torch.float16, 90 | scheduler=DDIM, 91 | ).to("cuda") 92 | 93 | # save memory and inference fast 94 | pipe.enable_xformers_memory_efficient_attention() 95 | pipe.enable_attention_slicing() 96 | pipe.vae.enable_tiling() 97 | pipe.enable_model_cpu_offload() 98 | 99 | image = Image.open(args.image_dir).convert('RGB').resize((512, 512)) 100 | mask = pipe.get_mask( 101 | latents_num=args.avg_times, 102 | refer_prompt=args.reference, 103 | query_prompt=args.query, 104 | image=image, 105 | strength=args.strength, 106 | num_inference_steps=args.steps, 107 | guidance_scale=args.scale, 108 | seed=args.seed, 109 | residual_guide=not args.not_residual_guide 110 | ) 111 | latents_set = pipe.get_latents( 112 | image=image, 113 | strength=args.strength, 114 | num_inference_steps=args.steps, 115 | generator=torch.Generator(device="cuda").manual_seed(args.seed) 116 | ) 117 | result = pipe( 118 | query=args.query, 119 | latents_set=latents_set, 120 | mask=mask.half(), 121 | strength=args.strength, 122 | num_inference_steps=args.steps, 123 | guidance_scale=args.scale, 124 | generator=torch.Generator(device="cuda").manual_seed(args.seed), 125 | )[0] 126 | pil_mask = mask.cpu().numpy() 127 | pil_mask_unresized = Image.fromarray(pil_mask).convert('RGB') 128 | pil_mask_resized = pil_mask_unresized.resize((512, 512)) 129 | if not os.path.exists(args.output_dir): 130 | os.makedirs(args.output_dir) 131 | result.save(f'{args.output_dir}/result_v3.png') 132 | pil_mask_unresized.save(f'{args.output_dir}/unresized_mask_v3.png') 133 | pil_mask_resized.save(f'{args.output_dir}/resized_mask_v3.png') 134 | -------------------------------------------------------------------------------- /get_mask_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from diffusers import DDIMScheduler 6 | from modules.diffedit_v1 import DiffEdit_v1 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--ckpt_dir', 12 | default=None, 13 | type=str, 14 | help='path of the weight of stable diffusioninpaint pipeline ' 15 | ) 16 | parser.add_argument( 17 | '--image_dir', 18 | default=None, 19 | type=str, 20 | help='path of the image to be edited' 21 | ) 22 | parser.add_argument( 23 | '--avg_times', 24 | default=10, 25 | type=int, 26 | help='times of caculate the difference to attain mask' 27 | ) 28 | parser.add_argument( 29 | '--reference', 30 | default=None, 31 | type=str, 32 | help='reference prompt' 33 | ) 34 | parser.add_argument( 35 | '--query', 36 | default=None, 37 | type=str, 38 | help='edit prompt' 39 | ) 40 | parser.add_argument( 41 | '--output_dir', 42 | default=None, 43 | type=str, 44 | help='path of the result to be saved' 45 | ) 46 | parser.add_argument( 47 | '--strength', 48 | default=0.8, 49 | type=float, 50 | help='hyperparamemter of pipeline' 51 | ) 52 | parser.add_argument( 53 | '--steps', 54 | default=50, 55 | type=int, 56 | help='hyperparamemter of pipeline' 57 | ) 58 | parser.add_argument( 59 | '--seed', 60 | default=2625, 61 | type=int, 62 | help='random seed' 63 | ) 64 | parser.add_argument( 65 | '--scale', 66 | default=7.5, 67 | type=float, 68 | help='hyperparamemter of pipeline' 69 | ) 70 | args = parser.parse_args() 71 | return args 72 | 73 | if __name__ == "__main__": 74 | args = get_args() 75 | 76 | DDIM = DDIMScheduler.from_pretrained( 77 | pretrained_model_name_or_path=args.ckpt_dir, 78 | subfolder="scheduler" 79 | ) 80 | pipe = DiffEdit_v1.from_pretrained( 81 | pretrained_model_name_or_path=args.ckpt_dir, 82 | safety_checker=None, 83 | torch_dtype=torch.float16, 84 | scheduler=DDIM, 85 | ).to("cuda") 86 | 87 | # save memory and inference fast 88 | pipe.enable_xformers_memory_efficient_attention() 89 | pipe.enable_attention_slicing() 90 | pipe.vae.enable_tiling() 91 | pipe.enable_model_cpu_offload() 92 | 93 | image = Image.open(args.image_dir).convert('RGB').resize((512, 512)) 94 | 95 | pil_mask = pipe.get_mask( 96 | latents_num=args.avg_times, 97 | refer_prompt=args.reference, 98 | query_prompt=args.query, 99 | image=image, 100 | strength=args.strength, 101 | num_inference_steps=args.steps, 102 | guidance_scale=args.scale, 103 | seed=args.seed 104 | ) 105 | pil_mask = Image.fromarray(pil_mask).convert('RGB').resize((512, 512)) 106 | if not os.path.exists(args.output_dir): 107 | os.makedirs(args.output_dir) 108 | pil_mask.save(f'{args.output_dir}/mask_v1.png') -------------------------------------------------------------------------------- /get_mask_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from skimage import morphology 6 | from diffusers import DDIMScheduler 7 | from modules.diffedit_v2 import DiffEdit_v2 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | '--ckpt_dir', 13 | default=None, 14 | type=str, 15 | help='path of the weight of stable diffusioninpaint pipeline ' 16 | ) 17 | parser.add_argument( 18 | '--image_dir', 19 | default=None, 20 | type=str, 21 | help='path of the image to be edited' 22 | ) 23 | parser.add_argument( 24 | '--avg_times', 25 | default=10, 26 | type=int, 27 | help='times of caculate the difference to attain mask' 28 | ) 29 | parser.add_argument( 30 | '--reference', 31 | default=None, 32 | type=str, 33 | help='reference prompt' 34 | ) 35 | parser.add_argument( 36 | '--query', 37 | default=None, 38 | type=str, 39 | help='edit prompt' 40 | ) 41 | parser.add_argument( 42 | '--output_dir', 43 | default=None, 44 | type=str, 45 | help='path of the result to be saved' 46 | ) 47 | parser.add_argument( 48 | '--strength', 49 | default=0.6, 50 | type=float, 51 | help='hyperparamemter of pipeline' 52 | ) 53 | parser.add_argument( 54 | '--steps', 55 | default=25, 56 | type=int, 57 | help='hyperparamemter of pipeline' 58 | ) 59 | parser.add_argument( 60 | '--seed', 61 | default=2625, 62 | type=int, 63 | help='random seed' 64 | ) 65 | parser.add_argument( 66 | '--scale', 67 | default=7.5, 68 | type=float, 69 | help='hyperparamemter of pipeline' 70 | ) 71 | parser.add_argument( 72 | '--not_residual_guide', 73 | default=False, 74 | action="store_true", 75 | help="whether to use the noise residual to compute the mask or not" 76 | ) 77 | args = parser.parse_args() 78 | return args 79 | 80 | if __name__ == "__main__": 81 | args = get_args() 82 | 83 | DDIM = DDIMScheduler.from_pretrained( 84 | pretrained_model_name_or_path=args.ckpt_dir, 85 | subfolder="scheduler" 86 | ) 87 | pipe = DiffEdit_v2.from_pretrained( 88 | pretrained_model_name_or_path=args.ckpt_dir, 89 | safety_checker=None, 90 | torch_dtype=torch.float16, 91 | scheduler=DDIM, 92 | ).to("cuda") 93 | 94 | # save memory and inference fast 95 | pipe.enable_xformers_memory_efficient_attention() 96 | pipe.enable_attention_slicing() 97 | pipe.vae.enable_tiling() 98 | pipe.enable_model_cpu_offload() 99 | 100 | image = Image.open(args.image_dir).convert('RGB').resize((512, 512)) 101 | 102 | pil_mask = pipe.get_mask( 103 | latents_num=args.avg_times, 104 | refer_prompt=args.reference, 105 | query_prompt=args.query, 106 | image=image, 107 | strength=args.strength, 108 | num_inference_steps=args.steps, 109 | guidance_scale=args.scale, 110 | seed=args.seed, 111 | residual_guide=not args.not_residual_guide 112 | ) 113 | uncleaned_mask = Image.fromarray(pil_mask).convert('RGB') 114 | pil_mask = morphology.remove_small_objects(pil_mask, min_size=16) 115 | cleaned_mask = Image.fromarray(pil_mask).convert('RGB') 116 | resized_cleaned_mask = Image.fromarray(pil_mask).convert('RGB').resize((512, 512)) 117 | if not os.path.exists(args.output_dir): 118 | os.makedirs(args.output_dir) 119 | uncleaned_mask.save(f'{args.output_dir}/unclean_mask_v2.png') 120 | cleaned_mask.save(f'{args.output_dir}/clean_mask_v2.png') 121 | resized_cleaned_mask.save(f'{args.output_dir}/mask_v2.png') -------------------------------------------------------------------------------- /inpaint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from diffusers import DDIMScheduler,StableDiffusionInpaintPipeline 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | '--ckpt_dir', 11 | default=None, 12 | type=str, 13 | help='path of the weight of stable diffusioninpaint pipeline ' 14 | ) 15 | parser.add_argument( 16 | '--image_dir', 17 | default=None, 18 | type=str, 19 | help='path of the image to be edited' 20 | ) 21 | parser.add_argument( 22 | '--mask_dir', 23 | default=None, 24 | type=str, 25 | help='path of the mask' 26 | ) 27 | parser.add_argument( 28 | '--query', 29 | default=None, 30 | type=str, 31 | help='edit prompt' 32 | ) 33 | parser.add_argument( 34 | '--output_dir', 35 | default=None, 36 | type=str, 37 | help='path of the result to be saved' 38 | ) 39 | parser.add_argument( 40 | '--steps', 41 | default=50, 42 | type=int, 43 | help='hyperparamemter of pipeline' 44 | ) 45 | parser.add_argument( 46 | '--seed', 47 | default=2625, 48 | type=int, 49 | help='random seed' 50 | ) 51 | parser.add_argument( 52 | '--scale', 53 | default=7.5, 54 | type=float, 55 | help='hyperparamemter of pipeline' 56 | ) 57 | args = parser.parse_args() 58 | return args 59 | 60 | if __name__ == "__main__": 61 | args = get_args() 62 | 63 | DDIM = DDIMScheduler.from_pretrained( 64 | pretrained_model_name_or_path=args.ckpt_dir, 65 | subfolder="scheduler" 66 | ) 67 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 68 | pretrained_model_name_or_path=args.ckpt_dir, 69 | safety_checker=None, 70 | torch_dtype=torch.float16, 71 | scheduler=DDIM, 72 | ).to("cuda") 73 | 74 | # save memory and inference fast 75 | pipe.enable_xformers_memory_efficient_attention() 76 | pipe.enable_attention_slicing() 77 | # pipe.vae.enable_tiling() 78 | # pipe.enable_model_cpu_offload() 79 | 80 | image = Image.open(args.image_dir).convert('RGB').resize((512, 512)) 81 | mask = Image.open(args.mask_dir).convert('RGB').resize((512, 512)) 82 | 83 | result = pipe( 84 | prompt=args.query, 85 | image=image, 86 | mask_image=mask, 87 | num_inference_steps=args.steps, 88 | guidance_scale=args.scale, 89 | generator=torch.Generator(device="cuda").manual_seed(args.seed) 90 | ).images[0] 91 | if not os.path.exists(args.output_dir): 92 | os.makedirs(args.output_dir) 93 | result.save(f'{args.output_dir}/result.png') -------------------------------------------------------------------------------- /modules/__pycache__/diffedit_v1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/modules/__pycache__/diffedit_v1.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/diffedit_v3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/modules/__pycache__/diffedit_v3.cpython-39.pyc -------------------------------------------------------------------------------- /modules/diffedit_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from diffusers import StableDiffusionImg2ImgPipeline 5 | from typing import List, Optional, Union 6 | 7 | class DiffEdit_v1(StableDiffusionImg2ImgPipeline): 8 | @torch.no_grad() 9 | def get_esitmate( 10 | self, 11 | prompt: Union[str, List[str]] = None, 12 | image: Union[torch.FloatTensor, Image.Image] = None, 13 | strength: float = 0.8, 14 | num_inference_steps: Optional[int] = 50, 15 | guidance_scale: Optional[float] = 7.5, 16 | eta: Optional[float] = 0.0, 17 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 18 | ): 19 | 20 | device = self._execution_device 21 | do_classifier_free_guidance = guidance_scale > 1.0 22 | # Encode input prompt 23 | prompt_embeds = self._encode_prompt( 24 | prompt, 25 | device, 26 | num_images_per_prompt=1, 27 | do_classifier_free_guidance=do_classifier_free_guidance, 28 | negative_prompt=None, 29 | prompt_embeds=None, 30 | negative_prompt_embeds=None, 31 | ) 32 | # Preprocess image 33 | image = self.image_processor.preprocess(image) 34 | 35 | # set timesteps 36 | self.scheduler.set_timesteps(num_inference_steps, device=device) 37 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 38 | latent_timestep = timesteps[:1].repeat(1) 39 | 40 | # Prepare latent variables 41 | latents = self.prepare_latents(image, latent_timestep, 1, 1, prompt_embeds.dtype, device, generator) 42 | 43 | # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 44 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 45 | 46 | # 8. Denoising loop 47 | 48 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 49 | with self.progress_bar(total=num_inference_steps) as progress_bar: 50 | for i, t in enumerate(timesteps): 51 | # expand the latents if we are doing classifier free guidance 52 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 53 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 54 | 55 | # predict the noise residual 56 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample 57 | 58 | # perform guidance 59 | if do_classifier_free_guidance: 60 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 61 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 62 | 63 | # compute the previous noisy sample x_t -> x_t-1 64 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 65 | 66 | # call the callback, if provided 67 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 68 | progress_bar.update() 69 | latents = 1 / self.vae.config.scaling_factor * latents 70 | image = self.vae.decode(latents).sample 71 | # size: 1x3xhxw 72 | return image 73 | 74 | @torch.no_grad() 75 | def get_mask( 76 | self, 77 | latents_num: int = 10, 78 | refer_prompt: Union[str, List[str]] = None, 79 | query_prompt: Union[str, List[str]] = None, 80 | image: Union[torch.FloatTensor, Image.Image] = None, 81 | strength: float = 0.5, 82 | num_inference_steps: Optional[int] = 50, 83 | guidance_scale: Optional[float] = 7.5, 84 | seed: int = 2625 85 | ): 86 | diff_list = [] 87 | # cycle for n(latents_num) times 88 | for index in range(latents_num): 89 | # get reference noise latents 90 | refer_image = self.get_esitmate( 91 | prompt=refer_prompt, 92 | image=image, 93 | strength=strength, 94 | num_inference_steps=num_inference_steps, 95 | guidance_scale=guidance_scale, 96 | generator=torch.Generator(device="cuda").manual_seed(seed * index) 97 | ) 98 | # get query noise latents 99 | query_image = self.get_esitmate( 100 | prompt=query_prompt, 101 | image=image, 102 | strength=strength, 103 | num_inference_steps=num_inference_steps, 104 | guidance_scale=guidance_scale, 105 | generator=torch.Generator(device="cuda").manual_seed(seed * index) 106 | ) 107 | diff_list.append(refer_image - query_image) 108 | # Creating a mask placeholder 109 | tensor_mask = torch.zeros_like(diff_list[0]) 110 | 111 | # Taking an average of 10 iterations 112 | for index in range(latents_num): 113 | tensor_mask += torch.abs(diff_list[index]) 114 | tensor_mask = tensor_mask/latents_num 115 | 116 | # Averaging multiple channels 117 | tensor_mask = tensor_mask.squeeze(0).mean(0) 118 | 119 | # Normalize 120 | tensor_mask = (tensor_mask - tensor_mask.min()) / (tensor_mask.max() - tensor_mask.min()) 121 | 122 | # Binarizing and returning the mask object 123 | pil_mask = (tensor_mask.cpu().numpy() > 0.5).astype(np.uint8)*255 124 | 125 | return pil_mask 126 | 127 | -------------------------------------------------------------------------------- /modules/diffedit_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from diffusers import StableDiffusionImg2ImgPipeline 4 | from typing import List, Optional, Union 5 | 6 | 7 | class DiffEdit_v2(StableDiffusionImg2ImgPipeline): 8 | @torch.no_grad() 9 | def get_esitmate( 10 | self, 11 | prompt: Union[str, List[str]] = None, 12 | image: Union[torch.FloatTensor, Image.Image] = None, 13 | strength: float = 0.8, 14 | num_inference_steps: Optional[int] = 50, 15 | guidance_scale: Optional[float] = 7.5, 16 | eta: Optional[float] = 0.0, 17 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 18 | ): 19 | 20 | device = self._execution_device 21 | do_classifier_free_guidance = guidance_scale > 1.0 22 | # Encode input prompt 23 | prompt_embeds = self._encode_prompt( 24 | prompt, 25 | device, 26 | num_images_per_prompt=1, 27 | do_classifier_free_guidance=do_classifier_free_guidance, 28 | negative_prompt=None, 29 | prompt_embeds=None, 30 | negative_prompt_embeds=None, 31 | ) 32 | # Preprocess image 33 | image = self.image_processor.preprocess(image) 34 | 35 | # set timesteps 36 | self.scheduler.set_timesteps(num_inference_steps, device=device) 37 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 38 | latent_timestep = timesteps[:1].repeat(1) 39 | 40 | # Prepare latent variables 41 | latents = self.prepare_latents(image, latent_timestep, 1, 1, prompt_embeds.dtype, device, generator) 42 | noise_pred = torch.zeros_like(latents) 43 | # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 44 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 45 | # 8. Denoising loop 46 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 47 | with self.progress_bar(total=num_inference_steps) as progress_bar: 48 | for i, t in enumerate(timesteps): 49 | # expand the latents if we are doing classifier free guidance 50 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 51 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 52 | 53 | # predict the noise residual 54 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample 55 | 56 | # perform guidance 57 | if do_classifier_free_guidance: 58 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 59 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 60 | 61 | # compute the previous noisy sample x_t -> x_t-1 62 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 63 | 64 | # call the callback, if provided 65 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 66 | progress_bar.update() 67 | # size: 1 4 h/4 w/4 68 | return latents,noise_pred 69 | 70 | @torch.no_grad() 71 | def get_mask( 72 | self, 73 | latents_num: int = 10, 74 | refer_prompt: Union[str, List[str]] = None, 75 | query_prompt: Union[str, List[str]] = None, 76 | image: Union[torch.FloatTensor, Image.Image] = None, 77 | strength: float = 0.5, 78 | num_inference_steps: Optional[int] = 50, 79 | guidance_scale: Optional[float] = 7.5, 80 | seed: int = 2625, 81 | residual_guide:bool = True 82 | ): 83 | diff_list = [] 84 | 85 | # cycle for n(latents_num) times 86 | for index in range(latents_num): 87 | # get reference noise latents 88 | refer_latents,refer_noise_pred = self.get_esitmate( 89 | prompt=refer_prompt, 90 | image=image, 91 | strength=strength, 92 | num_inference_steps=num_inference_steps, 93 | guidance_scale=guidance_scale, 94 | generator=torch.Generator(device="cuda").manual_seed(seed * index) 95 | ) 96 | 97 | # get query noise latents 98 | query_latents,query_noise_pred = self.get_esitmate( 99 | prompt=query_prompt, 100 | image=image, 101 | strength=strength, 102 | num_inference_steps=num_inference_steps, 103 | guidance_scale=guidance_scale, 104 | generator=torch.Generator(device="cuda").manual_seed(seed * index) 105 | ) 106 | if residual_guide: 107 | diff_list.append(refer_noise_pred-query_noise_pred) 108 | else: 109 | diff_list.append(refer_latents - query_latents) 110 | 111 | # Creating a mask placeholder 112 | tensor_mask = torch.zeros_like(diff_list[0]) 113 | 114 | # Taking an average of 10 iterations 115 | for index in range(latents_num): 116 | tensor_mask += torch.abs(diff_list[index]) 117 | tensor_mask /= latents_num 118 | 119 | # Averaging multiple channels 120 | tensor_mask = tensor_mask.squeeze(0).mean(0) 121 | 122 | # Normalize 123 | tensor_mask = (tensor_mask - tensor_mask.min()) / (tensor_mask.max() - tensor_mask.min()) 124 | 125 | # Binarizing and returning the mask object 126 | tensor_mask = (tensor_mask>0.5) 127 | # pil_mask = (tensor_mask.cpu().numpy()>0.5) 128 | pil_mask = tensor_mask.cpu().numpy() 129 | return pil_mask 130 | 131 | 132 | -------------------------------------------------------------------------------- /modules/diffedit_v3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler,DDIMInverseScheduler 4 | from typing import List, Optional, Union 5 | from diffusers.utils import randn_tensor 6 | 7 | class DiffEdit_v3(StableDiffusionImg2ImgPipeline): 8 | @torch.no_grad() 9 | def get_esitmate( 10 | self, 11 | prompt: Union[str, List[str]] = None, 12 | image: Union[torch.FloatTensor, Image.Image] = None, 13 | strength: float = 0.8, 14 | num_inference_steps: Optional[int] = 50, 15 | guidance_scale: Optional[float] = 7.5, 16 | eta: Optional[float] = 0.0, 17 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 18 | ): 19 | 20 | device = self._execution_device 21 | do_classifier_free_guidance = guidance_scale > 1.0 22 | # Encode input prompt 23 | prompt_embeds = self._encode_prompt( 24 | prompt, 25 | device, 26 | num_images_per_prompt=1, 27 | do_classifier_free_guidance=do_classifier_free_guidance, 28 | negative_prompt=None, 29 | prompt_embeds=None, 30 | negative_prompt_embeds=None, 31 | ) 32 | # Preprocess image 33 | image = self.image_processor.preprocess(image) 34 | 35 | # set timesteps 36 | self.scheduler.set_timesteps(num_inference_steps, device=device) 37 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 38 | latent_timestep = timesteps[:1].repeat(1) 39 | 40 | # Prepare latent variables 41 | latents = self.prepare_latents(image, latent_timestep, 1, 1, prompt_embeds.dtype, device, generator) 42 | noise_pred = torch.zeros_like(latents) 43 | # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 44 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 45 | # 8. Denoising loop 46 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 47 | with self.progress_bar(total=num_inference_steps) as progress_bar: 48 | for i, t in enumerate(timesteps): 49 | # expand the latents if we are doing classifier free guidance 50 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 51 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 52 | 53 | # predict the noise residual 54 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample 55 | 56 | # perform guidance 57 | if do_classifier_free_guidance: 58 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 59 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 60 | 61 | # compute the previous noisy sample x_t -> x_t-1 62 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 63 | 64 | # call the callback, if provided 65 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 66 | progress_bar.update() 67 | # size: 1 4 h/4 w/4 68 | return latents,noise_pred 69 | 70 | @torch.no_grad() 71 | def get_mask( 72 | self, 73 | latents_num: int = 10, 74 | refer_prompt: Union[str, List[str]] = None, 75 | query_prompt: Union[str, List[str]] = None, 76 | image: Union[torch.FloatTensor, Image.Image] = None, 77 | strength: float = 0.5, 78 | num_inference_steps: Optional[int] = 50, 79 | guidance_scale: Optional[float] = 7.5, 80 | seed: int = 2625, 81 | residual_guide:bool = True 82 | ): 83 | diff_list = [] 84 | # cycle for n(latents_num) times 85 | for index in range(latents_num): 86 | # get reference noise latents 87 | refer_latents,refer_noise_pred = self.get_esitmate( 88 | prompt=refer_prompt, 89 | image=image, 90 | strength=strength, 91 | num_inference_steps=num_inference_steps, 92 | guidance_scale=guidance_scale, 93 | generator=torch.Generator(device="cuda").manual_seed(seed * index) 94 | ) 95 | # get query noise latents 96 | query_latents,query_noise_pred = self.get_esitmate( 97 | prompt=query_prompt, 98 | image=image, 99 | strength=strength, 100 | num_inference_steps=num_inference_steps, 101 | guidance_scale=guidance_scale, 102 | generator=torch.Generator(device="cuda").manual_seed(seed * index) 103 | ) 104 | if residual_guide: 105 | diff_list.append(refer_noise_pred-query_noise_pred) 106 | else: 107 | diff_list.append(refer_latents - query_latents) 108 | # Creating a mask placeholder 109 | tensor_mask = torch.zeros_like(diff_list[0]) 110 | 111 | # Taking an average of 10 iterations 112 | for index in range(latents_num): 113 | tensor_mask += torch.abs(diff_list[index]) 114 | tensor_mask /= latents_num 115 | 116 | # Averaging multiple channels 117 | tensor_mask = tensor_mask.squeeze(0).mean(0) 118 | 119 | # Normalize 120 | tensor_mask = (tensor_mask - tensor_mask.min()) / (tensor_mask.max() - tensor_mask.min()) 121 | 122 | # Binarizing and returning the mask object 123 | tensor_mask = (tensor_mask>0.5) 124 | 125 | return tensor_mask 126 | 127 | @torch.no_grad() 128 | def get_latents( 129 | self, 130 | image: Optional[Image.Image] = None, 131 | strength: float = 0.8, 132 | num_inference_steps: Optional[int] = 50, 133 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 134 | ): 135 | device = self._execution_device 136 | 137 | # Preprocess image 138 | image = self.image_processor.preprocess(image) 139 | 140 | # set timesteps 141 | self.scheduler.set_timesteps(num_inference_steps, device=device) 142 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 143 | 144 | # Prepare latent variables 145 | latents_list = [] 146 | image = image.to(device=device, dtype=self.unet.dtype) 147 | latents = self.vae.encode(image).latent_dist.sample(generator) 148 | latents = self.vae.config.scaling_factor * latents 149 | 150 | # get noise 151 | noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=self.unet.dtype) 152 | 153 | # get latents list 154 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 155 | with self.progress_bar(total=num_inference_steps - 1) as progress_bar: 156 | for i, t in enumerate(timesteps): 157 | noise_latents = self.scheduler.add_noise(latents, noise, t) 158 | latents_list.append(noise_latents) 159 | 160 | # call the callback, if provided 161 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 162 | progress_bar.update() 163 | # image = self.decode_latents(noise_latents) 164 | # image = self.image_processor.postprocess(image, output_type='pil') 165 | # image[0].save(f'./nos/{i}.png') 166 | return latents_list 167 | 168 | @torch.no_grad() 169 | def __call__( 170 | self, 171 | query:Optional[str] = None, 172 | latents_set: Optional[List[torch.FloatTensor]] = None, 173 | mask:Optional[torch.FloatTensor] = None, 174 | strength: float = 0.8, 175 | num_inference_steps: Optional[int] = 50, 176 | guidance_scale: Optional[float] = 7.5, 177 | eta: Optional[float] = 0.0, 178 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 179 | ): 180 | # get hyperparamemter 181 | device = self._execution_device 182 | do_classifier_free_guidance = guidance_scale > 1.0 183 | 184 | # get embedding from reference prompt and query prompt, merge them 185 | query_prompt_embeds = self._encode_prompt( 186 | query, 187 | device, 188 | num_images_per_prompt=1, 189 | do_classifier_free_guidance=do_classifier_free_guidance, 190 | negative_prompt=None, 191 | prompt_embeds=None, 192 | negative_prompt_embeds=None, 193 | ) 194 | # consist of [uncond query] 195 | 196 | 197 | # set timesteps 198 | self.scheduler.set_timesteps(num_inference_steps, device=device) 199 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 200 | 201 | # Prepare latent variables and mask 202 | latents = latents_set[0] 203 | 204 | if len(mask.shape) == 2: 205 | tensor_mask = torch.cat([mask.unsqueeze(0)] * 4).unsqueeze(0) 206 | elif len(mask.shape) == 3: 207 | tensor_mask = torch.cat([mask] * 4).unsqueeze(0) 208 | 209 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 210 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 211 | 212 | # 8. Denoising loop 213 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 214 | with self.progress_bar(total=num_inference_steps) as progress_bar: 215 | for i, t in enumerate(timesteps): 216 | # expand the latents if we are doing classifier free guidance 217 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 218 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 219 | 220 | # predict the noise residual 221 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=query_prompt_embeds).sample 222 | 223 | # perform guidance 224 | if do_classifier_free_guidance: 225 | noise_pred_uncond,noise_pred_query = noise_pred.chunk(2) 226 | query_noise_residual = noise_pred_uncond + guidance_scale * (noise_pred_query - noise_pred_uncond) 227 | 228 | # compute the previous noisy sample x_t -> x_t-1 229 | latents_query = self.scheduler.step(query_noise_residual, t, latents, **extra_step_kwargs).prev_sample 230 | 231 | if i+1 num_warmup_steps and (i + 1) % self.scheduler.order == 0): 237 | progress_bar.update() 238 | 239 | image = self.decode_latents(latents) 240 | image = self.image_processor.postprocess(image, output_type='pil') 241 | 242 | return image 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /png/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/png/demo.png -------------------------------------------------------------------------------- /png/diffedit_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/png/diffedit_v1.png -------------------------------------------------------------------------------- /png/diffedit_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/png/diffedit_v2.png -------------------------------------------------------------------------------- /png/diffedit_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruilin19/DiffEdit-by-Stable-Diffusion/1366872092d78a5f87673730aeade7c3cb77fdc4/png/diffedit_v3.png --------------------------------------------------------------------------------