├── method.png ├── results.png ├── motivation.png ├── results1.png ├── training-free ├── infer_sd1-5_x0_optim_mask_fnal_para.sh ├── hpsv2_loss.py ├── infer_sd1-5_x0_optim_mask_final_para.py ├── sd_utils_x0_dpm.py └── sd_utils_x0_dpm_syn.py ├── training ├── infer_sd1-5_hardmask.sh ├── train_sd_lora.sh ├── train_sd_full.sh ├── train_hyperunet.sh ├── train_dreambooth_maskunet.sh ├── train_dreambooth.sh ├── infer_sd1-5_hardmask.py ├── hyperunet.py └── train_sd_lora.py ├── README.md └── environment.yaml /method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/method.png -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/results.png -------------------------------------------------------------------------------- /motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/motivation.png -------------------------------------------------------------------------------- /results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/results1.png -------------------------------------------------------------------------------- /training-free/infer_sd1-5_x0_optim_mask_fnal_para.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --multi_gpu --num_processes 6 --gpu_ids='all' infer_sd1-5_x0_optim_mask_final_para.py -------------------------------------------------------------------------------- /training/infer_sd1-5_hardmask.sh: -------------------------------------------------------------------------------- 1 | torchrun --master-port=29506 --nproc_per_node=8 infer_sd1-5_hardmask.py \ 2 | --pretrained_model_name_or_path "PixArt-alpha/PixArt-XL-2-512x512" \ 3 | --captions_file "captions_2017.txt" \ 4 | --world_size 8 \ 5 | --batch_size 16 -------------------------------------------------------------------------------- /training/train_sd_lora.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5" 2 | export DATASET_NAME="../fantasyfish/laion-art" 3 | 4 | accelerate launch --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_sd_lora.py \ 5 | --pretrained_model_name_or_path=$MODEL_NAME \ 6 | --dataset_name=$DATASET_NAME \ 7 | --image_column="image" \ 8 | --caption_column="text" \ 9 | --resolution=512 --center_crop --random_flip \ 10 | --train_batch_size=4 \ 11 | --gradient_accumulation_steps=1 \ 12 | --gradient_checkpointing \ 13 | --num_train_epochs 12 --checkpointing_steps=500 \ 14 | --max_grad_norm=1 \ 15 | --use_8bit_adam \ 16 | --learning_rate=1e-05 --lr_scheduler="constant" --lr_warmup_steps=0 \ 17 | --seed=42 \ 18 | --enable_xformers_memory_efficient_attention \ 19 | --output_dir="sd-naruto-model-lora" \ 20 | --validation_prompt="cute cat" -------------------------------------------------------------------------------- /training/train_sd_full.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5" 2 | export DATASET_NAME="../fantasyfish/laion-art" 3 | # --use_ema \ 4 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_sd_full.py \ 5 | --pretrained_model_name_or_path=$MODEL_NAME \ 6 | --dataset_name=$DATASET_NAME \ 7 | --image_column="image" \ 8 | --caption_column="text" \ 9 | --resolution=512 --center_crop --random_flip \ 10 | --train_batch_size=4 \ 11 | --gradient_accumulation_steps=1 \ 12 | --gradient_checkpointing \ 13 | --num_train_epochs 12 \ 14 | --checkpointing_steps=500 \ 15 | --learning_rate=1e-05 \ 16 | --max_grad_norm=1 \ 17 | --use_8bit_adam \ 18 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 19 | --seed=42 \ 20 | --enable_xformers_memory_efficient_attention \ 21 | --output_dir="sd-naruto-model_full" 22 | 23 | # # --max_train_steps=5000 \ -------------------------------------------------------------------------------- /training/train_hyperunet.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5" 2 | export DATASET_NAME="../fantasyfish/laion-art" 3 | # --use_ema \ 4 | #- 5 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 8 --gpu_ids='all' train_hyperunet.py \ 6 | --pretrained_model_name_or_path=$MODEL_NAME \ 7 | --dataset_name=$DATASET_NAME \ 8 | --image_column="image" \ 9 | --caption_column="text" \ 10 | --resolution=512 --center_crop --random_flip \ 11 | --train_batch_size=4 \ 12 | --gradient_accumulation_steps=1 \ 13 | --gradient_checkpointing \ 14 | --num_train_epochs 15 \ 15 | --checkpointing_steps=500 \ 16 | --learning_rate=1e-05 \ 17 | --max_grad_norm=1 \ 18 | --use_8bit_adam \ 19 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 20 | --seed=42 \ 21 | --enable_xformers_memory_efficient_attention \ 22 | --output_dir="sd-naruto-model_hard_improve1_factor4_trans_decoder_sd1.5_factorablation_epoch15" 23 | 24 | # # --max_train_steps=5000 \ -------------------------------------------------------------------------------- /training/train_dreambooth_maskunet.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5" 2 | export INSTANCE_DIR="./dreambooth/dataset/dog2" 3 | export CLASS_DIR="dream_booth_class_image" 4 | export OUTPUT_DIR="model_path_subject_dream_booth_dog" 5 | 6 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_dreambooth_maskunet.py \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --instance_data_dir=$INSTANCE_DIR \ 9 | --output_dir=$OUTPUT_DIR \ 10 | --train_batch_size=4 \ 11 | --gradient_accumulation_steps=1 \ 12 | --learning_rate=1e-3 \ 13 | --lr_scheduler="constant" \ 14 | --lr_warmup_steps=0 \ 15 | --instance_prompt="a photo of sks dog" \ 16 | --resolution=512 \ 17 | --max_train_steps=5000 18 | # --class_data_dir=$CLASS_DIR \ 19 | # --with_prior_preservation --prior_loss_weight=1.0 \ 20 | 21 | # --class_prompt="a photo of backpack" \ 22 | 23 | 24 | # --num_class_images=200 \ 25 | 26 | # --push_to_hub -------------------------------------------------------------------------------- /training/train_dreambooth.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5" 2 | # export INSTANCE_DIR="dog" 3 | export OUTPUT_DIR_root="dreambooth_model_ori" 4 | export CLASS_DIR_root="dreambooth_class-images" 5 | export DATA_SET_root="./dreambooth/dataset" 6 | 7 | export subject_name=("backpack_dog" "backpack" "bear_plushie" "berry_bowl" "can" "candle" "cat" "cat2" "clock" "colorful_sneaker" "dog" "dog2" "dog3" "dog5" "dog6" "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie" "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon" "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie") 8 | export class_name=("backpack" "backpack" "stuffed animal" "bowl" "can" "candle" "cat" "cat" "clock" "sneaker" "dog" "dog" "dog" "dog" "dog" "dog" "dog" "toy" "boot" "stuffed animal" "toy" "glasses" "toy" "toy" "cartoon" "toy" "sneaker" "teapot" "vase" "stuffed animal") 9 | 10 | for i in "${!subject_name[@]}"; do 11 | export DATA_SET="$DATA_SET_root/${subject_name[i]}" # Append each subject to the dataset path 12 | class="${class_name[i]}" 13 | export CLASS_DIR="$CLASS_DIR_root/$class" 14 | export OUTPUT_DIR="$OUTPUT_DIR_root/${subject_name[i]}" 15 | # echo "DataSet: $DATA_SET, class: $class, subject: ${subject_name[i]}, class_dir: $CLASS_DIR, output_dir: $OUTPUT_DIR" 16 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_dreambooth.py \ 17 | --pretrained_model_name_or_path=$MODEL_NAME \ 18 | --instance_data_dir=$DATA_SET \ 19 | --class_data_dir="$CLASS_DIR" \ 20 | --output_dir="$OUTPUT_DIR" \ 21 | --with_prior_preservation --prior_loss_weight=1.0 \ 22 | --instance_prompt="a photo of sks ${subject_name[i]}" \ 23 | --class_prompt="a photo of $class" \ 24 | --resolution=512 \ 25 | --train_batch_size=1 \ 26 | --gradient_accumulation_steps=1 \ 27 | --learning_rate=5e-6 \ 28 | --lr_scheduler="constant" \ 29 | --lr_warmup_steps=0 \ 30 | --num_class_images=200 \ 31 | --max_train_steps=800 \ 32 | # --push_to_hub 33 | done 34 | 35 | -------------------------------------------------------------------------------- /training-free/hpsv2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hpsv2.src.open_clip import create_model, get_tokenizer 3 | import torch.nn as nn 4 | import huggingface_hub 5 | 6 | class HPSV2Loss(nn.Module): 7 | """HPS reward loss function for optimization.""" 8 | 9 | def __init__( 10 | self, 11 | dtype: torch.dtype, 12 | device: torch.device, 13 | cache_dir: str, 14 | memsave: bool = False, 15 | ): 16 | super(HPSV2Loss, self).__init__() # 先调用父类的初始化方法 17 | self.hps_model = create_model( 18 | "ViT-H-14", 19 | "laion2B-s32B-b79K", 20 | precision=dtype, 21 | device=device, 22 | cache_dir=cache_dir, 23 | ) 24 | checkpoint_path = huggingface_hub.hf_hub_download( 25 | "xswu/HPSv2", "HPS_v2.1_compressed.pt", cache_dir=cache_dir 26 | ) 27 | self.hps_model.load_state_dict( 28 | torch.load(checkpoint_path, map_location=device)["state_dict"] 29 | ) 30 | self.hps_tokenizer = get_tokenizer("ViT-H-14") 31 | if memsave: 32 | import memsave_torch.nn 33 | 34 | self.hps_model = memsave_torch.nn.convert_to_memory_saving(self.hps_model) 35 | self.hps_model = self.hps_model.to(device, dtype=dtype) 36 | self.hps_model.eval() 37 | self.freeze_parameters(self.hps_model.parameters()) 38 | # super().__init__("HPS") 39 | self.hps_model.set_grad_checkpointing(True) 40 | 41 | @staticmethod 42 | def freeze_parameters(params: torch.nn.ParameterList): 43 | for param in params: 44 | param.requires_grad = False 45 | 46 | def get_image_features(self, image: torch.Tensor) -> torch.Tensor: 47 | hps_image_features = self.hps_model.encode_image(image) 48 | return hps_image_features 49 | 50 | def get_text_features(self, prompt: str) -> torch.Tensor: 51 | hps_text = self.hps_tokenizer(prompt).to("cuda") 52 | hps_text_features = self.hps_model.encode_text(hps_text) 53 | return hps_text_features 54 | 55 | def compute_loss( 56 | self, image_features: torch.Tensor, text_features: torch.Tensor 57 | ) -> torch.Tensor: 58 | logits_per_image = image_features @ text_features.T 59 | hps_loss = 1 - torch.diagonal(logits_per_image)[0] 60 | return hps_loss 61 | 62 | def process_features(self, features: torch.Tensor) -> torch.Tensor: 63 | features_normed = features / features.norm(dim=-1, keepdim=True) 64 | return features_normed 65 | 66 | def score_grad(self, image: torch.Tensor, prompt: str) -> torch.Tensor: 67 | image_features = self.get_image_features(image) 68 | text_features = self.get_text_features(prompt) 69 | 70 | image_features_normed = self.process_features(image_features) 71 | text_features_normed = self.process_features(text_features) 72 | 73 | loss = self.compute_loss(image_features_normed, text_features_normed) 74 | return loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 [CVPR 2025] [Not All Parameters Matter: Masking Diffusion Models for Enhancing Generation Ability](https://arxiv.org/pdf/2505.03097) 2 | 3 |
4 | demo 5 |
6 | 7 | Analysis of parameter distributions and denoising effects across different time steps for Stable Diffusion (SD) 1.5 with and without random masking. The first column shows the parameter distribution of SD 1.5, while the second to fifth columns display the distributions of parameters removed by the random mask. The last two columns compare the generated samples from SD 1.5 and the random mask. 8 | 9 |
10 | 11 | ## 📘 Introduction 12 | The diffusion models, in early stages focus on constructing basic image structures, while the refined details, including local features and textures, are generated in later stages. Thus the same network layers are forced to learn both structural and textural information simultaneously, significantly differing from the traditional deep learning architectures (e.g., ResNet or GANs) which captures or generates the image semantic information at different layers. This difference inspires us to explore the time-wise diffusion models. We initially investigate the key contributions of the U-Net parameters to the denoising process and identify that properly zeroing out certain parameters (including large parameters) contributes to denoising, substantially improving the generation quality on the fly. Capitalizing on this discovery, we propose a simple yet effective method—termed “MaskUNet”— that enhances generation quality with negligible parameter numbers. Our method fully leverages timestep- and sample-dependent effective U-Net parameters. To optimize MaskUNet, we offer two fine-tuning strategies: a training-based approach and a training-free approach, including tailored networks and optimization functions. In zero-shot inference on the COCO dataset, MaskUNet achieves the best FID score and further demonstrates its effectiveness in downstream task evaluations. 13 | 14 | method 15 | 16 |
17 | The pipeline of the MaskUnet. G-Sig represents the Gumbel-Sigmoid activate function. GAP is global average pooling. 18 | 19 |
20 | 21 | ## Training 22 | ### Datasets 23 | fantasyfish/laion-art [link1](https://huggingface.co/datasets/fantasyfish/laion-art) [link2](https://hf-mirror.com/datasets/fantasyfish/laion-art) 24 | 25 | ### Installation 26 | ```shell 27 | conda env create -f environment.yaml 28 | ``` 29 | ### Training-based 30 | #### train 31 | ```shell 32 | ./training/train_hyperunet.sh 33 | ``` 34 | #### inference 35 | ```shell 36 | ./training/infer_sd1-5_hardmask.sh 37 | ``` 38 | 39 | ### Training-free 40 | ```shell 41 | ./training-free/infer_sd1-5_x0_optim_mask_fnal_para.sh 42 | ``` 43 | ## ✨ Qualitative results 44 | 45 |
46 | 47 | Quality results compared to other methods. 48 | 49 |
50 | sd-ddim50 51 | 52 | ## 📈 Quantitative results 53 |

54 | origin 55 |

