├── README.md ├── __pycache__ └── config.cpython-38.pyc ├── config.py ├── outputs ├── 1.png ├── 10-100 │ ├── 1.png │ ├── 1_canvas.png │ ├── 2.png │ ├── 2_canvas.png │ ├── 3.png │ ├── 3_canvas.png │ ├── 4.png │ ├── 4_canvas.png │ ├── 5.png │ ├── 5_canvas.png │ ├── 6.png │ ├── 6_canvas.png │ ├── 7.png │ ├── 7_canvas.png │ ├── 8.png │ ├── 8_canvas.png │ ├── 9.png │ └── 9_canvas.png ├── 10-200 │ ├── 1.png │ ├── 1_canvas.png │ ├── 2.png │ ├── 2_canvas.png │ ├── 3.png │ ├── 3_canvas.png │ ├── 4.png │ ├── 4_canvas.png │ ├── 5.png │ ├── 5_canvas.png │ ├── 6.png │ ├── 6_canvas.png │ ├── 7.png │ ├── 7_canvas.png │ ├── 8.png │ ├── 8_canvas.png │ ├── 9.png │ └── 9_canvas.png ├── 10-300 │ ├── 1.png │ ├── 1_canvas.png │ ├── 2.png │ ├── 2_canvas.png │ ├── 3.png │ ├── 3_canvas.png │ ├── 4.png │ ├── 4_canvas.png │ ├── 5.png │ ├── 5_canvas.png │ ├── 6.png │ ├── 6_canvas.png │ ├── 7.png │ ├── 7_canvas.png │ ├── 8.png │ ├── 8_canvas.png │ ├── 9.png │ └── 9_canvas.png ├── 25-100 │ ├── 1.png │ ├── 1_canvas.png │ ├── 2.png │ ├── 2_canvas.png │ ├── 3.png │ ├── 3_canvas.png │ ├── 4.png │ ├── 4_canvas.png │ ├── 5.png │ ├── 5_canvas.png │ ├── 6.png │ ├── 6_canvas.png │ ├── 7.png │ ├── 7_canvas.png │ ├── 8.png │ ├── 8_canvas.png │ ├── 9.png │ └── 9_canvas.png ├── 50-100 │ ├── 1.png │ ├── 1_canvas.png │ ├── 2.png │ ├── 2_canvas.png │ ├── 3.png │ ├── 3_canvas.png │ ├── 4.png │ ├── 4_canvas.png │ ├── 5.png │ ├── 5_canvas.png │ ├── 6.png │ ├── 6_canvas.png │ ├── 7.png │ ├── 7_canvas.png │ ├── 8.png │ ├── 8_canvas.png │ ├── 9.png │ └── 9_canvas.png ├── A rabbit wearing sunglasses looks very proud.png └── baseline-sdxl │ ├── 1.png │ ├── 1_canvas.png │ ├── 2.png │ ├── 2_canvas.png │ ├── 3.png │ ├── 3_canvas.png │ ├── 4.png │ ├── 4_canvas.png │ ├── 5.png │ ├── 5_canvas.png │ ├── 6.png │ ├── 6_canvas.png │ ├── 7.png │ ├── 7_canvas.png │ ├── 8.png │ ├── 8_canvas.png │ ├── 9.png │ └── 9_canvas.png ├── pipeline ├── __pycache__ │ ├── processors.cpython-38.pyc │ └── sdxl_pipeline_boxdiff.cpython-38.pyc ├── processors.py └── sdxl_pipeline_boxdiff.py ├── run_sd_boxdiff.py └── utils ├── __pycache__ ├── drawer.cpython-38.pyc ├── gaussian_smoothing.cpython-38.pyc ├── ptp_utils.cpython-38.pyc └── vis_utils.cpython-38.pyc ├── drawer.py ├── gaussian_smoothing.py ├── ptp_utils.py └── vis_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BoxDiff-XL 2 | 3 | Improve [BoxDiff](https://github.com/showlab/BoxDiff) to Stable Diffusion XL 4 | 5 | ```shell 6 | CUDA_VISIBLE_DEVICES=4 python run_sd_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4] --bbox [[67,87,366,512],[66,130,364,262]] --sd_xl True 7 | ``` 8 | 9 | 10 | # Some Experiments 11 | 12 | 13 | BoxDiff on SDXL is quite sensitive to the --max_iter_to_alter and weight_loss 14 | 15 | I did some simple experiments on above prompt. The results are in output files. 16 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | 6 | @dataclass 7 | class RunConfig: 8 | # Guiding text prompt 9 | prompt: str 10 | # Whether to use Stable Diffusion XL 11 | sd_xl: bool = False 12 | # Whether to use Stable Diffusion v2.1 13 | sd_2_1: bool = False 14 | # Which token indices to alter with attend-and-excite 15 | token_indices: List[int] = None 16 | # Which random seeds to use when generating 17 | seeds: List[int] = field(default_factory=lambda: [42]) 18 | # Path to save all outputs to 19 | output_path: Path = Path('./outputs') 20 | # Number of denoising steps 21 | n_inference_steps: int = 50 22 | # Text guidance scale 23 | guidance_scale: float = 7.5 24 | # Number of denoising steps to apply attend-and-excite 25 | max_iter_to_alter: int = 25 26 | # Resolution of UNet to compute attention maps over 27 | attention_res: int = 16 28 | # Whether to run standard SD or attend-and-excite 29 | run_standard_sd: bool = False 30 | # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in 31 | thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8}) 32 | # Scale factor for updating the denoised latent z_t 33 | scale_factor: int = 20 34 | # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep 35 | scale_range: tuple = field(default_factory=lambda: (1.0, 0.5)) 36 | # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token 37 | smooth_attentions: bool = True 38 | # Standard deviation for the Gaussian smoothing 39 | sigma: float = 0.5 40 | # Kernel size for the Gaussian smoothing 41 | kernel_size: int = 3 42 | # Whether to save cross attention maps for the final results 43 | save_cross_attention_maps: bool = False 44 | 45 | # BoxDiff 46 | bbox: List[list] = field(default_factory=lambda: [[], []]) 47 | color: List[str] = field(default_factory=lambda: ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black']) 48 | P: float = 0.2 49 | # number of pixels around the corner to be selected 50 | L: int = 1 51 | refine: bool = True 52 | gligen_phrases: List[str] = field(default_factory=lambda: ['', '']) 53 | n_splits: int = 4 54 | which_one: int = 1 55 | eval_output_path: Path = Path('./outputs/eval') 56 | weight_loss: int = 100 57 | 58 | 59 | def __post_init__(self): 60 | self.output_path.mkdir(exist_ok=True, parents=True) 61 | -------------------------------------------------------------------------------- /outputs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/1.png -------------------------------------------------------------------------------- /outputs/10-100/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/1.png -------------------------------------------------------------------------------- /outputs/10-100/1_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/1_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/2.png -------------------------------------------------------------------------------- /outputs/10-100/2_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/2_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/3.png -------------------------------------------------------------------------------- /outputs/10-100/3_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/3_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/4.png -------------------------------------------------------------------------------- /outputs/10-100/4_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/4_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/5.png -------------------------------------------------------------------------------- /outputs/10-100/5_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/5_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/6.png -------------------------------------------------------------------------------- /outputs/10-100/6_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/6_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/7.png -------------------------------------------------------------------------------- /outputs/10-100/7_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/7_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/8.png -------------------------------------------------------------------------------- /outputs/10-100/8_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/8_canvas.png -------------------------------------------------------------------------------- /outputs/10-100/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/9.png -------------------------------------------------------------------------------- /outputs/10-100/9_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-100/9_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/1.png -------------------------------------------------------------------------------- /outputs/10-200/1_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/1_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/2.png -------------------------------------------------------------------------------- /outputs/10-200/2_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/2_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/3.png -------------------------------------------------------------------------------- /outputs/10-200/3_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/3_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/4.png -------------------------------------------------------------------------------- /outputs/10-200/4_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/4_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/5.png -------------------------------------------------------------------------------- /outputs/10-200/5_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/5_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/6.png -------------------------------------------------------------------------------- /outputs/10-200/6_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/6_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/7.png -------------------------------------------------------------------------------- /outputs/10-200/7_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/7_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/8.png -------------------------------------------------------------------------------- /outputs/10-200/8_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/8_canvas.png -------------------------------------------------------------------------------- /outputs/10-200/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/9.png -------------------------------------------------------------------------------- /outputs/10-200/9_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-200/9_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/1.png -------------------------------------------------------------------------------- /outputs/10-300/1_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/1_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/2.png -------------------------------------------------------------------------------- /outputs/10-300/2_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/2_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/3.png -------------------------------------------------------------------------------- /outputs/10-300/3_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/3_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/4.png -------------------------------------------------------------------------------- /outputs/10-300/4_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/4_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/5.png -------------------------------------------------------------------------------- /outputs/10-300/5_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/5_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/6.png -------------------------------------------------------------------------------- /outputs/10-300/6_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/6_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/7.png -------------------------------------------------------------------------------- /outputs/10-300/7_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/7_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/8.png -------------------------------------------------------------------------------- /outputs/10-300/8_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/8_canvas.png -------------------------------------------------------------------------------- /outputs/10-300/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/9.png -------------------------------------------------------------------------------- /outputs/10-300/9_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/10-300/9_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/1.png -------------------------------------------------------------------------------- /outputs/25-100/1_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/1_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/2.png -------------------------------------------------------------------------------- /outputs/25-100/2_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/2_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/3.png -------------------------------------------------------------------------------- /outputs/25-100/3_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/3_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/4.png -------------------------------------------------------------------------------- /outputs/25-100/4_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/4_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/5.png -------------------------------------------------------------------------------- /outputs/25-100/5_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/5_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/6.png -------------------------------------------------------------------------------- /outputs/25-100/6_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/6_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/7.png -------------------------------------------------------------------------------- /outputs/25-100/7_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/7_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/8.png -------------------------------------------------------------------------------- /outputs/25-100/8_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/8_canvas.png -------------------------------------------------------------------------------- /outputs/25-100/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/9.png -------------------------------------------------------------------------------- /outputs/25-100/9_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/25-100/9_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/1.png -------------------------------------------------------------------------------- /outputs/50-100/1_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/1_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/2.png -------------------------------------------------------------------------------- /outputs/50-100/2_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/2_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/3.png -------------------------------------------------------------------------------- /outputs/50-100/3_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/3_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/4.png -------------------------------------------------------------------------------- /outputs/50-100/4_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/4_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/5.png -------------------------------------------------------------------------------- /outputs/50-100/5_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/5_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/6.png -------------------------------------------------------------------------------- /outputs/50-100/6_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/6_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/7.png -------------------------------------------------------------------------------- /outputs/50-100/7_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/7_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/8.png -------------------------------------------------------------------------------- /outputs/50-100/8_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/8_canvas.png -------------------------------------------------------------------------------- /outputs/50-100/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/9.png -------------------------------------------------------------------------------- /outputs/50-100/9_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/50-100/9_canvas.png -------------------------------------------------------------------------------- /outputs/A rabbit wearing sunglasses looks very proud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/A rabbit wearing sunglasses looks very proud.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/1.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/1_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/1_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/2.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/2_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/2_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/3.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/3_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/3_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/4.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/4_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/4_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/5.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/5_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/5_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/6.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/6_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/6_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/7.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/7_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/7_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/8.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/8_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/8_canvas.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/9.png -------------------------------------------------------------------------------- /outputs/baseline-sdxl/9_canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/outputs/baseline-sdxl/9_canvas.png -------------------------------------------------------------------------------- /pipeline/__pycache__/processors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/pipeline/__pycache__/processors.cpython-38.pyc -------------------------------------------------------------------------------- /pipeline/__pycache__/sdxl_pipeline_boxdiff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/pipeline/__pycache__/sdxl_pipeline_boxdiff.cpython-38.pyc -------------------------------------------------------------------------------- /pipeline/processors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from typing import Dict, List, Optional, Tuple, Union 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from diffusers.models.attention import Attention 9 | 10 | 11 | class P2PCrossAttnProcessor: 12 | def __init__(self, controller, place_in_unet): 13 | super().__init__() 14 | self.controller = controller 15 | self.place_in_unet = place_in_unet 16 | 17 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): 18 | batch_size, sequence_length, _ = hidden_states.shape 19 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 20 | 21 | query = attn.to_q(hidden_states) 22 | 23 | is_cross = encoder_hidden_states is not None 24 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 25 | key = attn.to_k(encoder_hidden_states) 26 | value = attn.to_v(encoder_hidden_states) 27 | 28 | query = attn.head_to_batch_dim(query) 29 | key = attn.head_to_batch_dim(key) 30 | value = attn.head_to_batch_dim(value) 31 | 32 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 33 | 34 | # one line change 35 | self.controller(attention_probs, is_cross, self.place_in_unet) 36 | 37 | hidden_states = torch.bmm(attention_probs, value) 38 | hidden_states = attn.batch_to_head_dim(hidden_states) 39 | 40 | # linear proj 41 | hidden_states = attn.to_out[0](hidden_states) 42 | # dropout 43 | hidden_states = attn.to_out[1](hidden_states) 44 | 45 | return hidden_states 46 | 47 | 48 | def create_controller( 49 | prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device, attn_res 50 | ) -> AttentionControl: 51 | edit_type = cross_attention_kwargs.get("edit_type", None) 52 | local_blend_words = cross_attention_kwargs.get("local_blend_words", None) 53 | equalizer_words = cross_attention_kwargs.get("equalizer_words", None) 54 | equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None) 55 | n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4) 56 | n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4) 57 | 58 | # only replace 59 | if edit_type == "replace" and local_blend_words is None: 60 | return AttentionReplace( 61 | [prompts, prompts], num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res 62 | ) 63 | 64 | # replace + localblend 65 | if edit_type == "replace" and local_blend_words is not None: 66 | lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) 67 | return AttentionReplace( 68 | prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res 69 | ) 70 | 71 | # only refine 72 | if edit_type == "refine" and local_blend_words is None: 73 | return AttentionRefine( 74 | prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res 75 | ) 76 | 77 | # refine + localblend 78 | if edit_type == "refine" and local_blend_words is not None: 79 | lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) 80 | return AttentionRefine( 81 | prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res 82 | ) 83 | 84 | # only reweight 85 | if edit_type == "reweight" and local_blend_words is None: 86 | assert ( 87 | equalizer_words is not None and equalizer_strengths is not None 88 | ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." 89 | assert len(equalizer_words) == len( 90 | equalizer_strengths 91 | ), "equalizer_words and equalizer_strengths must be of same length." 92 | equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) 93 | return AttentionReweight( 94 | prompts, 95 | num_inference_steps, 96 | n_cross_replace, 97 | n_self_replace, 98 | tokenizer=tokenizer, 99 | device=device, 100 | equalizer=equalizer, 101 | attn_res=attn_res, 102 | ) 103 | 104 | # reweight and localblend 105 | if edit_type == "reweight" and local_blend_words: 106 | assert ( 107 | equalizer_words is not None and equalizer_strengths is not None 108 | ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." 109 | assert len(equalizer_words) == len( 110 | equalizer_strengths 111 | ), "equalizer_words and equalizer_strengths must be of same length." 112 | equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) 113 | lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) 114 | return AttentionReweight( 115 | prompts, 116 | num_inference_steps, 117 | n_cross_replace, 118 | n_self_replace, 119 | tokenizer=tokenizer, 120 | device=device, 121 | equalizer=equalizer, 122 | attn_res=attn_res, 123 | local_blend=lb, 124 | ) 125 | 126 | raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.") 127 | 128 | 129 | class AttentionControl(abc.ABC): 130 | def step_callback(self, x_t): 131 | return x_t 132 | 133 | def between_steps(self): 134 | return 135 | 136 | @property 137 | def num_uncond_att_layers(self): 138 | return 0 139 | 140 | @abc.abstractmethod 141 | def forward(self, attn, is_cross: bool, place_in_unet: str): 142 | raise NotImplementedError 143 | 144 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 145 | if self.cur_att_layer >= self.num_uncond_att_layers: 146 | h = attn.shape[0] 147 | attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) 148 | self.cur_att_layer += 1 149 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 150 | self.cur_att_layer = 0 151 | self.cur_step += 1 152 | self.between_steps() 153 | return attn 154 | 155 | def reset(self): 156 | self.cur_step = 0 157 | self.cur_att_layer = 0 158 | 159 | def __init__(self, attn_res=None): 160 | self.cur_step = 0 161 | self.num_att_layers = -1 162 | self.cur_att_layer = 0 163 | self.attn_res = attn_res 164 | 165 | 166 | class EmptyControl(AttentionControl): 167 | def forward(self, attn, is_cross: bool, place_in_unet: str): 168 | return attn 169 | 170 | 171 | class AttentionStore(AttentionControl): 172 | @staticmethod 173 | def get_empty_store(): 174 | return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} 175 | 176 | def forward(self, attn, is_cross: bool, place_in_unet: str): 177 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 178 | if attn.shape[1] <= 32**2: # avoid memory overhead 179 | self.step_store[key].append(attn) 180 | return attn 181 | 182 | def between_steps(self): 183 | if len(self.attention_store) == 0: 184 | self.attention_store = self.step_store 185 | else: 186 | for key in self.attention_store: 187 | for i in range(len(self.attention_store[key])): 188 | self.attention_store[key][i] += self.step_store[key][i] 189 | self.step_store = self.get_empty_store() 190 | 191 | def get_average_attention(self): 192 | average_attention = { 193 | key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store 194 | } 195 | return average_attention 196 | 197 | def reset(self): 198 | super(AttentionStore, self).reset() 199 | self.step_store = self.get_empty_store() 200 | self.attention_store = {} 201 | 202 | def __init__(self, attn_res=None): 203 | super(AttentionStore, self).__init__(attn_res) 204 | self.step_store = self.get_empty_store() 205 | self.attention_store = {} 206 | 207 | 208 | 209 | class LocalBlend: 210 | def __call__(self, x_t, attention_store): 211 | # note that this code works on the latent level! 212 | k = 1 213 | # maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] # These are the numbers because we want to take layers that are 256 x 256, I think this can be changed to something smarter...like, get all attentions where thesecond dim is self.attn_res[0] * self.attn_res[1] in up and down cross. 214 | maps = [m for m in attention_store["down_cross"] + attention_store["mid_cross"] + attention_store["up_cross"] if m.shape[1] == self.attn_res[0] * self.attn_res[1]] 215 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, self.attn_res[0], self.attn_res[1], self.max_num_words) for item in maps] 216 | maps = torch.cat(maps, dim=1) 217 | maps = (maps * self.alpha_layers).sum(-1).mean(1) # since alpha_layers is all 0s except where we edit, the product zeroes out all but what we change. Then, the sum adds the values of the original and what we edit. Then, we average across dim=1, which is the number of layers. 218 | mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) 219 | mask = F.interpolate(mask, size=(x_t.shape[2:])) 220 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 221 | mask = mask.gt(self.threshold) 222 | 223 | mask = mask[:1] + mask[1:] 224 | mask = mask.to(torch.float16) 225 | 226 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) # x_t[:1] is the original image. mask*(x_t - x_t[:1]) zeroes out the original image and removes the difference between the original and each image we are generating (mostly just one). Then, it applies the mask on the image. That is, it's only keeping the cells we want to generate. 227 | return x_t 228 | 229 | def __init__( 230 | self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, attn_res=None 231 | ): 232 | self.max_num_words = 77 233 | self.attn_res = attn_res 234 | 235 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) 236 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 237 | if isinstance(words_, str): 238 | words_ = [words_] 239 | for word in words_: 240 | ind = get_word_inds(prompt, word, tokenizer) 241 | alpha_layers[i, :, :, :, :, ind] = 1 242 | self.alpha_layers = alpha_layers.to(device) # a one-hot vector where the 1s are the words we modify (source and target) 243 | self.threshold = threshold 244 | 245 | 246 | class AttentionControlEdit(AttentionStore, abc.ABC): 247 | def step_callback(self, x_t): 248 | if self.local_blend is not None: 249 | x_t = self.local_blend(x_t, self.attention_store) 250 | return x_t 251 | 252 | def replace_self_attention(self, attn_base, att_replace): 253 | if att_replace.shape[2] <= self.attn_res[0]**2: 254 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 255 | else: 256 | return att_replace 257 | 258 | @abc.abstractmethod 259 | def replace_cross_attention(self, attn_base, att_replace): 260 | raise NotImplementedError 261 | 262 | def forward(self, attn, is_cross: bool, place_in_unet: str): 263 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 264 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 265 | h = attn.shape[0] // (self.batch_size) 266 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 267 | attn_base, attn_replace = attn[0], attn[1:] 268 | if is_cross: 269 | alpha_words = self.cross_replace_alpha[self.cur_step] 270 | attn_replace_new = ( 271 | self.replace_cross_attention(attn_base, attn_replace) * alpha_words 272 | + (1 - alpha_words) * attn_replace 273 | ) 274 | attn[1:] = attn_replace_new 275 | else: 276 | attn[1:] = self.replace_self_attention(attn_base, attn_replace) 277 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 278 | return attn 279 | 280 | def __init__( 281 | self, 282 | prompts, 283 | num_steps: int, 284 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 285 | self_replace_steps: Union[float, Tuple[float, float]], 286 | local_blend: Optional[LocalBlend], 287 | tokenizer, 288 | device, 289 | attn_res=None, 290 | ): 291 | super(AttentionControlEdit, self).__init__(attn_res=attn_res) 292 | # add tokenizer and device here 293 | 294 | self.tokenizer = tokenizer 295 | self.device = device 296 | 297 | self.batch_size = len(prompts) 298 | self.cross_replace_alpha = get_time_words_attention_alpha( 299 | prompts, num_steps, cross_replace_steps, self.tokenizer 300 | ).to(self.device) 301 | if isinstance(self_replace_steps, float): 302 | self_replace_steps = 0, self_replace_steps 303 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 304 | self.local_blend = local_blend 305 | 306 | 307 | class AttentionReplace(AttentionControlEdit): 308 | def replace_cross_attention(self, attn_base, att_replace): 309 | return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) 310 | 311 | def __init__( 312 | self, 313 | prompts, 314 | num_steps: int, 315 | cross_replace_steps: float, 316 | self_replace_steps: float, 317 | local_blend: Optional[LocalBlend] = None, 318 | tokenizer=None, 319 | device=None, 320 | attn_res=None, 321 | ): 322 | super(AttentionReplace, self).__init__( 323 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res 324 | ) 325 | self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) 326 | 327 | 328 | class AttentionRefine(AttentionControlEdit): 329 | def replace_cross_attention(self, attn_base, att_replace): 330 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 331 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 332 | return attn_replace 333 | 334 | def __init__( 335 | self, 336 | prompts, 337 | num_steps: int, 338 | cross_replace_steps: float, 339 | self_replace_steps: float, 340 | local_blend: Optional[LocalBlend] = None, 341 | tokenizer=None, 342 | device=None, 343 | attn_res=None 344 | ): 345 | super(AttentionRefine, self).__init__( 346 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res 347 | ) 348 | self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) 349 | self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) 350 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 351 | 352 | 353 | class AttentionReweight(AttentionControlEdit): 354 | def replace_cross_attention(self, attn_base, att_replace): 355 | if self.prev_controller is not None: 356 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 357 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 358 | return attn_replace 359 | 360 | def __init__( 361 | self, 362 | prompts, 363 | num_steps: int, 364 | cross_replace_steps: float, 365 | self_replace_steps: float, 366 | equalizer, 367 | local_blend: Optional[LocalBlend] = None, 368 | controller: Optional[AttentionControlEdit] = None, 369 | tokenizer=None, 370 | device=None, 371 | attn_res=None, 372 | ): 373 | super(AttentionReweight, self).__init__( 374 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res 375 | ) 376 | self.equalizer = equalizer.to(self.device) 377 | self.prev_controller = controller 378 | 379 | 380 | ### util functions for all Edits 381 | def update_alpha_time_word( 382 | alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None 383 | ): 384 | if isinstance(bounds, float): 385 | bounds = 0, bounds 386 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 387 | if word_inds is None: 388 | word_inds = torch.arange(alpha.shape[2]) 389 | alpha[:start, prompt_ind, word_inds] = 0 390 | alpha[start:end, prompt_ind, word_inds] = 1 391 | alpha[end:, prompt_ind, word_inds] = 0 392 | return alpha 393 | 394 | 395 | def get_time_words_attention_alpha( 396 | prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 397 | ): 398 | if not isinstance(cross_replace_steps, dict): 399 | cross_replace_steps = {"default_": cross_replace_steps} 400 | if "default_" not in cross_replace_steps: 401 | cross_replace_steps["default_"] = (0.0, 1.0) 402 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 403 | for i in range(len(prompts) - 1): 404 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) 405 | for key, item in cross_replace_steps.items(): 406 | if key != "default_": 407 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 408 | for i, ind in enumerate(inds): 409 | if len(ind) > 0: 410 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 411 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 412 | return alpha_time_words 413 | 414 | 415 | ### util functions for LocalBlend and ReplacementEdit 416 | def get_word_inds(text: str, word_place: int, tokenizer): 417 | split_text = text.split(" ") 418 | if isinstance(word_place, str): 419 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 420 | elif isinstance(word_place, int): 421 | word_place = [word_place] 422 | out = [] 423 | if len(word_place) > 0: 424 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 425 | cur_len, ptr = 0, 0 426 | 427 | for i in range(len(words_encode)): 428 | cur_len += len(words_encode[i]) 429 | if ptr in word_place: 430 | out.append(i + 1) 431 | if cur_len >= len(split_text[ptr]): 432 | ptr += 1 433 | cur_len = 0 434 | return np.array(out) 435 | 436 | 437 | ### util functions for ReplacementEdit 438 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 439 | words_x = x.split(" ") 440 | words_y = y.split(" ") 441 | if len(words_x) != len(words_y): 442 | raise ValueError( 443 | f"attention replacement edit can only be applied on prompts with the same length" 444 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words." 445 | ) 446 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 447 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 448 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 449 | mapper = np.zeros((max_len, max_len)) 450 | i = j = 0 451 | cur_inds = 0 452 | while i < max_len and j < max_len: 453 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 454 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 455 | if len(inds_source_) == len(inds_target_): 456 | mapper[inds_source_, inds_target_] = 1 457 | else: 458 | ratio = 1 / len(inds_target_) 459 | for i_t in inds_target_: 460 | mapper[inds_source_, i_t] = ratio 461 | cur_inds += 1 462 | i += len(inds_source_) 463 | j += len(inds_target_) 464 | elif cur_inds < len(inds_source): 465 | mapper[i, j] = 1 466 | i += 1 467 | j += 1 468 | else: 469 | mapper[j, j] = 1 470 | i += 1 471 | j += 1 472 | 473 | # return torch.from_numpy(mapper).float() 474 | return torch.from_numpy(mapper).to(torch.float16) 475 | 476 | 477 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 478 | x_seq = prompts[0] 479 | mappers = [] 480 | for i in range(1, len(prompts)): 481 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 482 | mappers.append(mapper) 483 | return torch.stack(mappers) 484 | 485 | 486 | ### util functions for ReweightEdit 487 | def get_equalizer( 488 | text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer 489 | ): 490 | if isinstance(word_select, (int, str)): 491 | word_select = (word_select,) 492 | equalizer = torch.ones(len(values), 77) 493 | values = torch.tensor(values, dtype=torch.float32) 494 | for i, word in enumerate(word_select): 495 | inds = get_word_inds(text, word, tokenizer) 496 | equalizer[:, inds] = torch.FloatTensor(values[i]) 497 | return equalizer 498 | 499 | 500 | ### util functions for RefinementEdit 501 | class ScoreParams: 502 | def __init__(self, gap, match, mismatch): 503 | self.gap = gap 504 | self.match = match 505 | self.mismatch = mismatch 506 | 507 | def mis_match_char(self, x, y): 508 | if x != y: 509 | return self.mismatch 510 | else: 511 | return self.match 512 | 513 | 514 | def get_matrix(size_x, size_y, gap): 515 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 516 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 517 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 518 | return matrix 519 | 520 | 521 | def get_traceback_matrix(size_x, size_y): 522 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 523 | matrix[0, 1:] = 1 524 | matrix[1:, 0] = 2 525 | matrix[0, 0] = 4 526 | return matrix 527 | 528 | 529 | def global_align(x, y, score): 530 | matrix = get_matrix(len(x), len(y), score.gap) 531 | trace_back = get_traceback_matrix(len(x), len(y)) 532 | for i in range(1, len(x) + 1): 533 | for j in range(1, len(y) + 1): 534 | left = matrix[i, j - 1] + score.gap 535 | up = matrix[i - 1, j] + score.gap 536 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 537 | matrix[i, j] = max(left, up, diag) 538 | if matrix[i, j] == left: 539 | trace_back[i, j] = 1 540 | elif matrix[i, j] == up: 541 | trace_back[i, j] = 2 542 | else: 543 | trace_back[i, j] = 3 544 | return matrix, trace_back 545 | 546 | 547 | def get_aligned_sequences(x, y, trace_back): 548 | x_seq = [] 549 | y_seq = [] 550 | i = len(x) 551 | j = len(y) 552 | mapper_y_to_x = [] 553 | while i > 0 or j > 0: 554 | if trace_back[i, j] == 3: 555 | x_seq.append(x[i - 1]) 556 | y_seq.append(y[j - 1]) 557 | i = i - 1 558 | j = j - 1 559 | mapper_y_to_x.append((j, i)) 560 | elif trace_back[i][j] == 1: 561 | x_seq.append("-") 562 | y_seq.append(y[j - 1]) 563 | j = j - 1 564 | mapper_y_to_x.append((j, -1)) 565 | elif trace_back[i][j] == 2: 566 | x_seq.append(x[i - 1]) 567 | y_seq.append("-") 568 | i = i - 1 569 | elif trace_back[i][j] == 4: 570 | break 571 | mapper_y_to_x.reverse() 572 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 573 | 574 | 575 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 576 | x_seq = tokenizer.encode(x) 577 | y_seq = tokenizer.encode(y) 578 | score = ScoreParams(0, 1, -1) 579 | matrix, trace_back = global_align(x_seq, y_seq, score) 580 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 581 | alphas = torch.ones(max_len) 582 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 583 | mapper = torch.zeros(max_len, dtype=torch.int64) 584 | mapper[: mapper_base.shape[0]] = mapper_base[:, 1] 585 | mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq)) 586 | return mapper, alphas 587 | 588 | 589 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 590 | x_seq = prompts[0] 591 | mappers, alphas = [], [] 592 | for i in range(1, len(prompts)): 593 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 594 | mappers.append(mapper) 595 | alphas.append(alpha) 596 | return torch.stack(mappers), torch.stack(alphas) 597 | -------------------------------------------------------------------------------- /pipeline/sdxl_pipeline_boxdiff.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 2 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline 3 | from utils.gaussian_smoothing import GaussianSmoothing 4 | from .processors import * 5 | from dataclasses import fields, dataclass 6 | from collections import OrderedDict 7 | import PIL 8 | 9 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg 10 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 11 | """ 12 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 13 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 14 | """ 15 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 16 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 17 | # rescale the results from guidance (fixes overexposure) 18 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 19 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 20 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 21 | return noise_cfg 22 | 23 | def aggregate_attention(attention_store: AttentionStore, 24 | res: int, 25 | from_where: List[str], 26 | is_cross: bool, 27 | select: int) -> torch.Tensor: 28 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 29 | out = [] 30 | attention_maps = attention_store.get_average_attention() 31 | 32 | num_pixels = res ** 2 33 | for location in from_where: 34 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 35 | if item.shape[1] == num_pixels: 36 | # head = 20 37 | item = item[20:,...] # remove unconditional branch 38 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 39 | out.append(cross_maps) 40 | out = torch.cat(out, dim=0) 41 | out = out.sum(0) / out.shape[0] 42 | return out 43 | 44 | 45 | class BaseOutput(OrderedDict): 46 | def __post_init__(self): 47 | class_fields = fields(self) 48 | 49 | # Safety and consistency checks 50 | if not len(class_fields): 51 | raise ValueError(f"{self.__class__.__name__} has no fields.") 52 | 53 | first_field = getattr(self, class_fields[0].name) 54 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) 55 | 56 | if other_fields_are_none and isinstance(first_field, dict): 57 | for key, value in first_field.items(): 58 | self[key] = value 59 | else: 60 | for field in class_fields: 61 | v = getattr(self, field.name) 62 | if v is not None: 63 | self[field.name] = v 64 | 65 | def __delitem__(self, *args, **kwargs): 66 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") 67 | 68 | def setdefault(self, *args, **kwargs): 69 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") 70 | 71 | def pop(self, *args, **kwargs): 72 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") 73 | 74 | def update(self, *args, **kwargs): 75 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") 76 | 77 | def __getitem__(self, k): 78 | if isinstance(k, str): 79 | inner_dict = dict(self.items()) 80 | return inner_dict[k] 81 | else: 82 | return self.to_tuple()[k] 83 | 84 | def __setattr__(self, name, value): 85 | if name in self.keys() and value is not None: 86 | # Don't call self.__setitem__ to avoid recursion errors 87 | super().__setitem__(name, value) 88 | super().__setattr__(name, value) 89 | 90 | def __setitem__(self, key, value): 91 | # Will raise a KeyException if needed 92 | super().__setitem__(key, value) 93 | # Don't call self.__setattr__ to avoid recursion errors 94 | super().__setattr__(key, value) 95 | 96 | def to_tuple(self) -> Tuple[Any]: 97 | return tuple(self[k] for k in self.keys()) 98 | 99 | @dataclass 100 | class StableDiffusionXLPipelineOutput(BaseOutput): 101 | images: Union[List[PIL.Image.Image], np.ndarray] 102 | 103 | 104 | class BoxDiffPipeline(StableDiffusionXLPipeline): 105 | r""" 106 | Args: 107 | Prompt-to-Prompt-Pipeline for text-to-image generation using Stable Diffusion. This model inherits from 108 | [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for 109 | all the pipelines (such as downloading or saving, running on a particular device, etc.) 110 | vae ([`AutoencoderKL`]): 111 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 112 | text_encoder ([`CLIPTextModel`]): 113 | Frozen text-encoder. Stable Diffusion uses the text portion of 114 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 115 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 116 | tokenizer (`CLIPTokenizer`): 117 | Tokenizer of class 118 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 119 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler 120 | ([`SchedulerMixin`]): 121 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 122 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 123 | safety_checker ([`StableDiffusionSafetyChecker`]): 124 | Classification module that estimates whether generated images could be considered offensive or harmful. 125 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 126 | feature_extractor ([`CLIPFeatureExtractor`]): 127 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 128 | """ 129 | 130 | _optional_components = ["safety_checker", "feature_extractor"] 131 | 132 | def check_inputs( 133 | self, 134 | prompt, 135 | prompt_2, 136 | height, 137 | width, 138 | callback_steps, 139 | negative_prompt=None, 140 | negative_prompt_2=None, 141 | prompt_embeds=None, 142 | negative_prompt_embeds=None, 143 | pooled_prompt_embeds=None, 144 | negative_pooled_prompt_embeds=None, 145 | ): 146 | if height % 8 != 0 or width % 8 != 0: 147 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 148 | 149 | if (callback_steps is None) or ( 150 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 151 | ): 152 | raise ValueError( 153 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 154 | f" {type(callback_steps)}." 155 | ) 156 | 157 | if prompt is not None and prompt_embeds is not None: 158 | raise ValueError( 159 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 160 | " only forward one of the two." 161 | ) 162 | elif prompt_2 is not None and prompt_embeds is not None: 163 | raise ValueError( 164 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 165 | " only forward one of the two." 166 | ) 167 | elif prompt is None and prompt_embeds is None: 168 | raise ValueError( 169 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 170 | ) 171 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 172 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 173 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 174 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 175 | 176 | if negative_prompt is not None and negative_prompt_embeds is not None: 177 | raise ValueError( 178 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 179 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 180 | ) 181 | elif negative_prompt_2 is not None and negative_prompt_embeds is not None: 182 | raise ValueError( 183 | f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" 184 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 185 | ) 186 | 187 | if prompt_embeds is not None and negative_prompt_embeds is not None: 188 | if prompt_embeds.shape != negative_prompt_embeds.shape: 189 | raise ValueError( 190 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 191 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 192 | f" {negative_prompt_embeds.shape}." 193 | ) 194 | 195 | if prompt_embeds is not None and pooled_prompt_embeds is None: 196 | raise ValueError( 197 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 198 | ) 199 | 200 | if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: 201 | raise ValueError( 202 | "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." 203 | ) 204 | 205 | def _aggregate_and_get_attention_maps_per_token(self, with_softmax): 206 | attention_maps = self.controller.aggregate_attention( 207 | from_where=("up_cross", "down_cross", "mid_cross"), 208 | # from_where=("up", "down"), 209 | # from_where=("down",) 210 | ) 211 | attention_maps_list = self._get_attention_maps_list( 212 | attention_maps=attention_maps, with_softmax=with_softmax 213 | ) 214 | return attention_maps_list 215 | 216 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore, 217 | indices_to_alter: List[int], 218 | attention_res: int = 16, 219 | smooth_attentions: bool = False, 220 | sigma: float = 0.5, 221 | kernel_size: int = 3, 222 | normalize_eot: bool = False, 223 | bbox: List[int] = None, 224 | config=None, 225 | ): 226 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """ 227 | attention_maps = aggregate_attention( 228 | attention_store=attention_store, 229 | res=attention_res, 230 | from_where=("up", "down", "mid"), 231 | is_cross=True, 232 | select=0) 233 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index( 234 | attention_maps=attention_maps, 235 | indices_to_alter=indices_to_alter, 236 | smooth_attentions=smooth_attentions, 237 | sigma=sigma, 238 | kernel_size=kernel_size, 239 | normalize_eot=normalize_eot, 240 | bbox=bbox, 241 | config=config, 242 | ) 243 | return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y 244 | 245 | def _compute_max_attention_per_index(self, 246 | attention_maps: torch.Tensor, 247 | indices_to_alter: List[int], 248 | smooth_attentions: bool = False, 249 | sigma: float = 0.5, 250 | kernel_size: int = 3, 251 | normalize_eot: bool = False, 252 | bbox: List[int] = None, 253 | config=None, 254 | ) -> List[torch.Tensor]: 255 | """ Computes the maximum attention value for each of the tokens we wish to alter. """ 256 | last_idx = -1 257 | if normalize_eot: 258 | prompt = self.prompt 259 | if isinstance(self.prompt, list): 260 | prompt = self.prompt[0] 261 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1 262 | attention_for_text = attention_maps[:, :, 1:last_idx] 263 | attention_for_text *= 100 264 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) 265 | 266 | # Shift indices since we removed the first token 267 | indices_to_alter = [index - 1 for index in indices_to_alter] 268 | 269 | # Extract the maximum values 270 | max_indices_list_fg = [] 271 | max_indices_list_bg = [] 272 | dist_x = [] 273 | dist_y = [] 274 | 275 | cnt = 0 276 | for i in indices_to_alter: 277 | image = attention_for_text[:, :, i] 278 | 279 | box = [max(round(b / (512 / image.shape[0])), 0) for b in bbox[cnt]] 280 | x1, y1, x2, y2 = box 281 | cnt += 1 282 | 283 | # coordinates to masks 284 | obj_mask = torch.zeros_like(image) 285 | ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device) 286 | obj_mask[y1:y2, x1:x2] = ones_mask 287 | bg_mask = 1 - obj_mask 288 | 289 | if smooth_attentions: 290 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() 291 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') 292 | image = smoothing(input).squeeze(0).squeeze(0) 293 | 294 | # Inner-Box constraint 295 | k = (obj_mask.sum() * config.P).long() 296 | max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean()) 297 | 298 | # Outer-Box constraint 299 | k = (bg_mask.sum() * config.P).long() 300 | max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean()) 301 | 302 | # Corner Constraint 303 | gt_proj_x = torch.max(obj_mask, dim=0)[0] 304 | gt_proj_y = torch.max(obj_mask, dim=1)[0] 305 | corner_mask_x = torch.zeros_like(gt_proj_x) 306 | corner_mask_y = torch.zeros_like(gt_proj_y) 307 | 308 | # create gt according to the number config.L 309 | N = gt_proj_x.shape[0] 310 | corner_mask_x[max(box[0] - config.L, 0): min(box[0] + config.L + 1, N)] = 1. 311 | corner_mask_x[max(box[2] - config.L, 0): min(box[2] + config.L + 1, N)] = 1. 312 | corner_mask_y[max(box[1] - config.L, 0): min(box[1] + config.L + 1, N)] = 1. 313 | corner_mask_y[max(box[3] - config.L, 0): min(box[3] + config.L + 1, N)] = 1. 314 | dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction='none') * corner_mask_x).mean()) 315 | dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction='none') * corner_mask_y).mean()) 316 | 317 | return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y 318 | 319 | @staticmethod 320 | def _get_attention_maps_list( 321 | attention_maps: torch.Tensor, with_softmax 322 | ) -> List[torch.Tensor]: 323 | attention_maps *= 100 324 | 325 | if with_softmax: 326 | attention_maps = torch.nn.functional.softmax(attention_maps, dim=-1) 327 | 328 | attention_maps_list = [ 329 | attention_maps[:, :, i] for i in range(attention_maps.shape[2]) 330 | ] 331 | return attention_maps_list 332 | 333 | @staticmethod 334 | def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor], 335 | dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor: 336 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """ 337 | losses_fg = [max(0, 1. - curr_max) for curr_max in max_attention_per_index_fg] 338 | losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg] 339 | loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y) 340 | if return_losses: 341 | return max(losses_fg), losses_fg 342 | else: 343 | return max(losses_fg), loss 344 | 345 | def _perform_iterative_refinement_step(self, 346 | latents: torch.Tensor, 347 | indices_to_alter: List[int], 348 | loss_fg: torch.Tensor, 349 | threshold: float, 350 | text_embeddings: torch.Tensor, 351 | attention_store: AttentionStore, 352 | step_size: float, 353 | t: int, 354 | attention_res: int = 16, 355 | smooth_attentions: bool = True, 356 | sigma: float = 0.5, 357 | kernel_size: int = 3, 358 | max_refinement_steps: int = 20, 359 | normalize_eot: bool = False, 360 | bbox: List[int] = None, 361 | config=None, 362 | ): 363 | """ 364 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent 365 | code according to our loss objective until the given threshold is reached for all tokens. 366 | """ 367 | iteration = 0 368 | target_loss = max(0, 1. - threshold) 369 | while loss_fg > target_loss: 370 | iteration += 1 371 | 372 | latents = latents.clone().detach().requires_grad_(True) 373 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 374 | self.unet.zero_grad() 375 | 376 | # Get max activation value for each subject token 377 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 378 | attention_store=attention_store, 379 | indices_to_alter=indices_to_alter, 380 | attention_res=attention_res, 381 | smooth_attentions=smooth_attentions, 382 | sigma=sigma, 383 | kernel_size=kernel_size, 384 | normalize_eot=normalize_eot, 385 | bbox=bbox, 386 | config=config, 387 | ) 388 | 389 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True) 390 | 391 | if loss_fg != 0: 392 | latents = self._update_latent(latents, loss_fg, step_size) 393 | 394 | with torch.no_grad(): 395 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample 396 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 397 | 398 | try: 399 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses_fg]) 400 | except Exception as e: 401 | print(e) # catch edge case :) 402 | 403 | low_token = np.argmax(losses_fg) 404 | 405 | if iteration >= max_refinement_steps: 406 | # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' 407 | # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}') 408 | break 409 | 410 | # Run one more time but don't compute gradients and update the latents. 411 | # We just need to compute the new loss - the grad update will occur below 412 | latents = latents.clone().detach().requires_grad_(True) 413 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 414 | self.unet.zero_grad() 415 | 416 | # Get max activation value for each subject token 417 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 418 | attention_store=attention_store, 419 | indices_to_alter=indices_to_alter, 420 | attention_res=attention_res, 421 | smooth_attentions=smooth_attentions, 422 | sigma=sigma, 423 | kernel_size=kernel_size, 424 | normalize_eot=normalize_eot, 425 | bbox=bbox, 426 | config=config, 427 | ) 428 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True) 429 | # print(f"\t Finished with loss of: {loss_fg}") 430 | return loss_fg, latents, max_attention_per_index_fg 431 | 432 | @staticmethod 433 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: 434 | """ Update the latent according to the computed loss. """ 435 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] 436 | latents = latents - step_size * grad_cond 437 | return latents 438 | 439 | 440 | @torch.no_grad() 441 | def __call__( 442 | self, 443 | attention_store: AttentionStore, 444 | prompt: Union[str, List[str]], 445 | prompt_2: Optional[Union[str, List[str]]] = None, 446 | height: Optional[int] = None, 447 | width: Optional[int] = None, 448 | num_inference_steps: int = 50, 449 | denoising_end: Optional[float] = None, 450 | guidance_scale: float = 7.5, 451 | negative_prompt: Optional[Union[str, List[str]]] = None, 452 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 453 | num_images_per_prompt: Optional[int] = 1, 454 | eta: float = 0.0, 455 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 456 | latents: Optional[torch.FloatTensor] = None, 457 | prompt_embeds: Optional[torch.FloatTensor] = None, 458 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 459 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 460 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 461 | output_type: Optional[str] = "pil", 462 | return_dict: bool = True, 463 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 464 | callback_steps: Optional[int] = 1, 465 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 466 | guidance_rescale: float = 0.0, 467 | original_size: Optional[Tuple[int, int]] = None, 468 | crops_coords_top_left: Tuple[int, int] = (0, 0), 469 | target_size: Optional[Tuple[int, int]] = None, 470 | negative_original_size: Optional[Tuple[int, int]] = None, 471 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 472 | negative_target_size: Optional[Tuple[int, int]] = None, 473 | attn_res=None, 474 | indices_to_alter: List[int] = None, 475 | attention_res: int = 16, 476 | max_iter_to_alter: Optional[int] = 25, 477 | run_standard_sd: bool = False, 478 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8}, 479 | scale_factor: int = 20, 480 | scale_range: Tuple[float, float] = (1., 0.5), 481 | smooth_attentions: bool = True, 482 | sigma: float = 0.5, 483 | kernel_size: int = 3, 484 | sd_2_1: bool = False, 485 | bbox: List[int] = None, 486 | weight_loss: int = 100, 487 | config = None, 488 | ): 489 | r""" 490 | Function invoked when calling the pipeline for generation. 491 | 492 | Args: 493 | prompt (`str` or `List[str]`): 494 | The prompt or prompts to guide the image generation. 495 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 496 | The height in pixels of the generated image. 497 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 498 | The width in pixels of the generated image. 499 | num_inference_steps (`int`, *optional*, defaults to 50): 500 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 501 | expense of slower inference. 502 | guidance_scale (`float`, *optional*, defaults to 7.5): 503 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 504 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 505 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 506 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 507 | usually at the expense of lower image quality. 508 | negative_prompt (`str` or `List[str]`, *optional*): 509 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 510 | if `guidance_scale` is less than `1`). 511 | num_images_per_prompt (`int`, *optional*, defaults to 1): 512 | The number of images to generate per prompt. 513 | eta (`float`, *optional*, defaults to 0.0): 514 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 515 | [`schedulers.DDIMScheduler`], will be ignored for others. 516 | generator (`torch.Generator`, *optional*): 517 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 518 | to make generation deterministic. 519 | latents (`torch.FloatTensor`, *optional*): 520 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 521 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 522 | tensor will ge generated by sampling using the supplied random `generator`. 523 | output_type (`str`, *optional*, defaults to `"pil"`): 524 | The output format of the generate image. Choose between 525 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 526 | return_dict (`bool`, *optional*, defaults to `True`): 527 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 528 | plain tuple. 529 | callback (`Callable`, *optional*): 530 | A function that will be called every `callback_steps` steps during inference. The function will be 531 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 532 | callback_steps (`int`, *optional*, defaults to 1): 533 | The frequency at which the `callback` function will be called. If not specified, the callback will be 534 | called at every step. 535 | cross_attention_kwargs (`dict`, *optional*): 536 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 537 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 538 | 539 | The keyword arguments to configure the edit are: 540 | - edit_type (`str`). The edit type to apply. Can be either of `replace`, `refine`, `reweight`. 541 | - n_cross_replace (`int`): Number of diffusion steps in which cross attention should be replaced 542 | - n_self_replace (`int`): Number of diffusion steps in which self attention should be replaced 543 | - local_blend_words(`List[str]`, *optional*, default to `None`): Determines which area should be 544 | changed. If None, then the whole image can be changed. 545 | - equalizer_words(`List[str]`, *optional*, default to `None`): Required for edit type `reweight`. 546 | Determines which words should be enhanced. 547 | - equalizer_strengths (`List[float]`, *optional*, default to `None`) Required for edit type `reweight`. 548 | Determines which how much the words in `equalizer_words` should be enhanced. 549 | 550 | guidance_rescale (`float`, *optional*, defaults to 0.0): 551 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are 552 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when 553 | using zero terminal SNR. 554 | 555 | Returns: 556 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 557 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 558 | When returning a tuple, the first element is a list with the generated images, and the second element is a 559 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 560 | (nsfw) content, according to the `safety_checker`. 561 | """ 562 | 563 | # 0. Default height and width to unet 564 | height = height or self.unet.config.sample_size * self.vae_scale_factor 565 | width = width or self.unet.config.sample_size * self.vae_scale_factor 566 | 567 | original_size = original_size or (height, width) 568 | target_size = target_size or (height, width) 569 | 570 | if attn_res is None: 571 | attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) 572 | self.attn_res = attn_res 573 | 574 | self.controller = attention_store 575 | self.register_attention_control(self.controller) # add attention controller 576 | 577 | # 1. Check inputs. Raise error if not correct 578 | self.check_inputs( 579 | prompt, 580 | prompt_2, 581 | height, 582 | width, 583 | callback_steps, 584 | negative_prompt, 585 | negative_prompt_2, 586 | prompt_embeds, 587 | negative_prompt_embeds, 588 | pooled_prompt_embeds, 589 | negative_pooled_prompt_embeds, 590 | ) 591 | 592 | # 2. Define call parameters 593 | if prompt is not None and isinstance(prompt, str): 594 | batch_size = 1 595 | elif prompt is not None and isinstance(prompt, list): 596 | batch_size = len(prompt) 597 | else: 598 | batch_size = prompt_embeds.shape[0] 599 | 600 | device = self._execution_device 601 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 602 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 603 | # corresponds to doing no classifier free guidance. 604 | do_classifier_free_guidance = guidance_scale > 1.0 605 | 606 | # 3. Encode input prompt 607 | text_encoder_lora_scale = ( 608 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 609 | ) 610 | ( 611 | prompt_embeds, 612 | negative_prompt_embeds, 613 | pooled_prompt_embeds, 614 | negative_pooled_prompt_embeds, 615 | ) = self.encode_prompt( 616 | prompt=prompt, 617 | prompt_2=prompt_2, 618 | device=device, 619 | num_images_per_prompt=num_images_per_prompt, 620 | do_classifier_free_guidance=do_classifier_free_guidance, 621 | negative_prompt=negative_prompt, 622 | negative_prompt_2=negative_prompt_2, 623 | prompt_embeds=prompt_embeds, 624 | negative_prompt_embeds=negative_prompt_embeds, 625 | pooled_prompt_embeds=pooled_prompt_embeds, 626 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 627 | lora_scale=text_encoder_lora_scale, 628 | ) 629 | 630 | # 4. Prepare timesteps 631 | self.scheduler.set_timesteps(num_inference_steps, device=device) 632 | timesteps = self.scheduler.timesteps 633 | 634 | # 5. Prepare latent variables 635 | num_channels_latents = self.unet.config.in_channels 636 | latents = self.prepare_latents( 637 | batch_size * num_images_per_prompt, 638 | num_channels_latents, 639 | height, 640 | width, 641 | prompt_embeds.dtype, 642 | device, 643 | generator, 644 | latents, 645 | ) 646 | # latents[1] = latents[0] 647 | 648 | 649 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 650 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 651 | 652 | # 7. Prepare added time ids & embeddings 653 | add_text_embeds = pooled_prompt_embeds 654 | add_time_ids = self._get_add_time_ids( 655 | original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, 656 | ) 657 | if negative_original_size is not None and negative_target_size is not None: 658 | negative_add_time_ids = self._get_add_time_ids( 659 | negative_original_size, 660 | negative_crops_coords_top_left, 661 | negative_target_size, 662 | dtype=prompt_embeds.dtype, 663 | ) 664 | else: 665 | negative_add_time_ids = add_time_ids 666 | 667 | if do_classifier_free_guidance: 668 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 669 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 670 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 671 | 672 | prompt_embeds = prompt_embeds.to(device) 673 | add_text_embeds = add_text_embeds.to(device) 674 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 675 | 676 | # 8. Denoising loop 677 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 678 | 679 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) 680 | 681 | # 7.1 Apply denoising_end 682 | if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: 683 | discrete_timestep_cutoff = int( 684 | round( 685 | self.scheduler.config.num_train_timesteps 686 | - (denoising_end * self.scheduler.config.num_train_timesteps) 687 | ) 688 | ) 689 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 690 | timesteps = timesteps[:num_inference_steps] 691 | 692 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 693 | with self.progress_bar(total=num_inference_steps) as progress_bar: 694 | for i, t in enumerate(timesteps): 695 | with torch.enable_grad(): 696 | 697 | latents = latents.clone().detach().requires_grad_(True) 698 | latent_model_input = torch.cat([latents] * 2) 699 | 700 | # Forward pass of denoising with text conditioning 701 | noise_pred_text = self.unet(latent_model_input, t, 702 | encoder_hidden_states=prompt_embeds, cross_attention_kwargs=None, added_cond_kwargs=added_cond_kwargs, return_dict=True,).sample 703 | self.unet.zero_grad() 704 | 705 | # Get max activation value for each subject token 706 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token( 707 | attention_store=attention_store, 708 | indices_to_alter=indices_to_alter, 709 | attention_res=attention_res, 710 | smooth_attentions=smooth_attentions, 711 | sigma=sigma, 712 | kernel_size=kernel_size, 713 | normalize_eot=sd_2_1, 714 | bbox=bbox, 715 | config=config, 716 | ) 717 | 718 | run_standard_sd = False 719 | if run_standard_sd: 720 | 721 | loss_fg, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y) 722 | 723 | # Refinement from attend-and-excite (not necessary) 724 | if i in thresholds.keys() and loss_fg > 1. - thresholds[i] and config.refine: 725 | del noise_pred_text 726 | torch.cuda.empty_cache() 727 | loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step( 728 | latents=latents, 729 | indices_to_alter=indices_to_alter, 730 | loss_fg=loss_fg, 731 | threshold=thresholds[i], 732 | text_embeddings=prompt_embeds, 733 | attention_store=attention_store, 734 | step_size=scale_factor * np.sqrt(scale_range[i]), 735 | t=t, 736 | attention_res=attention_res, 737 | smooth_attentions=smooth_attentions, 738 | sigma=sigma, 739 | kernel_size=kernel_size, 740 | normalize_eot=sd_2_1, 741 | bbox=bbox, 742 | config=config, 743 | ) 744 | 745 | # Perform gradient update 746 | if i < max_iter_to_alter: 747 | _, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y) 748 | if loss != 0: 749 | loss = loss * weight_loss 750 | latent_model_input = self._update_latent(latents=latent_model_input, loss=loss, 751 | step_size=scale_factor * np.sqrt(scale_range[i])) 752 | latents = latent_model_input[1:,...] 753 | 754 | print(f'Iteration {i} | Loss: {loss:0.4f}') 755 | # expand the latents if we are doing classifier free guidance 756 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 757 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 758 | 759 | # predict the noise residual 760 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, 761 | added_cond_kwargs=added_cond_kwargs, ).sample 762 | 763 | # perform guidance 764 | if do_classifier_free_guidance: 765 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 766 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 767 | 768 | if do_classifier_free_guidance and guidance_rescale > 0.0: 769 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 770 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 771 | 772 | # compute the previous noisy sample x_t -> x_t-1 773 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 774 | 775 | # step callback 776 | latents = self.controller.step_callback(latents) 777 | 778 | # call the callback, if provided 779 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 780 | progress_bar.update() 781 | if callback is not None and i % callback_steps == 0: 782 | step_idx = i // getattr(self.scheduler, "order", 1) 783 | callback(step_idx, t, latents) 784 | 785 | # 8. Post-processing 786 | if not output_type == "latent": 787 | # make sure the VAE is in float32 mode, as it overflows in float16 788 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 789 | 790 | if needs_upcasting: 791 | self.upcast_vae() 792 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 793 | 794 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 795 | 796 | # cast back to fp16 if needed 797 | if needs_upcasting: 798 | self.vae.to(dtype=torch.float16) 799 | else: 800 | image = latents 801 | 802 | if not output_type == "latent": 803 | # apply watermark if available 804 | if self.watermark is not None: 805 | image = self.watermark.apply_watermark(image) 806 | 807 | image = self.image_processor.postprocess(image, output_type=output_type) 808 | 809 | # Offload all models 810 | self.maybe_free_model_hooks() 811 | 812 | if not return_dict: 813 | return (image,) 814 | 815 | return StableDiffusionXLPipelineOutput(images=image) 816 | 817 | 818 | def register_attention_control(self, controller): 819 | attn_procs = {} 820 | cross_att_count = 0 821 | for name in self.unet.attn_processors.keys(): 822 | None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim 823 | if name.startswith("mid_block"): 824 | self.unet.config.block_out_channels[-1] 825 | place_in_unet = "mid" 826 | elif name.startswith("up_blocks"): 827 | block_id = int(name[len("up_blocks.")]) 828 | list(reversed(self.unet.config.block_out_channels))[block_id] 829 | place_in_unet = "up" 830 | elif name.startswith("down_blocks"): 831 | block_id = int(name[len("down_blocks.")]) 832 | self.unet.config.block_out_channels[block_id] 833 | place_in_unet = "down" 834 | else: 835 | continue 836 | cross_att_count += 1 837 | attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet) 838 | 839 | self.unet.set_attn_processor(attn_procs) 840 | controller.num_att_layers = cross_att_count 841 | -------------------------------------------------------------------------------- /run_sd_boxdiff.py: -------------------------------------------------------------------------------- 1 | 2 | import pprint 3 | from typing import List 4 | 5 | import pyrallis 6 | import torch 7 | from PIL import Image 8 | from config import RunConfig 9 | from pipeline.sdxl_pipeline_boxdiff import BoxDiffPipeline 10 | from utils import ptp_utils, vis_utils 11 | from utils.ptp_utils import AttentionStore 12 | 13 | import numpy as np 14 | from utils.drawer import draw_rectangle, DashedImageDraw 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore", category=UserWarning) 18 | 19 | 20 | def load_model(config: RunConfig): 21 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 22 | 23 | if config.sd_xl: 24 | stable_diffusion_version = "stabilityai/stable-diffusion-xl-base-1.0" 25 | elif config.sd_2_1: 26 | stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base" 27 | else: 28 | stable_diffusion_version = "CompVis/stable-diffusion-v1-4" 29 | # If you cannot access the huggingface on your server, you can use the local prepared one. 30 | # stable_diffusion_version = "../../packages/huggingface/hub/stable-diffusion-v1-4" 31 | print(f"Loading model from {stable_diffusion_version}") 32 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version,torch_dtype=torch.float16 ).to(device) 33 | 34 | return stable 35 | 36 | 37 | def get_indices_to_alter(stable, prompt: str) -> List[int]: 38 | token_idx_to_word = {idx: stable.tokenizer.decode(t) 39 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids']) 40 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1} 41 | pprint.pprint(token_idx_to_word) 42 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to " 43 | "alter (e.g., 2,5): ") 44 | token_indices = [int(i) for i in token_indices.split(",")] 45 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}") 46 | return token_indices 47 | 48 | 49 | def run_on_prompt(prompt: List[str], 50 | model: BoxDiffPipeline, 51 | controller: AttentionStore, 52 | token_indices: List[int], 53 | seed: torch.Generator, 54 | config: RunConfig) -> Image.Image: 55 | # if controller is not None: 56 | # ptp_utils.register_attention_control(model, controller) 57 | outputs = model(prompt=prompt, 58 | height= 512, 59 | width=512, 60 | attention_store=controller, 61 | indices_to_alter=token_indices, 62 | attention_res=config.attention_res, 63 | guidance_scale=config.guidance_scale, 64 | generator=seed, 65 | num_inference_steps=config.n_inference_steps, 66 | max_iter_to_alter=config.max_iter_to_alter, 67 | run_standard_sd=config.run_standard_sd, 68 | thresholds=config.thresholds, 69 | scale_factor=config.scale_factor, 70 | scale_range=config.scale_range, 71 | smooth_attentions=config.smooth_attentions, 72 | sigma=config.sigma, 73 | kernel_size=config.kernel_size, 74 | sd_2_1=config.sd_2_1, 75 | bbox=config.bbox, 76 | weight_loss= config.weight_loss, 77 | config=config) 78 | image = outputs.images[0] 79 | return image 80 | 81 | 82 | @pyrallis.wrap() 83 | def main(config: RunConfig): 84 | stable = load_model(config) 85 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices 86 | 87 | if len(config.bbox[0]) == 0: 88 | config.bbox = draw_rectangle() 89 | 90 | images = [] 91 | for seed in config.seeds: 92 | print(f"Current seed is : {seed}") 93 | g = torch.Generator('cuda').manual_seed(seed) 94 | controller = AttentionStore() 95 | image = run_on_prompt(prompt=config.prompt, 96 | model=stable, 97 | controller=controller, 98 | token_indices=token_indices, 99 | seed=g, 100 | config=config) 101 | prompt_output_path = config.output_path / config.prompt[:100] 102 | prompt_output_path.mkdir(exist_ok=True, parents=True) 103 | image.save(prompt_output_path / f'{seed}.png') 104 | images.append(image) 105 | 106 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220) 107 | draw = DashedImageDraw(canvas) 108 | 109 | for i in range(len(config.bbox)): 110 | x1, y1, x2, y2 = config.bbox[i] 111 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5) 112 | canvas.save(prompt_output_path / f'{seed}_canvas.png') 113 | 114 | # save a grid of results across all seeds 115 | joined_image = vis_utils.get_image_grid(images) 116 | joined_image.save(config.output_path / f'{config.prompt}.png') 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /utils/__pycache__/drawer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/utils/__pycache__/drawer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/gaussian_smoothing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/utils/__pycache__/gaussian_smoothing.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ptp_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/utils/__pycache__/ptp_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vis_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cominclip/BoxDiff-XL/a15a60449385b451e0b5b1ffce8e4dfbc8c2bb52/utils/__pycache__/vis_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/drawer.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | from PIL import ImageDraw as D 3 | 4 | import math 5 | class DashedImageDraw(D.ImageDraw): 6 | 7 | def thick_line(self, xy, direction, fill=None, width=0): 8 | 9 | if xy[0] != xy[1]: 10 | self.line(xy, fill=fill, width=width) 11 | else: 12 | x1, y1 = xy[0] 13 | dx1, dy1 = direction[0] 14 | dx2, dy2 = direction[1] 15 | if dy2 - dy1 < 0: 16 | x1 -= 1 17 | if dx2 - dx1 < 0: 18 | y1 -= 1 19 | if dy2 - dy1 != 0: 20 | if dx2 - dx1 != 0: 21 | k = - (dx2 - dx1) / (dy2 - dy1) 22 | a = 1 / math.sqrt(1 + k ** 2) 23 | b = (width * a - 1) / 2 24 | else: 25 | k = 0 26 | b = (width - 1) / 2 27 | x3 = x1 - math.floor(b) 28 | y3 = y1 - int(k * b) 29 | x4 = x1 + math.ceil(b) 30 | y4 = y1 + int(k * b) 31 | else: 32 | x3 = x1 33 | y3 = y1 - math.floor((width - 1) / 2) 34 | x4 = x1 35 | y4 = y1 + math.ceil((width - 1) / 2) 36 | self.line([(x3, y3), (x4, y4)], fill=fill, width=1) 37 | return 38 | 39 | def dashed_line(self, xy, dash=(2, 2), fill=None, width=0): 40 | for i in range(len(xy) - 1): 41 | x1, y1 = xy[i] 42 | x2, y2 = xy[i + 1] 43 | x_length = x2 - x1 44 | y_length = y2 - y1 45 | length = math.sqrt(x_length ** 2 + y_length ** 2) 46 | dash_enabled = True 47 | postion = 0 48 | while postion <= length: 49 | for dash_step in dash: 50 | if postion > length: 51 | break 52 | if dash_enabled: 53 | start = postion / length 54 | end = min((postion + dash_step - 1) / length, 1) 55 | self.thick_line([(round(x1 + start * x_length), 56 | round(y1 + start * y_length)), 57 | (round(x1 + end * x_length), 58 | round(y1 + end * y_length))], 59 | xy, fill, width) 60 | dash_enabled = not dash_enabled 61 | postion += dash_step 62 | return 63 | 64 | def dashed_rectangle(self, xy, dash=(2, 2), outline=None, width=0): 65 | x1, y1 = xy[0] 66 | x2, y2 = xy[1] 67 | halfwidth1 = math.floor((width - 1) / 2) 68 | halfwidth2 = math.ceil((width - 1) / 2) 69 | min_dash_gap = min(dash[1::2]) 70 | end_change1 = halfwidth1 + min_dash_gap + 1 71 | end_change2 = halfwidth2 + min_dash_gap + 1 72 | odd_width_change = (width - 1) % 2 73 | self.dashed_line([(x1 - halfwidth1, y1), (x2 - end_change1, y1)], 74 | dash, outline, width) 75 | self.dashed_line([(x2, y1 - halfwidth1), (x2, y2 - end_change1)], 76 | dash, outline, width) 77 | self.dashed_line([(x2 + halfwidth2, y2 + odd_width_change), 78 | (x1 + end_change2, y2 + odd_width_change)], 79 | dash, outline, width) 80 | self.dashed_line([(x1 + odd_width_change, y2 + halfwidth2), 81 | (x1 + odd_width_change, y1 + end_change2)], 82 | dash, outline, width) 83 | return 84 | 85 | class RectangleDrawer: 86 | def __init__(self, master): 87 | self.master = master 88 | width, height = 512, 512 89 | self.canvas = Canvas(self.master, bg='#F0FFF0', width=width, height=height) 90 | self.canvas.pack() 91 | 92 | self.rectangles = [] 93 | self.colors = ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black'] 94 | 95 | self.canvas.bind("", self.on_button_press) 96 | self.canvas.bind("", self.on_move_press) 97 | self.canvas.bind("", self.on_button_release) 98 | self.start_x = None 99 | self.start_y = None 100 | self.cur_rect = None 101 | self.master.update() 102 | width = self.master.winfo_width() 103 | height = self.master.winfo_height() 104 | x = (self.master.winfo_screenwidth() // 2) - (width // 2) 105 | y = (self.master.winfo_screenheight() // 2) - (height // 2) 106 | self.master.geometry('{}x{}+{}+{}'.format(width, height, x, y)) 107 | 108 | 109 | def on_button_press(self, event): 110 | self.start_x = event.x 111 | self.start_y = event.y 112 | self.cur_rect = self.canvas.create_rectangle(self.start_x, self.start_y, self.start_x, self.start_y, outline=self.colors[len(self.rectangles)%len(self.colors)], width=5, dash=(4, 4)) 113 | 114 | def on_move_press(self, event): 115 | cur_x, cur_y = (event.x, event.y) 116 | self.canvas.coords(self.cur_rect, self.start_x, self.start_y, cur_x, cur_y) 117 | 118 | def on_button_release(self, event): 119 | cur_x, cur_y = (event.x, event.y) 120 | self.rectangles.append([self.start_x, self.start_y, cur_x, cur_y]) 121 | self.cur_rect = None 122 | 123 | def get_rectangles(self): 124 | return self.rectangles 125 | 126 | 127 | def draw_rectangle(): 128 | root = Tk() 129 | root.title("Rectangle Drawer") 130 | 131 | drawer = RectangleDrawer(root) 132 | 133 | def on_enter_press(event): 134 | root.quit() 135 | 136 | root.bind('', on_enter_press) 137 | 138 | root.mainloop() 139 | rectangles = drawer.get_rectangles() 140 | 141 | new_rects = [] 142 | for r in rectangles: 143 | new_rects.extend(r) 144 | 145 | return new_rects 146 | 147 | if __name__ == '__main__': 148 | root = Tk() 149 | root.title("Rectangle Drawer") 150 | 151 | drawer = RectangleDrawer(root) 152 | 153 | def on_enter_press(event): 154 | root.quit() 155 | 156 | root.bind('', on_enter_press) 157 | 158 | root.mainloop() 159 | rectangles = drawer.get_rectangles() 160 | 161 | string = '[' 162 | for r in rectangles: 163 | string += '[' 164 | for n in r: 165 | string += str(n) 166 | string += ',' 167 | string = string[:-1] 168 | string += '],' 169 | string = string[:-1] 170 | string += ']' 171 | print("Rectangles:", string) -------------------------------------------------------------------------------- /utils/gaussian_smoothing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class GaussianSmoothing(nn.Module): 9 | """ 10 | Apply gaussian smoothing on a 11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 12 | in the input using a depthwise convolution. 13 | Arguments: 14 | channels (int, sequence): Number of channels of the input tensors. Output will 15 | have this number of channels as well. 16 | kernel_size (int, sequence): Size of the gaussian kernel. 17 | sigma (float, sequence): Standard deviation of the gaussian kernel. 18 | dim (int, optional): The number of dimensions of the data. 19 | Default value is 2 (spatial). 20 | """ 21 | def __init__(self, channels, kernel_size, sigma, dim=2): 22 | super(GaussianSmoothing, self).__init__() 23 | if isinstance(kernel_size, numbers.Number): 24 | kernel_size = [kernel_size] * dim 25 | if isinstance(sigma, numbers.Number): 26 | sigma = [sigma] * dim 27 | 28 | # The gaussian kernel is the product of the 29 | # gaussian function of each dimension. 30 | kernel = 1 31 | meshgrids = torch.meshgrid( 32 | [ 33 | torch.arange(size, dtype=torch.float32) 34 | for size in kernel_size 35 | ] 36 | ) 37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 38 | mean = (size - 1) / 2 39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 40 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 41 | 42 | # Make sure sum of values in gaussian kernel equals 1. 43 | kernel = kernel / torch.sum(kernel) 44 | 45 | # Reshape to depthwise convolutional weight 46 | kernel = kernel.view(1, 1, *kernel.size()) 47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 48 | 49 | self.register_buffer('weight', kernel) 50 | self.groups = channels 51 | 52 | if dim == 1: 53 | self.conv = F.conv1d 54 | elif dim == 2: 55 | self.conv = F.conv2d 56 | elif dim == 3: 57 | self.conv = F.conv3d 58 | else: 59 | raise RuntimeError( 60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 61 | ) 62 | 63 | def forward(self, input): 64 | """ 65 | Apply gaussian filter to input. 66 | Arguments: 67 | input (torch.Tensor): Input to apply gaussian filter on. 68 | Returns: 69 | filtered (torch.Tensor): Filtered output. 70 | """ 71 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) 72 | 73 | 74 | class AverageSmoothing(nn.Module): 75 | """ 76 | Apply average smoothing on a 77 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 78 | in the input using a depthwise convolution. 79 | Arguments: 80 | channels (int, sequence): Number of channels of the input tensors. Output will 81 | have this number of channels as well. 82 | kernel_size (int, sequence): Size of the average kernel. 83 | sigma (float, sequence): Standard deviation of the rage kernel. 84 | dim (int, optional): The number of dimensions of the data. 85 | Default value is 2 (spatial). 86 | """ 87 | def __init__(self, channels, kernel_size, dim=2): 88 | super(AverageSmoothing, self).__init__() 89 | 90 | # Make sure sum of values in gaussian kernel equals 1. 91 | kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size) 92 | 93 | # Reshape to depthwise convolutional weight 94 | kernel = kernel.view(1, 1, *kernel.size()) 95 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 96 | 97 | self.register_buffer('weight', kernel) 98 | self.groups = channels 99 | 100 | if dim == 1: 101 | self.conv = F.conv1d 102 | elif dim == 2: 103 | self.conv = F.conv2d 104 | elif dim == 3: 105 | self.conv = F.conv3d 106 | else: 107 | raise RuntimeError( 108 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 109 | ) 110 | 111 | def forward(self, input): 112 | """ 113 | Apply average filter to input. 114 | Arguments: 115 | input (torch.Tensor): Input to apply average filter on. 116 | Returns: 117 | filtered (torch.Tensor): Filtered output. 118 | """ 119 | return self.conv(input, weight=self.weight, groups=self.groups) 120 | -------------------------------------------------------------------------------- /utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from IPython.display import display 7 | from PIL import Image 8 | from typing import Union, Tuple, List 9 | from diffusers.models.attention import Attention as CrossAttention 10 | 11 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: 12 | h, w, c = image.shape 13 | offset = int(h * .2) 14 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 15 | font = cv2.FONT_HERSHEY_SIMPLEX 16 | img[:h] = image 17 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 18 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 19 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 20 | return img 21 | 22 | 23 | def view_images(images: Union[np.ndarray, List], 24 | num_rows: int = 1, 25 | offset_ratio: float = 0.02, 26 | display_image: bool = True) -> Image.Image: 27 | """ Displays a list of images in a grid. """ 28 | if type(images) is list: 29 | num_empty = len(images) % num_rows 30 | elif images.ndim == 4: 31 | num_empty = images.shape[0] % num_rows 32 | else: 33 | images = [images] 34 | num_empty = 0 35 | 36 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 37 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 38 | num_items = len(images) 39 | 40 | h, w, c = images[0].shape 41 | offset = int(h * offset_ratio) 42 | num_cols = num_items // num_rows 43 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 44 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 45 | for i in range(num_rows): 46 | for j in range(num_cols): 47 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 48 | i * num_cols + j] 49 | 50 | pil_img = Image.fromarray(image_) 51 | if display_image: 52 | display(pil_img) 53 | return pil_img 54 | 55 | 56 | class AttendExciteCrossAttnProcessor: 57 | 58 | def __init__(self, attnstore, place_in_unet): 59 | super().__init__() 60 | self.attnstore = attnstore 61 | self.place_in_unet = place_in_unet 62 | 63 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 64 | batch_size, sequence_length, _ = hidden_states.shape 65 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1) 66 | query = attn.to_q(hidden_states) 67 | 68 | is_cross = encoder_hidden_states is not None 69 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 70 | key = attn.to_k(encoder_hidden_states) 71 | value = attn.to_v(encoder_hidden_states) 72 | 73 | query = attn.head_to_batch_dim(query) 74 | key = attn.head_to_batch_dim(key) 75 | value = attn.head_to_batch_dim(value) 76 | 77 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 78 | 79 | self.attnstore(attention_probs, is_cross, self.place_in_unet) 80 | 81 | hidden_states = torch.bmm(attention_probs, value) 82 | hidden_states = attn.batch_to_head_dim(hidden_states) 83 | 84 | # linear proj 85 | hidden_states = attn.to_out[0](hidden_states) 86 | # dropout 87 | hidden_states = attn.to_out[1](hidden_states) 88 | 89 | return hidden_states 90 | 91 | 92 | def register_attention_control(model, controller): 93 | 94 | attn_procs = {} 95 | cross_att_count = 0 96 | for name in model.unet.attn_processors.keys(): 97 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 98 | if name.startswith("mid_block"): 99 | hidden_size = model.unet.config.block_out_channels[-1] 100 | place_in_unet = "mid" 101 | elif name.startswith("up_blocks"): 102 | block_id = int(name[len("up_blocks.")]) 103 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 104 | place_in_unet = "up" 105 | elif name.startswith("down_blocks"): 106 | block_id = int(name[len("down_blocks.")]) 107 | hidden_size = model.unet.config.block_out_channels[block_id] 108 | place_in_unet = "down" 109 | else: 110 | continue 111 | 112 | cross_att_count += 1 113 | attn_procs[name] = AttendExciteCrossAttnProcessor( 114 | attnstore=controller, place_in_unet=place_in_unet 115 | ) 116 | model.unet.set_attn_processor(attn_procs) 117 | controller.num_att_layers = cross_att_count 118 | 119 | class AttentionControl(abc.ABC): 120 | 121 | def step_callback(self, x_t): 122 | return x_t 123 | 124 | def between_steps(self): 125 | return 126 | 127 | # @property 128 | # def num_uncond_att_layers(self): 129 | # return 0 130 | 131 | @abc.abstractmethod 132 | def forward(self, attn, is_cross: bool, place_in_unet: str): 133 | raise NotImplementedError 134 | 135 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 136 | if self.cur_att_layer >= self.num_uncond_att_layers: 137 | self.forward(attn, is_cross, place_in_unet) 138 | self.cur_att_layer += 1 139 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 140 | self.cur_att_layer = 0 141 | self.cur_step += 1 142 | self.between_steps() 143 | 144 | def reset(self): 145 | self.cur_step = 0 146 | self.cur_att_layer = 0 147 | 148 | def __init__(self): 149 | self.cur_step = 0 150 | self.num_att_layers = -1 151 | self.cur_att_layer = 0 152 | 153 | 154 | class EmptyControl(AttentionControl): 155 | 156 | def forward(self, attn, is_cross: bool, place_in_unet: str): 157 | return attn 158 | 159 | 160 | class AttentionStore(AttentionControl): 161 | 162 | @staticmethod 163 | def get_empty_store(): 164 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 165 | "down_self": [], "mid_self": [], "up_self": []} 166 | 167 | def forward(self, attn, is_cross: bool, place_in_unet: str): 168 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 169 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 170 | self.step_store[key].append(attn) 171 | return attn 172 | 173 | def between_steps(self): 174 | self.attention_store = self.step_store 175 | if self.save_global_store: 176 | with torch.no_grad(): 177 | if len(self.global_store) == 0: 178 | self.global_store = self.step_store 179 | else: 180 | for key in self.global_store: 181 | for i in range(len(self.global_store[key])): 182 | self.global_store[key][i] += self.step_store[key][i].detach() 183 | self.step_store = self.get_empty_store() 184 | self.step_store = self.get_empty_store() 185 | 186 | def get_average_attention(self): 187 | average_attention = self.attention_store 188 | return average_attention 189 | 190 | def get_average_global_attention(self): 191 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 192 | self.attention_store} 193 | return average_attention 194 | 195 | def reset(self): 196 | super(AttentionStore, self).reset() 197 | self.step_store = self.get_empty_store() 198 | self.attention_store = {} 199 | self.global_store = {} 200 | 201 | def __init__(self, save_global_store=False): 202 | ''' 203 | Initialize an empty AttentionStore 204 | :param step_index: used to visualize only a specific step in the diffusion process 205 | ''' 206 | super(AttentionStore, self).__init__() 207 | self.save_global_store = save_global_store 208 | self.step_store = self.get_empty_store() 209 | self.attention_store = {} 210 | self.global_store = {} 211 | self.curr_step_index = 0 212 | self.num_uncond_att_layers = 0 213 | 214 | 215 | def aggregate_attention(attention_store: AttentionStore, 216 | res: int, 217 | from_where: List[str], 218 | is_cross: bool, 219 | select: int) -> torch.Tensor: 220 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 221 | out = [] 222 | attention_maps = attention_store.get_average_attention() 223 | 224 | num_pixels = res ** 2 225 | for location in from_where: 226 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 227 | if item.shape[1] == num_pixels: 228 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 229 | out.append(cross_maps) 230 | out = torch.cat(out, dim=0) 231 | out = out.sum(0) / out.shape[0] 232 | return out 233 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from PIL import Image 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from utils import ptp_utils 9 | from utils.ptp_utils import AttentionStore, aggregate_attention 10 | 11 | 12 | def show_cross_attention(prompt: str, 13 | attention_store: AttentionStore, 14 | tokenizer, 15 | indices_to_alter: List[int], 16 | res: int, 17 | from_where: List[str], 18 | select: int = 0, 19 | orig_image=None): 20 | tokens = tokenizer.encode(prompt) 21 | decoder = tokenizer.decode 22 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu() 23 | images = [] 24 | 25 | # show spatial attention for indices of tokens to strengthen 26 | for i in range(len(tokens)): 27 | image = attention_maps[:, :, i] 28 | if i in indices_to_alter: 29 | image = show_image_relevance(image, orig_image) 30 | image = image.astype(np.uint8) 31 | image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2))) 32 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 33 | images.append(image) 34 | 35 | ptp_utils.view_images(np.stack(images, axis=0)) 36 | 37 | 38 | def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16): 39 | # create heatmap from mask on image 40 | def show_cam_on_image(img, mask): 41 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 42 | heatmap = np.float32(heatmap) / 255 43 | cam = heatmap + np.float32(img) 44 | cam = cam / np.max(cam) 45 | return cam 46 | 47 | image = image.resize((relevnace_res ** 2, relevnace_res ** 2)) 48 | image = np.array(image) 49 | 50 | image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1]) 51 | image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu 52 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear') 53 | image_relevance = image_relevance.cpu() # send it back to cpu 54 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) 55 | image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2) 56 | image = (image - image.min()) / (image.max() - image.min()) 57 | vis = show_cam_on_image(image, image_relevance) 58 | vis = np.uint8(255 * vis) 59 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 60 | return vis 61 | 62 | 63 | def get_image_grid(images: List[Image.Image]) -> Image: 64 | num_images = len(images) 65 | cols = int(math.ceil(math.sqrt(num_images))) 66 | rows = int(math.ceil(num_images / cols)) 67 | width, height = images[0].size 68 | grid_image = Image.new('RGB', (cols * width, rows * height)) 69 | for i, img in enumerate(images): 70 | x = i % cols 71 | y = i // cols 72 | grid_image.paste(img, (x * width, y * height)) 73 | return grid_image 74 | --------------------------------------------------------------------------------