56 | 57 | ## Citation 58 | 59 | ``` 60 | @inproceedings{wang2025not, 61 | title={Not All Parameters Matter: Masking Diffusion Models for Enhancing Generation Ability}, 62 | author={Wang, Lei and Li, Senmao and Yang, Fei and Wang, Jianye and Zhang, Ziheng and Liu, Yuhan and Wang, Yaxing and Yang, Jian}, 63 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 64 | pages={12880--12890}, 65 | year={2025} 66 | } 67 | 68 | ``` 69 | ## Acknowledgement 70 | 71 | This project is based on [Diffusers](https://github.com/huggingface/diffusers). Thanks for their awesome works. 72 | ## Contact 73 | If you have any questions, please feel free to reach out to me at `scitop1998@gmail.com`. 74 | -------------------------------------------------------------------------------- /training/infer_sd1-5_hardmask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from diffusers import StableDiffusionPipeline, UNet2DConditionModel 4 | from hyper_unet6 import FineGrainedUNet2DConditionModel 5 | from diffusers import DPMSolverMultistepScheduler, DDIMScheduler 6 | import argparse 7 | import torch 8 | import torch.distributed as dist 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 11 | from diffusers import PixArtAlphaPipeline 12 | from safetensors.torch import load_file 13 | import os 14 | 15 | class PromptDataset(Dataset): 16 | def __init__(self, prompts): 17 | self.prompts = prompts 18 | 19 | def __len__(self): 20 | return len(self.prompts) 21 | 22 | def __getitem__(self, idx): 23 | return self.prompts[idx], idx # Return prompt and its index 24 | 25 | def setup(rank, world_size): 26 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 27 | 28 | def cleanup(): 29 | dist.destroy_process_group() 30 | # Set a fixed seed for reproducibility 31 | 32 | 33 | def main_worker(rank, args): 34 | setup(rank, args.world_size) 35 | 36 | seed =42 37 | 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | 41 | with open(args.captions_file, 'r', encoding='utf-8') as f: 42 | prompts = f.readlines() 43 | model_path = "path_to_saved_model" 44 | unet = FineGrainedUNet2DConditionModel.from_pretrained( 45 | "../share/runwayml/stable-diffusion-v1-5", subfolder="unet", 46 | low_cpu_mem_usage=False, 47 | device_map=None, 48 | torch_dtype=torch.float16 49 | ) 50 | 51 | input_dir = "./sd-naruto-model_hard_improve1_factor4_trans_decoder_sd1.5_factorablation_epoch15/checkpoint-9000" 52 | 53 | load_mask_path = os.path.join(input_dir, "mask_generators.pth") 54 | loaded_mask_state_dict = torch.load(load_mask_path) 55 | 56 | # Load weights for mask_generators 57 | unet.mask_generators.load_state_dict(loaded_mask_state_dict) 58 | 59 | # Load weights for adapters 60 | load_path_adapters = os.path.join(input_dir, "adapters.pth") 61 | adapters_state_dict = torch.load(load_path_adapters) 62 | unet.adapters.load_state_dict(adapters_state_dict) 63 | 64 | print("load success") 65 | scheduler = DDIMScheduler.from_pretrained("../share/runwayml/stable-diffusion-v1-5", subfolder="scheduler") 66 | # Load the Stable Diffusion pipeline with the modified UNet model 67 | pipe = StableDiffusionPipeline.from_pretrained( 68 | "../share/runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, 69 | scheduler=scheduler, 70 | unet = unet 71 | ) 72 | pipe.to(f"cuda:{rank}") 73 | 74 | output_dir = "fake_images_sd1.5_mask_vis_9000_final" 75 | 76 | if rank == 0: 77 | os.makedirs(output_dir, exist_ok=True) 78 | 79 | generator = torch.Generator(device=f"cuda:{rank}").manual_seed(args.seed) if args.seed else None 80 | def dummy_checker(images, **kwargs): 81 | # Return the images as is, and set NSFW detection results to False (not NSFW) 82 | return images, [False] * len(images) 83 | 84 | pipe.safety_checker = dummy_checker 85 | 86 | dataset = PromptDataset(prompts) 87 | sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=False) 88 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, sampler=sampler) 89 | 90 | for i, (batch_prompts, indices) in enumerate(dataloader): 91 | # Ensure batch_prompts is a list of strings 92 | batch_prompts = [prompt for prompt in batch_prompts] 93 | images = pipe(prompt=batch_prompts, generator=generator, num_inference_steps=50, guidance_scale=7.5).images 94 | 95 | # Save images immediately after generation 96 | for j, image in enumerate(images): 97 | output_filename = f"{output_dir}/{indices[j]:05}.png" 98 | image.save(output_filename) 99 | 100 | # Set the seed in the pipeline for deterministic image generation 101 | # generator = torch.Generator(device="cuda").manual_seed(seed) 102 | 103 | # Generate the image with a fixed seed #cute dragon creature 104 | # image = pipe(prompt="A naruto with green eyes and red legs.", generator=generator).images[0] 105 | # image = pipe(prompt="digital art of a little cat traveling around forest, wearing a scarf around the neck, carrying a tiny backpack on his back", generator=generator, num_inference_steps=10, guidance_scale=7.5).images[0] 106 | # image = pipe(prompt="cute dragon creature", generator=generator).images[0] 107 | # image.save("yoda-naruto.png") 108 | 109 | cleanup() 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser(description="Simple example of an inference script.") 113 | parser.add_argument( 114 | "--pretrained_model_name_or_path", 115 | type=str, 116 | default="thuanz123/swiftbrush", 117 | required=True, 118 | help="Path to pretrained model or model identifier from huggingface.co/models.", 119 | ) 120 | parser.add_argument( 121 | "--captions_file", 122 | type=str, 123 | default="captions.txt", 124 | required=True, 125 | help="Path to the captions file.", 126 | ) 127 | 128 | parser.add_argument( 129 | "--seed", 130 | type=int, 131 | default=42, 132 | required=False, 133 | help="Random seed used for inference.", 134 | ) 135 | 136 | parser.add_argument( 137 | "--batch_size", 138 | type=int, 139 | default=16, 140 | required=False, 141 | help="Batch size for inference.", 142 | ) 143 | 144 | parser.add_argument( 145 | "--world_size", 146 | type=int, 147 | default=torch.cuda.device_count(), 148 | help="Number of GPUs to use.", 149 | ) 150 | 151 | args = parser.parse_args() 152 | 153 | # Get the rank from the environment variable set by torchrun 154 | rank = int(os.environ["LOCAL_RANK"]) 155 | 156 | main_worker(rank, args) -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: hyperunet 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - bzip2=1.0.8=h5eee18b_6 10 | - ca-certificates=2024.9.24=h06a4308_0 11 | - ld_impl_linux-64=2.40=h12ee557_0 12 | - libffi=3.4.4=h6a678d5_1 13 | - libgcc-ng=11.2.0=h1234567_1 14 | - libgomp=11.2.0=h1234567_1 15 | - libstdcxx-ng=11.2.0=h1234567_1 16 | - libuuid=1.41.5=h5eee18b_0 17 | - ncurses=6.4=h6a678d5_0 18 | - openssl=3.0.15=h5eee18b_0 19 | - pip=24.2=py310h06a4308_0 20 | - python=3.10.15=he870216_1 21 | - readline=8.2=h5eee18b_0 22 | - sqlite=3.45.3=h5eee18b_0 23 | - tk=8.6.14=h39e8969_0 24 | - wheel=0.44.0=py310h06a4308_0 25 | - xz=5.4.6=h5eee18b_1 26 | - zlib=1.2.13=h5eee18b_1 27 | - pip: 28 | - accelerate==0.34.2 29 | - addict==2.4.0 30 | - aiohappyeyeballs==2.4.3 31 | - aiohttp==3.10.9 32 | - aiohttp-retry==2.8.3 33 | - aiosignal==1.3.1 34 | - aliyun-python-sdk-core==2.16.0 35 | - aliyun-python-sdk-kms==2.16.5 36 | - amqp==5.2.0 37 | - annotated-types==0.7.0 38 | - antlr4-python3-runtime==4.9.3 39 | - appdirs==1.4.4 40 | - args==0.1.0 41 | - async-timeout==4.0.3 42 | - asyncssh==2.17.0 43 | - atpublic==5.0 44 | - attrs==24.2.0 45 | - billiard==4.2.1 46 | - bitsandbytes==0.42.0 47 | - blis==1.0.1 48 | - braceexpand==0.1.7 49 | - catalogue==2.0.10 50 | - celery==5.4.0 51 | - certifi==2024.8.30 52 | - cffi==1.17.1 53 | - charset-normalizer==3.3.2 54 | - clean-fid==0.1.35 55 | - click==8.1.7 56 | - click-didyoumean==0.3.1 57 | - click-plugins==1.1.1 58 | - click-repl==0.3.0 59 | - clint==0.5.1 60 | - clip==1.0 61 | - clip-benchmark==1.6.1 62 | - cloudpathlib==0.20.0 63 | - cmake==3.30.4 64 | - colorama==0.4.6 65 | - confection==0.1.5 66 | - configobj==5.0.9 67 | - contourpy==1.3.0 68 | - controlnet-aux==0.0.5 69 | - crcmod==1.7 70 | - cryptography==43.0.1 71 | - curated-tokenizers==0.0.9 72 | - curated-transformers==0.1.1 73 | - cycler==0.12.1 74 | - cymem==2.0.8 75 | - datasets==3.0.2 76 | - dictdiffer==0.9.0 77 | - diffusers==0.30.3 78 | - dill==0.3.7 79 | - diskcache==5.6.3 80 | - distro==1.9.0 81 | - docker-pycreds==0.4.0 82 | - dpath==2.2.0 83 | - dulwich==0.22.1 84 | - dvc==3.55.2 85 | - dvc-data==3.16.6 86 | - dvc-http==2.32.0 87 | - dvc-objects==5.1.0 88 | - dvc-render==1.0.2 89 | - dvc-studio-client==0.21.0 90 | - dvc-task==0.40.2 91 | - einops==0.8.0 92 | - en-core-web-trf==3.8.0 93 | - entrypoints==0.4 94 | - exceptiongroup==1.2.2 95 | - fairscale==0.4.13 96 | - filelock==3.14.0 97 | - flatten-dict==0.4.2 98 | - flufl-lock==8.1.0 99 | - fonttools==4.54.1 100 | - frozenlist==1.4.1 101 | - fsspec==2024.6.1 102 | - ftfy==6.3.0 103 | - funcy==2.0 104 | - gitdb==4.0.11 105 | - gitpython==3.1.43 106 | - grandalf==0.8 107 | - gto==1.7.1 108 | - hpsv2==1.2.0 109 | - huggingface-hub==0.26.2 110 | - hydra-core==1.3.2 111 | - idna==3.10 112 | - image-reward==1.5 113 | - imageio==2.35.1 114 | - imageio-ffmpeg==0.5.1 115 | - importlib-metadata==8.5.0 116 | - iniconfig==2.0.0 117 | - iterative-telemetry==0.0.9 118 | - jinja2==3.1.4 119 | - jmespath==0.10.0 120 | - joblib==1.4.2 121 | - kiwisolver==1.4.7 122 | - kombu==5.4.2 123 | - langcodes==3.4.1 124 | - language-data==1.2.0 125 | - lazy-loader==0.4 126 | - lit==18.1.8 127 | - marisa-trie==1.2.1 128 | - markdown==3.7 129 | - markdown-it-py==3.0.0 130 | - markupsafe==2.1.5 131 | - mat4py==0.6.0 132 | - matplotlib==3.8.0 133 | - mdurl==0.1.2 134 | - memsave-torch==1.0.0 135 | - mmcv==2.2.0 136 | - mmdet==3.3.0 137 | - mmengine==0.10.5 138 | - mmpretrain==1.2.0 139 | - model-index==0.1.11 140 | - modelindex==0.0.2 141 | - motmetrics==1.4.0 142 | - mpmath==1.3.0 143 | - multidict==6.1.0 144 | - multiprocess==0.70.15 145 | - munch==4.0.0 146 | - murmurhash==1.0.10 147 | - networkx==3.3 148 | - numpy==1.23.5 149 | - nvidia-cublas-cu12==12.1.3.1 150 | - nvidia-cuda-cupti-cu12==12.1.105 151 | - nvidia-cuda-nvrtc-cu12==12.1.105 152 | - nvidia-cuda-runtime-cu12==12.1.105 153 | - nvidia-cudnn-cu12==9.1.0.70 154 | - nvidia-cufft-cu12==11.0.2.54 155 | - nvidia-curand-cu12==10.3.2.106 156 | - nvidia-cusolver-cu12==11.4.5.107 157 | - nvidia-cusparse-cu12==12.1.0.106 158 | - nvidia-nccl-cu12==2.20.5 159 | - nvidia-nvjitlink-cu12==12.6.77 160 | - nvidia-nvtx-cu12==12.1.105 161 | - omegaconf==2.3.0 162 | - open-clip-torch==2.26.1 163 | - opencv-python==4.8.0.76 164 | - opendatalab==0.0.10 165 | - openmim==0.3.9 166 | - openxlab==0.1.2 167 | - ordered-set==4.1.0 168 | - orjson==3.10.7 169 | - oss2==2.17.0 170 | - packaging==24.1 171 | - pandas==2.2.3 172 | - pathspec==0.12.1 173 | - peft==0.13.2 174 | - pillow==10.0.1 175 | - platformdirs==3.11.0 176 | - pluggy==1.5.0 177 | - preshed==3.0.9 178 | - prompt-toolkit==3.0.48 179 | - protobuf==3.20.3 180 | - psutil==6.0.0 181 | - pyarrow==17.0.0 182 | - pycocoevalcap==1.2 183 | - pycocotools==2.0.8 184 | - pycparser==2.22 185 | - pycryptodome==3.21.0 186 | - pydantic==1.10.18 187 | - pydantic-core==2.23.4 188 | - pydot==3.0.2 189 | - pygit2==1.16.0 190 | - pygments==2.18.0 191 | - pygtrie==2.5.0 192 | - pyparsing==3.1.4 193 | - pytest==7.2.0 194 | - pytest-split==0.8.0 195 | - python-dateutil==2.9.0.post0 196 | - pytz==2023.4 197 | - pyyaml==6.0.1 198 | - regex==2024.9.11 199 | - requests==2.32.2 200 | - rich==13.4.2 201 | - ruamel-yaml==0.18.6 202 | - ruamel-yaml-clib==0.2.8 203 | - safetensors==0.4.5 204 | - scikit-image==0.24.0 205 | - scikit-learn==1.5.2 206 | - scipy==1.9.3 207 | - scmrepo==3.3.8 208 | - seaborn==0.13.2 209 | - semver==3.0.2 210 | - sentencepiece==0.2.0 211 | - sentry-sdk==2.16.0 212 | - setproctitle==1.3.3 213 | - setuptools==60.2.0 214 | - shapely==2.0.6 215 | - shellingham==1.5.4 216 | - shortuuid==1.0.13 217 | - shtab==1.7.1 218 | - six==1.16.0 219 | - smart-open==7.0.5 220 | - smmap==5.0.1 221 | - spacy==3.8.2 222 | - spacy-curated-transformers==0.3.0 223 | - spacy-legacy==3.0.12 224 | - spacy-loggers==1.0.5 225 | - sqltrie==0.11.1 226 | - srsly==2.4.8 227 | - sympy==1.13.3 228 | - tabulate==0.9.0 229 | - termcolor==2.5.0 230 | - terminaltables==3.1.10 231 | - thinc==8.3.2 232 | - threadpoolctl==3.5.0 233 | - tifffile==2024.9.20 234 | - timm==0.6.13 235 | - tokenizers==0.20.1 236 | - tomli==2.0.2 237 | - tomlkit==0.13.2 238 | - torch==2.4.0+cu121 239 | - torchaudio==2.4.0+cu121 240 | - torchvision==0.19.0+cu121 241 | - tqdm==4.66.3 242 | - transformers==4.46.0 243 | - triton==3.0.0 244 | - typer==0.12.5 245 | - typing-extensions==4.12.2 246 | - tzdata==2024.2 247 | - urllib3==1.26.20 248 | - vine==5.1.0 249 | - voluptuous==0.15.2 250 | - wandb==0.18.3 251 | - wasabi==1.1.3 252 | - wcwidth==0.2.13 253 | - weasel==0.4.1 254 | - webdataset==0.2.100 255 | - wrapt==1.16.0 256 | - xformers==0.0.27.post2 257 | - xmltodict==0.14.2 258 | - xxhash==3.5.0 259 | - yapf==0.40.2 260 | - yarl==1.13.1 261 | - zc-lockfile==3.0.post1 262 | - zipp==3.20.2 263 | prefix: /home/u1120240347/miniconda3/envs/hyperunet 264 | -------------------------------------------------------------------------------- /training-free/infer_sd1-5_x0_optim_mask_final_para.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from diffusers import StableDiffusionPipeline, UNet2DConditionModel 5 | from diffusers import DPMSolverMultistepScheduler, DDIMScheduler 6 | from sd_utils_x0_dpm import register_sd_forward, register_sdschedule_step 7 | import ImageReward as RM 8 | from torch import Tensor 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | try: 12 | from torchvision.transforms import InterpolationMode 13 | BICUBIC = InterpolationMode.BICUBIC 14 | except ImportError: 15 | BICUBIC = Image.BICUBIC 16 | 17 | from torchvision.transforms import ToPILImage 18 | from torchvision.transforms import Compose, Resize, CenterCrop, Normalize 19 | from accelerate import Accelerator # 导入 Accelerator 20 | import gc 21 | from hpsv2_loss import HPSV2Loss 22 | import xformers 23 | import time # 导入 time 模块以记录时间 24 | import json 25 | 26 | def gumbel_sigmoid(logits: Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> Tensor: 27 | """ 28 | Samples from the Gumbel-Sigmoid distribution and optionally discretizes. 29 | The discretization converts the values greater than threshold to 1 and the rest to 0. 30 | The code is adapted from the official PyTorch implementation of gumbel_softmax: 31 | https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax 32 | 33 | Args: 34 | logits: [..., num_features] unnormalized log probabilities 35 | tau: non-negative scalar temperature 36 | hard: if True, the returned samples will be discretized, 37 | but will be differentiated as if it is the soft sample in autograd 38 | threshold: threshold for the discretization, 39 | values greater than this will be set to 1 and the rest to 0 40 | 41 | Returns: 42 | Sampled tensor of same shape as logits from the Gumbel-Sigmoid distribution. 43 | If hard=True, the returned samples are descretized according to threshold, otherwise they will 44 | be probability distributions. 45 | 46 | """ 47 | gumbels = ( 48 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() 49 | ) # ~Gumbel(0, 1) 50 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits, tau) 51 | y_soft = gumbels.sigmoid() 52 | 53 | if hard: 54 | # Straight through. 55 | indices = (y_soft > threshold).nonzero(as_tuple=True) 56 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) 57 | y_hard[indices[0], indices[1]] = 1.0 58 | ret = y_hard - y_soft.detach() + y_soft 59 | else: 60 | # Reparametrization trick. 61 | ret = y_soft 62 | return ret 63 | 64 | class MaskApplier: 65 | def __init__(self, unet): 66 | self.unet = unet 67 | self.masks = nn.ParameterDict() 68 | self.hooks = [] 69 | self.mask_probs = {} 70 | self.mask_samples = {} 71 | self.initial_masks = {} # 保存初始mask状态 72 | self.module_dict = {} # 添加此字典来存储 safe_name 到 module 的映射 73 | 74 | # Initialize mask logits 75 | skip_layers = ["time_emb_proj", "ff", "conv_shortcut", "proj_in", "proj_out"] 76 | for name, module in self.unet.up_blocks.named_modules(): 77 | if isinstance(module, (torch.nn.Linear)) and not any(skip_layer in name for skip_layer in skip_layers): 78 | safe_name = name.replace('.', '_') 79 | mask_shape = module.weight.shape 80 | # Initialize logits to a high value to start with masks mostly turned on 81 | mask_logit = torch.ones(mask_shape, device=module.weight.device, dtype=torch.float32) 82 | self.masks[safe_name] = nn.Parameter(mask_logit) 83 | self.initial_masks[safe_name] = mask_logit.clone() # 保存初始状态 84 | self.module_dict[safe_name] = module # 保存 safe_name 到 module 的映射 85 | hook = module.register_forward_hook(self.hook_fn(safe_name)) 86 | self.hooks.append(hook) 87 | 88 | def hook_fn(self, layer_name): 89 | def apply_mask(module, input, output): 90 | # Use gumbel_sigmoid to sample mask 91 | logits = self.masks[layer_name].to(module.weight.device) 92 | mask = gumbel_sigmoid(logits, hard=True) 93 | batch_size = input[0].size(0) 94 | # Store the probabilities and samples for policy gradient update 95 | # with torch.no_grad(): 96 | # probs = torch.sigmoid(logits) 97 | # self.mask_probs[layer_name] = probs.detach() 98 | # self.mask_samples[layer_name] = mask.detach() 99 | # Apply the binary mask to the weights 100 | masked_weight = module.weight * mask 101 | # Recompute the output using the masked weights without modifying module.weight 102 | if isinstance(module, nn.Conv2d): 103 | bias = module.bias 104 | stride = module.stride 105 | padding = module.padding 106 | dilation = module.dilation 107 | groups = module.groups 108 | # Recompute the output 109 | return F.conv2d(input[0], masked_weight, bias, stride, padding, dilation, groups) 110 | elif isinstance(module, nn.Linear): 111 | weight_shape = module.weight.shape 112 | # print("input", input[0].shape) 113 | 114 | output = F.linear(input[0], masked_weight) 115 | if module.bias is not None: 116 | output += module.bias 117 | return output 118 | else: 119 | return output # For layers that are not Conv2d or Linear 120 | return apply_mask 121 | 122 | def reset_masks(self): 123 | # 重置所有 mask 到初始状态 124 | for name in self.masks.keys(): 125 | self.masks[name].data = self.initial_masks[name].clone() 126 | 127 | 128 | class MaskOptimizer: 129 | def __init__(self, unet, mask_applier, reward_model, hpsv2_model, accelerator, num_iters=20): 130 | self.unet = unet 131 | self.reward_model = reward_model 132 | self.hpsv2_model = hpsv2_model 133 | self.mask_applier = mask_applier 134 | self.num_iters = num_iters 135 | self.latent_cache = None 136 | self.accelerator = accelerator # 保存 accelerator 实例 137 | 138 | def optimize_masks(self, prompt, pipe, num_iterations=20, seed=42, num_inference_steps=10): 139 | rm_input_ids = self.reward_model.module.blip.tokenizer( 140 | prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt" 141 | ).input_ids.to(self.accelerator.device) 142 | rm_attention_mask = self.reward_model.module.blip.tokenizer( 143 | prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt" 144 | ).attention_mask.to(self.accelerator.device) 145 | gen_image = [] 146 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, self.mask_applier.masks.parameters()), lr=1e-2) 147 | optimizer = self.accelerator.prepare(optimizer) 148 | # optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.mask_applier.masks.parameters()), lr=1e-2) 149 | # 对每个时间步执行独立的 mask 优化 150 | for stop_step in range(0, num_inference_steps): 151 | pipe.stop_step = stop_step 152 | if stop_step <= 6: 153 | num_iterations = 1 154 | else: 155 | num_iterations = 10 156 | for iteration in range(num_iterations): 157 | # generator = torch.Generator(device="cuda").manual_seed(seed) 158 | generator = None 159 | 160 | # Generate image 161 | with self.accelerator.autocast(): 162 | image, cur_latents = pipe.forward(prompt=prompt, num_inference_steps=num_inference_steps, 163 | guidance_scale=7.5, generator=generator, 164 | latents = self.latent_cache, return_dict=False, 165 | widh=512,height=512) 166 | x_0_per_step = image 167 | 168 | # 保存生成的图像 169 | if stop_step == num_inference_steps-1 and iteration == num_iterations-1: 170 | save_imge = x_0_per_step.squeeze(0) 171 | to_pil = ToPILImage() 172 | image_save = to_pil(save_imge.cpu()) 173 | gen_image.append(image_save) 174 | # image_save.save(f"{save_path}/{prompt}/seed{seed}_step{stop_step+1}_iter{iteration+1}_x0.png") 175 | # gc.collect() 176 | # torch.cuda.empty_cache() # 清理 177 | # 预处理图像 178 | rm_preprocess = Compose([ 179 | Resize(224, interpolation=BICUBIC), 180 | CenterCrop(224), 181 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 182 | ]) 183 | x_0_per_step = rm_preprocess(x_0_per_step) 184 | x_0_per_step = x_0_per_step.to(self.accelerator.device) 185 | 186 | 187 | 188 | # 准备输入数据 189 | 190 | # torch.autograd.set_detect_anomaly(True) 191 | # Compute reward and loss 192 | with self.accelerator.autocast(): 193 | hpsv2_loss = self.hpsv2_model.score_grad(x_0_per_step, prompt) 194 | # print("hpsv2", hpsv2_loss) 195 | reward = self.reward_model.module.score_gard(rm_input_ids, rm_attention_mask, x_0_per_step) 196 | loss = F.relu(-reward + 2).mean() 197 | loss = loss + 5.0*hpsv2_loss 198 | # print(loss) 199 | 200 | 201 | 202 | 203 | self.accelerator.backward(loss) # 使用 accelerator 进行反向传播 204 | 205 | # for name, param in self.mask_applier.masks.items(): 206 | # if param.grad is not None: 207 | # print(f"Mask logits for {name}, grad mean: {param.grad.mean().item()}, grad std: {param.grad.std().item()}") 208 | 209 | # 可选:梯度裁剪 210 | # self.accelerator.clip_grad_norm_(self.unet.parameters(), max_norm=1.0) 211 | # self.accelerator.clip_grad_norm_(self.mask_applier.masks.parameters(), max_norm=1.0) 212 | 213 | optimizer.step() # 更新参数 214 | # for name, param in self.unet.named_parameters(): 215 | # if param.requires_grad: 216 | # mask = gumbel_sigmoid(param.data, tau=1.0, hard=True) 217 | # zeros = (mask == 0).sum().item() 218 | # ones = (mask == 1).sum().item() 219 | # print(f"After update - {name}: max {param.data.max()}, min {param.data.min()}") 220 | # print(f"Gumbel Sigmoid applied on {name}: Zeros={zeros}, Ones={ones}") 221 | optimizer.zero_grad() # 重置优化器梯度 222 | 223 | # 打印参数更新信息 224 | # for name, param in self.unet.named_parameters(): 225 | # if param.requires_grad: 226 | # print(f"After update - {name}: max {param.data.max()}, min {param.data.min()}") 227 | # for name, param in self.unet.named_parameters(): 228 | # if param.requires_grad: 229 | # mask = gumbel_sigmoid(param.data, tau=1.0, hard=True) 230 | # zeros = (mask == 0).sum().item() 231 | # ones = (mask == 1).sum().item() 232 | # print(f"After update - {name}: max {param.data.max()}, min {param.data.min()}") 233 | # print(f"Gumbel Sigmoid applied on {name}: Zeros={zeros}, Ones={ones}") 234 | 235 | # 清理内存 236 | if stop_step == num_inference_steps-1 and iteration == num_iterations-1: 237 | print(f"Step {stop_step+1}/{num_inference_steps}, Iteration {iteration + 1}/{num_iterations}, Reward: {reward.item()}") 238 | del image, x_0_per_step, reward, loss, hpsv2_loss 239 | gc.collect() 240 | torch.cuda.empty_cache() # 添加这行代码释放显存 241 | 242 | self.latent_cache = cur_latents.detach().clone() 243 | # 每个时间步结束后重置mask 244 | # gc.collect() 245 | self.mask_applier.reset_masks() 246 | torch.cuda.empty_cache() 247 | # print(f"Mask reset after step {stop_step+1}") 248 | return gen_image 249 | 250 | def main(): 251 | # Set a fixed seed for reproducibility 252 | # seed = 42 253 | # torch.manual_seed(seed) 254 | # torch.cuda.manual_seed_all(seed) 255 | 256 | # 初始化 Accelerator 257 | accelerator = Accelerator(device_placement=True, mixed_precision='fp16') # 使用混合精度 258 | 259 | # Load the UNet model 260 | # unet = UNet2DConditionModel.from_pretrained( 261 | # "../share/runwayml/stable-diffusion-v1-5", subfolder="unet", 262 | # torch_dtype=torch.float32 263 | # ) 264 | 265 | # 创建 MaskApplier 实例 266 | 267 | 268 | # scheduler = DDIMScheduler.from_pretrained("../share/runwayml/stable-diffusion-v1-5", subfolder="scheduler") 269 | scheduler = DPMSolverMultistepScheduler.from_pretrained("../share/runwayml/stable-diffusion-v1-5", subfolder="scheduler", algorithm_type="dpmsolver++") 270 | # Load the Stable Diffusion pipeline with the modified UNet model 271 | pipe = StableDiffusionPipeline.from_pretrained( 272 | "../share/runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, 273 | scheduler=scheduler 274 | ) 275 | 276 | mask_applier = MaskApplier(pipe.unet) 277 | 278 | # 使用 accelerator.prepare() 一起包装模型和 MaskApplier 279 | unet, mask_applier = accelerator.prepare(pipe.unet, mask_applier) 280 | 281 | pipe.to(accelerator.device) 282 | 283 | # pipe.unet.enable_gradient_checkpointing() 284 | pipe.unet.enable_xformers_memory_efficient_attention() 285 | 286 | register_sd_forward(pipe) 287 | register_sdschedule_step(pipe.scheduler) 288 | 289 | def dummy_checker(images, **kwargs): 290 | # Return the images as is, and set NSFW detection results to False (not NSFW) 291 | return images, [False] * len(images) 292 | 293 | pipe.safety_checker = dummy_checker 294 | 295 | save_dir = "training-free_mask_fix_stop7_iter10" 296 | 297 | # Create MaskApplier and ImageReward instances 298 | reward_model = RM.load("ImageReward-v1.0").to(device = accelerator.device, dtype = torch.float16) 299 | reward_model = accelerator.prepare(reward_model) 300 | #HPSV2 model 301 | hpsv2_model = HPSV2Loss( 302 | dtype = torch.float16, 303 | device = accelerator.device, 304 | cache_dir = "./HPSV2_checkpoint" 305 | # memsave = True 306 | ) 307 | hpsv2_model = accelerator.prepare(hpsv2_model) 308 | 309 | 310 | prompt_list_file = "./geneval/prompts/evaluation_metadata.jsonl" 311 | with open(prompt_list_file) as fp: 312 | metadatas = [json.loads(line) for line in fp] 313 | 314 | 315 | total_prompts = len(metadatas) 316 | num_processes = accelerator.num_processes 317 | prompts_per_process = total_prompts // num_processes 318 | start_index = accelerator.process_index * prompts_per_process 319 | end_index = start_index + prompts_per_process if accelerator.process_index != num_processes - 1 else total_prompts 320 | 321 | # Process prompts assigned to this process 322 | for idx in range(start_index, end_index): 323 | prompt = metadatas[idx]["prompt"] 324 | 325 | # Create output directory for each prompt 326 | outdir = f"{save_dir}/{idx:0>5}" 327 | os.makedirs(f"{outdir}/samples", exist_ok=True) 328 | 329 | # Save metadata for each prompt 330 | with open(f"{outdir}/metadata.jsonl", "w") as fp: 331 | json.dump(metadatas[idx], fp) 332 | 333 | # Create Mask Optimizer 334 | mask_optimizer = MaskOptimizer(unet, mask_applier, reward_model, hpsv2_model, accelerator) 335 | 336 | # Run the mask optimization process 337 | image = mask_optimizer.optimize_masks([prompt], pipe, seed=None, num_inference_steps=15) 338 | 339 | # Save the generated image 340 | for i, img in enumerate(image): 341 | img.save(f"{outdir}/samples/{i:05}.png") 342 | 343 | del image 344 | torch.cuda.empty_cache() # 清空显存缓存 345 | 346 | accelerator.wait_for_everyone() 347 | print(f"Process {accelerator.process_index}: Finished generating images.") 348 | # sample_out = 0 349 | # for index, metadata in enumerate(metadatas): 350 | # prompt = metadata["prompt"] 351 | # outdir = f"{save_dir}" 352 | # outpath = f"{outdir}/{index:0>5}" 353 | # os.makedirs(f"{outpath}/samples", exist_ok=True) 354 | # with open(f"{outpath}/metadata.jsonl", "w") as fp: 355 | # json.dump(metadata, fp) 356 | # # prompt = "a cat wearing a hat and a dog wearing a glasses" 357 | # # Create Mask Optimizer 358 | # mask_optimizer = MaskOptimizer(unet, mask_applier, reward_model, hpsv2_model, accelerator) 359 | 360 | # # Optimize masks 361 | # # start_time = time.time() 362 | # batch_prompts = [prompt for _ in range(1)] # 假设每次处理1个 prompt 363 | # image = mask_optimizer.optimize_masks(batch_prompts, pipe, seed = seed, num_inference_steps=15) 364 | # # end_time = time.time() # 记录图像生成结束时间 365 | # # duration = end_time - start_time # 计算生成时间 366 | # for i, img in enumerate(image): 367 | # img.save(f"{outpath}/samples/{sample_out + i:05}.png") 368 | # sample_out += len(image) 369 | # # print(f"Time taken for mask gen: {duration:.2f} seconds") 370 | # print("Image generation with optimized mask completed and saved.") 371 | 372 | if __name__ == "__main__": 373 | main() -------------------------------------------------------------------------------- /training/hyperunet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from typing import Optional, Union, Tuple, List, Callable, Dict, Any 5 | from copy import deepcopy 6 | from accelerate import Accelerator 7 | 8 | from diffusers.utils import USE_PEFT_BACKEND, deprecate, scale_lora_layers, unscale_lora_layers, BaseOutput 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from diffusers import UNet2DConditionModel 12 | from dataclasses import dataclass 13 | from functools import partial 14 | from einops import rearrange 15 | from torch.cuda.amp import autocast, GradScaler 16 | from torch import Tensor 17 | 18 | @dataclass 19 | class UNet2DConditionOutput(BaseOutput): 20 | """ 21 | The output of [`UNet2DConditionModel`]. 22 | 23 | Args: 24 | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): 25 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 26 | """ 27 | 28 | sample: torch.Tensor = None 29 | 30 | def weights_init_kaiming(m): 31 | """ 32 | Kaiming (He) Initialization for Conv2D and Linear layers. 33 | """ 34 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 35 | nn.init.constant_(m.weight, 0.01) # 初始化为接近恒等映射的小值 36 | if m.bias is not None: 37 | nn.init.constant_(m.bias, 0) # 偏置设置为0 38 | 39 | class DecoderWeightsMaskGenerator(nn.Module): 40 | def __init__(self, in_channels, out_channels, hidden_dim=64, factor=16): 41 | super().__init__() 42 | 43 | new_out = out_channels // factor 44 | new_in = in_channels // factor 45 | # new_out = 128 46 | # new_in = 128 47 | self.conv_kernel_mask = nn.Sequential( 48 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(hidden_dim, new_in * new_out, kernel_size=1), 51 | nn.ReLU(inplace=True) 52 | ) 53 | 54 | self.proj_temb = nn.Linear(1280, in_channels, bias=False) 55 | # self.proj_pemb = nn.Linear(768, in_channels, bias=False) 56 | 57 | self.in_c = new_in 58 | self.out_c = new_out 59 | self.apply(self.weights_init) 60 | 61 | def weights_init(self, m): 62 | if isinstance(m, nn.Conv2d): 63 | if m.kernel_size == (1, 1) and m.out_channels == self.in_c * self.out_c: 64 | nn.init.constant_(m.weight, 0) 65 | if m.bias is not None: 66 | nn.init.constant_(m.bias, 0) 67 | else: 68 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu') 69 | if m.bias is not None: 70 | nn.init.constant_(m.bias, 0) 71 | 72 | def forward(self, weight_shape, sample, res_sample, temb, encoder_hidden_states): 73 | batch_size = sample.size(0) 74 | flag = len(weight_shape) 75 | #emb [n,c] 76 | temb = self.proj_temb(temb).unsqueeze(-1).unsqueeze(-1) 77 | # pemb = self.proj_pemb(encoder_hidden_states).permute(0, 2, 1).contiguous().mean(-1).unsqueeze(-1).unsqueeze(-1) 78 | if res_sample is None: 79 | x = sample 80 | else: 81 | x = torch.cat([sample, res_sample], dim=1) 82 | if flag == 4: 83 | _, _, k_h, k_w = weight_shape 84 | x = F.adaptive_avg_pool2d(x, (k_h, k_w)) 85 | else: 86 | x = F.adaptive_avg_pool2d(x, (1, 1)) 87 | 88 | x = x + temb 89 | # x = temb 90 | 91 | mask = self.conv_kernel_mask(x) 92 | if flag == 4: 93 | mask = mask.view(mask.size(0), self.out_c, self.in_c, k_h, k_w) 94 | else: 95 | mask = mask.view(mask.size(0), self.out_c, self.in_c) 96 | 97 | return mask 98 | 99 | 100 | def gumbel_sigmoid(logits: Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> Tensor: 101 | """ 102 | Samples from the Gumbel-Sigmoid distribution and optionally discretizes. 103 | The discretization converts the values greater than `threshold` to 1 and the rest to 0. 104 | The code is adapted from the official PyTorch implementation of gumbel_softmax: 105 | https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax 106 | 107 | Args: 108 | logits: `[..., num_features]` unnormalized log probabilities 109 | tau: non-negative scalar temperature 110 | hard: if ``True``, the returned samples will be discretized, 111 | but will be differentiated as if it is the soft sample in autograd 112 | threshold: threshold for the discretization, 113 | values greater than this will be set to 1 and the rest to 0 114 | 115 | Returns: 116 | Sampled tensor of same shape as `logits` from the Gumbel-Sigmoid distribution. 117 | If ``hard=True``, the returned samples are descretized according to `threshold`, otherwise they will 118 | be probability distributions. 119 | 120 | """ 121 | gumbels = ( 122 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() 123 | ) # ~Gumbel(0, 1) 124 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits, tau) 125 | y_soft = gumbels.sigmoid() 126 | 127 | if hard: 128 | # Straight through. 129 | indices = (y_soft > threshold).nonzero(as_tuple=True) 130 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) 131 | y_hard[indices[0], indices[1]] = 1.0 132 | ret = y_hard - y_soft.detach() + y_soft 133 | else: 134 | # Reparametrization trick. 135 | ret = y_soft 136 | return ret 137 | 138 | 139 | class Adapter(nn.Module): 140 | def __init__(self, out_c, in_c, new_out_c, new_in_c, tau = 1.0): 141 | super(Adapter, self).__init__() 142 | # Use 1x1 convolutions for channel adaptation 143 | self.conv_in = nn.Conv2d(out_c, new_out_c, kernel_size=1) 144 | self.conv_out = nn.Conv2d(in_c, new_in_c, kernel_size=1) 145 | self.tau = tau 146 | 147 | self.apply(self.weights_init) 148 | 149 | def weights_init(self, m): 150 | """ 151 | Kaiming (He) Initialization for Conv2D and Linear layers. 152 | """ 153 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 154 | nn.init.constant_(m.weight, 0) # 初始化为0 155 | if m.bias is not None: 156 | nn.init.constant_(m.bias, 5.0) # 偏置设置为5.0 157 | 158 | def forward(self, input_tensor_3x3): 159 | # The input tensor shape is [batch_size, out_c, in_c, k_h, k_w] 160 | 161 | # Step 1: Adapt the in_channels using a 1x1 convolution 162 | # We rearrange the tensor to merge the out_c dimension into the batch size for efficient processing 163 | if input_tensor_3x3.dim() == 3: 164 | input_tensor_3x3 = input_tensor_3x3.unsqueeze(-1).unsqueeze(-1) 165 | input_adapted = rearrange(input_tensor_3x3, 'b out_c in_c h w -> (b in_c) out_c h w') 166 | 167 | # if self.conv_in.bias is not None: 168 | # self.conv_in.bias.data = self.conv_in.bias.data.to(input_adapted.dtype) 169 | # Apply the in-channel adaptation 170 | input_adapted = self.conv_in(input_adapted) 171 | 172 | # Step 2: Adapt the out_channels 173 | # Rearrange to bring the adapted in_channels back into the batch dimension for processing 174 | input_adapted = rearrange(input_adapted, '(b in_c) new_out_c h w -> b new_out_c in_c h w', b=input_tensor_3x3.shape[0]) 175 | input_adapted = rearrange(input_adapted, 'b new_out_c in_c h w -> (b new_out_c) in_c h w') 176 | 177 | # if self.conv_out.bias is not None: 178 | # self.conv_out.bias.data = self.conv_out.bias.data.to(input_adapted.dtype) 179 | # Apply the out-channel adaptation 180 | output = self.conv_out(input_adapted) 181 | output = gumbel_sigmoid(output, tau=self.tau, hard=True) 182 | # Step 3: Reshape the output back to the original format with new dimensions 183 | output = rearrange(output, '(b new_out_c) new_in_c h w -> b new_out_c new_in_c h w', b=input_tensor_3x3.shape[0]) 184 | 185 | if output.size(-1)==1: 186 | output = output.squeeze(-1).squeeze(-1) 187 | return output 188 | 189 | 190 | def get_hook_fn(sample, temb, res_sample, mask_generator, adapter, encoder_hidden_states): 191 | def hook_fn(module, input, output): 192 | batch_size = input[0].size(0) 193 | 194 | if isinstance(module, nn.Conv2d): 195 | weight_shape = module.weight.shape 196 | mask = mask_generator(weight_shape, sample, res_sample, temb, encoder_hidden_states).to(module.weight.device) 197 | # print(module.weight.shape) 198 | # print(mask.shape) 199 | # print() 200 | mask = adapter(mask) 201 | masked_weight = module.weight * mask 202 | masked_weight = masked_weight.reshape(-1, *module.weight.shape[1:]).contiguous() # 使用 reshape 替代 view 203 | input_reshaped = input[0].reshape(1, -1, *input[0].shape[2:]).contiguous() # 使用 reshape 替代 view 204 | 205 | # 如果有 bias,则需要处理 bias 的形状 206 | if module.bias is not None: 207 | # 确保 bias 的形状正确,并直接扩展到 batch_size 维度 208 | masked_bias = module.bias.repeat(batch_size).contiguous() # 扩展 bias 到与分组卷积匹配的大小 209 | else: 210 | masked_bias = None 211 | 212 | output = F.conv2d( 213 | input_reshaped, 214 | masked_weight, 215 | bias=masked_bias, 216 | stride=module.stride, 217 | padding=module.padding, 218 | dilation=module.dilation, 219 | groups=batch_size 220 | ) 221 | output = output.reshape(batch_size, -1, *output.shape[2:]).contiguous() # 使用 reshape 替代 view 222 | return output 223 | 224 | elif isinstance(module, nn.Linear): 225 | if module.weight.dim() == 2: 226 | weight_shape = module.weight.shape 227 | mask = mask_generator(weight_shape, sample, res_sample, temb, encoder_hidden_states).to(module.weight.device) 228 | mask = adapter(mask) 229 | masked_weight = module.weight * mask 230 | if input[0].dim() == 2: 231 | input_batched = input[0].unsqueeze(1) # (batch_size, 1, in_features) 232 | else: 233 | input_batched = input[0] 234 | # print(input_batched.shape) 235 | # assert 2==1 236 | output = torch.bmm(input_batched, masked_weight.permute(0, 2, 1).contiguous()).squeeze(1) 237 | if module.bias is not None: 238 | output += module.bias.unsqueeze(0).expand_as(output) 239 | return output 240 | return hook_fn 241 | 242 | # def __init__(self, *args, cross_attention_dim=768, **kwargs): 243 | # super(FineGrainedUNet2DConditionModel, self).__init__(*args, cross_attention_dim=cross_attention_dim, **kwargs) 244 | class FineGrainedUNet2DConditionModel(UNet2DConditionModel): 245 | def __init__(self, *args, cross_attention_dim=768, use_linear_projection=False,**kwargs): 246 | super(FineGrainedUNet2DConditionModel, self).__init__(*args, cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, **kwargs) 247 | self.mask_generators = nn.ModuleList() 248 | # self.mask_generators_down = nn.ModuleList() 249 | # self.layer_adapters = {} 250 | self.adapters = nn.ModuleList() # 使用 ModuleList 代替普通 list 251 | # self.adapters_down = nn.ModuleList() # 使用 ModuleList 代替普通 list 252 | # self.decoder_weights_mask_generator.apply(self.init_weights) 253 | 254 | self._hooks = [] 255 | 256 | # for name, module in self.up_blocks.named_modules(): 257 | # if 'proj_out' in name: 258 | # print(f"{name} layer structure: {module}") 259 | # print(f"Shape of {name} weights: {module.weight.shape}") 260 | 261 | # assert 2==1 262 | 263 | skip_layers = ["time_emb_proj", "ff", "conv_shortcut", "proj_in", "proj_out"] 264 | 265 | # for i, downsample_block in enumerate(self.down_blocks): 266 | # print(f"--- Upsample Block {i} ---") 267 | # first_resnet_conv1 = downsample_block.resnets[0].conv1 # Get resnets.0.conv1 layer 268 | # if isinstance(first_resnet_conv1, nn.Conv2d): 269 | # in_channels = first_resnet_conv1.in_channels 270 | # out_channels = first_resnet_conv1.out_channels 271 | # # out_channels = 256 272 | # print(f"Block {i}: in_channels={in_channels}, out_channels={out_channels}") 273 | 274 | # # Initialize mask_generator for this block 275 | # mask_generator = DecoderWeightsMaskGenerator(in_channels, out_channels) 276 | # self.mask_generators_down.append(mask_generator) 277 | # block_adapters = nn.ModuleDict() # 使用 ModuleDict 来存储 block 的 adapter 278 | # # Use these channels for all adapters in this block 279 | # for name, sub_module in downsample_block.named_modules(): 280 | # # if isinstance(sub_module, nn.Conv2d) and not any(skip_layer in name for skip_layer in skip_layers): 281 | # # sub_in_channels = sub_module.in_channels 282 | # # sub_out_channels = sub_module.out_channels 283 | # if isinstance(sub_module, nn.Linear) and not any(skip_layer in name for skip_layer in skip_layers): 284 | # sub_in_channels = sub_module.in_features 285 | # sub_out_channels = sub_module.out_features 286 | # else: 287 | # continue 288 | 289 | # # 替换 name 中的 "." 为 "_" 290 | # sanitized_name = name.replace(".", "_") 291 | # print(f"Adapter for Layer {sanitized_name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}") 292 | 293 | # # print(f"Adapter for Layer {name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}") 294 | 295 | # # Initialize Adapter using first_resnet_conv1 channels and layer's channels 296 | # # factor 297 | # factor = 8 298 | # adapter = Adapter(out_channels // factor, in_channels // factor, sub_out_channels, sub_in_channels) 299 | # block_adapters[sanitized_name] = adapter 300 | # self.adapters_down.append(block_adapters) 301 | # factor = [16, 8, 4, 2] 302 | for i, upsample_block in enumerate(self.up_blocks): 303 | print(f"--- Upsample Block {i} ---") 304 | first_resnet_conv1 = upsample_block.resnets[0].conv1 # Get resnets.0.conv1 layer 305 | if isinstance(first_resnet_conv1, nn.Conv2d): 306 | in_channels = first_resnet_conv1.in_channels 307 | out_channels = first_resnet_conv1.out_channels 308 | # out_channels = 256 309 | # print(f"Block {i}: in_channels={in_channels}, out_channels={out_channels}") 310 | factor = 4 311 | # Initialize mask_generator for this block 312 | mask_generator = DecoderWeightsMaskGenerator(in_channels, out_channels, factor = factor) 313 | self.mask_generators.append(mask_generator) 314 | block_adapters = nn.ModuleDict() # 使用 ModuleDict 来存储 block 的 adapter 315 | # Use these channels for all adapters in this block 316 | for name, sub_module in upsample_block.named_modules(): 317 | # if isinstance(sub_module, nn.Conv2d) and not any(skip_layer in name for skip_layer in skip_layers): 318 | # sub_in_channels = sub_module.in_channels 319 | # sub_out_channels = sub_module.out_channels 320 | if isinstance(sub_module, nn.Linear) and not any(skip_layer in name for skip_layer in skip_layers): 321 | sub_in_channels = sub_module.in_features 322 | sub_out_channels = sub_module.out_features 323 | else: 324 | continue 325 | 326 | # 替换 name 中的 "." 为 "_" 327 | sanitized_name = name.replace(".", "_") 328 | # print(f"Adapter for Layer {sanitized_name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}") 329 | 330 | # print(f"Adapter for Layer {name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}") 331 | 332 | # Initialize Adapter using first_resnet_conv1 channels and layer's channels 333 | # factor 334 | # factor = 16 335 | adapter = Adapter(out_channels // factor, in_channels // factor, sub_out_channels, sub_in_channels) 336 | # adapter = Adapter(128, 128, sub_out_channels, sub_in_channels) 337 | block_adapters[sanitized_name] = adapter 338 | self.adapters.append(block_adapters) 339 | 340 | # assert 2==1 341 | 342 | # for i, upsample_block in enumerate(self.up_blocks): 343 | # print(f"--- Upsample Block {i} ---") 344 | # first_resnet_conv1 = upsample_block.resnets[0].conv1 # 获取 resnets.0.conv1 卷积层 345 | # if isinstance(first_resnet_conv1, nn.Conv2d): 346 | # in_channels = first_resnet_conv1.in_channels 347 | # out_channels = first_resnet_conv1.out_channels 348 | # self.mask_generators.append(DecoderWeightsMaskGenerator(in_channels, out_channels)) 349 | 350 | # assert 2==1 351 | 352 | @staticmethod 353 | def init_weights(module): 354 | # 初始化 nn.Linear 层 355 | if isinstance(module, nn.Conv2d): 356 | nn.init.xavier_uniform_(module.weight) 357 | if module.bias is not None: 358 | nn.init.constant_(module.bias, 2.0) 359 | 360 | # # 初始化 nn.Parameter 参数 361 | elif isinstance(module, nn.Parameter): 362 | nn.init.constant_(module, 0) # 初始化为0 363 | 364 | def forward( 365 | self, 366 | sample: torch.Tensor, 367 | timestep: Union[torch.Tensor, float, int], 368 | encoder_hidden_states: torch.Tensor, 369 | class_labels: Optional[torch.Tensor] = None, 370 | timestep_cond: Optional[torch.Tensor] = None, 371 | attention_mask: Optional[torch.Tensor] = None, 372 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 373 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 374 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 375 | mid_block_additional_residual: Optional[torch.Tensor] = None, 376 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 377 | encoder_attention_mask: Optional[torch.Tensor] = None, 378 | return_dict: bool = True, 379 | ) -> Union[UNet2DConditionOutput, Tuple]: 380 | 381 | if self.config.center_input_sample: 382 | sample = 2 * sample - 1.0 383 | 384 | t_emb = self.get_time_embed(sample=sample, timestep=timestep) 385 | emb = self.time_embedding(t_emb, timestep_cond) 386 | aug_emb = None 387 | 388 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) 389 | if class_emb is not None: 390 | if self.config.class_embeddings_concat: 391 | emb = torch.cat([emb, class_emb], dim=-1) 392 | else: 393 | emb = emb + class_emb 394 | 395 | aug_emb = self.get_aug_embed( 396 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 397 | ) 398 | if self.config.addition_embed_type == "image_hint": 399 | aug_emb, hint = aug_emb 400 | sample = torch.cat([sample, hint], dim=1) 401 | 402 | emb = emb + aug_emb if aug_emb is not None else emb 403 | 404 | if self.time_embed_act is not None: 405 | emb = self.time_embed_act(emb) 406 | 407 | encoder_hidden_states = self.process_encoder_hidden_states( 408 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 409 | ) 410 | 411 | # 2. pre-process 412 | # hook = self.conv_in.register_forward_hook(get_hook_fn(sample, emb)) 413 | # self._hooks.append(hook) 414 | # self.conv_in.decoder_weights_mask_generator = self.decoder_weights_mask_generator 415 | sample = self.conv_in(sample) 416 | 417 | down_block_res_samples = (sample,) 418 | 419 | for i, downsample_block in enumerate(self.down_blocks): 420 | 421 | # mask_generator = self.mask_generators_down[i] 422 | # adapters = self.adapters_down[i] 423 | 424 | # self._register_hooks(downsample_block, sample, None, emb, mask_generator, adapters) 425 | # self._register_hooks(downsample_block, sample, emb) 426 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 427 | sample, res_samples = downsample_block( 428 | hidden_states=sample, 429 | temb=emb, 430 | encoder_hidden_states=encoder_hidden_states, 431 | attention_mask=attention_mask, 432 | cross_attention_kwargs=cross_attention_kwargs, 433 | encoder_attention_mask=encoder_attention_mask, 434 | ) 435 | else: 436 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 437 | down_block_res_samples += res_samples 438 | 439 | # 4. mid 440 | if self.mid_block is not None: 441 | # self._register_hooks(self.mid_block, sample, emb) 442 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 443 | sample = self.mid_block( 444 | sample, 445 | emb, 446 | encoder_hidden_states=encoder_hidden_states, 447 | attention_mask=attention_mask, 448 | cross_attention_kwargs=cross_attention_kwargs, 449 | encoder_attention_mask=encoder_attention_mask, 450 | ) 451 | else: 452 | sample = self.mid_block(sample, emb) 453 | 454 | # 5. up 455 | for i, upsample_block in enumerate(self.up_blocks): 456 | # print("sample", sample.shape) 457 | mask_generator = self.mask_generators[i] 458 | adapters = self.adapters[i] 459 | is_final_block = i == len(self.up_blocks) - 1 460 | 461 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 462 | down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)] 463 | # 打印 res_samples 的形状 464 | # print(res_samples[-1].shape) 465 | # for i, res_sample in enumerate(res_samples): 466 | # print(f"Shape of res_sample {i}: {res_sample.shape}") 467 | 468 | if not is_final_block: 469 | upsample_size = down_block_res_samples[-1].shape[2:] 470 | else: 471 | upsample_size = None 472 | 473 | self._register_hooks(upsample_block, sample ,res_samples[-1], emb, mask_generator, adapters, encoder_hidden_states) 474 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 475 | sample = upsample_block( 476 | hidden_states=sample, 477 | temb=emb, 478 | res_hidden_states_tuple=res_samples, 479 | encoder_hidden_states=encoder_hidden_states, 480 | cross_attention_kwargs=cross_attention_kwargs, 481 | upsample_size=upsample_size, 482 | attention_mask=attention_mask, 483 | encoder_attention_mask=encoder_attention_mask, 484 | ) 485 | else: 486 | sample = upsample_block( 487 | hidden_states=sample, 488 | temb=emb, 489 | res_hidden_states_tuple=res_samples, 490 | upsample_size=upsample_size, 491 | ) 492 | 493 | # 6. post-process 494 | if self.conv_norm_out: 495 | # self._register_hooks(self.conv_norm_out, sample, emb) 496 | sample = self.conv_norm_out(sample) 497 | sample = self.conv_act(sample) 498 | # self._register_hooks(self.conv_out, sample, emb) 499 | sample = self.conv_out(sample) 500 | 501 | # Remove hooks 502 | for h in self._hooks: 503 | h.remove() 504 | self._hooks = [] 505 | 506 | if not return_dict: 507 | return (sample,) 508 | 509 | return UNet2DConditionOutput(sample=sample) 510 | 511 | def _register_hooks(self, module, sample, res_sample, temb, mask_generators, adapters, encoder_hidden_states): 512 | skip_layers = ["time_emb_proj", "ff", "conv_shortcut", "proj_in", "proj_out"] 513 | for name, sub_module in module.named_modules(): 514 | #nn.Conv2d, 515 | if isinstance(sub_module, (nn.Linear)) and not any(skip_layer in name for skip_layer in skip_layers): 516 | # 需要使用 sanitized_name 替换 "." 为 "_" 517 | sanitized_name = name.replace(".", "_") 518 | 519 | # 使用 "in" 判断键是否存在 520 | if sanitized_name in adapters: 521 | adapter = adapters[sanitized_name] 522 | if adapter is None: 523 | continue 524 | hook = sub_module.register_forward_hook(get_hook_fn(sample, temb, res_sample, mask_generators, adapter, encoder_hidden_states)) 525 | self._hooks.append(hook) 526 | # print(f"Hook registered for layer: {name}") 527 | 528 | 529 | # def main(): 530 | # # 示例训练代码 531 | # unet = FineGrainedUNet2DConditionModel.from_pretrained( 532 | # "../share/runwayml/stable-diffusion-v1-5", subfolder="unet", 533 | # low_cpu_mem_usage=False, 534 | # device_map=None 535 | # ) 536 | 537 | # # print(unet) 538 | 539 | # # 冻结其他参数,只训练 decoder_weights_mask_generator 540 | # for param in unet.parameters(): 541 | # param.requires_grad = False 542 | 543 | # for param in unet.mask_generators.parameters(): 544 | # param.requires_grad = True 545 | 546 | # for block_adapters in unet.adapters: 547 | # for adapter in block_adapters.values(): 548 | # for param in adapter.parameters(): 549 | # param.requires_grad = True 550 | 551 | # sample = torch.randn(1, 4, 64, 64) 552 | # timestep = torch.tensor([50]) 553 | # encoder_hidden_states = torch.randn(1, 77, 768) 554 | 555 | # output = unet(sample, timestep, encoder_hidden_states)['sample'] 556 | 557 | # # 假设一个简单的损失函数 558 | # target = torch.randn_like(output) 559 | # loss = F.mse_loss(output, target) 560 | # loss.backward() 561 | 562 | # # 检查梯度 563 | # for i, mask_generator in enumerate(unet.mask_generators): 564 | # for param_name, param in mask_generator.named_parameters(): 565 | # if param.grad is None: 566 | # print(f"No gradient for mask_generator in block {i}, parameter: {param_name}") 567 | # else: 568 | # print(f"Gradient for mask_generator in block {i}, parameter {param_name}: {param.grad.norm()}") 569 | 570 | 571 | # for i, block_adapters in enumerate(unet.adapters): 572 | # for name, adapter in block_adapters.items(): 573 | # for param_name, param in adapter.named_parameters(): 574 | # if param.grad is None: 575 | # print(f"No gradient for adapter {name} in block {i}, parameter: {param_name}") 576 | # else: 577 | # print(f"Gradient for adapter {name} in block {i}, parameter {param_name}: {param.grad.norm()}") 578 | 579 | def main(): 580 | # Initialize the accelerator with FP16 precision 581 | accelerator = Accelerator(mixed_precision="fp16") 582 | 583 | # Initialize model 584 | unet = FineGrainedUNet2DConditionModel.from_pretrained( 585 | "../share/runwayml/stable-diffusion-v1-5", 586 | subfolder="unet", 587 | low_cpu_mem_usage=False, 588 | device_map=None, 589 | ) 590 | 591 | # Freeze other parameters, only train mask_generators and adapters 592 | for param in unet.parameters(): 593 | param.requires_grad = False 594 | for param in unet.mask_generators.parameters(): 595 | param.requires_grad = True 596 | for block_adapters in unet.adapters: 597 | for adapter in block_adapters.values(): 598 | for param in adapter.parameters(): 599 | param.requires_grad = True 600 | 601 | # Initialize optimizer 602 | optimizer = torch.optim.Adam( 603 | filter(lambda p: p.requires_grad, unet.parameters()), lr=1e-4 604 | ) 605 | 606 | for name, param in unet.named_parameters(): 607 | if param.requires_grad: 608 | print(f"Parameter {name} dtype: {param.dtype}") 609 | if param.dtype == "torch.float32": 610 | print("**********************") 611 | 612 | # Prepare model and optimizer with accelerator 613 | unet, optimizer = accelerator.prepare(unet, optimizer) 614 | 615 | # Simulate input (keep as FP32) 616 | sample = torch.randn(1, 4, 64, 64).to(accelerator.device) 617 | timestep = torch.tensor([50]).to(accelerator.device) 618 | encoder_hidden_states = torch.randn(1, 77, 768).to(accelerator.device) 619 | 620 | # Forward pass 621 | optimizer.zero_grad() 622 | with accelerator.autocast(): 623 | output = unet(sample, timestep, encoder_hidden_states)["sample"] 624 | # Compute loss 625 | target = torch.randn_like(output) 626 | loss = F.mse_loss(output, target) 627 | 628 | # Backward pass 629 | accelerator.backward(loss) 630 | 631 | # Clip gradients (no need to unscale manually) 632 | params_to_clip = [p for p in unet.parameters() if p.requires_grad] 633 | accelerator.clip_grad_norm_(params_to_clip, max_norm=1.0) 634 | 635 | # Optimizer step 636 | optimizer.step() 637 | optimizer.zero_grad() 638 | 639 | 640 | if __name__ == "__main__": 641 | main() 642 | 643 | 644 | # if __name__ == "__main__": 645 | # main() 646 | -------------------------------------------------------------------------------- /training-free/sd_utils_x0_dpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | from typing import Any, Callable, Dict, List, Optional, Union 16 | 17 | import torch 18 | from packaging import version 19 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 20 | 21 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 22 | from diffusers.configuration_utils import FrozenDict 23 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 24 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin 25 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel 26 | from diffusers.models.lora import adjust_lora_scale_text_encoder 27 | from diffusers.schedulers import KarrasDiffusionSchedulers 28 | from diffusers.utils import ( 29 | USE_PEFT_BACKEND, 30 | deprecate, 31 | logging, 32 | replace_example_docstring, 33 | scale_lora_layers, 34 | unscale_lora_layers, 35 | ) 36 | from diffusers.utils.torch_utils import randn_tensor 37 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin 38 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput 39 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 40 | import torch.nn.functional as F 41 | import torch.optim as optim 42 | from diffusers.pipelines.stable_diffusion_attend_and_excite.pipeline_stable_diffusion_attend_and_excite import ( 43 | AttentionStore, 44 | AttendExciteAttnProcessor 45 | ) 46 | from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, align_wordpieces_indices, extract_attribution_indices, extract_attribution_indices_with_verbs, extract_attribution_indices_with_verb_root, extract_entities_only 47 | 48 | 49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 50 | 51 | EXAMPLE_DOC_STRING = """ 52 | Examples: 53 | ```py 54 | >>> import torch 55 | >>> from diffusers import StableDiffusionPipeline 56 | 57 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) 58 | >>> pipe = pipe.to("cuda") 59 | 60 | >>> prompt = "a photo of an astronaut riding a horse on mars" 61 | >>> image = pipe(prompt).images[0] 62 | ``` 63 | """ 64 | 65 | 66 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 67 | """ 68 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 69 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 70 | """ 71 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 72 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 73 | # rescale the results from guidance (fixes overexposure) 74 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 75 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 76 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 77 | return noise_cfg 78 | 79 | 80 | def retrieve_timesteps( 81 | scheduler, 82 | num_inference_steps: Optional[int] = None, 83 | device: Optional[Union[str, torch.device]] = None, 84 | timesteps: Optional[List[int]] = None, 85 | sigmas: Optional[List[float]] = None, 86 | **kwargs, 87 | ): 88 | """ 89 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 90 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 91 | 92 | Args: 93 | scheduler (`SchedulerMixin`): 94 | The scheduler to get timesteps from. 95 | num_inference_steps (`int`): 96 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 97 | must be `None`. 98 | device (`str` or `torch.device`, *optional*): 99 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 100 | timesteps (`List[int]`, *optional*): 101 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 102 | `num_inference_steps` and `sigmas` must be `None`. 103 | sigmas (`List[float]`, *optional*): 104 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 105 | `num_inference_steps` and `timesteps` must be `None`. 106 | 107 | Returns: 108 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 109 | second element is the number of inference steps. 110 | """ 111 | if timesteps is not None and sigmas is not None: 112 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 113 | if timesteps is not None: 114 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 115 | if not accepts_timesteps: 116 | raise ValueError( 117 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 118 | f" timestep schedules. Please check whether you are using the correct scheduler." 119 | ) 120 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 121 | timesteps = scheduler.timesteps 122 | num_inference_steps = len(timesteps) 123 | elif sigmas is not None: 124 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 125 | if not accept_sigmas: 126 | raise ValueError( 127 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 128 | f" sigmas schedules. Please check whether you are using the correct scheduler." 129 | ) 130 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 131 | timesteps = scheduler.timesteps 132 | num_inference_steps = len(timesteps) 133 | else: 134 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 135 | timesteps = scheduler.timesteps 136 | return timesteps, num_inference_steps 137 | 138 | 139 | def register_sd_forward(model): 140 | def sd_forward(self): 141 | # @torch.no_grad() 142 | @replace_example_docstring(EXAMPLE_DOC_STRING) 143 | def forward( 144 | prompt: Union[str, List[str]] = None, 145 | height: Optional[int] = None, 146 | width: Optional[int] = None, 147 | num_inference_steps: int = 50, 148 | timesteps: List[int] = None, 149 | sigmas: List[float] = None, 150 | guidance_scale: float = 7.5, 151 | negative_prompt: Optional[Union[str, List[str]]] = None, 152 | num_images_per_prompt: Optional[int] = 1, 153 | eta: float = 0.0, 154 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 155 | latents: Optional[torch.Tensor] = None, 156 | prompt_embeds: Optional[torch.Tensor] = None, 157 | negative_prompt_embeds: Optional[torch.Tensor] = None, 158 | ip_adapter_image: Optional[PipelineImageInput] = None, 159 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 160 | output_type: Optional[str] = "pil", 161 | return_dict: bool = True, 162 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 163 | guidance_rescale: float = 0.0, 164 | clip_skip: Optional[int] = None, 165 | callback_on_step_end: Optional[ 166 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 167 | ] = None, 168 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 169 | **kwargs, 170 | ): 171 | r""" 172 | The call function to the pipeline for generation. 173 | 174 | Args: 175 | prompt (`str` or `List[str]`, *optional*): 176 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 177 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 178 | The height in pixels of the generated image. 179 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 180 | The width in pixels of the generated image. 181 | num_inference_steps (`int`, *optional*, defaults to 50): 182 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 183 | expense of slower inference. 184 | timesteps (`List[int]`, *optional*): 185 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 186 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 187 | passed will be used. Must be in descending order. 188 | sigmas (`List[float]`, *optional*): 189 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 190 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 191 | will be used. 192 | guidance_scale (`float`, *optional*, defaults to 7.5): 193 | A higher guidance scale value encourages the model to generate images closely linked to the text 194 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 195 | negative_prompt (`str` or `List[str]`, *optional*): 196 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 197 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 198 | num_images_per_prompt (`int`, *optional*, defaults to 1): 199 | The number of images to generate per prompt. 200 | eta (`float`, *optional*, defaults to 0.0): 201 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 202 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 203 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 204 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 205 | generation deterministic. 206 | latents (`torch.Tensor`, *optional*): 207 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 208 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 209 | tensor is generated by sampling using the supplied random `generator`. 210 | prompt_embeds (`torch.Tensor`, *optional*): 211 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 212 | provided, text embeddings are generated from the `prompt` input argument. 213 | negative_prompt_embeds (`torch.Tensor`, *optional*): 214 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 215 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 216 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 217 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 218 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 219 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should 220 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not 221 | provided, embeddings are computed from the `ip_adapter_image` input argument. 222 | output_type (`str`, *optional*, defaults to `"pil"`): 223 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 224 | return_dict (`bool`, *optional*, defaults to `True`): 225 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 226 | plain tuple. 227 | cross_attention_kwargs (`dict`, *optional*): 228 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 229 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 230 | guidance_rescale (`float`, *optional*, defaults to 0.0): 231 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are 232 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when 233 | using zero terminal SNR. 234 | clip_skip (`int`, *optional*): 235 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 236 | the output of the pre-final layer will be used for computing the prompt embeddings. 237 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 238 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 239 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 240 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 241 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 242 | callback_on_step_end_tensor_inputs (`List`, *optional*): 243 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 244 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 245 | `._callback_tensor_inputs` attribute of your pipeline class. 246 | 247 | Examples: 248 | 249 | Returns: 250 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 251 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 252 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 253 | second element is a list of `bool`s indicating whether the corresponding generated image contains 254 | "not-safe-for-work" (nsfw) content. 255 | """ 256 | 257 | callback = kwargs.pop("callback", None) 258 | callback_steps = kwargs.pop("callback_steps", None) 259 | 260 | if callback is not None: 261 | deprecate( 262 | "callback", 263 | "1.0.0", 264 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 265 | ) 266 | if callback_steps is not None: 267 | deprecate( 268 | "callback_steps", 269 | "1.0.0", 270 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 271 | ) 272 | 273 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 274 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 275 | 276 | # 0. Default height and width to unet 277 | height = height or self.unet.config.sample_size * self.vae_scale_factor 278 | width = width or self.unet.config.sample_size * self.vae_scale_factor 279 | # to deal with lora scaling and other possible forward hooks 280 | 281 | # 1. Check inputs. Raise error if not correct 282 | self.check_inputs( 283 | prompt, 284 | height, 285 | width, 286 | callback_steps, 287 | negative_prompt, 288 | prompt_embeds, 289 | negative_prompt_embeds, 290 | ip_adapter_image, 291 | ip_adapter_image_embeds, 292 | callback_on_step_end_tensor_inputs, 293 | ) 294 | 295 | self._guidance_scale = guidance_scale 296 | self._guidance_rescale = guidance_rescale 297 | self._clip_skip = clip_skip 298 | self._cross_attention_kwargs = cross_attention_kwargs 299 | self._interrupt = False 300 | 301 | # 2. Define call parameters 302 | if prompt is not None and isinstance(prompt, str): 303 | batch_size = 1 304 | elif prompt is not None and isinstance(prompt, list): 305 | batch_size = len(prompt) 306 | else: 307 | batch_size = prompt_embeds.shape[0] 308 | 309 | device = self._execution_device 310 | 311 | # 3. Encode input prompt 312 | lora_scale = ( 313 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 314 | ) 315 | 316 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 317 | prompt, 318 | device, 319 | num_images_per_prompt, 320 | self.do_classifier_free_guidance, 321 | negative_prompt, 322 | prompt_embeds=prompt_embeds, 323 | negative_prompt_embeds=negative_prompt_embeds, 324 | lora_scale=lora_scale, 325 | clip_skip=self.clip_skip, 326 | ) 327 | # print(prompt_embeds.size()) 328 | # assert 2==1 329 | 330 | # For classifier free guidance, we need to do two forward passes. 331 | # Here we concatenate the unconditional and text embeddings into a single batch 332 | # to avoid doing two forward passes 333 | if self.do_classifier_free_guidance: 334 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 335 | 336 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 337 | image_embeds = self.prepare_ip_adapter_image_embeds( 338 | ip_adapter_image, 339 | ip_adapter_image_embeds, 340 | device, 341 | batch_size * num_images_per_prompt, 342 | self.do_classifier_free_guidance, 343 | ) 344 | 345 | # 4. Prepare timesteps 346 | timesteps, num_inference_steps = retrieve_timesteps( 347 | self.scheduler, num_inference_steps, device, timesteps, sigmas 348 | ) 349 | 350 | # 5. Prepare latent variables 351 | num_channels_latents = self.unet.config.in_channels 352 | latents = self.prepare_latents( 353 | batch_size * num_images_per_prompt, 354 | num_channels_latents, 355 | height, 356 | width, 357 | prompt_embeds.dtype, 358 | device, 359 | generator, 360 | latents, 361 | ) 362 | 363 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 364 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 365 | 366 | # 6.1 Add image embeds for IP-Adapter 367 | added_cond_kwargs = ( 368 | {"image_embeds": image_embeds} 369 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) 370 | else None 371 | ) 372 | 373 | # 6.2 Optionally get Guidance Scale Embedding 374 | timestep_cond = None 375 | if self.unet.config.time_cond_proj_dim is not None: 376 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 377 | timestep_cond = self.get_guidance_scale_embedding( 378 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 379 | ).to(device=device, dtype=latents.dtype) 380 | 381 | # 7. Denoising loop 382 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 383 | self._num_timesteps = len(timesteps) 384 | with self.progress_bar(total=num_inference_steps) as progress_bar: 385 | for i, t in enumerate(timesteps): 386 | if self.interrupt: 387 | continue 388 | 389 | if i < self.stop_step: 390 | continue 391 | 392 | 393 | # expand the latents if we are doing classifier free guidance 394 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 395 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 396 | # print(self.unet) 397 | # assert 2==1 398 | # predict the noise residual 399 | noise_pred = self.unet( 400 | latent_model_input, 401 | t, 402 | encoder_hidden_states=prompt_embeds, 403 | timestep_cond=timestep_cond, 404 | cross_attention_kwargs=self.cross_attention_kwargs, 405 | added_cond_kwargs=added_cond_kwargs, 406 | return_dict=False, 407 | )[0] 408 | 409 | # perform guidance 410 | if self.do_classifier_free_guidance: 411 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 412 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 413 | 414 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: 415 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 416 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) 417 | 418 | # compute the previous noisy sample x_t -> x_t-1 419 | latents, pred_original_sample = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False) 420 | 421 | 422 | if callback_on_step_end is not None: 423 | callback_kwargs = {} 424 | for k in callback_on_step_end_tensor_inputs: 425 | callback_kwargs[k] = locals()[k] 426 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 427 | 428 | latents = callback_outputs.pop("latents", latents) 429 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 430 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 431 | 432 | # call the callback, if provided 433 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 434 | progress_bar.update() 435 | if callback is not None and i % callback_steps == 0: 436 | step_idx = i // getattr(self.scheduler, "order", 1) 437 | callback(step_idx, t, latents) 438 | 439 | if i >= self.stop_step: 440 | # alpha_t = self.alphas[int(t.item())] ** 0.5 441 | # sigma_t = (1 - self.alphas[int(t.item())]) ** 0.5 442 | # latents = (latents - sigma_t * noise_pred) / alpha_t 443 | latents_cur = latents 444 | latents = pred_original_sample 445 | break 446 | if not output_type == "latent": 447 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 448 | 0 449 | ] 450 | # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 451 | else: 452 | has_nsfw_concept = None 453 | 454 | 455 | # if has_nsfw_concept is None: 456 | # do_denormalize = [True] * image.shape[0] 457 | # else: 458 | # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 459 | 460 | # image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 461 | # image = (image / 2 + 0.5).clamp(0, 1) 462 | image = (image / 2 + 0.5).clamp(0, 1) 463 | # optimizer = optim.Adam(self.unet.parameters(), lr=1e-1) 464 | # input_tensor = image # 模拟一个输入图像 465 | # target = torch.randn_like(input_tensor).to(noise_pred.device).half() # 目标与输入形状相同 466 | # loss = F.mse_loss(image, target) 467 | # optimizer.zero_grad() 468 | # loss.backward() 469 | # for name, param in self.unet.named_parameters(): 470 | # if param.grad is not None: 471 | # print(f"Gradient for {name}: {param.grad}") 472 | # else: 473 | # print(f"No gradient for {name}") 474 | 475 | # assert 2==1 476 | 477 | # Offload all models 478 | self.maybe_free_model_hooks() 479 | 480 | if not return_dict: 481 | # return (image, has_nsfw_concept) 482 | return (image, latents_cur) 483 | 484 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), latents_cur 485 | return forward 486 | if model.__class__.__name__ == 'StableDiffusionPipeline': 487 | # model.__call__ = sdturbo_forward(model) 488 | model.forward = sd_forward(model) 489 | 490 | from typing import Tuple 491 | from dataclasses import dataclass 492 | from diffusers.schedulers.scheduling_lcm import LCMSchedulerOutput 493 | from diffusers.utils.torch_utils import randn_tensor 494 | from diffusers.utils import BaseOutput 495 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput 496 | 497 | def register_sdschedule_step(model): 498 | def sd_schedule_step(self): 499 | def step( 500 | model_output: torch.Tensor, 501 | timestep: Union[int, torch.Tensor], 502 | sample: torch.Tensor, 503 | generator=None, 504 | variance_noise: Optional[torch.Tensor] = None, 505 | return_dict: bool = True, 506 | ) -> Union[SchedulerOutput, Tuple]: 507 | """ 508 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with 509 | the multistep DPMSolver. 510 | 511 | Args: 512 | model_output (`torch.Tensor`): 513 | The direct output from learned diffusion model. 514 | timestep (`int`): 515 | The current discrete timestep in the diffusion chain. 516 | sample (`torch.Tensor`): 517 | A current instance of a sample created by the diffusion process. 518 | generator (`torch.Generator`, *optional*): 519 | A random number generator. 520 | variance_noise (`torch.Tensor`): 521 | Alternative to generating noise with `generator` by directly providing the noise for the variance 522 | itself. Useful for methods such as [`LEdits++`]. 523 | return_dict (`bool`): 524 | Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. 525 | 526 | Returns: 527 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: 528 | If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a 529 | tuple is returned where the first element is the sample tensor. 530 | 531 | """ 532 | if self.num_inference_steps is None: 533 | raise ValueError( 534 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 535 | ) 536 | 537 | if self.step_index is None: 538 | self._init_step_index(timestep) 539 | 540 | # Improve numerical stability for small number of steps 541 | lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( 542 | self.config.euler_at_final 543 | or (self.config.lower_order_final and len(self.timesteps) < 15) 544 | or self.config.final_sigmas_type == "zero" 545 | ) 546 | lower_order_second = ( 547 | (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 548 | ) 549 | 550 | model_output = self.convert_model_output(model_output, sample=sample) 551 | for i in range(self.config.solver_order - 1): 552 | self.model_outputs[i] = self.model_outputs[i + 1] 553 | self.model_outputs[-1] = model_output 554 | 555 | # Upcast to avoid precision issues when computing prev_sample 556 | sample = sample.to(torch.float32) 557 | if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: 558 | noise = randn_tensor( 559 | model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 560 | ) 561 | elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: 562 | noise = variance_noise.to(device=model_output.device, dtype=torch.float32) 563 | else: 564 | noise = None 565 | 566 | if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: 567 | prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) 568 | elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: 569 | prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) 570 | else: 571 | prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) 572 | 573 | if self.lower_order_nums < self.config.solver_order: 574 | self.lower_order_nums += 1 575 | 576 | # Cast sample back to expected dtype 577 | prev_sample = prev_sample.to(model_output.dtype) 578 | 579 | # upon completion increase step index by one 580 | self._step_index += 1 581 | 582 | if not return_dict: 583 | return (prev_sample, model_output) 584 | 585 | return SchedulerOutput(prev_sample=prev_sample) 586 | 587 | return step 588 | if model.__class__.__name__ == 'DPMSolverMultistepScheduler': 589 | model.step = sd_schedule_step(model) 590 | -------------------------------------------------------------------------------- /training/train_sd_lora.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" 17 | 18 | import argparse 19 | import logging 20 | import math 21 | import os 22 | import random 23 | import shutil 24 | from contextlib import nullcontext 25 | from pathlib import Path 26 | 27 | import datasets 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from datasets import load_dataset 37 | from huggingface_hub import create_repo, upload_folder 38 | from packaging import version 39 | from peft import LoraConfig 40 | from peft.utils import get_peft_model_state_dict 41 | from torchvision import transforms 42 | from tqdm.auto import tqdm 43 | from transformers import CLIPTextModel, CLIPTokenizer 44 | 45 | import diffusers 46 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel 47 | from diffusers.optimization import get_scheduler 48 | from diffusers.training_utils import cast_training_params, compute_snr 49 | from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available 50 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 51 | from diffusers.utils.import_utils import is_xformers_available 52 | from diffusers.utils.torch_utils import is_compiled_module 53 | 54 | 55 | if is_wandb_available(): 56 | import wandb 57 | 58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 59 | # check_min_version("0.31.0.dev0") 60 | 61 | logger = get_logger(__name__, log_level="INFO") 62 | 63 | 64 | def save_model_card( 65 | repo_id: str, 66 | images: list = None, 67 | base_model: str = None, 68 | dataset_name: str = None, 69 | repo_folder: str = None, 70 | ): 71 | img_str = "" 72 | if images is not None: 73 | for i, image in enumerate(images): 74 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 75 | img_str += f"![img_{i}](./image_{i}.png)\n" 76 | 77 | model_description = f""" 78 | # LoRA text2image fine-tuning - {repo_id} 79 | These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n 80 | {img_str} 81 | """ 82 | 83 | model_card = load_or_create_model_card( 84 | repo_id_or_path=repo_id, 85 | from_training=True, 86 | license="creativeml-openrail-m", 87 | base_model=base_model, 88 | model_description=model_description, 89 | inference=True, 90 | ) 91 | 92 | tags = [ 93 | "stable-diffusion", 94 | "stable-diffusion-diffusers", 95 | "text-to-image", 96 | "diffusers", 97 | "diffusers-training", 98 | "lora", 99 | ] 100 | model_card = populate_model_card(model_card, tags=tags) 101 | 102 | model_card.save(os.path.join(repo_folder, "README.md")) 103 | 104 | 105 | def log_validation( 106 | pipeline, 107 | args, 108 | accelerator, 109 | epoch, 110 | is_final_validation=False, 111 | ): 112 | logger.info( 113 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 114 | f" {args.validation_prompt}." 115 | ) 116 | pipeline = pipeline.to(accelerator.device) 117 | pipeline.set_progress_bar_config(disable=True) 118 | generator = torch.Generator(device=accelerator.device) 119 | if args.seed is not None: 120 | generator = generator.manual_seed(args.seed) 121 | images = [] 122 | if torch.backends.mps.is_available(): 123 | autocast_ctx = nullcontext() 124 | else: 125 | autocast_ctx = torch.autocast(accelerator.device.type) 126 | 127 | with autocast_ctx: 128 | for _ in range(args.num_validation_images): 129 | images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) 130 | 131 | for tracker in accelerator.trackers: 132 | phase_name = "test" if is_final_validation else "validation" 133 | if tracker.name == "tensorboard": 134 | np_images = np.stack([np.asarray(img) for img in images]) 135 | tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") 136 | if tracker.name == "wandb": 137 | tracker.log( 138 | { 139 | phase_name: [ 140 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 141 | ] 142 | } 143 | ) 144 | return images 145 | 146 | 147 | def parse_args(): 148 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 149 | parser.add_argument( 150 | "--pretrained_model_name_or_path", 151 | type=str, 152 | default=None, 153 | required=True, 154 | help="Path to pretrained model or model identifier from huggingface.co/models.", 155 | ) 156 | parser.add_argument( 157 | "--revision", 158 | type=str, 159 | default=None, 160 | required=False, 161 | help="Revision of pretrained model identifier from huggingface.co/models.", 162 | ) 163 | parser.add_argument( 164 | "--variant", 165 | type=str, 166 | default=None, 167 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 168 | ) 169 | parser.add_argument( 170 | "--dataset_name", 171 | type=str, 172 | default=None, 173 | help=( 174 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 175 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 176 | " or to a folder containing files that 🤗 Datasets can understand." 177 | ), 178 | ) 179 | parser.add_argument( 180 | "--dataset_config_name", 181 | type=str, 182 | default=None, 183 | help="The config of the Dataset, leave as None if there's only one config.", 184 | ) 185 | parser.add_argument( 186 | "--train_data_dir", 187 | type=str, 188 | default=None, 189 | help=( 190 | "A folder containing the training data. Folder contents must follow the structure described in" 191 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 192 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 193 | ), 194 | ) 195 | parser.add_argument( 196 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 197 | ) 198 | parser.add_argument( 199 | "--caption_column", 200 | type=str, 201 | default="text", 202 | help="The column of the dataset containing a caption or a list of captions.", 203 | ) 204 | parser.add_argument( 205 | "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." 206 | ) 207 | parser.add_argument( 208 | "--num_validation_images", 209 | type=int, 210 | default=4, 211 | help="Number of images that should be generated during validation with `validation_prompt`.", 212 | ) 213 | parser.add_argument( 214 | "--validation_epochs", 215 | type=int, 216 | default=1, 217 | help=( 218 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" 219 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 220 | ), 221 | ) 222 | parser.add_argument( 223 | "--max_train_samples", 224 | type=int, 225 | default=None, 226 | help=( 227 | "For debugging purposes or quicker training, truncate the number of training examples to this " 228 | "value if set." 229 | ), 230 | ) 231 | parser.add_argument( 232 | "--output_dir", 233 | type=str, 234 | default="sd-model-finetuned-lora", 235 | help="The output directory where the model predictions and checkpoints will be written.", 236 | ) 237 | parser.add_argument( 238 | "--cache_dir", 239 | type=str, 240 | default=None, 241 | help="The directory where the downloaded models and datasets will be stored.", 242 | ) 243 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 244 | parser.add_argument( 245 | "--resolution", 246 | type=int, 247 | default=512, 248 | help=( 249 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 250 | " resolution" 251 | ), 252 | ) 253 | parser.add_argument( 254 | "--center_crop", 255 | default=False, 256 | action="store_true", 257 | help=( 258 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 259 | " cropped. The images will be resized to the resolution first before cropping." 260 | ), 261 | ) 262 | parser.add_argument( 263 | "--random_flip", 264 | action="store_true", 265 | help="whether to randomly flip images horizontally", 266 | ) 267 | parser.add_argument( 268 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 269 | ) 270 | parser.add_argument("--num_train_epochs", type=int, default=100) 271 | parser.add_argument( 272 | "--max_train_steps", 273 | type=int, 274 | default=None, 275 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 276 | ) 277 | parser.add_argument( 278 | "--gradient_accumulation_steps", 279 | type=int, 280 | default=1, 281 | help="Number of updates steps to accumulate before performing a backward/update pass.", 282 | ) 283 | parser.add_argument( 284 | "--gradient_checkpointing", 285 | action="store_true", 286 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 287 | ) 288 | parser.add_argument( 289 | "--learning_rate", 290 | type=float, 291 | default=1e-4, 292 | help="Initial learning rate (after the potential warmup period) to use.", 293 | ) 294 | parser.add_argument( 295 | "--scale_lr", 296 | action="store_true", 297 | default=False, 298 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 299 | ) 300 | parser.add_argument( 301 | "--lr_scheduler", 302 | type=str, 303 | default="constant", 304 | help=( 305 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 306 | ' "constant", "constant_with_warmup"]' 307 | ), 308 | ) 309 | parser.add_argument( 310 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 311 | ) 312 | parser.add_argument( 313 | "--snr_gamma", 314 | type=float, 315 | default=None, 316 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 317 | "More details here: https://arxiv.org/abs/2303.09556.", 318 | ) 319 | parser.add_argument( 320 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 321 | ) 322 | parser.add_argument( 323 | "--allow_tf32", 324 | action="store_true", 325 | help=( 326 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 327 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 328 | ), 329 | ) 330 | parser.add_argument( 331 | "--dataloader_num_workers", 332 | type=int, 333 | default=0, 334 | help=( 335 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 336 | ), 337 | ) 338 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 339 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 340 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 341 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 342 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 343 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 344 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 345 | parser.add_argument( 346 | "--prediction_type", 347 | type=str, 348 | default=None, 349 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.", 350 | ) 351 | parser.add_argument( 352 | "--hub_model_id", 353 | type=str, 354 | default=None, 355 | help="The name of the repository to keep in sync with the local `output_dir`.", 356 | ) 357 | parser.add_argument( 358 | "--logging_dir", 359 | type=str, 360 | default="logs", 361 | help=( 362 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 363 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 364 | ), 365 | ) 366 | parser.add_argument( 367 | "--mixed_precision", 368 | type=str, 369 | default=None, 370 | choices=["no", "fp16", "bf16"], 371 | help=( 372 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 373 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 374 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 375 | ), 376 | ) 377 | parser.add_argument( 378 | "--report_to", 379 | type=str, 380 | default="tensorboard", 381 | help=( 382 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 383 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 384 | ), 385 | ) 386 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 387 | parser.add_argument( 388 | "--checkpointing_steps", 389 | type=int, 390 | default=500, 391 | help=( 392 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 393 | " training using `--resume_from_checkpoint`." 394 | ), 395 | ) 396 | parser.add_argument( 397 | "--checkpoints_total_limit", 398 | type=int, 399 | default=None, 400 | help=("Max number of checkpoints to store."), 401 | ) 402 | parser.add_argument( 403 | "--resume_from_checkpoint", 404 | type=str, 405 | default=None, 406 | help=( 407 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 408 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 409 | ), 410 | ) 411 | parser.add_argument( 412 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 413 | ) 414 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 415 | parser.add_argument( 416 | "--rank", 417 | type=int, 418 | default=4, 419 | help=("The dimension of the LoRA update matrices."), 420 | ) 421 | 422 | args = parser.parse_args() 423 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 424 | if env_local_rank != -1 and env_local_rank != args.local_rank: 425 | args.local_rank = env_local_rank 426 | 427 | # Sanity checks 428 | if args.dataset_name is None and args.train_data_dir is None: 429 | raise ValueError("Need either a dataset name or a training folder.") 430 | 431 | return args 432 | 433 | 434 | DATASET_NAME_MAPPING = { 435 | "lambdalabs/naruto-blip-captions": ("image", "text"), 436 | } 437 | 438 | 439 | def main(): 440 | args = parse_args() 441 | if args.report_to == "wandb" and args.hub_token is not None: 442 | raise ValueError( 443 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 444 | " Please use `huggingface-cli login` to authenticate with the Hub." 445 | ) 446 | 447 | logging_dir = Path(args.output_dir, args.logging_dir) 448 | 449 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 450 | 451 | accelerator = Accelerator( 452 | gradient_accumulation_steps=args.gradient_accumulation_steps, 453 | mixed_precision=args.mixed_precision, 454 | log_with=args.report_to, 455 | project_config=accelerator_project_config, 456 | ) 457 | 458 | # Disable AMP for MPS. 459 | if torch.backends.mps.is_available(): 460 | accelerator.native_amp = False 461 | 462 | # Make one log on every process with the configuration for debugging. 463 | logging.basicConfig( 464 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 465 | datefmt="%m/%d/%Y %H:%M:%S", 466 | level=logging.INFO, 467 | ) 468 | logger.info(accelerator.state, main_process_only=False) 469 | if accelerator.is_local_main_process: 470 | datasets.utils.logging.set_verbosity_warning() 471 | transformers.utils.logging.set_verbosity_warning() 472 | diffusers.utils.logging.set_verbosity_info() 473 | else: 474 | datasets.utils.logging.set_verbosity_error() 475 | transformers.utils.logging.set_verbosity_error() 476 | diffusers.utils.logging.set_verbosity_error() 477 | 478 | # If passed along, set the training seed now. 479 | if args.seed is not None: 480 | set_seed(args.seed) 481 | 482 | # Handle the repository creation 483 | if accelerator.is_main_process: 484 | if args.output_dir is not None: 485 | os.makedirs(args.output_dir, exist_ok=True) 486 | 487 | if args.push_to_hub: 488 | repo_id = create_repo( 489 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 490 | ).repo_id 491 | # Load scheduler, tokenizer and models. 492 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 493 | tokenizer = CLIPTokenizer.from_pretrained( 494 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision 495 | ) 496 | text_encoder = CLIPTextModel.from_pretrained( 497 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 498 | ) 499 | vae = AutoencoderKL.from_pretrained( 500 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 501 | ) 502 | unet = UNet2DConditionModel.from_pretrained( 503 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 504 | ) 505 | # freeze parameters of models to save more memory 506 | unet.requires_grad_(False) 507 | vae.requires_grad_(False) 508 | text_encoder.requires_grad_(False) 509 | 510 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision 511 | # as these weights are only used for inference, keeping weights in full precision is not required. 512 | weight_dtype = torch.float32 513 | if accelerator.mixed_precision == "fp16": 514 | weight_dtype = torch.float16 515 | elif accelerator.mixed_precision == "bf16": 516 | weight_dtype = torch.bfloat16 517 | 518 | # Freeze the unet parameters before adding adapters 519 | for param in unet.parameters(): 520 | param.requires_grad_(False) 521 | 522 | unet_lora_config = LoraConfig( 523 | r=256, 524 | lora_alpha=512, 525 | init_lora_weights="gaussian", 526 | target_modules=["to_k", "to_q", "to_v", "to_out.0"], 527 | ) 528 | 529 | # Move unet, vae and text_encoder to device and cast to weight_dtype 530 | unet.to(accelerator.device, dtype=weight_dtype) 531 | vae.to(accelerator.device, dtype=weight_dtype) 532 | text_encoder.to(accelerator.device, dtype=weight_dtype) 533 | 534 | # Add adapter and make sure the trainable params are in float32. 535 | unet.add_adapter(unet_lora_config) 536 | if args.mixed_precision == "fp16": 537 | # only upcast trainable parameters (LoRA) into fp32 538 | cast_training_params(unet, dtype=torch.float32) 539 | 540 | if args.enable_xformers_memory_efficient_attention: 541 | if is_xformers_available(): 542 | import xformers 543 | 544 | xformers_version = version.parse(xformers.__version__) 545 | if xformers_version == version.parse("0.0.16"): 546 | logger.warning( 547 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 548 | ) 549 | unet.enable_xformers_memory_efficient_attention() 550 | else: 551 | raise ValueError("xformers is not available. Make sure it is installed correctly") 552 | 553 | 554 | for n, p in unet.named_parameters(): 555 | if not p.requires_grad: 556 | continue # frozen weights 557 | p.data = p.data.to(torch.float32) 558 | 559 | 560 | lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) 561 | 562 | 563 | 564 | # for i, param in enumerate(lora_layers): 565 | # print(f"Layer {i}: dtype = {param.dtype}") 566 | 567 | if args.gradient_checkpointing: 568 | unet.enable_gradient_checkpointing() 569 | 570 | # Enable TF32 for faster training on Ampere GPUs, 571 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 572 | if args.allow_tf32: 573 | torch.backends.cuda.matmul.allow_tf32 = True 574 | 575 | if args.scale_lr: 576 | args.learning_rate = ( 577 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 578 | ) 579 | 580 | # Initialize the optimizer 581 | if args.use_8bit_adam: 582 | try: 583 | import bitsandbytes as bnb 584 | except ImportError: 585 | raise ImportError( 586 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 587 | ) 588 | 589 | optimizer_cls = bnb.optim.AdamW8bit 590 | else: 591 | optimizer_cls = torch.optim.AdamW 592 | 593 | optimizer = optimizer_cls( 594 | lora_layers, 595 | lr=args.learning_rate, 596 | betas=(args.adam_beta1, args.adam_beta2), 597 | weight_decay=args.adam_weight_decay, 598 | eps=args.adam_epsilon, 599 | ) 600 | 601 | # Get the datasets: you can either provide your own training and evaluation files (see below) 602 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 603 | 604 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 605 | # download the dataset. 606 | if args.dataset_name is not None: 607 | # Downloading and loading a dataset from the hub. 608 | dataset = load_dataset( 609 | args.dataset_name, 610 | args.dataset_config_name, 611 | cache_dir=args.cache_dir, 612 | data_dir=args.train_data_dir, 613 | ) 614 | else: 615 | data_files = {} 616 | if args.train_data_dir is not None: 617 | data_files["train"] = os.path.join(args.train_data_dir, "**") 618 | dataset = load_dataset( 619 | "imagefolder", 620 | data_files=data_files, 621 | cache_dir=args.cache_dir, 622 | ) 623 | # See more about loading custom images at 624 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 625 | 626 | # Preprocessing the datasets. 627 | # We need to tokenize inputs and targets. 628 | column_names = dataset["train"].column_names 629 | 630 | # 6. Get the column names for input/target. 631 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 632 | if args.image_column is None: 633 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 634 | else: 635 | image_column = args.image_column 636 | if image_column not in column_names: 637 | raise ValueError( 638 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 639 | ) 640 | if args.caption_column is None: 641 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 642 | else: 643 | caption_column = args.caption_column 644 | if caption_column not in column_names: 645 | raise ValueError( 646 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 647 | ) 648 | 649 | # Preprocessing the datasets. 650 | # We need to tokenize input captions and transform the images. 651 | def tokenize_captions(examples, is_train=True): 652 | captions = [] 653 | for caption in examples[caption_column]: 654 | if isinstance(caption, str): 655 | captions.append(caption) 656 | elif isinstance(caption, (list, np.ndarray)): 657 | # take a random caption if there are multiple 658 | captions.append(random.choice(caption) if is_train else caption[0]) 659 | else: 660 | raise ValueError( 661 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 662 | ) 663 | inputs = tokenizer( 664 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 665 | ) 666 | return inputs.input_ids 667 | 668 | # Preprocessing the datasets. 669 | train_transforms = transforms.Compose( 670 | [ 671 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 672 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 673 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 674 | transforms.ToTensor(), 675 | transforms.Normalize([0.5], [0.5]), 676 | ] 677 | ) 678 | 679 | def unwrap_model(model): 680 | model = accelerator.unwrap_model(model) 681 | model = model._orig_mod if is_compiled_module(model) else model 682 | return model 683 | 684 | def preprocess_train(examples): 685 | images = [image.convert("RGB") for image in examples[image_column]] 686 | examples["pixel_values"] = [train_transforms(image) for image in images] 687 | examples["input_ids"] = tokenize_captions(examples) 688 | return examples 689 | 690 | with accelerator.main_process_first(): 691 | if args.max_train_samples is not None: 692 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 693 | # Set the training transforms 694 | train_dataset = dataset["train"].with_transform(preprocess_train) 695 | 696 | def collate_fn(examples): 697 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 698 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 699 | input_ids = torch.stack([example["input_ids"] for example in examples]) 700 | return {"pixel_values": pixel_values, "input_ids": input_ids} 701 | 702 | # DataLoaders creation: 703 | train_dataloader = torch.utils.data.DataLoader( 704 | train_dataset, 705 | shuffle=True, 706 | collate_fn=collate_fn, 707 | batch_size=args.train_batch_size, 708 | num_workers=args.dataloader_num_workers, 709 | ) 710 | 711 | # Scheduler and math around the number of training steps. 712 | # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. 713 | num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes 714 | if args.max_train_steps is None: 715 | len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) 716 | num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) 717 | num_training_steps_for_scheduler = ( 718 | args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes 719 | ) 720 | else: 721 | num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes 722 | 723 | lr_scheduler = get_scheduler( 724 | args.lr_scheduler, 725 | optimizer=optimizer, 726 | num_warmup_steps=num_warmup_steps_for_scheduler, 727 | num_training_steps=num_training_steps_for_scheduler, 728 | ) 729 | 730 | # Prepare everything with our `accelerator`. 731 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 732 | unet, optimizer, train_dataloader, lr_scheduler 733 | ) 734 | 735 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 736 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 737 | if args.max_train_steps is None: 738 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 739 | if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: 740 | logger.warning( 741 | f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " 742 | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " 743 | f"This inconsistency may result in the learning rate scheduler not functioning properly." 744 | ) 745 | # Afterwards we recalculate our number of training epochs 746 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 747 | 748 | # We need to initialize the trackers we use, and also store our configuration. 749 | # The trackers initializes automatically on the main process. 750 | if accelerator.is_main_process: 751 | accelerator.init_trackers("text2image-fine-tune", config=vars(args)) 752 | 753 | # Train! 754 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 755 | 756 | logger.info("***** Running training *****") 757 | logger.info(f" Num examples = {len(train_dataset)}") 758 | logger.info(f" Num Epochs = {args.num_train_epochs}") 759 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 760 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 761 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 762 | logger.info(f" Total optimization steps = {args.max_train_steps}") 763 | global_step = 0 764 | first_epoch = 0 765 | 766 | # Potentially load in the weights and states from a previous save 767 | if args.resume_from_checkpoint: 768 | if args.resume_from_checkpoint != "latest": 769 | path = os.path.basename(args.resume_from_checkpoint) 770 | else: 771 | # Get the most recent checkpoint 772 | dirs = os.listdir(args.output_dir) 773 | dirs = [d for d in dirs if d.startswith("checkpoint")] 774 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 775 | path = dirs[-1] if len(dirs) > 0 else None 776 | 777 | if path is None: 778 | accelerator.print( 779 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 780 | ) 781 | args.resume_from_checkpoint = None 782 | initial_global_step = 0 783 | else: 784 | accelerator.print(f"Resuming from checkpoint {path}") 785 | accelerator.load_state(os.path.join(args.output_dir, path)) 786 | global_step = int(path.split("-")[1]) 787 | 788 | initial_global_step = global_step 789 | first_epoch = global_step // num_update_steps_per_epoch 790 | else: 791 | initial_global_step = 0 792 | 793 | progress_bar = tqdm( 794 | range(0, args.max_train_steps), 795 | initial=initial_global_step, 796 | desc="Steps", 797 | # Only show the progress bar once on each machine. 798 | disable=not accelerator.is_local_main_process, 799 | ) 800 | 801 | for epoch in range(first_epoch, args.num_train_epochs): 802 | unet.train() 803 | train_loss = 0.0 804 | for step, batch in enumerate(train_dataloader): 805 | with accelerator.accumulate(unet): 806 | # Convert images to latent space 807 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 808 | latents = latents * vae.config.scaling_factor 809 | 810 | # Sample noise that we'll add to the latents 811 | noise = torch.randn_like(latents) 812 | if args.noise_offset: 813 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 814 | noise += args.noise_offset * torch.randn( 815 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 816 | ) 817 | 818 | bsz = latents.shape[0] 819 | # Sample a random timestep for each image 820 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 821 | timesteps = timesteps.long() 822 | 823 | # Add noise to the latents according to the noise magnitude at each timestep 824 | # (this is the forward diffusion process) 825 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 826 | 827 | # Get the text embedding for conditioning 828 | encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] 829 | 830 | # Get the target for loss depending on the prediction type 831 | if args.prediction_type is not None: 832 | # set prediction_type of scheduler if defined 833 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 834 | 835 | if noise_scheduler.config.prediction_type == "epsilon": 836 | target = noise 837 | elif noise_scheduler.config.prediction_type == "v_prediction": 838 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 839 | else: 840 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 841 | 842 | # Predict the noise residual and compute loss 843 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] 844 | 845 | if args.snr_gamma is None: 846 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 847 | else: 848 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 849 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 850 | # This is discussed in Section 4.2 of the same paper. 851 | snr = compute_snr(noise_scheduler, timesteps) 852 | mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( 853 | dim=1 854 | )[0] 855 | if noise_scheduler.config.prediction_type == "epsilon": 856 | mse_loss_weights = mse_loss_weights / snr 857 | elif noise_scheduler.config.prediction_type == "v_prediction": 858 | mse_loss_weights = mse_loss_weights / (snr + 1) 859 | 860 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 861 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 862 | loss = loss.mean() 863 | 864 | # Gather the losses across all processes for logging (if we use distributed training). 865 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 866 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 867 | 868 | # Backpropagate 869 | accelerator.backward(loss) 870 | if accelerator.sync_gradients: 871 | params_to_clip = lora_layers 872 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 873 | optimizer.step() 874 | lr_scheduler.step() 875 | optimizer.zero_grad() 876 | 877 | # Checks if the accelerator has performed an optimization step behind the scenes 878 | if accelerator.sync_gradients: 879 | progress_bar.update(1) 880 | global_step += 1 881 | accelerator.log({"train_loss": train_loss}, step=global_step) 882 | train_loss = 0.0 883 | 884 | if global_step % args.checkpointing_steps == 0: 885 | if accelerator.is_main_process: 886 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 887 | if args.checkpoints_total_limit is not None: 888 | checkpoints = os.listdir(args.output_dir) 889 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 890 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 891 | 892 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 893 | if len(checkpoints) >= args.checkpoints_total_limit: 894 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 895 | removing_checkpoints = checkpoints[0:num_to_remove] 896 | 897 | logger.info( 898 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 899 | ) 900 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 901 | 902 | for removing_checkpoint in removing_checkpoints: 903 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 904 | shutil.rmtree(removing_checkpoint) 905 | 906 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 907 | accelerator.save_state(save_path) 908 | 909 | unwrapped_unet = unwrap_model(unet) 910 | unet_lora_state_dict = convert_state_dict_to_diffusers( 911 | get_peft_model_state_dict(unwrapped_unet) 912 | ) 913 | 914 | StableDiffusionPipeline.save_lora_weights( 915 | save_directory=save_path, 916 | unet_lora_layers=unet_lora_state_dict, 917 | safe_serialization=True, 918 | ) 919 | 920 | logger.info(f"Saved state to {save_path}") 921 | 922 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 923 | progress_bar.set_postfix(**logs) 924 | 925 | if global_step >= args.max_train_steps: 926 | break 927 | 928 | if accelerator.is_main_process: 929 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: 930 | # create pipeline 931 | pipeline = DiffusionPipeline.from_pretrained( 932 | args.pretrained_model_name_or_path, 933 | unet=unwrap_model(unet), 934 | revision=args.revision, 935 | variant=args.variant, 936 | torch_dtype=weight_dtype, 937 | ) 938 | images = log_validation(pipeline, args, accelerator, epoch) 939 | 940 | del pipeline 941 | torch.cuda.empty_cache() 942 | 943 | # Save the lora layers 944 | accelerator.wait_for_everyone() 945 | if accelerator.is_main_process: 946 | unet = unet.to(torch.float32) 947 | 948 | unwrapped_unet = unwrap_model(unet) 949 | unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet)) 950 | StableDiffusionPipeline.save_lora_weights( 951 | save_directory=args.output_dir, 952 | unet_lora_layers=unet_lora_state_dict, 953 | safe_serialization=True, 954 | ) 955 | 956 | # Final inference 957 | # Load previous pipeline 958 | if args.validation_prompt is not None: 959 | pipeline = DiffusionPipeline.from_pretrained( 960 | args.pretrained_model_name_or_path, 961 | revision=args.revision, 962 | variant=args.variant, 963 | torch_dtype=weight_dtype, 964 | ) 965 | 966 | # load attention processors 967 | pipeline.load_lora_weights(args.output_dir) 968 | 969 | # run inference 970 | images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True) 971 | 972 | if args.push_to_hub: 973 | save_model_card( 974 | repo_id, 975 | images=images, 976 | base_model=args.pretrained_model_name_or_path, 977 | dataset_name=args.dataset_name, 978 | repo_folder=args.output_dir, 979 | ) 980 | upload_folder( 981 | repo_id=repo_id, 982 | folder_path=args.output_dir, 983 | commit_message="End of training", 984 | ignore_patterns=["step_*", "epoch_*"], 985 | ) 986 | 987 | accelerator.end_training() 988 | 989 | 990 | if __name__ == "__main__": 991 | main() -------------------------------------------------------------------------------- /training-free/sd_utils_x0_dpm_syn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | from typing import Any, Callable, Dict, List, Optional, Union 16 | 17 | import torch 18 | from packaging import version 19 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 20 | 21 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 22 | from diffusers.configuration_utils import FrozenDict 23 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 24 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin 25 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel 26 | from diffusers.models.lora import adjust_lora_scale_text_encoder 27 | from diffusers.schedulers import KarrasDiffusionSchedulers 28 | from diffusers.utils import ( 29 | USE_PEFT_BACKEND, 30 | deprecate, 31 | logging, 32 | replace_example_docstring, 33 | scale_lora_layers, 34 | unscale_lora_layers, 35 | ) 36 | from diffusers.utils.torch_utils import randn_tensor 37 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin 38 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput 39 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 40 | import torch.nn.functional as F 41 | import torch.optim as optim 42 | from diffusers.pipelines.stable_diffusion_attend_and_excite.pipeline_stable_diffusion_attend_and_excite import ( 43 | AttentionStore, 44 | AttendExciteAttnProcessor 45 | ) 46 | from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, align_wordpieces_indices, extract_attribution_indices, extract_attribution_indices_with_verbs, extract_attribution_indices_with_verb_root, extract_entities_only 47 | 48 | 49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 50 | 51 | EXAMPLE_DOC_STRING = """ 52 | Examples: 53 | ```py 54 | >>> import torch 55 | >>> from diffusers import StableDiffusionPipeline 56 | 57 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) 58 | >>> pipe = pipe.to("cuda") 59 | 60 | >>> prompt = "a photo of an astronaut riding a horse on mars" 61 | >>> image = pipe(prompt).images[0] 62 | ``` 63 | """ 64 | 65 | 66 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 67 | """ 68 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 69 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 70 | """ 71 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 72 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 73 | # rescale the results from guidance (fixes overexposure) 74 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 75 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 76 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 77 | return noise_cfg 78 | 79 | 80 | def retrieve_timesteps( 81 | scheduler, 82 | num_inference_steps: Optional[int] = None, 83 | device: Optional[Union[str, torch.device]] = None, 84 | timesteps: Optional[List[int]] = None, 85 | sigmas: Optional[List[float]] = None, 86 | **kwargs, 87 | ): 88 | """ 89 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 90 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 91 | 92 | Args: 93 | scheduler (`SchedulerMixin`): 94 | The scheduler to get timesteps from. 95 | num_inference_steps (`int`): 96 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 97 | must be `None`. 98 | device (`str` or `torch.device`, *optional*): 99 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 100 | timesteps (`List[int]`, *optional*): 101 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 102 | `num_inference_steps` and `sigmas` must be `None`. 103 | sigmas (`List[float]`, *optional*): 104 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 105 | `num_inference_steps` and `timesteps` must be `None`. 106 | 107 | Returns: 108 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 109 | second element is the number of inference steps. 110 | """ 111 | if timesteps is not None and sigmas is not None: 112 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 113 | if timesteps is not None: 114 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 115 | if not accepts_timesteps: 116 | raise ValueError( 117 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 118 | f" timestep schedules. Please check whether you are using the correct scheduler." 119 | ) 120 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 121 | timesteps = scheduler.timesteps 122 | num_inference_steps = len(timesteps) 123 | elif sigmas is not None: 124 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 125 | if not accept_sigmas: 126 | raise ValueError( 127 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 128 | f" sigmas schedules. Please check whether you are using the correct scheduler." 129 | ) 130 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 131 | timesteps = scheduler.timesteps 132 | num_inference_steps = len(timesteps) 133 | else: 134 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 135 | timesteps = scheduler.timesteps 136 | return timesteps, num_inference_steps 137 | 138 | 139 | import itertools 140 | from typing import Any, Callable, Dict, Optional, Union, List 141 | 142 | import spacy 143 | import torch 144 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel 145 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker 146 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 147 | EXAMPLE_DOC_STRING, 148 | rescale_noise_cfg 149 | ) 150 | from diffusers.pipelines.stable_diffusion_attend_and_excite.pipeline_stable_diffusion_attend_and_excite import ( 151 | AttentionStore, 152 | AttendExciteAttnProcessor 153 | ) 154 | import numpy as np 155 | from diffusers.schedulers import KarrasDiffusionSchedulers 156 | from diffusers.utils import ( 157 | logging, 158 | replace_example_docstring, 159 | ) 160 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 161 | 162 | from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, align_wordpieces_indices, extract_attribution_indices, extract_attribution_indices_with_verbs, extract_attribution_indices_with_verb_root, extract_entities_only 163 | 164 | 165 | logger = logging.get_logger(__name__) 166 | 167 | 168 | class SynGenMaskDiffusionPipeline(StableDiffusionPipeline): 169 | def __init__(self, 170 | vae: AutoencoderKL, 171 | text_encoder: CLIPTextModel, 172 | tokenizer: CLIPTokenizer, 173 | unet: UNet2DConditionModel, 174 | scheduler: KarrasDiffusionSchedulers, 175 | safety_checker: StableDiffusionSafetyChecker, 176 | feature_extractor: CLIPImageProcessor, 177 | image_encoder: CLIPVisionModelWithProjection = None, 178 | requires_safety_checker: bool = True, 179 | include_entities: bool = False, 180 | ): 181 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor,image_encoder, 182 | requires_safety_checker) 183 | 184 | self.parser = spacy.load("en_core_web_trf") 185 | self.subtrees_indices = None 186 | self.doc = None 187 | self.include_entities = include_entities 188 | 189 | def _aggregate_and_get_attention_maps_per_token(self): 190 | attention_maps = self.attention_store.aggregate_attention( 191 | from_where=("up", "down", "mid"), 192 | ) 193 | attention_maps_list = _get_attention_maps_list( 194 | attention_maps=attention_maps 195 | ) 196 | return attention_maps_list 197 | 198 | @staticmethod 199 | def _update_latent( 200 | latents: torch.Tensor, loss: torch.Tensor, step_size: float 201 | ) -> torch.Tensor: 202 | """Update the latent according to the computed loss.""" 203 | grad_cond = torch.autograd.grad( 204 | loss.requires_grad_(True), [latents], retain_graph=True 205 | )[0] 206 | latents = latents - step_size * grad_cond 207 | return latents 208 | 209 | def register_attention_control(self): 210 | attn_procs = {} 211 | cross_att_count = 0 212 | for name in self.unet.attn_processors.keys(): 213 | if name.startswith("mid_block"): 214 | place_in_unet = "mid" 215 | elif name.startswith("up_blocks"): 216 | place_in_unet = "up" 217 | elif name.startswith("down_blocks"): 218 | place_in_unet = "down" 219 | else: 220 | continue 221 | 222 | cross_att_count += 1 223 | attn_procs[name] = AttendExciteAttnProcessor( 224 | attnstore=self.attention_store, place_in_unet=place_in_unet 225 | ) 226 | 227 | self.unet.set_attn_processor(attn_procs) 228 | self.attention_store.num_att_layers = cross_att_count 229 | 230 | # @torch.no_grad() 231 | @replace_example_docstring(EXAMPLE_DOC_STRING) 232 | def __call__( 233 | self, 234 | prompt: Union[str, List[str]] = None, 235 | height: Optional[int] = None, 236 | width: Optional[int] = None, 237 | num_inference_steps: int = 50, 238 | timesteps: List[int] = None, 239 | sigmas: List[float] = None, 240 | guidance_scale: float = 7.5, 241 | negative_prompt: Optional[Union[str, List[str]]] = None, 242 | num_images_per_prompt: Optional[int] = 1, 243 | eta: float = 0.0, 244 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 245 | latents: Optional[torch.Tensor] = None, 246 | prompt_embeds: Optional[torch.Tensor] = None, 247 | negative_prompt_embeds: Optional[torch.Tensor] = None, 248 | ip_adapter_image: Optional[PipelineImageInput] = None, 249 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 250 | output_type: Optional[str] = "pil", 251 | return_dict: bool = True, 252 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 253 | guidance_rescale: float = 0.0, 254 | clip_skip: Optional[int] = None, 255 | callback_on_step_end: Optional[ 256 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 257 | ] = None, 258 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 259 | attn_res=None, 260 | syngen_step_size: float = 20.0, 261 | parsed_prompt: str = None, 262 | num_intervention_steps: int = 25, 263 | **kwargs, 264 | ): 265 | r""" 266 | The call function to the pipeline for generation. 267 | 268 | Args: 269 | prompt (`str` or `List[str]`, *optional*): 270 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 271 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 272 | The height in pixels of the generated image. 273 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 274 | The width in pixels of the generated image. 275 | num_inference_steps (`int`, *optional*, defaults to 50): 276 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 277 | expense of slower inference. 278 | timesteps (`List[int]`, *optional*): 279 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 280 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 281 | passed will be used. Must be in descending order. 282 | sigmas (`List[float]`, *optional*): 283 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 284 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 285 | will be used. 286 | guidance_scale (`float`, *optional*, defaults to 7.5): 287 | A higher guidance scale value encourages the model to generate images closely linked to the text 288 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 289 | negative_prompt (`str` or `List[str]`, *optional*): 290 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 291 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 292 | num_images_per_prompt (`int`, *optional*, defaults to 1): 293 | The number of images to generate per prompt. 294 | eta (`float`, *optional*, defaults to 0.0): 295 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 296 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 297 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 298 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 299 | generation deterministic. 300 | latents (`torch.Tensor`, *optional*): 301 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 302 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 303 | tensor is generated by sampling using the supplied random `generator`. 304 | prompt_embeds (`torch.Tensor`, *optional*): 305 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 306 | provided, text embeddings are generated from the `prompt` input argument. 307 | negative_prompt_embeds (`torch.Tensor`, *optional*): 308 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 309 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 310 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 311 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 312 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 313 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should 314 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not 315 | provided, embeddings are computed from the `ip_adapter_image` input argument. 316 | output_type (`str`, *optional*, defaults to `"pil"`): 317 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 318 | return_dict (`bool`, *optional*, defaults to `True`): 319 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 320 | plain tuple. 321 | cross_attention_kwargs (`dict`, *optional*): 322 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 323 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 324 | guidance_rescale (`float`, *optional*, defaults to 0.0): 325 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are 326 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when 327 | using zero terminal SNR. 328 | clip_skip (`int`, *optional*): 329 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 330 | the output of the pre-final layer will be used for computing the prompt embeddings. 331 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 332 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 333 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 334 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 335 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 336 | callback_on_step_end_tensor_inputs (`List`, *optional*): 337 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 338 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 339 | `._callback_tensor_inputs` attribute of your pipeline class. 340 | 341 | Examples: 342 | 343 | Returns: 344 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 345 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 346 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 347 | second element is a list of `bool`s indicating whether the corresponding generated image contains 348 | "not-safe-for-work" (nsfw) content. 349 | """ 350 | 351 | callback = kwargs.pop("callback", None) 352 | callback_steps = kwargs.pop("callback_steps", None) 353 | 354 | if callback is not None: 355 | deprecate( 356 | "callback", 357 | "1.0.0", 358 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 359 | ) 360 | if callback_steps is not None: 361 | deprecate( 362 | "callback_steps", 363 | "1.0.0", 364 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 365 | ) 366 | 367 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 368 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 369 | if parsed_prompt: 370 | self.doc = parsed_prompt 371 | else: 372 | self.doc = self.parser(prompt) 373 | # 0. Default height and width to unet 374 | height = height or self.unet.config.sample_size * self.vae_scale_factor 375 | width = width or self.unet.config.sample_size * self.vae_scale_factor 376 | # to deal with lora scaling and other possible forward hooks 377 | 378 | # 1. Check inputs. Raise error if not correct 379 | self.check_inputs( 380 | prompt, 381 | height, 382 | width, 383 | callback_steps, 384 | negative_prompt, 385 | prompt_embeds, 386 | negative_prompt_embeds, 387 | ip_adapter_image, 388 | ip_adapter_image_embeds, 389 | callback_on_step_end_tensor_inputs, 390 | ) 391 | 392 | self._guidance_scale = guidance_scale 393 | self._guidance_rescale = guidance_rescale 394 | self._clip_skip = clip_skip 395 | self._cross_attention_kwargs = cross_attention_kwargs 396 | self._interrupt = False 397 | 398 | # 2. Define call parameters 399 | if prompt is not None and isinstance(prompt, str): 400 | batch_size = 1 401 | elif prompt is not None and isinstance(prompt, list): 402 | batch_size = len(prompt) 403 | else: 404 | batch_size = prompt_embeds.shape[0] 405 | 406 | device = self._execution_device 407 | 408 | # 3. Encode input prompt 409 | lora_scale = ( 410 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 411 | ) 412 | 413 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 414 | prompt, 415 | device, 416 | num_images_per_prompt, 417 | self.do_classifier_free_guidance, 418 | negative_prompt, 419 | prompt_embeds=prompt_embeds, 420 | negative_prompt_embeds=negative_prompt_embeds, 421 | lora_scale=lora_scale, 422 | clip_skip=self.clip_skip, 423 | ) 424 | # print(prompt_embeds.size()) 425 | # assert 2==1 426 | 427 | # For classifier free guidance, we need to do two forward passes. 428 | # Here we concatenate the unconditional and text embeddings into a single batch 429 | # to avoid doing two forward passes 430 | if self.do_classifier_free_guidance: 431 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 432 | 433 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 434 | image_embeds = self.prepare_ip_adapter_image_embeds( 435 | ip_adapter_image, 436 | ip_adapter_image_embeds, 437 | device, 438 | batch_size * num_images_per_prompt, 439 | self.do_classifier_free_guidance, 440 | ) 441 | 442 | # 4. Prepare timesteps 443 | timesteps, num_inference_steps = retrieve_timesteps( 444 | self.scheduler, num_inference_steps, device, timesteps, sigmas 445 | ) 446 | 447 | # 5. Prepare latent variables 448 | num_channels_latents = self.unet.config.in_channels 449 | latents = self.prepare_latents( 450 | batch_size * num_images_per_prompt, 451 | num_channels_latents, 452 | height, 453 | width, 454 | prompt_embeds.dtype, 455 | device, 456 | generator, 457 | latents, 458 | ) 459 | 460 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 461 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 462 | 463 | if attn_res is None: 464 | attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) 465 | self.attn_res = attn_res 466 | self.attention_store = AttentionStore(self.attn_res) 467 | self.register_attention_control() 468 | 469 | # 6.1 Add image embeds for IP-Adapter 470 | added_cond_kwargs = ( 471 | {"image_embeds": image_embeds} 472 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) 473 | else None 474 | ) 475 | 476 | # 6.2 Optionally get Guidance Scale Embedding 477 | timestep_cond = None 478 | if self.unet.config.time_cond_proj_dim is not None: 479 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 480 | timestep_cond = self.get_guidance_scale_embedding( 481 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 482 | ).to(device=device, dtype=latents.dtype) 483 | 484 | text_embeddings = ( 485 | prompt_embeds[batch_size * num_images_per_prompt:] if self.do_classifier_free_guidance else prompt_embeds 486 | ) 487 | # 7. Denoising loop 488 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 489 | self._num_timesteps = len(timesteps) 490 | with self.progress_bar(total=num_inference_steps) as progress_bar: 491 | for i, t in enumerate(timesteps): 492 | 493 | if self.interrupt: 494 | continue 495 | 496 | if i < self.stop_step: 497 | continue 498 | 499 | if self.stop_step < num_intervention_steps: 500 | latents = self._syngen_step( 501 | latents, 502 | text_embeddings, 503 | t, 504 | self.stop_step, 505 | syngen_step_size, 506 | cross_attention_kwargs, 507 | prompt, 508 | num_intervention_steps=num_intervention_steps, 509 | ) 510 | # expand the latents if we are doing classifier free guidance 511 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 512 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 513 | # print(self.unet) 514 | # assert 2==1 515 | # predict the noise residual 516 | noise_pred = self.unet( 517 | latent_model_input, 518 | t, 519 | encoder_hidden_states=prompt_embeds, 520 | timestep_cond=timestep_cond, 521 | cross_attention_kwargs=self.cross_attention_kwargs, 522 | added_cond_kwargs=added_cond_kwargs, 523 | return_dict=False, 524 | )[0] 525 | 526 | # perform guidance 527 | if self.do_classifier_free_guidance: 528 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 529 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 530 | 531 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: 532 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 533 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) 534 | 535 | # compute the previous noisy sample x_t -> x_t-1 536 | latents, pred_original_sample = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False) 537 | 538 | 539 | if callback_on_step_end is not None: 540 | callback_kwargs = {} 541 | for k in callback_on_step_end_tensor_inputs: 542 | callback_kwargs[k] = locals()[k] 543 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 544 | 545 | latents = callback_outputs.pop("latents", latents) 546 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 547 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 548 | 549 | # call the callback, if provided 550 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 551 | progress_bar.update() 552 | if callback is not None and i % callback_steps == 0: 553 | step_idx = i // getattr(self.scheduler, "order", 1) 554 | callback(step_idx, t, latents) 555 | 556 | if i >= self.stop_step: 557 | # alpha_t = self.alphas[int(t.item())] ** 0.5 558 | # sigma_t = (1 - self.alphas[int(t.item())]) ** 0.5 559 | # latents = (latents - sigma_t * noise_pred) / alpha_t 560 | latents_cur = latents 561 | latents = pred_original_sample 562 | break 563 | if not output_type == "latent": 564 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 565 | 0 566 | ] 567 | # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 568 | else: 569 | has_nsfw_concept = None 570 | 571 | 572 | # if has_nsfw_concept is None: 573 | # do_denormalize = [True] * image.shape[0] 574 | # else: 575 | # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 576 | 577 | # image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 578 | # image = (image / 2 + 0.5).clamp(0, 1) 579 | image = (image / 2 + 0.5).clamp(0, 1) 580 | # optimizer = optim.Adam(self.unet.parameters(), lr=1e-1) 581 | # input_tensor = image # 模拟一个输入图像 582 | # target = torch.randn_like(input_tensor).to(noise_pred.device).half() # 目标与输入形状相同 583 | # loss = F.mse_loss(image, target) 584 | # optimizer.zero_grad() 585 | # loss.backward() 586 | # for name, param in self.unet.named_parameters(): 587 | # if param.grad is not None: 588 | # print(f"Gradient for {name}: {param.grad}") 589 | # else: 590 | # print(f"No gradient for {name}") 591 | 592 | # assert 2==1 593 | 594 | # Offload all models 595 | self.maybe_free_model_hooks() 596 | 597 | if not return_dict: 598 | # return (image, has_nsfw_concept) 599 | return (image, latents_cur) 600 | 601 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), latents_cur 602 | 603 | def _syngen_step( 604 | self, 605 | latents, 606 | text_embeddings, 607 | t, 608 | i, 609 | step_size, 610 | cross_attention_kwargs, 611 | prompt, 612 | num_intervention_steps, 613 | ): 614 | with torch.enable_grad(): 615 | latents = latents.clone().detach().requires_grad_(True) 616 | updated_latents = [] 617 | for latent, text_embedding in zip(latents, text_embeddings): 618 | # Forward pass of denoising with text conditioning 619 | latent = latent.unsqueeze(0) 620 | text_embedding = text_embedding.unsqueeze(0) 621 | 622 | self.unet( 623 | latent, 624 | t, 625 | encoder_hidden_states=text_embedding, 626 | cross_attention_kwargs=cross_attention_kwargs, 627 | return_dict=False, 628 | )[0] 629 | self.unet.zero_grad() 630 | # Get attention maps 631 | attention_maps = self._aggregate_and_get_attention_maps_per_token() 632 | loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt) 633 | # Perform gradient update 634 | if i < num_intervention_steps: 635 | if loss != 0: 636 | latent = self._update_latent( 637 | latents=latent, loss=loss, step_size=step_size 638 | ) 639 | logger.info(f"Iteration {i} | Loss: {loss:0.4f}") 640 | 641 | updated_latents.append(latent) 642 | 643 | latents = torch.cat(updated_latents, dim=0) 644 | 645 | return latents 646 | 647 | def _compute_loss( 648 | self, attention_maps: List[torch.Tensor], prompt: Union[str, List[str]] 649 | ) -> torch.Tensor: 650 | attn_map_idx_to_wp = get_attention_map_index_to_wordpiece(self.tokenizer, prompt) 651 | loss = self._attribution_loss(attention_maps, prompt, attn_map_idx_to_wp) 652 | 653 | return loss 654 | 655 | def _attribution_loss( 656 | self, 657 | attention_maps: List[torch.Tensor], 658 | prompt: Union[str, List[str]], 659 | attn_map_idx_to_wp, 660 | ) -> torch.Tensor: 661 | if not self.subtrees_indices: 662 | self.subtrees_indices = self._extract_attribution_indices(prompt) 663 | subtrees_indices = self.subtrees_indices 664 | 665 | loss = 0 666 | 667 | for subtree_indices in subtrees_indices: 668 | noun, modifier = split_indices(subtree_indices) 669 | all_subtree_pairs = list(itertools.product(noun, modifier)) 670 | if noun and not modifier: 671 | if isinstance(noun, list) and len(noun) == 1: 672 | processed_noun = noun[0] 673 | else: 674 | processed_noun = noun 675 | loss += calculate_negative_loss( 676 | attention_maps, modifier, processed_noun, subtree_indices, attn_map_idx_to_wp 677 | ) 678 | else: 679 | positive_loss, negative_loss = self._calculate_losses( 680 | attention_maps, 681 | all_subtree_pairs, 682 | subtree_indices, 683 | attn_map_idx_to_wp, 684 | ) 685 | 686 | loss += positive_loss 687 | loss += negative_loss 688 | 689 | return loss 690 | 691 | def _calculate_losses( 692 | self, 693 | attention_maps, 694 | all_subtree_pairs, 695 | subtree_indices, 696 | attn_map_idx_to_wp, 697 | ): 698 | positive_loss = [] 699 | negative_loss = [] 700 | for pair in all_subtree_pairs: 701 | noun, modifier = pair 702 | positive_loss.append( 703 | calculate_positive_loss(attention_maps, modifier, noun) 704 | ) 705 | negative_loss.append( 706 | calculate_negative_loss( 707 | attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp 708 | ) 709 | ) 710 | 711 | positive_loss = sum(positive_loss) 712 | negative_loss = sum(negative_loss) 713 | 714 | return positive_loss, negative_loss 715 | 716 | def _align_indices(self, prompt, spacy_pairs): 717 | wordpieces2indices = get_indices(self.tokenizer, prompt) 718 | paired_indices = [] 719 | collected_spacy_indices = ( 720 | set() 721 | ) # helps track recurring nouns across different relations (i.e., cases where there is more than one instance of the same word) 722 | 723 | for pair in spacy_pairs: 724 | curr_collected_wp_indices = ( 725 | [] 726 | ) # helps track which nouns and amods were added to the current pair (this is useful in sentences with repeating amod on the same relation (e.g., "a red red red bear")) 727 | for member in pair: 728 | for idx, wp in wordpieces2indices.items(): 729 | if wp in [start_token, end_token]: 730 | continue 731 | 732 | wp = wp.replace("", "") 733 | if member.text.lower() == wp.lower(): 734 | if idx not in curr_collected_wp_indices and idx not in collected_spacy_indices: 735 | curr_collected_wp_indices.append(idx) 736 | break 737 | # take care of wordpieces that are split up 738 | elif member.text.lower().startswith(wp.lower()) and wp.lower() != member.text.lower(): # can maybe be while loop 739 | wp_indices = align_wordpieces_indices( 740 | wordpieces2indices, idx, member.text 741 | ) 742 | # check if all wp_indices are not already in collected_spacy_indices 743 | if wp_indices and (wp_indices not in curr_collected_wp_indices) and all( 744 | [wp_idx not in collected_spacy_indices for wp_idx in wp_indices]): 745 | curr_collected_wp_indices.append(wp_indices) 746 | break 747 | 748 | for collected_idx in curr_collected_wp_indices: 749 | if isinstance(collected_idx, list): 750 | for idx in collected_idx: 751 | collected_spacy_indices.add(idx) 752 | else: 753 | collected_spacy_indices.add(collected_idx) 754 | 755 | if curr_collected_wp_indices: 756 | paired_indices.append(curr_collected_wp_indices) 757 | else: 758 | print(f"No wordpieces were aligned for {pair} in _align_indices") 759 | 760 | return paired_indices 761 | 762 | def _extract_attribution_indices(self, prompt): 763 | modifier_indices = [] 764 | # extract standard attribution indices 765 | modifier_sets_1 = extract_attribution_indices(self.doc) 766 | modifier_indices_1 = self._align_indices(prompt, modifier_sets_1) 767 | if modifier_indices_1: 768 | modifier_indices.append(modifier_indices_1) 769 | 770 | # extract attribution indices with verbs in between 771 | modifier_sets_2 = extract_attribution_indices_with_verb_root(self.doc) 772 | modifier_indices_2 = self._align_indices(prompt, modifier_sets_2) 773 | if modifier_indices_2: 774 | modifier_indices.append(modifier_indices_2) 775 | 776 | modifier_sets_3 = extract_attribution_indices_with_verbs(self.doc) 777 | modifier_indices_3 = self._align_indices(prompt, modifier_sets_3) 778 | if modifier_indices_3: 779 | modifier_indices.append(modifier_indices_3) 780 | 781 | # entities only 782 | if self.include_entities: 783 | modifier_sets_4 = extract_entities_only(self.doc) 784 | modifier_indices_4 = self._align_indices(prompt, modifier_sets_4) 785 | modifier_indices.append(modifier_indices_4) 786 | 787 | # make sure there are no duplicates 788 | modifier_indices = unify_lists(modifier_indices) 789 | print(f"Final modifier indices collected:{modifier_indices}") 790 | 791 | return modifier_indices 792 | 793 | 794 | def _get_attention_maps_list( 795 | attention_maps: torch.Tensor 796 | ) -> List[torch.Tensor]: 797 | attention_maps *= 100 798 | attention_maps_list = [ 799 | attention_maps[:, :, i] for i in range(attention_maps.shape[2]) 800 | ] 801 | 802 | return attention_maps_list 803 | 804 | 805 | def unify_lists(list_of_lists): 806 | def flatten(lst): 807 | for elem in lst: 808 | if isinstance(elem, list): 809 | yield from flatten(elem) 810 | else: 811 | yield elem 812 | 813 | def have_common_element(lst1, lst2): 814 | flat_list1 = set(flatten(lst1)) 815 | flat_list2 = set(flatten(lst2)) 816 | return not flat_list1.isdisjoint(flat_list2) 817 | 818 | lst = [] 819 | for l in list_of_lists: 820 | lst += l 821 | changed = True 822 | while changed: 823 | changed = False 824 | merged_list = [] 825 | while lst: 826 | first = lst.pop(0) 827 | was_merged = False 828 | for index, other in enumerate(lst): 829 | if have_common_element(first, other): 830 | # If we merge, we should flatten the other list but not first 831 | new_merged = first + [item for item in other if item not in first] 832 | lst[index] = new_merged 833 | changed = True 834 | was_merged = True 835 | break 836 | if not was_merged: 837 | merged_list.append(first) 838 | lst = merged_list 839 | 840 | return lst 841 | 842 | from typing import Tuple 843 | from dataclasses import dataclass 844 | from diffusers.schedulers.scheduling_lcm import LCMSchedulerOutput 845 | from diffusers.utils.torch_utils import randn_tensor 846 | from diffusers.utils import BaseOutput 847 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput 848 | 849 | def register_sdschedule_step(model): 850 | def sd_schedule_step(self): 851 | def step( 852 | model_output: torch.Tensor, 853 | timestep: Union[int, torch.Tensor], 854 | sample: torch.Tensor, 855 | generator=None, 856 | variance_noise: Optional[torch.Tensor] = None, 857 | return_dict: bool = True, 858 | ) -> Union[SchedulerOutput, Tuple]: 859 | """ 860 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with 861 | the multistep DPMSolver. 862 | 863 | Args: 864 | model_output (`torch.Tensor`): 865 | The direct output from learned diffusion model. 866 | timestep (`int`): 867 | The current discrete timestep in the diffusion chain. 868 | sample (`torch.Tensor`): 869 | A current instance of a sample created by the diffusion process. 870 | generator (`torch.Generator`, *optional*): 871 | A random number generator. 872 | variance_noise (`torch.Tensor`): 873 | Alternative to generating noise with `generator` by directly providing the noise for the variance 874 | itself. Useful for methods such as [`LEdits++`]. 875 | return_dict (`bool`): 876 | Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. 877 | 878 | Returns: 879 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: 880 | If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a 881 | tuple is returned where the first element is the sample tensor. 882 | 883 | """ 884 | if self.num_inference_steps is None: 885 | raise ValueError( 886 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 887 | ) 888 | 889 | if self.step_index is None: 890 | self._init_step_index(timestep) 891 | 892 | # Improve numerical stability for small number of steps 893 | lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( 894 | self.config.euler_at_final 895 | or (self.config.lower_order_final and len(self.timesteps) < 15) 896 | or self.config.final_sigmas_type == "zero" 897 | ) 898 | lower_order_second = ( 899 | (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 900 | ) 901 | 902 | model_output = self.convert_model_output(model_output, sample=sample) 903 | for i in range(self.config.solver_order - 1): 904 | self.model_outputs[i] = self.model_outputs[i + 1] 905 | self.model_outputs[-1] = model_output 906 | 907 | # Upcast to avoid precision issues when computing prev_sample 908 | sample = sample.to(torch.float32) 909 | if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: 910 | noise = randn_tensor( 911 | model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 912 | ) 913 | elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: 914 | noise = variance_noise.to(device=model_output.device, dtype=torch.float32) 915 | else: 916 | noise = None 917 | 918 | if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: 919 | prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) 920 | elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: 921 | prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) 922 | else: 923 | prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) 924 | 925 | if self.lower_order_nums < self.config.solver_order: 926 | self.lower_order_nums += 1 927 | 928 | # Cast sample back to expected dtype 929 | prev_sample = prev_sample.to(model_output.dtype) 930 | 931 | # upon completion increase step index by one 932 | self._step_index += 1 933 | 934 | if not return_dict: 935 | return (prev_sample, model_output) 936 | 937 | return SchedulerOutput(prev_sample=prev_sample) 938 | 939 | return step 940 | if model.__class__.__name__ == 'DPMSolverMultistepScheduler': 941 | model.step = sd_schedule_step(model) --------------------------------------------------------------------------------