├── method.png
├── results.png
├── motivation.png
├── results1.png
├── training-free
├── infer_sd1-5_x0_optim_mask_fnal_para.sh
├── hpsv2_loss.py
├── infer_sd1-5_x0_optim_mask_final_para.py
├── sd_utils_x0_dpm.py
└── sd_utils_x0_dpm_syn.py
├── training
├── infer_sd1-5_hardmask.sh
├── train_sd_lora.sh
├── train_sd_full.sh
├── train_hyperunet.sh
├── train_dreambooth_maskunet.sh
├── train_dreambooth.sh
├── infer_sd1-5_hardmask.py
├── hyperunet.py
└── train_sd_lora.py
├── README.md
└── environment.yaml
/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/method.png
--------------------------------------------------------------------------------
/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/results.png
--------------------------------------------------------------------------------
/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/motivation.png
--------------------------------------------------------------------------------
/results1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gudaochangsheng/MaskUnet/HEAD/results1.png
--------------------------------------------------------------------------------
/training-free/infer_sd1-5_x0_optim_mask_fnal_para.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --multi_gpu --num_processes 6 --gpu_ids='all' infer_sd1-5_x0_optim_mask_final_para.py
--------------------------------------------------------------------------------
/training/infer_sd1-5_hardmask.sh:
--------------------------------------------------------------------------------
1 | torchrun --master-port=29506 --nproc_per_node=8 infer_sd1-5_hardmask.py \
2 | --pretrained_model_name_or_path "PixArt-alpha/PixArt-XL-2-512x512" \
3 | --captions_file "captions_2017.txt" \
4 | --world_size 8 \
5 | --batch_size 16
--------------------------------------------------------------------------------
/training/train_sd_lora.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5"
2 | export DATASET_NAME="../fantasyfish/laion-art"
3 |
4 | accelerate launch --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_sd_lora.py \
5 | --pretrained_model_name_or_path=$MODEL_NAME \
6 | --dataset_name=$DATASET_NAME \
7 | --image_column="image" \
8 | --caption_column="text" \
9 | --resolution=512 --center_crop --random_flip \
10 | --train_batch_size=4 \
11 | --gradient_accumulation_steps=1 \
12 | --gradient_checkpointing \
13 | --num_train_epochs 12 --checkpointing_steps=500 \
14 | --max_grad_norm=1 \
15 | --use_8bit_adam \
16 | --learning_rate=1e-05 --lr_scheduler="constant" --lr_warmup_steps=0 \
17 | --seed=42 \
18 | --enable_xformers_memory_efficient_attention \
19 | --output_dir="sd-naruto-model-lora" \
20 | --validation_prompt="cute cat"
--------------------------------------------------------------------------------
/training/train_sd_full.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5"
2 | export DATASET_NAME="../fantasyfish/laion-art"
3 | # --use_ema \
4 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_sd_full.py \
5 | --pretrained_model_name_or_path=$MODEL_NAME \
6 | --dataset_name=$DATASET_NAME \
7 | --image_column="image" \
8 | --caption_column="text" \
9 | --resolution=512 --center_crop --random_flip \
10 | --train_batch_size=4 \
11 | --gradient_accumulation_steps=1 \
12 | --gradient_checkpointing \
13 | --num_train_epochs 12 \
14 | --checkpointing_steps=500 \
15 | --learning_rate=1e-05 \
16 | --max_grad_norm=1 \
17 | --use_8bit_adam \
18 | --lr_scheduler="constant" --lr_warmup_steps=0 \
19 | --seed=42 \
20 | --enable_xformers_memory_efficient_attention \
21 | --output_dir="sd-naruto-model_full"
22 |
23 | # # --max_train_steps=5000 \
--------------------------------------------------------------------------------
/training/train_hyperunet.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5"
2 | export DATASET_NAME="../fantasyfish/laion-art"
3 | # --use_ema \
4 | #-
5 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 8 --gpu_ids='all' train_hyperunet.py \
6 | --pretrained_model_name_or_path=$MODEL_NAME \
7 | --dataset_name=$DATASET_NAME \
8 | --image_column="image" \
9 | --caption_column="text" \
10 | --resolution=512 --center_crop --random_flip \
11 | --train_batch_size=4 \
12 | --gradient_accumulation_steps=1 \
13 | --gradient_checkpointing \
14 | --num_train_epochs 15 \
15 | --checkpointing_steps=500 \
16 | --learning_rate=1e-05 \
17 | --max_grad_norm=1 \
18 | --use_8bit_adam \
19 | --lr_scheduler="constant" --lr_warmup_steps=0 \
20 | --seed=42 \
21 | --enable_xformers_memory_efficient_attention \
22 | --output_dir="sd-naruto-model_hard_improve1_factor4_trans_decoder_sd1.5_factorablation_epoch15"
23 |
24 | # # --max_train_steps=5000 \
--------------------------------------------------------------------------------
/training/train_dreambooth_maskunet.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5"
2 | export INSTANCE_DIR="./dreambooth/dataset/dog2"
3 | export CLASS_DIR="dream_booth_class_image"
4 | export OUTPUT_DIR="model_path_subject_dream_booth_dog"
5 |
6 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_dreambooth_maskunet.py \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --instance_data_dir=$INSTANCE_DIR \
9 | --output_dir=$OUTPUT_DIR \
10 | --train_batch_size=4 \
11 | --gradient_accumulation_steps=1 \
12 | --learning_rate=1e-3 \
13 | --lr_scheduler="constant" \
14 | --lr_warmup_steps=0 \
15 | --instance_prompt="a photo of sks dog" \
16 | --resolution=512 \
17 | --max_train_steps=5000
18 | # --class_data_dir=$CLASS_DIR \
19 | # --with_prior_preservation --prior_loss_weight=1.0 \
20 |
21 | # --class_prompt="a photo of backpack" \
22 |
23 |
24 | # --num_class_images=200 \
25 |
26 | # --push_to_hub
--------------------------------------------------------------------------------
/training/train_dreambooth.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="../share/runwayml/stable-diffusion-v1-5"
2 | # export INSTANCE_DIR="dog"
3 | export OUTPUT_DIR_root="dreambooth_model_ori"
4 | export CLASS_DIR_root="dreambooth_class-images"
5 | export DATA_SET_root="./dreambooth/dataset"
6 |
7 | export subject_name=("backpack_dog" "backpack" "bear_plushie" "berry_bowl" "can" "candle" "cat" "cat2" "clock" "colorful_sneaker" "dog" "dog2" "dog3" "dog5" "dog6" "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie" "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon" "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie")
8 | export class_name=("backpack" "backpack" "stuffed animal" "bowl" "can" "candle" "cat" "cat" "clock" "sneaker" "dog" "dog" "dog" "dog" "dog" "dog" "dog" "toy" "boot" "stuffed animal" "toy" "glasses" "toy" "toy" "cartoon" "toy" "sneaker" "teapot" "vase" "stuffed animal")
9 |
10 | for i in "${!subject_name[@]}"; do
11 | export DATA_SET="$DATA_SET_root/${subject_name[i]}" # Append each subject to the dataset path
12 | class="${class_name[i]}"
13 | export CLASS_DIR="$CLASS_DIR_root/$class"
14 | export OUTPUT_DIR="$OUTPUT_DIR_root/${subject_name[i]}"
15 | # echo "DataSet: $DATA_SET, class: $class, subject: ${subject_name[i]}, class_dir: $CLASS_DIR, output_dir: $OUTPUT_DIR"
16 | accelerate launch --main_process_port 29501 --mixed_precision="fp16" --multi_gpu --num_processes 6 --gpu_ids='all' train_dreambooth.py \
17 | --pretrained_model_name_or_path=$MODEL_NAME \
18 | --instance_data_dir=$DATA_SET \
19 | --class_data_dir="$CLASS_DIR" \
20 | --output_dir="$OUTPUT_DIR" \
21 | --with_prior_preservation --prior_loss_weight=1.0 \
22 | --instance_prompt="a photo of sks ${subject_name[i]}" \
23 | --class_prompt="a photo of $class" \
24 | --resolution=512 \
25 | --train_batch_size=1 \
26 | --gradient_accumulation_steps=1 \
27 | --learning_rate=5e-6 \
28 | --lr_scheduler="constant" \
29 | --lr_warmup_steps=0 \
30 | --num_class_images=200 \
31 | --max_train_steps=800 \
32 | # --push_to_hub
33 | done
34 |
35 |
--------------------------------------------------------------------------------
/training-free/hpsv2_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from hpsv2.src.open_clip import create_model, get_tokenizer
3 | import torch.nn as nn
4 | import huggingface_hub
5 |
6 | class HPSV2Loss(nn.Module):
7 | """HPS reward loss function for optimization."""
8 |
9 | def __init__(
10 | self,
11 | dtype: torch.dtype,
12 | device: torch.device,
13 | cache_dir: str,
14 | memsave: bool = False,
15 | ):
16 | super(HPSV2Loss, self).__init__() # 先调用父类的初始化方法
17 | self.hps_model = create_model(
18 | "ViT-H-14",
19 | "laion2B-s32B-b79K",
20 | precision=dtype,
21 | device=device,
22 | cache_dir=cache_dir,
23 | )
24 | checkpoint_path = huggingface_hub.hf_hub_download(
25 | "xswu/HPSv2", "HPS_v2.1_compressed.pt", cache_dir=cache_dir
26 | )
27 | self.hps_model.load_state_dict(
28 | torch.load(checkpoint_path, map_location=device)["state_dict"]
29 | )
30 | self.hps_tokenizer = get_tokenizer("ViT-H-14")
31 | if memsave:
32 | import memsave_torch.nn
33 |
34 | self.hps_model = memsave_torch.nn.convert_to_memory_saving(self.hps_model)
35 | self.hps_model = self.hps_model.to(device, dtype=dtype)
36 | self.hps_model.eval()
37 | self.freeze_parameters(self.hps_model.parameters())
38 | # super().__init__("HPS")
39 | self.hps_model.set_grad_checkpointing(True)
40 |
41 | @staticmethod
42 | def freeze_parameters(params: torch.nn.ParameterList):
43 | for param in params:
44 | param.requires_grad = False
45 |
46 | def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
47 | hps_image_features = self.hps_model.encode_image(image)
48 | return hps_image_features
49 |
50 | def get_text_features(self, prompt: str) -> torch.Tensor:
51 | hps_text = self.hps_tokenizer(prompt).to("cuda")
52 | hps_text_features = self.hps_model.encode_text(hps_text)
53 | return hps_text_features
54 |
55 | def compute_loss(
56 | self, image_features: torch.Tensor, text_features: torch.Tensor
57 | ) -> torch.Tensor:
58 | logits_per_image = image_features @ text_features.T
59 | hps_loss = 1 - torch.diagonal(logits_per_image)[0]
60 | return hps_loss
61 |
62 | def process_features(self, features: torch.Tensor) -> torch.Tensor:
63 | features_normed = features / features.norm(dim=-1, keepdim=True)
64 | return features_normed
65 |
66 | def score_grad(self, image: torch.Tensor, prompt: str) -> torch.Tensor:
67 | image_features = self.get_image_features(image)
68 | text_features = self.get_text_features(prompt)
69 |
70 | image_features_normed = self.process_features(image_features)
71 | text_features_normed = self.process_features(text_features)
72 |
73 | loss = self.compute_loss(image_features_normed, text_features_normed)
74 | return loss
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🚀 [CVPR 2025] [Not All Parameters Matter: Masking Diffusion Models for Enhancing Generation Ability](https://arxiv.org/pdf/2505.03097)
2 |
3 |
4 |

5 |
6 |
7 | Analysis of parameter distributions and denoising effects across different time steps for Stable Diffusion (SD) 1.5 with and without random masking. The first column shows the parameter distribution of SD 1.5, while the second to fifth columns display the distributions of parameters removed by the random mask. The last two columns compare the generated samples from SD 1.5 and the random mask.
8 |
9 |
10 |
11 | ## 📘 Introduction
12 | The diffusion models, in early stages focus on constructing basic image structures, while the refined details, including local features and textures, are generated in later stages. Thus the same network layers are forced to learn both structural and textural information simultaneously, significantly differing from the traditional deep learning architectures (e.g., ResNet or GANs) which captures or generates the image semantic information at different layers. This difference inspires us to explore the time-wise diffusion models. We initially investigate the key contributions of the U-Net parameters to the denoising process and identify that properly zeroing out certain parameters (including large parameters) contributes to denoising, substantially improving the generation quality on the fly. Capitalizing on this discovery, we propose a simple yet effective method—termed “MaskUNet”— that enhances generation quality with negligible parameter numbers. Our method fully leverages timestep- and sample-dependent effective U-Net parameters. To optimize MaskUNet, we offer two fine-tuning strategies: a training-based approach and a training-free approach, including tailored networks and optimization functions. In zero-shot inference on the COCO dataset, MaskUNet achieves the best FID score and further demonstrates its effectiveness in downstream task evaluations.
13 |
14 |
15 |
16 |
17 | The pipeline of the MaskUnet. G-Sig represents the Gumbel-Sigmoid activate function. GAP is global average pooling.
18 |
19 |
20 |
21 | ## Training
22 | ### Datasets
23 | fantasyfish/laion-art [link1](https://huggingface.co/datasets/fantasyfish/laion-art) [link2](https://hf-mirror.com/datasets/fantasyfish/laion-art)
24 |
25 | ### Installation
26 | ```shell
27 | conda env create -f environment.yaml
28 | ```
29 | ### Training-based
30 | #### train
31 | ```shell
32 | ./training/train_hyperunet.sh
33 | ```
34 | #### inference
35 | ```shell
36 | ./training/infer_sd1-5_hardmask.sh
37 | ```
38 |
39 | ### Training-free
40 | ```shell
41 | ./training-free/infer_sd1-5_x0_optim_mask_fnal_para.sh
42 | ```
43 | ## ✨ Qualitative results
44 |
45 |
46 |
47 | Quality results compared to other methods.
48 |
49 |
50 |
51 |
52 | ## 📈 Quantitative results
53 |
54 |
55 |
56 |
57 | ## Citation
58 |
59 | ```
60 | @inproceedings{wang2025not,
61 | title={Not All Parameters Matter: Masking Diffusion Models for Enhancing Generation Ability},
62 | author={Wang, Lei and Li, Senmao and Yang, Fei and Wang, Jianye and Zhang, Ziheng and Liu, Yuhan and Wang, Yaxing and Yang, Jian},
63 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
64 | pages={12880--12890},
65 | year={2025}
66 | }
67 |
68 | ```
69 | ## Acknowledgement
70 |
71 | This project is based on [Diffusers](https://github.com/huggingface/diffusers). Thanks for their awesome works.
72 | ## Contact
73 | If you have any questions, please feel free to reach out to me at `scitop1998@gmail.com`.
74 |
--------------------------------------------------------------------------------
/training/infer_sd1-5_hardmask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4 | from hyper_unet6 import FineGrainedUNet2DConditionModel
5 | from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
6 | import argparse
7 | import torch
8 | import torch.distributed as dist
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from torch.utils.data import DataLoader, Dataset, DistributedSampler
11 | from diffusers import PixArtAlphaPipeline
12 | from safetensors.torch import load_file
13 | import os
14 |
15 | class PromptDataset(Dataset):
16 | def __init__(self, prompts):
17 | self.prompts = prompts
18 |
19 | def __len__(self):
20 | return len(self.prompts)
21 |
22 | def __getitem__(self, idx):
23 | return self.prompts[idx], idx # Return prompt and its index
24 |
25 | def setup(rank, world_size):
26 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
27 |
28 | def cleanup():
29 | dist.destroy_process_group()
30 | # Set a fixed seed for reproducibility
31 |
32 |
33 | def main_worker(rank, args):
34 | setup(rank, args.world_size)
35 |
36 | seed =42
37 |
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed)
40 |
41 | with open(args.captions_file, 'r', encoding='utf-8') as f:
42 | prompts = f.readlines()
43 | model_path = "path_to_saved_model"
44 | unet = FineGrainedUNet2DConditionModel.from_pretrained(
45 | "../share/runwayml/stable-diffusion-v1-5", subfolder="unet",
46 | low_cpu_mem_usage=False,
47 | device_map=None,
48 | torch_dtype=torch.float16
49 | )
50 |
51 | input_dir = "./sd-naruto-model_hard_improve1_factor4_trans_decoder_sd1.5_factorablation_epoch15/checkpoint-9000"
52 |
53 | load_mask_path = os.path.join(input_dir, "mask_generators.pth")
54 | loaded_mask_state_dict = torch.load(load_mask_path)
55 |
56 | # Load weights for mask_generators
57 | unet.mask_generators.load_state_dict(loaded_mask_state_dict)
58 |
59 | # Load weights for adapters
60 | load_path_adapters = os.path.join(input_dir, "adapters.pth")
61 | adapters_state_dict = torch.load(load_path_adapters)
62 | unet.adapters.load_state_dict(adapters_state_dict)
63 |
64 | print("load success")
65 | scheduler = DDIMScheduler.from_pretrained("../share/runwayml/stable-diffusion-v1-5", subfolder="scheduler")
66 | # Load the Stable Diffusion pipeline with the modified UNet model
67 | pipe = StableDiffusionPipeline.from_pretrained(
68 | "../share/runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
69 | scheduler=scheduler,
70 | unet = unet
71 | )
72 | pipe.to(f"cuda:{rank}")
73 |
74 | output_dir = "fake_images_sd1.5_mask_vis_9000_final"
75 |
76 | if rank == 0:
77 | os.makedirs(output_dir, exist_ok=True)
78 |
79 | generator = torch.Generator(device=f"cuda:{rank}").manual_seed(args.seed) if args.seed else None
80 | def dummy_checker(images, **kwargs):
81 | # Return the images as is, and set NSFW detection results to False (not NSFW)
82 | return images, [False] * len(images)
83 |
84 | pipe.safety_checker = dummy_checker
85 |
86 | dataset = PromptDataset(prompts)
87 | sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=False)
88 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, sampler=sampler)
89 |
90 | for i, (batch_prompts, indices) in enumerate(dataloader):
91 | # Ensure batch_prompts is a list of strings
92 | batch_prompts = [prompt for prompt in batch_prompts]
93 | images = pipe(prompt=batch_prompts, generator=generator, num_inference_steps=50, guidance_scale=7.5).images
94 |
95 | # Save images immediately after generation
96 | for j, image in enumerate(images):
97 | output_filename = f"{output_dir}/{indices[j]:05}.png"
98 | image.save(output_filename)
99 |
100 | # Set the seed in the pipeline for deterministic image generation
101 | # generator = torch.Generator(device="cuda").manual_seed(seed)
102 |
103 | # Generate the image with a fixed seed #cute dragon creature
104 | # image = pipe(prompt="A naruto with green eyes and red legs.", generator=generator).images[0]
105 | # image = pipe(prompt="digital art of a little cat traveling around forest, wearing a scarf around the neck, carrying a tiny backpack on his back", generator=generator, num_inference_steps=10, guidance_scale=7.5).images[0]
106 | # image = pipe(prompt="cute dragon creature", generator=generator).images[0]
107 | # image.save("yoda-naruto.png")
108 |
109 | cleanup()
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser(description="Simple example of an inference script.")
113 | parser.add_argument(
114 | "--pretrained_model_name_or_path",
115 | type=str,
116 | default="thuanz123/swiftbrush",
117 | required=True,
118 | help="Path to pretrained model or model identifier from huggingface.co/models.",
119 | )
120 | parser.add_argument(
121 | "--captions_file",
122 | type=str,
123 | default="captions.txt",
124 | required=True,
125 | help="Path to the captions file.",
126 | )
127 |
128 | parser.add_argument(
129 | "--seed",
130 | type=int,
131 | default=42,
132 | required=False,
133 | help="Random seed used for inference.",
134 | )
135 |
136 | parser.add_argument(
137 | "--batch_size",
138 | type=int,
139 | default=16,
140 | required=False,
141 | help="Batch size for inference.",
142 | )
143 |
144 | parser.add_argument(
145 | "--world_size",
146 | type=int,
147 | default=torch.cuda.device_count(),
148 | help="Number of GPUs to use.",
149 | )
150 |
151 | args = parser.parse_args()
152 |
153 | # Get the rank from the environment variable set by torchrun
154 | rank = int(os.environ["LOCAL_RANK"])
155 |
156 | main_worker(rank, args)
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: hyperunet
2 | channels:
3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - bzip2=1.0.8=h5eee18b_6
10 | - ca-certificates=2024.9.24=h06a4308_0
11 | - ld_impl_linux-64=2.40=h12ee557_0
12 | - libffi=3.4.4=h6a678d5_1
13 | - libgcc-ng=11.2.0=h1234567_1
14 | - libgomp=11.2.0=h1234567_1
15 | - libstdcxx-ng=11.2.0=h1234567_1
16 | - libuuid=1.41.5=h5eee18b_0
17 | - ncurses=6.4=h6a678d5_0
18 | - openssl=3.0.15=h5eee18b_0
19 | - pip=24.2=py310h06a4308_0
20 | - python=3.10.15=he870216_1
21 | - readline=8.2=h5eee18b_0
22 | - sqlite=3.45.3=h5eee18b_0
23 | - tk=8.6.14=h39e8969_0
24 | - wheel=0.44.0=py310h06a4308_0
25 | - xz=5.4.6=h5eee18b_1
26 | - zlib=1.2.13=h5eee18b_1
27 | - pip:
28 | - accelerate==0.34.2
29 | - addict==2.4.0
30 | - aiohappyeyeballs==2.4.3
31 | - aiohttp==3.10.9
32 | - aiohttp-retry==2.8.3
33 | - aiosignal==1.3.1
34 | - aliyun-python-sdk-core==2.16.0
35 | - aliyun-python-sdk-kms==2.16.5
36 | - amqp==5.2.0
37 | - annotated-types==0.7.0
38 | - antlr4-python3-runtime==4.9.3
39 | - appdirs==1.4.4
40 | - args==0.1.0
41 | - async-timeout==4.0.3
42 | - asyncssh==2.17.0
43 | - atpublic==5.0
44 | - attrs==24.2.0
45 | - billiard==4.2.1
46 | - bitsandbytes==0.42.0
47 | - blis==1.0.1
48 | - braceexpand==0.1.7
49 | - catalogue==2.0.10
50 | - celery==5.4.0
51 | - certifi==2024.8.30
52 | - cffi==1.17.1
53 | - charset-normalizer==3.3.2
54 | - clean-fid==0.1.35
55 | - click==8.1.7
56 | - click-didyoumean==0.3.1
57 | - click-plugins==1.1.1
58 | - click-repl==0.3.0
59 | - clint==0.5.1
60 | - clip==1.0
61 | - clip-benchmark==1.6.1
62 | - cloudpathlib==0.20.0
63 | - cmake==3.30.4
64 | - colorama==0.4.6
65 | - confection==0.1.5
66 | - configobj==5.0.9
67 | - contourpy==1.3.0
68 | - controlnet-aux==0.0.5
69 | - crcmod==1.7
70 | - cryptography==43.0.1
71 | - curated-tokenizers==0.0.9
72 | - curated-transformers==0.1.1
73 | - cycler==0.12.1
74 | - cymem==2.0.8
75 | - datasets==3.0.2
76 | - dictdiffer==0.9.0
77 | - diffusers==0.30.3
78 | - dill==0.3.7
79 | - diskcache==5.6.3
80 | - distro==1.9.0
81 | - docker-pycreds==0.4.0
82 | - dpath==2.2.0
83 | - dulwich==0.22.1
84 | - dvc==3.55.2
85 | - dvc-data==3.16.6
86 | - dvc-http==2.32.0
87 | - dvc-objects==5.1.0
88 | - dvc-render==1.0.2
89 | - dvc-studio-client==0.21.0
90 | - dvc-task==0.40.2
91 | - einops==0.8.0
92 | - en-core-web-trf==3.8.0
93 | - entrypoints==0.4
94 | - exceptiongroup==1.2.2
95 | - fairscale==0.4.13
96 | - filelock==3.14.0
97 | - flatten-dict==0.4.2
98 | - flufl-lock==8.1.0
99 | - fonttools==4.54.1
100 | - frozenlist==1.4.1
101 | - fsspec==2024.6.1
102 | - ftfy==6.3.0
103 | - funcy==2.0
104 | - gitdb==4.0.11
105 | - gitpython==3.1.43
106 | - grandalf==0.8
107 | - gto==1.7.1
108 | - hpsv2==1.2.0
109 | - huggingface-hub==0.26.2
110 | - hydra-core==1.3.2
111 | - idna==3.10
112 | - image-reward==1.5
113 | - imageio==2.35.1
114 | - imageio-ffmpeg==0.5.1
115 | - importlib-metadata==8.5.0
116 | - iniconfig==2.0.0
117 | - iterative-telemetry==0.0.9
118 | - jinja2==3.1.4
119 | - jmespath==0.10.0
120 | - joblib==1.4.2
121 | - kiwisolver==1.4.7
122 | - kombu==5.4.2
123 | - langcodes==3.4.1
124 | - language-data==1.2.0
125 | - lazy-loader==0.4
126 | - lit==18.1.8
127 | - marisa-trie==1.2.1
128 | - markdown==3.7
129 | - markdown-it-py==3.0.0
130 | - markupsafe==2.1.5
131 | - mat4py==0.6.0
132 | - matplotlib==3.8.0
133 | - mdurl==0.1.2
134 | - memsave-torch==1.0.0
135 | - mmcv==2.2.0
136 | - mmdet==3.3.0
137 | - mmengine==0.10.5
138 | - mmpretrain==1.2.0
139 | - model-index==0.1.11
140 | - modelindex==0.0.2
141 | - motmetrics==1.4.0
142 | - mpmath==1.3.0
143 | - multidict==6.1.0
144 | - multiprocess==0.70.15
145 | - munch==4.0.0
146 | - murmurhash==1.0.10
147 | - networkx==3.3
148 | - numpy==1.23.5
149 | - nvidia-cublas-cu12==12.1.3.1
150 | - nvidia-cuda-cupti-cu12==12.1.105
151 | - nvidia-cuda-nvrtc-cu12==12.1.105
152 | - nvidia-cuda-runtime-cu12==12.1.105
153 | - nvidia-cudnn-cu12==9.1.0.70
154 | - nvidia-cufft-cu12==11.0.2.54
155 | - nvidia-curand-cu12==10.3.2.106
156 | - nvidia-cusolver-cu12==11.4.5.107
157 | - nvidia-cusparse-cu12==12.1.0.106
158 | - nvidia-nccl-cu12==2.20.5
159 | - nvidia-nvjitlink-cu12==12.6.77
160 | - nvidia-nvtx-cu12==12.1.105
161 | - omegaconf==2.3.0
162 | - open-clip-torch==2.26.1
163 | - opencv-python==4.8.0.76
164 | - opendatalab==0.0.10
165 | - openmim==0.3.9
166 | - openxlab==0.1.2
167 | - ordered-set==4.1.0
168 | - orjson==3.10.7
169 | - oss2==2.17.0
170 | - packaging==24.1
171 | - pandas==2.2.3
172 | - pathspec==0.12.1
173 | - peft==0.13.2
174 | - pillow==10.0.1
175 | - platformdirs==3.11.0
176 | - pluggy==1.5.0
177 | - preshed==3.0.9
178 | - prompt-toolkit==3.0.48
179 | - protobuf==3.20.3
180 | - psutil==6.0.0
181 | - pyarrow==17.0.0
182 | - pycocoevalcap==1.2
183 | - pycocotools==2.0.8
184 | - pycparser==2.22
185 | - pycryptodome==3.21.0
186 | - pydantic==1.10.18
187 | - pydantic-core==2.23.4
188 | - pydot==3.0.2
189 | - pygit2==1.16.0
190 | - pygments==2.18.0
191 | - pygtrie==2.5.0
192 | - pyparsing==3.1.4
193 | - pytest==7.2.0
194 | - pytest-split==0.8.0
195 | - python-dateutil==2.9.0.post0
196 | - pytz==2023.4
197 | - pyyaml==6.0.1
198 | - regex==2024.9.11
199 | - requests==2.32.2
200 | - rich==13.4.2
201 | - ruamel-yaml==0.18.6
202 | - ruamel-yaml-clib==0.2.8
203 | - safetensors==0.4.5
204 | - scikit-image==0.24.0
205 | - scikit-learn==1.5.2
206 | - scipy==1.9.3
207 | - scmrepo==3.3.8
208 | - seaborn==0.13.2
209 | - semver==3.0.2
210 | - sentencepiece==0.2.0
211 | - sentry-sdk==2.16.0
212 | - setproctitle==1.3.3
213 | - setuptools==60.2.0
214 | - shapely==2.0.6
215 | - shellingham==1.5.4
216 | - shortuuid==1.0.13
217 | - shtab==1.7.1
218 | - six==1.16.0
219 | - smart-open==7.0.5
220 | - smmap==5.0.1
221 | - spacy==3.8.2
222 | - spacy-curated-transformers==0.3.0
223 | - spacy-legacy==3.0.12
224 | - spacy-loggers==1.0.5
225 | - sqltrie==0.11.1
226 | - srsly==2.4.8
227 | - sympy==1.13.3
228 | - tabulate==0.9.0
229 | - termcolor==2.5.0
230 | - terminaltables==3.1.10
231 | - thinc==8.3.2
232 | - threadpoolctl==3.5.0
233 | - tifffile==2024.9.20
234 | - timm==0.6.13
235 | - tokenizers==0.20.1
236 | - tomli==2.0.2
237 | - tomlkit==0.13.2
238 | - torch==2.4.0+cu121
239 | - torchaudio==2.4.0+cu121
240 | - torchvision==0.19.0+cu121
241 | - tqdm==4.66.3
242 | - transformers==4.46.0
243 | - triton==3.0.0
244 | - typer==0.12.5
245 | - typing-extensions==4.12.2
246 | - tzdata==2024.2
247 | - urllib3==1.26.20
248 | - vine==5.1.0
249 | - voluptuous==0.15.2
250 | - wandb==0.18.3
251 | - wasabi==1.1.3
252 | - wcwidth==0.2.13
253 | - weasel==0.4.1
254 | - webdataset==0.2.100
255 | - wrapt==1.16.0
256 | - xformers==0.0.27.post2
257 | - xmltodict==0.14.2
258 | - xxhash==3.5.0
259 | - yapf==0.40.2
260 | - yarl==1.13.1
261 | - zc-lockfile==3.0.post1
262 | - zipp==3.20.2
263 | prefix: /home/u1120240347/miniconda3/envs/hyperunet
264 |
--------------------------------------------------------------------------------
/training-free/infer_sd1-5_x0_optim_mask_final_para.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | from diffusers import StableDiffusionPipeline, UNet2DConditionModel
5 | from diffusers import DPMSolverMultistepScheduler, DDIMScheduler
6 | from sd_utils_x0_dpm import register_sd_forward, register_sdschedule_step
7 | import ImageReward as RM
8 | from torch import Tensor
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | try:
12 | from torchvision.transforms import InterpolationMode
13 | BICUBIC = InterpolationMode.BICUBIC
14 | except ImportError:
15 | BICUBIC = Image.BICUBIC
16 |
17 | from torchvision.transforms import ToPILImage
18 | from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
19 | from accelerate import Accelerator # 导入 Accelerator
20 | import gc
21 | from hpsv2_loss import HPSV2Loss
22 | import xformers
23 | import time # 导入 time 模块以记录时间
24 | import json
25 |
26 | def gumbel_sigmoid(logits: Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> Tensor:
27 | """
28 | Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
29 | The discretization converts the values greater than threshold to 1 and the rest to 0.
30 | The code is adapted from the official PyTorch implementation of gumbel_softmax:
31 | https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
32 |
33 | Args:
34 | logits: [..., num_features] unnormalized log probabilities
35 | tau: non-negative scalar temperature
36 | hard: if True, the returned samples will be discretized,
37 | but will be differentiated as if it is the soft sample in autograd
38 | threshold: threshold for the discretization,
39 | values greater than this will be set to 1 and the rest to 0
40 |
41 | Returns:
42 | Sampled tensor of same shape as logits from the Gumbel-Sigmoid distribution.
43 | If hard=True, the returned samples are descretized according to threshold, otherwise they will
44 | be probability distributions.
45 |
46 | """
47 | gumbels = (
48 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
49 | ) # ~Gumbel(0, 1)
50 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits, tau)
51 | y_soft = gumbels.sigmoid()
52 |
53 | if hard:
54 | # Straight through.
55 | indices = (y_soft > threshold).nonzero(as_tuple=True)
56 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
57 | y_hard[indices[0], indices[1]] = 1.0
58 | ret = y_hard - y_soft.detach() + y_soft
59 | else:
60 | # Reparametrization trick.
61 | ret = y_soft
62 | return ret
63 |
64 | class MaskApplier:
65 | def __init__(self, unet):
66 | self.unet = unet
67 | self.masks = nn.ParameterDict()
68 | self.hooks = []
69 | self.mask_probs = {}
70 | self.mask_samples = {}
71 | self.initial_masks = {} # 保存初始mask状态
72 | self.module_dict = {} # 添加此字典来存储 safe_name 到 module 的映射
73 |
74 | # Initialize mask logits
75 | skip_layers = ["time_emb_proj", "ff", "conv_shortcut", "proj_in", "proj_out"]
76 | for name, module in self.unet.up_blocks.named_modules():
77 | if isinstance(module, (torch.nn.Linear)) and not any(skip_layer in name for skip_layer in skip_layers):
78 | safe_name = name.replace('.', '_')
79 | mask_shape = module.weight.shape
80 | # Initialize logits to a high value to start with masks mostly turned on
81 | mask_logit = torch.ones(mask_shape, device=module.weight.device, dtype=torch.float32)
82 | self.masks[safe_name] = nn.Parameter(mask_logit)
83 | self.initial_masks[safe_name] = mask_logit.clone() # 保存初始状态
84 | self.module_dict[safe_name] = module # 保存 safe_name 到 module 的映射
85 | hook = module.register_forward_hook(self.hook_fn(safe_name))
86 | self.hooks.append(hook)
87 |
88 | def hook_fn(self, layer_name):
89 | def apply_mask(module, input, output):
90 | # Use gumbel_sigmoid to sample mask
91 | logits = self.masks[layer_name].to(module.weight.device)
92 | mask = gumbel_sigmoid(logits, hard=True)
93 | batch_size = input[0].size(0)
94 | # Store the probabilities and samples for policy gradient update
95 | # with torch.no_grad():
96 | # probs = torch.sigmoid(logits)
97 | # self.mask_probs[layer_name] = probs.detach()
98 | # self.mask_samples[layer_name] = mask.detach()
99 | # Apply the binary mask to the weights
100 | masked_weight = module.weight * mask
101 | # Recompute the output using the masked weights without modifying module.weight
102 | if isinstance(module, nn.Conv2d):
103 | bias = module.bias
104 | stride = module.stride
105 | padding = module.padding
106 | dilation = module.dilation
107 | groups = module.groups
108 | # Recompute the output
109 | return F.conv2d(input[0], masked_weight, bias, stride, padding, dilation, groups)
110 | elif isinstance(module, nn.Linear):
111 | weight_shape = module.weight.shape
112 | # print("input", input[0].shape)
113 |
114 | output = F.linear(input[0], masked_weight)
115 | if module.bias is not None:
116 | output += module.bias
117 | return output
118 | else:
119 | return output # For layers that are not Conv2d or Linear
120 | return apply_mask
121 |
122 | def reset_masks(self):
123 | # 重置所有 mask 到初始状态
124 | for name in self.masks.keys():
125 | self.masks[name].data = self.initial_masks[name].clone()
126 |
127 |
128 | class MaskOptimizer:
129 | def __init__(self, unet, mask_applier, reward_model, hpsv2_model, accelerator, num_iters=20):
130 | self.unet = unet
131 | self.reward_model = reward_model
132 | self.hpsv2_model = hpsv2_model
133 | self.mask_applier = mask_applier
134 | self.num_iters = num_iters
135 | self.latent_cache = None
136 | self.accelerator = accelerator # 保存 accelerator 实例
137 |
138 | def optimize_masks(self, prompt, pipe, num_iterations=20, seed=42, num_inference_steps=10):
139 | rm_input_ids = self.reward_model.module.blip.tokenizer(
140 | prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt"
141 | ).input_ids.to(self.accelerator.device)
142 | rm_attention_mask = self.reward_model.module.blip.tokenizer(
143 | prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt"
144 | ).attention_mask.to(self.accelerator.device)
145 | gen_image = []
146 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, self.mask_applier.masks.parameters()), lr=1e-2)
147 | optimizer = self.accelerator.prepare(optimizer)
148 | # optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.mask_applier.masks.parameters()), lr=1e-2)
149 | # 对每个时间步执行独立的 mask 优化
150 | for stop_step in range(0, num_inference_steps):
151 | pipe.stop_step = stop_step
152 | if stop_step <= 6:
153 | num_iterations = 1
154 | else:
155 | num_iterations = 10
156 | for iteration in range(num_iterations):
157 | # generator = torch.Generator(device="cuda").manual_seed(seed)
158 | generator = None
159 |
160 | # Generate image
161 | with self.accelerator.autocast():
162 | image, cur_latents = pipe.forward(prompt=prompt, num_inference_steps=num_inference_steps,
163 | guidance_scale=7.5, generator=generator,
164 | latents = self.latent_cache, return_dict=False,
165 | widh=512,height=512)
166 | x_0_per_step = image
167 |
168 | # 保存生成的图像
169 | if stop_step == num_inference_steps-1 and iteration == num_iterations-1:
170 | save_imge = x_0_per_step.squeeze(0)
171 | to_pil = ToPILImage()
172 | image_save = to_pil(save_imge.cpu())
173 | gen_image.append(image_save)
174 | # image_save.save(f"{save_path}/{prompt}/seed{seed}_step{stop_step+1}_iter{iteration+1}_x0.png")
175 | # gc.collect()
176 | # torch.cuda.empty_cache() # 清理
177 | # 预处理图像
178 | rm_preprocess = Compose([
179 | Resize(224, interpolation=BICUBIC),
180 | CenterCrop(224),
181 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
182 | ])
183 | x_0_per_step = rm_preprocess(x_0_per_step)
184 | x_0_per_step = x_0_per_step.to(self.accelerator.device)
185 |
186 |
187 |
188 | # 准备输入数据
189 |
190 | # torch.autograd.set_detect_anomaly(True)
191 | # Compute reward and loss
192 | with self.accelerator.autocast():
193 | hpsv2_loss = self.hpsv2_model.score_grad(x_0_per_step, prompt)
194 | # print("hpsv2", hpsv2_loss)
195 | reward = self.reward_model.module.score_gard(rm_input_ids, rm_attention_mask, x_0_per_step)
196 | loss = F.relu(-reward + 2).mean()
197 | loss = loss + 5.0*hpsv2_loss
198 | # print(loss)
199 |
200 |
201 |
202 |
203 | self.accelerator.backward(loss) # 使用 accelerator 进行反向传播
204 |
205 | # for name, param in self.mask_applier.masks.items():
206 | # if param.grad is not None:
207 | # print(f"Mask logits for {name}, grad mean: {param.grad.mean().item()}, grad std: {param.grad.std().item()}")
208 |
209 | # 可选:梯度裁剪
210 | # self.accelerator.clip_grad_norm_(self.unet.parameters(), max_norm=1.0)
211 | # self.accelerator.clip_grad_norm_(self.mask_applier.masks.parameters(), max_norm=1.0)
212 |
213 | optimizer.step() # 更新参数
214 | # for name, param in self.unet.named_parameters():
215 | # if param.requires_grad:
216 | # mask = gumbel_sigmoid(param.data, tau=1.0, hard=True)
217 | # zeros = (mask == 0).sum().item()
218 | # ones = (mask == 1).sum().item()
219 | # print(f"After update - {name}: max {param.data.max()}, min {param.data.min()}")
220 | # print(f"Gumbel Sigmoid applied on {name}: Zeros={zeros}, Ones={ones}")
221 | optimizer.zero_grad() # 重置优化器梯度
222 |
223 | # 打印参数更新信息
224 | # for name, param in self.unet.named_parameters():
225 | # if param.requires_grad:
226 | # print(f"After update - {name}: max {param.data.max()}, min {param.data.min()}")
227 | # for name, param in self.unet.named_parameters():
228 | # if param.requires_grad:
229 | # mask = gumbel_sigmoid(param.data, tau=1.0, hard=True)
230 | # zeros = (mask == 0).sum().item()
231 | # ones = (mask == 1).sum().item()
232 | # print(f"After update - {name}: max {param.data.max()}, min {param.data.min()}")
233 | # print(f"Gumbel Sigmoid applied on {name}: Zeros={zeros}, Ones={ones}")
234 |
235 | # 清理内存
236 | if stop_step == num_inference_steps-1 and iteration == num_iterations-1:
237 | print(f"Step {stop_step+1}/{num_inference_steps}, Iteration {iteration + 1}/{num_iterations}, Reward: {reward.item()}")
238 | del image, x_0_per_step, reward, loss, hpsv2_loss
239 | gc.collect()
240 | torch.cuda.empty_cache() # 添加这行代码释放显存
241 |
242 | self.latent_cache = cur_latents.detach().clone()
243 | # 每个时间步结束后重置mask
244 | # gc.collect()
245 | self.mask_applier.reset_masks()
246 | torch.cuda.empty_cache()
247 | # print(f"Mask reset after step {stop_step+1}")
248 | return gen_image
249 |
250 | def main():
251 | # Set a fixed seed for reproducibility
252 | # seed = 42
253 | # torch.manual_seed(seed)
254 | # torch.cuda.manual_seed_all(seed)
255 |
256 | # 初始化 Accelerator
257 | accelerator = Accelerator(device_placement=True, mixed_precision='fp16') # 使用混合精度
258 |
259 | # Load the UNet model
260 | # unet = UNet2DConditionModel.from_pretrained(
261 | # "../share/runwayml/stable-diffusion-v1-5", subfolder="unet",
262 | # torch_dtype=torch.float32
263 | # )
264 |
265 | # 创建 MaskApplier 实例
266 |
267 |
268 | # scheduler = DDIMScheduler.from_pretrained("../share/runwayml/stable-diffusion-v1-5", subfolder="scheduler")
269 | scheduler = DPMSolverMultistepScheduler.from_pretrained("../share/runwayml/stable-diffusion-v1-5", subfolder="scheduler", algorithm_type="dpmsolver++")
270 | # Load the Stable Diffusion pipeline with the modified UNet model
271 | pipe = StableDiffusionPipeline.from_pretrained(
272 | "../share/runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
273 | scheduler=scheduler
274 | )
275 |
276 | mask_applier = MaskApplier(pipe.unet)
277 |
278 | # 使用 accelerator.prepare() 一起包装模型和 MaskApplier
279 | unet, mask_applier = accelerator.prepare(pipe.unet, mask_applier)
280 |
281 | pipe.to(accelerator.device)
282 |
283 | # pipe.unet.enable_gradient_checkpointing()
284 | pipe.unet.enable_xformers_memory_efficient_attention()
285 |
286 | register_sd_forward(pipe)
287 | register_sdschedule_step(pipe.scheduler)
288 |
289 | def dummy_checker(images, **kwargs):
290 | # Return the images as is, and set NSFW detection results to False (not NSFW)
291 | return images, [False] * len(images)
292 |
293 | pipe.safety_checker = dummy_checker
294 |
295 | save_dir = "training-free_mask_fix_stop7_iter10"
296 |
297 | # Create MaskApplier and ImageReward instances
298 | reward_model = RM.load("ImageReward-v1.0").to(device = accelerator.device, dtype = torch.float16)
299 | reward_model = accelerator.prepare(reward_model)
300 | #HPSV2 model
301 | hpsv2_model = HPSV2Loss(
302 | dtype = torch.float16,
303 | device = accelerator.device,
304 | cache_dir = "./HPSV2_checkpoint"
305 | # memsave = True
306 | )
307 | hpsv2_model = accelerator.prepare(hpsv2_model)
308 |
309 |
310 | prompt_list_file = "./geneval/prompts/evaluation_metadata.jsonl"
311 | with open(prompt_list_file) as fp:
312 | metadatas = [json.loads(line) for line in fp]
313 |
314 |
315 | total_prompts = len(metadatas)
316 | num_processes = accelerator.num_processes
317 | prompts_per_process = total_prompts // num_processes
318 | start_index = accelerator.process_index * prompts_per_process
319 | end_index = start_index + prompts_per_process if accelerator.process_index != num_processes - 1 else total_prompts
320 |
321 | # Process prompts assigned to this process
322 | for idx in range(start_index, end_index):
323 | prompt = metadatas[idx]["prompt"]
324 |
325 | # Create output directory for each prompt
326 | outdir = f"{save_dir}/{idx:0>5}"
327 | os.makedirs(f"{outdir}/samples", exist_ok=True)
328 |
329 | # Save metadata for each prompt
330 | with open(f"{outdir}/metadata.jsonl", "w") as fp:
331 | json.dump(metadatas[idx], fp)
332 |
333 | # Create Mask Optimizer
334 | mask_optimizer = MaskOptimizer(unet, mask_applier, reward_model, hpsv2_model, accelerator)
335 |
336 | # Run the mask optimization process
337 | image = mask_optimizer.optimize_masks([prompt], pipe, seed=None, num_inference_steps=15)
338 |
339 | # Save the generated image
340 | for i, img in enumerate(image):
341 | img.save(f"{outdir}/samples/{i:05}.png")
342 |
343 | del image
344 | torch.cuda.empty_cache() # 清空显存缓存
345 |
346 | accelerator.wait_for_everyone()
347 | print(f"Process {accelerator.process_index}: Finished generating images.")
348 | # sample_out = 0
349 | # for index, metadata in enumerate(metadatas):
350 | # prompt = metadata["prompt"]
351 | # outdir = f"{save_dir}"
352 | # outpath = f"{outdir}/{index:0>5}"
353 | # os.makedirs(f"{outpath}/samples", exist_ok=True)
354 | # with open(f"{outpath}/metadata.jsonl", "w") as fp:
355 | # json.dump(metadata, fp)
356 | # # prompt = "a cat wearing a hat and a dog wearing a glasses"
357 | # # Create Mask Optimizer
358 | # mask_optimizer = MaskOptimizer(unet, mask_applier, reward_model, hpsv2_model, accelerator)
359 |
360 | # # Optimize masks
361 | # # start_time = time.time()
362 | # batch_prompts = [prompt for _ in range(1)] # 假设每次处理1个 prompt
363 | # image = mask_optimizer.optimize_masks(batch_prompts, pipe, seed = seed, num_inference_steps=15)
364 | # # end_time = time.time() # 记录图像生成结束时间
365 | # # duration = end_time - start_time # 计算生成时间
366 | # for i, img in enumerate(image):
367 | # img.save(f"{outpath}/samples/{sample_out + i:05}.png")
368 | # sample_out += len(image)
369 | # # print(f"Time taken for mask gen: {duration:.2f} seconds")
370 | # print("Image generation with optimized mask completed and saved.")
371 |
372 | if __name__ == "__main__":
373 | main()
--------------------------------------------------------------------------------
/training/hyperunet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | from typing import Optional, Union, Tuple, List, Callable, Dict, Any
5 | from copy import deepcopy
6 | from accelerate import Accelerator
7 |
8 | from diffusers.utils import USE_PEFT_BACKEND, deprecate, scale_lora_layers, unscale_lora_layers, BaseOutput
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from diffusers import UNet2DConditionModel
12 | from dataclasses import dataclass
13 | from functools import partial
14 | from einops import rearrange
15 | from torch.cuda.amp import autocast, GradScaler
16 | from torch import Tensor
17 |
18 | @dataclass
19 | class UNet2DConditionOutput(BaseOutput):
20 | """
21 | The output of [`UNet2DConditionModel`].
22 |
23 | Args:
24 | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
25 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
26 | """
27 |
28 | sample: torch.Tensor = None
29 |
30 | def weights_init_kaiming(m):
31 | """
32 | Kaiming (He) Initialization for Conv2D and Linear layers.
33 | """
34 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
35 | nn.init.constant_(m.weight, 0.01) # 初始化为接近恒等映射的小值
36 | if m.bias is not None:
37 | nn.init.constant_(m.bias, 0) # 偏置设置为0
38 |
39 | class DecoderWeightsMaskGenerator(nn.Module):
40 | def __init__(self, in_channels, out_channels, hidden_dim=64, factor=16):
41 | super().__init__()
42 |
43 | new_out = out_channels // factor
44 | new_in = in_channels // factor
45 | # new_out = 128
46 | # new_in = 128
47 | self.conv_kernel_mask = nn.Sequential(
48 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(hidden_dim, new_in * new_out, kernel_size=1),
51 | nn.ReLU(inplace=True)
52 | )
53 |
54 | self.proj_temb = nn.Linear(1280, in_channels, bias=False)
55 | # self.proj_pemb = nn.Linear(768, in_channels, bias=False)
56 |
57 | self.in_c = new_in
58 | self.out_c = new_out
59 | self.apply(self.weights_init)
60 |
61 | def weights_init(self, m):
62 | if isinstance(m, nn.Conv2d):
63 | if m.kernel_size == (1, 1) and m.out_channels == self.in_c * self.out_c:
64 | nn.init.constant_(m.weight, 0)
65 | if m.bias is not None:
66 | nn.init.constant_(m.bias, 0)
67 | else:
68 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
69 | if m.bias is not None:
70 | nn.init.constant_(m.bias, 0)
71 |
72 | def forward(self, weight_shape, sample, res_sample, temb, encoder_hidden_states):
73 | batch_size = sample.size(0)
74 | flag = len(weight_shape)
75 | #emb [n,c]
76 | temb = self.proj_temb(temb).unsqueeze(-1).unsqueeze(-1)
77 | # pemb = self.proj_pemb(encoder_hidden_states).permute(0, 2, 1).contiguous().mean(-1).unsqueeze(-1).unsqueeze(-1)
78 | if res_sample is None:
79 | x = sample
80 | else:
81 | x = torch.cat([sample, res_sample], dim=1)
82 | if flag == 4:
83 | _, _, k_h, k_w = weight_shape
84 | x = F.adaptive_avg_pool2d(x, (k_h, k_w))
85 | else:
86 | x = F.adaptive_avg_pool2d(x, (1, 1))
87 |
88 | x = x + temb
89 | # x = temb
90 |
91 | mask = self.conv_kernel_mask(x)
92 | if flag == 4:
93 | mask = mask.view(mask.size(0), self.out_c, self.in_c, k_h, k_w)
94 | else:
95 | mask = mask.view(mask.size(0), self.out_c, self.in_c)
96 |
97 | return mask
98 |
99 |
100 | def gumbel_sigmoid(logits: Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> Tensor:
101 | """
102 | Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
103 | The discretization converts the values greater than `threshold` to 1 and the rest to 0.
104 | The code is adapted from the official PyTorch implementation of gumbel_softmax:
105 | https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
106 |
107 | Args:
108 | logits: `[..., num_features]` unnormalized log probabilities
109 | tau: non-negative scalar temperature
110 | hard: if ``True``, the returned samples will be discretized,
111 | but will be differentiated as if it is the soft sample in autograd
112 | threshold: threshold for the discretization,
113 | values greater than this will be set to 1 and the rest to 0
114 |
115 | Returns:
116 | Sampled tensor of same shape as `logits` from the Gumbel-Sigmoid distribution.
117 | If ``hard=True``, the returned samples are descretized according to `threshold`, otherwise they will
118 | be probability distributions.
119 |
120 | """
121 | gumbels = (
122 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
123 | ) # ~Gumbel(0, 1)
124 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits, tau)
125 | y_soft = gumbels.sigmoid()
126 |
127 | if hard:
128 | # Straight through.
129 | indices = (y_soft > threshold).nonzero(as_tuple=True)
130 | y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
131 | y_hard[indices[0], indices[1]] = 1.0
132 | ret = y_hard - y_soft.detach() + y_soft
133 | else:
134 | # Reparametrization trick.
135 | ret = y_soft
136 | return ret
137 |
138 |
139 | class Adapter(nn.Module):
140 | def __init__(self, out_c, in_c, new_out_c, new_in_c, tau = 1.0):
141 | super(Adapter, self).__init__()
142 | # Use 1x1 convolutions for channel adaptation
143 | self.conv_in = nn.Conv2d(out_c, new_out_c, kernel_size=1)
144 | self.conv_out = nn.Conv2d(in_c, new_in_c, kernel_size=1)
145 | self.tau = tau
146 |
147 | self.apply(self.weights_init)
148 |
149 | def weights_init(self, m):
150 | """
151 | Kaiming (He) Initialization for Conv2D and Linear layers.
152 | """
153 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
154 | nn.init.constant_(m.weight, 0) # 初始化为0
155 | if m.bias is not None:
156 | nn.init.constant_(m.bias, 5.0) # 偏置设置为5.0
157 |
158 | def forward(self, input_tensor_3x3):
159 | # The input tensor shape is [batch_size, out_c, in_c, k_h, k_w]
160 |
161 | # Step 1: Adapt the in_channels using a 1x1 convolution
162 | # We rearrange the tensor to merge the out_c dimension into the batch size for efficient processing
163 | if input_tensor_3x3.dim() == 3:
164 | input_tensor_3x3 = input_tensor_3x3.unsqueeze(-1).unsqueeze(-1)
165 | input_adapted = rearrange(input_tensor_3x3, 'b out_c in_c h w -> (b in_c) out_c h w')
166 |
167 | # if self.conv_in.bias is not None:
168 | # self.conv_in.bias.data = self.conv_in.bias.data.to(input_adapted.dtype)
169 | # Apply the in-channel adaptation
170 | input_adapted = self.conv_in(input_adapted)
171 |
172 | # Step 2: Adapt the out_channels
173 | # Rearrange to bring the adapted in_channels back into the batch dimension for processing
174 | input_adapted = rearrange(input_adapted, '(b in_c) new_out_c h w -> b new_out_c in_c h w', b=input_tensor_3x3.shape[0])
175 | input_adapted = rearrange(input_adapted, 'b new_out_c in_c h w -> (b new_out_c) in_c h w')
176 |
177 | # if self.conv_out.bias is not None:
178 | # self.conv_out.bias.data = self.conv_out.bias.data.to(input_adapted.dtype)
179 | # Apply the out-channel adaptation
180 | output = self.conv_out(input_adapted)
181 | output = gumbel_sigmoid(output, tau=self.tau, hard=True)
182 | # Step 3: Reshape the output back to the original format with new dimensions
183 | output = rearrange(output, '(b new_out_c) new_in_c h w -> b new_out_c new_in_c h w', b=input_tensor_3x3.shape[0])
184 |
185 | if output.size(-1)==1:
186 | output = output.squeeze(-1).squeeze(-1)
187 | return output
188 |
189 |
190 | def get_hook_fn(sample, temb, res_sample, mask_generator, adapter, encoder_hidden_states):
191 | def hook_fn(module, input, output):
192 | batch_size = input[0].size(0)
193 |
194 | if isinstance(module, nn.Conv2d):
195 | weight_shape = module.weight.shape
196 | mask = mask_generator(weight_shape, sample, res_sample, temb, encoder_hidden_states).to(module.weight.device)
197 | # print(module.weight.shape)
198 | # print(mask.shape)
199 | # print()
200 | mask = adapter(mask)
201 | masked_weight = module.weight * mask
202 | masked_weight = masked_weight.reshape(-1, *module.weight.shape[1:]).contiguous() # 使用 reshape 替代 view
203 | input_reshaped = input[0].reshape(1, -1, *input[0].shape[2:]).contiguous() # 使用 reshape 替代 view
204 |
205 | # 如果有 bias,则需要处理 bias 的形状
206 | if module.bias is not None:
207 | # 确保 bias 的形状正确,并直接扩展到 batch_size 维度
208 | masked_bias = module.bias.repeat(batch_size).contiguous() # 扩展 bias 到与分组卷积匹配的大小
209 | else:
210 | masked_bias = None
211 |
212 | output = F.conv2d(
213 | input_reshaped,
214 | masked_weight,
215 | bias=masked_bias,
216 | stride=module.stride,
217 | padding=module.padding,
218 | dilation=module.dilation,
219 | groups=batch_size
220 | )
221 | output = output.reshape(batch_size, -1, *output.shape[2:]).contiguous() # 使用 reshape 替代 view
222 | return output
223 |
224 | elif isinstance(module, nn.Linear):
225 | if module.weight.dim() == 2:
226 | weight_shape = module.weight.shape
227 | mask = mask_generator(weight_shape, sample, res_sample, temb, encoder_hidden_states).to(module.weight.device)
228 | mask = adapter(mask)
229 | masked_weight = module.weight * mask
230 | if input[0].dim() == 2:
231 | input_batched = input[0].unsqueeze(1) # (batch_size, 1, in_features)
232 | else:
233 | input_batched = input[0]
234 | # print(input_batched.shape)
235 | # assert 2==1
236 | output = torch.bmm(input_batched, masked_weight.permute(0, 2, 1).contiguous()).squeeze(1)
237 | if module.bias is not None:
238 | output += module.bias.unsqueeze(0).expand_as(output)
239 | return output
240 | return hook_fn
241 |
242 | # def __init__(self, *args, cross_attention_dim=768, **kwargs):
243 | # super(FineGrainedUNet2DConditionModel, self).__init__(*args, cross_attention_dim=cross_attention_dim, **kwargs)
244 | class FineGrainedUNet2DConditionModel(UNet2DConditionModel):
245 | def __init__(self, *args, cross_attention_dim=768, use_linear_projection=False,**kwargs):
246 | super(FineGrainedUNet2DConditionModel, self).__init__(*args, cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, **kwargs)
247 | self.mask_generators = nn.ModuleList()
248 | # self.mask_generators_down = nn.ModuleList()
249 | # self.layer_adapters = {}
250 | self.adapters = nn.ModuleList() # 使用 ModuleList 代替普通 list
251 | # self.adapters_down = nn.ModuleList() # 使用 ModuleList 代替普通 list
252 | # self.decoder_weights_mask_generator.apply(self.init_weights)
253 |
254 | self._hooks = []
255 |
256 | # for name, module in self.up_blocks.named_modules():
257 | # if 'proj_out' in name:
258 | # print(f"{name} layer structure: {module}")
259 | # print(f"Shape of {name} weights: {module.weight.shape}")
260 |
261 | # assert 2==1
262 |
263 | skip_layers = ["time_emb_proj", "ff", "conv_shortcut", "proj_in", "proj_out"]
264 |
265 | # for i, downsample_block in enumerate(self.down_blocks):
266 | # print(f"--- Upsample Block {i} ---")
267 | # first_resnet_conv1 = downsample_block.resnets[0].conv1 # Get resnets.0.conv1 layer
268 | # if isinstance(first_resnet_conv1, nn.Conv2d):
269 | # in_channels = first_resnet_conv1.in_channels
270 | # out_channels = first_resnet_conv1.out_channels
271 | # # out_channels = 256
272 | # print(f"Block {i}: in_channels={in_channels}, out_channels={out_channels}")
273 |
274 | # # Initialize mask_generator for this block
275 | # mask_generator = DecoderWeightsMaskGenerator(in_channels, out_channels)
276 | # self.mask_generators_down.append(mask_generator)
277 | # block_adapters = nn.ModuleDict() # 使用 ModuleDict 来存储 block 的 adapter
278 | # # Use these channels for all adapters in this block
279 | # for name, sub_module in downsample_block.named_modules():
280 | # # if isinstance(sub_module, nn.Conv2d) and not any(skip_layer in name for skip_layer in skip_layers):
281 | # # sub_in_channels = sub_module.in_channels
282 | # # sub_out_channels = sub_module.out_channels
283 | # if isinstance(sub_module, nn.Linear) and not any(skip_layer in name for skip_layer in skip_layers):
284 | # sub_in_channels = sub_module.in_features
285 | # sub_out_channels = sub_module.out_features
286 | # else:
287 | # continue
288 |
289 | # # 替换 name 中的 "." 为 "_"
290 | # sanitized_name = name.replace(".", "_")
291 | # print(f"Adapter for Layer {sanitized_name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}")
292 |
293 | # # print(f"Adapter for Layer {name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}")
294 |
295 | # # Initialize Adapter using first_resnet_conv1 channels and layer's channels
296 | # # factor
297 | # factor = 8
298 | # adapter = Adapter(out_channels // factor, in_channels // factor, sub_out_channels, sub_in_channels)
299 | # block_adapters[sanitized_name] = adapter
300 | # self.adapters_down.append(block_adapters)
301 | # factor = [16, 8, 4, 2]
302 | for i, upsample_block in enumerate(self.up_blocks):
303 | print(f"--- Upsample Block {i} ---")
304 | first_resnet_conv1 = upsample_block.resnets[0].conv1 # Get resnets.0.conv1 layer
305 | if isinstance(first_resnet_conv1, nn.Conv2d):
306 | in_channels = first_resnet_conv1.in_channels
307 | out_channels = first_resnet_conv1.out_channels
308 | # out_channels = 256
309 | # print(f"Block {i}: in_channels={in_channels}, out_channels={out_channels}")
310 | factor = 4
311 | # Initialize mask_generator for this block
312 | mask_generator = DecoderWeightsMaskGenerator(in_channels, out_channels, factor = factor)
313 | self.mask_generators.append(mask_generator)
314 | block_adapters = nn.ModuleDict() # 使用 ModuleDict 来存储 block 的 adapter
315 | # Use these channels for all adapters in this block
316 | for name, sub_module in upsample_block.named_modules():
317 | # if isinstance(sub_module, nn.Conv2d) and not any(skip_layer in name for skip_layer in skip_layers):
318 | # sub_in_channels = sub_module.in_channels
319 | # sub_out_channels = sub_module.out_channels
320 | if isinstance(sub_module, nn.Linear) and not any(skip_layer in name for skip_layer in skip_layers):
321 | sub_in_channels = sub_module.in_features
322 | sub_out_channels = sub_module.out_features
323 | else:
324 | continue
325 |
326 | # 替换 name 中的 "." 为 "_"
327 | sanitized_name = name.replace(".", "_")
328 | # print(f"Adapter for Layer {sanitized_name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}")
329 |
330 | # print(f"Adapter for Layer {name}: new_in_c={sub_in_channels}, new_out_c={sub_out_channels}")
331 |
332 | # Initialize Adapter using first_resnet_conv1 channels and layer's channels
333 | # factor
334 | # factor = 16
335 | adapter = Adapter(out_channels // factor, in_channels // factor, sub_out_channels, sub_in_channels)
336 | # adapter = Adapter(128, 128, sub_out_channels, sub_in_channels)
337 | block_adapters[sanitized_name] = adapter
338 | self.adapters.append(block_adapters)
339 |
340 | # assert 2==1
341 |
342 | # for i, upsample_block in enumerate(self.up_blocks):
343 | # print(f"--- Upsample Block {i} ---")
344 | # first_resnet_conv1 = upsample_block.resnets[0].conv1 # 获取 resnets.0.conv1 卷积层
345 | # if isinstance(first_resnet_conv1, nn.Conv2d):
346 | # in_channels = first_resnet_conv1.in_channels
347 | # out_channels = first_resnet_conv1.out_channels
348 | # self.mask_generators.append(DecoderWeightsMaskGenerator(in_channels, out_channels))
349 |
350 | # assert 2==1
351 |
352 | @staticmethod
353 | def init_weights(module):
354 | # 初始化 nn.Linear 层
355 | if isinstance(module, nn.Conv2d):
356 | nn.init.xavier_uniform_(module.weight)
357 | if module.bias is not None:
358 | nn.init.constant_(module.bias, 2.0)
359 |
360 | # # 初始化 nn.Parameter 参数
361 | elif isinstance(module, nn.Parameter):
362 | nn.init.constant_(module, 0) # 初始化为0
363 |
364 | def forward(
365 | self,
366 | sample: torch.Tensor,
367 | timestep: Union[torch.Tensor, float, int],
368 | encoder_hidden_states: torch.Tensor,
369 | class_labels: Optional[torch.Tensor] = None,
370 | timestep_cond: Optional[torch.Tensor] = None,
371 | attention_mask: Optional[torch.Tensor] = None,
372 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
373 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
374 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
375 | mid_block_additional_residual: Optional[torch.Tensor] = None,
376 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
377 | encoder_attention_mask: Optional[torch.Tensor] = None,
378 | return_dict: bool = True,
379 | ) -> Union[UNet2DConditionOutput, Tuple]:
380 |
381 | if self.config.center_input_sample:
382 | sample = 2 * sample - 1.0
383 |
384 | t_emb = self.get_time_embed(sample=sample, timestep=timestep)
385 | emb = self.time_embedding(t_emb, timestep_cond)
386 | aug_emb = None
387 |
388 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
389 | if class_emb is not None:
390 | if self.config.class_embeddings_concat:
391 | emb = torch.cat([emb, class_emb], dim=-1)
392 | else:
393 | emb = emb + class_emb
394 |
395 | aug_emb = self.get_aug_embed(
396 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
397 | )
398 | if self.config.addition_embed_type == "image_hint":
399 | aug_emb, hint = aug_emb
400 | sample = torch.cat([sample, hint], dim=1)
401 |
402 | emb = emb + aug_emb if aug_emb is not None else emb
403 |
404 | if self.time_embed_act is not None:
405 | emb = self.time_embed_act(emb)
406 |
407 | encoder_hidden_states = self.process_encoder_hidden_states(
408 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
409 | )
410 |
411 | # 2. pre-process
412 | # hook = self.conv_in.register_forward_hook(get_hook_fn(sample, emb))
413 | # self._hooks.append(hook)
414 | # self.conv_in.decoder_weights_mask_generator = self.decoder_weights_mask_generator
415 | sample = self.conv_in(sample)
416 |
417 | down_block_res_samples = (sample,)
418 |
419 | for i, downsample_block in enumerate(self.down_blocks):
420 |
421 | # mask_generator = self.mask_generators_down[i]
422 | # adapters = self.adapters_down[i]
423 |
424 | # self._register_hooks(downsample_block, sample, None, emb, mask_generator, adapters)
425 | # self._register_hooks(downsample_block, sample, emb)
426 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
427 | sample, res_samples = downsample_block(
428 | hidden_states=sample,
429 | temb=emb,
430 | encoder_hidden_states=encoder_hidden_states,
431 | attention_mask=attention_mask,
432 | cross_attention_kwargs=cross_attention_kwargs,
433 | encoder_attention_mask=encoder_attention_mask,
434 | )
435 | else:
436 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
437 | down_block_res_samples += res_samples
438 |
439 | # 4. mid
440 | if self.mid_block is not None:
441 | # self._register_hooks(self.mid_block, sample, emb)
442 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
443 | sample = self.mid_block(
444 | sample,
445 | emb,
446 | encoder_hidden_states=encoder_hidden_states,
447 | attention_mask=attention_mask,
448 | cross_attention_kwargs=cross_attention_kwargs,
449 | encoder_attention_mask=encoder_attention_mask,
450 | )
451 | else:
452 | sample = self.mid_block(sample, emb)
453 |
454 | # 5. up
455 | for i, upsample_block in enumerate(self.up_blocks):
456 | # print("sample", sample.shape)
457 | mask_generator = self.mask_generators[i]
458 | adapters = self.adapters[i]
459 | is_final_block = i == len(self.up_blocks) - 1
460 |
461 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
462 | down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
463 | # 打印 res_samples 的形状
464 | # print(res_samples[-1].shape)
465 | # for i, res_sample in enumerate(res_samples):
466 | # print(f"Shape of res_sample {i}: {res_sample.shape}")
467 |
468 | if not is_final_block:
469 | upsample_size = down_block_res_samples[-1].shape[2:]
470 | else:
471 | upsample_size = None
472 |
473 | self._register_hooks(upsample_block, sample ,res_samples[-1], emb, mask_generator, adapters, encoder_hidden_states)
474 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
475 | sample = upsample_block(
476 | hidden_states=sample,
477 | temb=emb,
478 | res_hidden_states_tuple=res_samples,
479 | encoder_hidden_states=encoder_hidden_states,
480 | cross_attention_kwargs=cross_attention_kwargs,
481 | upsample_size=upsample_size,
482 | attention_mask=attention_mask,
483 | encoder_attention_mask=encoder_attention_mask,
484 | )
485 | else:
486 | sample = upsample_block(
487 | hidden_states=sample,
488 | temb=emb,
489 | res_hidden_states_tuple=res_samples,
490 | upsample_size=upsample_size,
491 | )
492 |
493 | # 6. post-process
494 | if self.conv_norm_out:
495 | # self._register_hooks(self.conv_norm_out, sample, emb)
496 | sample = self.conv_norm_out(sample)
497 | sample = self.conv_act(sample)
498 | # self._register_hooks(self.conv_out, sample, emb)
499 | sample = self.conv_out(sample)
500 |
501 | # Remove hooks
502 | for h in self._hooks:
503 | h.remove()
504 | self._hooks = []
505 |
506 | if not return_dict:
507 | return (sample,)
508 |
509 | return UNet2DConditionOutput(sample=sample)
510 |
511 | def _register_hooks(self, module, sample, res_sample, temb, mask_generators, adapters, encoder_hidden_states):
512 | skip_layers = ["time_emb_proj", "ff", "conv_shortcut", "proj_in", "proj_out"]
513 | for name, sub_module in module.named_modules():
514 | #nn.Conv2d,
515 | if isinstance(sub_module, (nn.Linear)) and not any(skip_layer in name for skip_layer in skip_layers):
516 | # 需要使用 sanitized_name 替换 "." 为 "_"
517 | sanitized_name = name.replace(".", "_")
518 |
519 | # 使用 "in" 判断键是否存在
520 | if sanitized_name in adapters:
521 | adapter = adapters[sanitized_name]
522 | if adapter is None:
523 | continue
524 | hook = sub_module.register_forward_hook(get_hook_fn(sample, temb, res_sample, mask_generators, adapter, encoder_hidden_states))
525 | self._hooks.append(hook)
526 | # print(f"Hook registered for layer: {name}")
527 |
528 |
529 | # def main():
530 | # # 示例训练代码
531 | # unet = FineGrainedUNet2DConditionModel.from_pretrained(
532 | # "../share/runwayml/stable-diffusion-v1-5", subfolder="unet",
533 | # low_cpu_mem_usage=False,
534 | # device_map=None
535 | # )
536 |
537 | # # print(unet)
538 |
539 | # # 冻结其他参数,只训练 decoder_weights_mask_generator
540 | # for param in unet.parameters():
541 | # param.requires_grad = False
542 |
543 | # for param in unet.mask_generators.parameters():
544 | # param.requires_grad = True
545 |
546 | # for block_adapters in unet.adapters:
547 | # for adapter in block_adapters.values():
548 | # for param in adapter.parameters():
549 | # param.requires_grad = True
550 |
551 | # sample = torch.randn(1, 4, 64, 64)
552 | # timestep = torch.tensor([50])
553 | # encoder_hidden_states = torch.randn(1, 77, 768)
554 |
555 | # output = unet(sample, timestep, encoder_hidden_states)['sample']
556 |
557 | # # 假设一个简单的损失函数
558 | # target = torch.randn_like(output)
559 | # loss = F.mse_loss(output, target)
560 | # loss.backward()
561 |
562 | # # 检查梯度
563 | # for i, mask_generator in enumerate(unet.mask_generators):
564 | # for param_name, param in mask_generator.named_parameters():
565 | # if param.grad is None:
566 | # print(f"No gradient for mask_generator in block {i}, parameter: {param_name}")
567 | # else:
568 | # print(f"Gradient for mask_generator in block {i}, parameter {param_name}: {param.grad.norm()}")
569 |
570 |
571 | # for i, block_adapters in enumerate(unet.adapters):
572 | # for name, adapter in block_adapters.items():
573 | # for param_name, param in adapter.named_parameters():
574 | # if param.grad is None:
575 | # print(f"No gradient for adapter {name} in block {i}, parameter: {param_name}")
576 | # else:
577 | # print(f"Gradient for adapter {name} in block {i}, parameter {param_name}: {param.grad.norm()}")
578 |
579 | def main():
580 | # Initialize the accelerator with FP16 precision
581 | accelerator = Accelerator(mixed_precision="fp16")
582 |
583 | # Initialize model
584 | unet = FineGrainedUNet2DConditionModel.from_pretrained(
585 | "../share/runwayml/stable-diffusion-v1-5",
586 | subfolder="unet",
587 | low_cpu_mem_usage=False,
588 | device_map=None,
589 | )
590 |
591 | # Freeze other parameters, only train mask_generators and adapters
592 | for param in unet.parameters():
593 | param.requires_grad = False
594 | for param in unet.mask_generators.parameters():
595 | param.requires_grad = True
596 | for block_adapters in unet.adapters:
597 | for adapter in block_adapters.values():
598 | for param in adapter.parameters():
599 | param.requires_grad = True
600 |
601 | # Initialize optimizer
602 | optimizer = torch.optim.Adam(
603 | filter(lambda p: p.requires_grad, unet.parameters()), lr=1e-4
604 | )
605 |
606 | for name, param in unet.named_parameters():
607 | if param.requires_grad:
608 | print(f"Parameter {name} dtype: {param.dtype}")
609 | if param.dtype == "torch.float32":
610 | print("**********************")
611 |
612 | # Prepare model and optimizer with accelerator
613 | unet, optimizer = accelerator.prepare(unet, optimizer)
614 |
615 | # Simulate input (keep as FP32)
616 | sample = torch.randn(1, 4, 64, 64).to(accelerator.device)
617 | timestep = torch.tensor([50]).to(accelerator.device)
618 | encoder_hidden_states = torch.randn(1, 77, 768).to(accelerator.device)
619 |
620 | # Forward pass
621 | optimizer.zero_grad()
622 | with accelerator.autocast():
623 | output = unet(sample, timestep, encoder_hidden_states)["sample"]
624 | # Compute loss
625 | target = torch.randn_like(output)
626 | loss = F.mse_loss(output, target)
627 |
628 | # Backward pass
629 | accelerator.backward(loss)
630 |
631 | # Clip gradients (no need to unscale manually)
632 | params_to_clip = [p for p in unet.parameters() if p.requires_grad]
633 | accelerator.clip_grad_norm_(params_to_clip, max_norm=1.0)
634 |
635 | # Optimizer step
636 | optimizer.step()
637 | optimizer.zero_grad()
638 |
639 |
640 | if __name__ == "__main__":
641 | main()
642 |
643 |
644 | # if __name__ == "__main__":
645 | # main()
646 |
--------------------------------------------------------------------------------
/training-free/sd_utils_x0_dpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | from typing import Any, Callable, Dict, List, Optional, Union
16 |
17 | import torch
18 | from packaging import version
19 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20 |
21 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22 | from diffusers.configuration_utils import FrozenDict
23 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
24 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
25 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26 | from diffusers.models.lora import adjust_lora_scale_text_encoder
27 | from diffusers.schedulers import KarrasDiffusionSchedulers
28 | from diffusers.utils import (
29 | USE_PEFT_BACKEND,
30 | deprecate,
31 | logging,
32 | replace_example_docstring,
33 | scale_lora_layers,
34 | unscale_lora_layers,
35 | )
36 | from diffusers.utils.torch_utils import randn_tensor
37 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
38 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
39 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
40 | import torch.nn.functional as F
41 | import torch.optim as optim
42 | from diffusers.pipelines.stable_diffusion_attend_and_excite.pipeline_stable_diffusion_attend_and_excite import (
43 | AttentionStore,
44 | AttendExciteAttnProcessor
45 | )
46 | from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, align_wordpieces_indices, extract_attribution_indices, extract_attribution_indices_with_verbs, extract_attribution_indices_with_verb_root, extract_entities_only
47 |
48 |
49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50 |
51 | EXAMPLE_DOC_STRING = """
52 | Examples:
53 | ```py
54 | >>> import torch
55 | >>> from diffusers import StableDiffusionPipeline
56 |
57 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
58 | >>> pipe = pipe.to("cuda")
59 |
60 | >>> prompt = "a photo of an astronaut riding a horse on mars"
61 | >>> image = pipe(prompt).images[0]
62 | ```
63 | """
64 |
65 |
66 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
67 | """
68 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
69 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
70 | """
71 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
72 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
73 | # rescale the results from guidance (fixes overexposure)
74 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
75 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
76 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
77 | return noise_cfg
78 |
79 |
80 | def retrieve_timesteps(
81 | scheduler,
82 | num_inference_steps: Optional[int] = None,
83 | device: Optional[Union[str, torch.device]] = None,
84 | timesteps: Optional[List[int]] = None,
85 | sigmas: Optional[List[float]] = None,
86 | **kwargs,
87 | ):
88 | """
89 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
90 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
91 |
92 | Args:
93 | scheduler (`SchedulerMixin`):
94 | The scheduler to get timesteps from.
95 | num_inference_steps (`int`):
96 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
97 | must be `None`.
98 | device (`str` or `torch.device`, *optional*):
99 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
100 | timesteps (`List[int]`, *optional*):
101 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
102 | `num_inference_steps` and `sigmas` must be `None`.
103 | sigmas (`List[float]`, *optional*):
104 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
105 | `num_inference_steps` and `timesteps` must be `None`.
106 |
107 | Returns:
108 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
109 | second element is the number of inference steps.
110 | """
111 | if timesteps is not None and sigmas is not None:
112 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
113 | if timesteps is not None:
114 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
115 | if not accepts_timesteps:
116 | raise ValueError(
117 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118 | f" timestep schedules. Please check whether you are using the correct scheduler."
119 | )
120 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
121 | timesteps = scheduler.timesteps
122 | num_inference_steps = len(timesteps)
123 | elif sigmas is not None:
124 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125 | if not accept_sigmas:
126 | raise ValueError(
127 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128 | f" sigmas schedules. Please check whether you are using the correct scheduler."
129 | )
130 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131 | timesteps = scheduler.timesteps
132 | num_inference_steps = len(timesteps)
133 | else:
134 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135 | timesteps = scheduler.timesteps
136 | return timesteps, num_inference_steps
137 |
138 |
139 | def register_sd_forward(model):
140 | def sd_forward(self):
141 | # @torch.no_grad()
142 | @replace_example_docstring(EXAMPLE_DOC_STRING)
143 | def forward(
144 | prompt: Union[str, List[str]] = None,
145 | height: Optional[int] = None,
146 | width: Optional[int] = None,
147 | num_inference_steps: int = 50,
148 | timesteps: List[int] = None,
149 | sigmas: List[float] = None,
150 | guidance_scale: float = 7.5,
151 | negative_prompt: Optional[Union[str, List[str]]] = None,
152 | num_images_per_prompt: Optional[int] = 1,
153 | eta: float = 0.0,
154 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
155 | latents: Optional[torch.Tensor] = None,
156 | prompt_embeds: Optional[torch.Tensor] = None,
157 | negative_prompt_embeds: Optional[torch.Tensor] = None,
158 | ip_adapter_image: Optional[PipelineImageInput] = None,
159 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
160 | output_type: Optional[str] = "pil",
161 | return_dict: bool = True,
162 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
163 | guidance_rescale: float = 0.0,
164 | clip_skip: Optional[int] = None,
165 | callback_on_step_end: Optional[
166 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
167 | ] = None,
168 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
169 | **kwargs,
170 | ):
171 | r"""
172 | The call function to the pipeline for generation.
173 |
174 | Args:
175 | prompt (`str` or `List[str]`, *optional*):
176 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
177 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
178 | The height in pixels of the generated image.
179 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
180 | The width in pixels of the generated image.
181 | num_inference_steps (`int`, *optional*, defaults to 50):
182 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
183 | expense of slower inference.
184 | timesteps (`List[int]`, *optional*):
185 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
186 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
187 | passed will be used. Must be in descending order.
188 | sigmas (`List[float]`, *optional*):
189 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
190 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
191 | will be used.
192 | guidance_scale (`float`, *optional*, defaults to 7.5):
193 | A higher guidance scale value encourages the model to generate images closely linked to the text
194 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
195 | negative_prompt (`str` or `List[str]`, *optional*):
196 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to
197 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
198 | num_images_per_prompt (`int`, *optional*, defaults to 1):
199 | The number of images to generate per prompt.
200 | eta (`float`, *optional*, defaults to 0.0):
201 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
202 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
203 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
204 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
205 | generation deterministic.
206 | latents (`torch.Tensor`, *optional*):
207 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
208 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
209 | tensor is generated by sampling using the supplied random `generator`.
210 | prompt_embeds (`torch.Tensor`, *optional*):
211 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
212 | provided, text embeddings are generated from the `prompt` input argument.
213 | negative_prompt_embeds (`torch.Tensor`, *optional*):
214 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
215 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
216 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
217 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
218 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
219 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
220 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
221 | provided, embeddings are computed from the `ip_adapter_image` input argument.
222 | output_type (`str`, *optional*, defaults to `"pil"`):
223 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
224 | return_dict (`bool`, *optional*, defaults to `True`):
225 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
226 | plain tuple.
227 | cross_attention_kwargs (`dict`, *optional*):
228 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
229 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
230 | guidance_rescale (`float`, *optional*, defaults to 0.0):
231 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
232 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
233 | using zero terminal SNR.
234 | clip_skip (`int`, *optional*):
235 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
236 | the output of the pre-final layer will be used for computing the prompt embeddings.
237 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
238 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
239 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
240 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
241 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
242 | callback_on_step_end_tensor_inputs (`List`, *optional*):
243 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
244 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
245 | `._callback_tensor_inputs` attribute of your pipeline class.
246 |
247 | Examples:
248 |
249 | Returns:
250 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
251 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
252 | otherwise a `tuple` is returned where the first element is a list with the generated images and the
253 | second element is a list of `bool`s indicating whether the corresponding generated image contains
254 | "not-safe-for-work" (nsfw) content.
255 | """
256 |
257 | callback = kwargs.pop("callback", None)
258 | callback_steps = kwargs.pop("callback_steps", None)
259 |
260 | if callback is not None:
261 | deprecate(
262 | "callback",
263 | "1.0.0",
264 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
265 | )
266 | if callback_steps is not None:
267 | deprecate(
268 | "callback_steps",
269 | "1.0.0",
270 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
271 | )
272 |
273 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
274 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
275 |
276 | # 0. Default height and width to unet
277 | height = height or self.unet.config.sample_size * self.vae_scale_factor
278 | width = width or self.unet.config.sample_size * self.vae_scale_factor
279 | # to deal with lora scaling and other possible forward hooks
280 |
281 | # 1. Check inputs. Raise error if not correct
282 | self.check_inputs(
283 | prompt,
284 | height,
285 | width,
286 | callback_steps,
287 | negative_prompt,
288 | prompt_embeds,
289 | negative_prompt_embeds,
290 | ip_adapter_image,
291 | ip_adapter_image_embeds,
292 | callback_on_step_end_tensor_inputs,
293 | )
294 |
295 | self._guidance_scale = guidance_scale
296 | self._guidance_rescale = guidance_rescale
297 | self._clip_skip = clip_skip
298 | self._cross_attention_kwargs = cross_attention_kwargs
299 | self._interrupt = False
300 |
301 | # 2. Define call parameters
302 | if prompt is not None and isinstance(prompt, str):
303 | batch_size = 1
304 | elif prompt is not None and isinstance(prompt, list):
305 | batch_size = len(prompt)
306 | else:
307 | batch_size = prompt_embeds.shape[0]
308 |
309 | device = self._execution_device
310 |
311 | # 3. Encode input prompt
312 | lora_scale = (
313 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
314 | )
315 |
316 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
317 | prompt,
318 | device,
319 | num_images_per_prompt,
320 | self.do_classifier_free_guidance,
321 | negative_prompt,
322 | prompt_embeds=prompt_embeds,
323 | negative_prompt_embeds=negative_prompt_embeds,
324 | lora_scale=lora_scale,
325 | clip_skip=self.clip_skip,
326 | )
327 | # print(prompt_embeds.size())
328 | # assert 2==1
329 |
330 | # For classifier free guidance, we need to do two forward passes.
331 | # Here we concatenate the unconditional and text embeddings into a single batch
332 | # to avoid doing two forward passes
333 | if self.do_classifier_free_guidance:
334 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
335 |
336 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
337 | image_embeds = self.prepare_ip_adapter_image_embeds(
338 | ip_adapter_image,
339 | ip_adapter_image_embeds,
340 | device,
341 | batch_size * num_images_per_prompt,
342 | self.do_classifier_free_guidance,
343 | )
344 |
345 | # 4. Prepare timesteps
346 | timesteps, num_inference_steps = retrieve_timesteps(
347 | self.scheduler, num_inference_steps, device, timesteps, sigmas
348 | )
349 |
350 | # 5. Prepare latent variables
351 | num_channels_latents = self.unet.config.in_channels
352 | latents = self.prepare_latents(
353 | batch_size * num_images_per_prompt,
354 | num_channels_latents,
355 | height,
356 | width,
357 | prompt_embeds.dtype,
358 | device,
359 | generator,
360 | latents,
361 | )
362 |
363 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
364 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
365 |
366 | # 6.1 Add image embeds for IP-Adapter
367 | added_cond_kwargs = (
368 | {"image_embeds": image_embeds}
369 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
370 | else None
371 | )
372 |
373 | # 6.2 Optionally get Guidance Scale Embedding
374 | timestep_cond = None
375 | if self.unet.config.time_cond_proj_dim is not None:
376 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
377 | timestep_cond = self.get_guidance_scale_embedding(
378 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
379 | ).to(device=device, dtype=latents.dtype)
380 |
381 | # 7. Denoising loop
382 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
383 | self._num_timesteps = len(timesteps)
384 | with self.progress_bar(total=num_inference_steps) as progress_bar:
385 | for i, t in enumerate(timesteps):
386 | if self.interrupt:
387 | continue
388 |
389 | if i < self.stop_step:
390 | continue
391 |
392 |
393 | # expand the latents if we are doing classifier free guidance
394 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
395 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
396 | # print(self.unet)
397 | # assert 2==1
398 | # predict the noise residual
399 | noise_pred = self.unet(
400 | latent_model_input,
401 | t,
402 | encoder_hidden_states=prompt_embeds,
403 | timestep_cond=timestep_cond,
404 | cross_attention_kwargs=self.cross_attention_kwargs,
405 | added_cond_kwargs=added_cond_kwargs,
406 | return_dict=False,
407 | )[0]
408 |
409 | # perform guidance
410 | if self.do_classifier_free_guidance:
411 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
412 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
413 |
414 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
415 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
416 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
417 |
418 | # compute the previous noisy sample x_t -> x_t-1
419 | latents, pred_original_sample = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)
420 |
421 |
422 | if callback_on_step_end is not None:
423 | callback_kwargs = {}
424 | for k in callback_on_step_end_tensor_inputs:
425 | callback_kwargs[k] = locals()[k]
426 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
427 |
428 | latents = callback_outputs.pop("latents", latents)
429 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
430 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
431 |
432 | # call the callback, if provided
433 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
434 | progress_bar.update()
435 | if callback is not None and i % callback_steps == 0:
436 | step_idx = i // getattr(self.scheduler, "order", 1)
437 | callback(step_idx, t, latents)
438 |
439 | if i >= self.stop_step:
440 | # alpha_t = self.alphas[int(t.item())] ** 0.5
441 | # sigma_t = (1 - self.alphas[int(t.item())]) ** 0.5
442 | # latents = (latents - sigma_t * noise_pred) / alpha_t
443 | latents_cur = latents
444 | latents = pred_original_sample
445 | break
446 | if not output_type == "latent":
447 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
448 | 0
449 | ]
450 | # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
451 | else:
452 | has_nsfw_concept = None
453 |
454 |
455 | # if has_nsfw_concept is None:
456 | # do_denormalize = [True] * image.shape[0]
457 | # else:
458 | # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
459 |
460 | # image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
461 | # image = (image / 2 + 0.5).clamp(0, 1)
462 | image = (image / 2 + 0.5).clamp(0, 1)
463 | # optimizer = optim.Adam(self.unet.parameters(), lr=1e-1)
464 | # input_tensor = image # 模拟一个输入图像
465 | # target = torch.randn_like(input_tensor).to(noise_pred.device).half() # 目标与输入形状相同
466 | # loss = F.mse_loss(image, target)
467 | # optimizer.zero_grad()
468 | # loss.backward()
469 | # for name, param in self.unet.named_parameters():
470 | # if param.grad is not None:
471 | # print(f"Gradient for {name}: {param.grad}")
472 | # else:
473 | # print(f"No gradient for {name}")
474 |
475 | # assert 2==1
476 |
477 | # Offload all models
478 | self.maybe_free_model_hooks()
479 |
480 | if not return_dict:
481 | # return (image, has_nsfw_concept)
482 | return (image, latents_cur)
483 |
484 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), latents_cur
485 | return forward
486 | if model.__class__.__name__ == 'StableDiffusionPipeline':
487 | # model.__call__ = sdturbo_forward(model)
488 | model.forward = sd_forward(model)
489 |
490 | from typing import Tuple
491 | from dataclasses import dataclass
492 | from diffusers.schedulers.scheduling_lcm import LCMSchedulerOutput
493 | from diffusers.utils.torch_utils import randn_tensor
494 | from diffusers.utils import BaseOutput
495 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
496 |
497 | def register_sdschedule_step(model):
498 | def sd_schedule_step(self):
499 | def step(
500 | model_output: torch.Tensor,
501 | timestep: Union[int, torch.Tensor],
502 | sample: torch.Tensor,
503 | generator=None,
504 | variance_noise: Optional[torch.Tensor] = None,
505 | return_dict: bool = True,
506 | ) -> Union[SchedulerOutput, Tuple]:
507 | """
508 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
509 | the multistep DPMSolver.
510 |
511 | Args:
512 | model_output (`torch.Tensor`):
513 | The direct output from learned diffusion model.
514 | timestep (`int`):
515 | The current discrete timestep in the diffusion chain.
516 | sample (`torch.Tensor`):
517 | A current instance of a sample created by the diffusion process.
518 | generator (`torch.Generator`, *optional*):
519 | A random number generator.
520 | variance_noise (`torch.Tensor`):
521 | Alternative to generating noise with `generator` by directly providing the noise for the variance
522 | itself. Useful for methods such as [`LEdits++`].
523 | return_dict (`bool`):
524 | Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
525 |
526 | Returns:
527 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
528 | If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
529 | tuple is returned where the first element is the sample tensor.
530 |
531 | """
532 | if self.num_inference_steps is None:
533 | raise ValueError(
534 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
535 | )
536 |
537 | if self.step_index is None:
538 | self._init_step_index(timestep)
539 |
540 | # Improve numerical stability for small number of steps
541 | lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
542 | self.config.euler_at_final
543 | or (self.config.lower_order_final and len(self.timesteps) < 15)
544 | or self.config.final_sigmas_type == "zero"
545 | )
546 | lower_order_second = (
547 | (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
548 | )
549 |
550 | model_output = self.convert_model_output(model_output, sample=sample)
551 | for i in range(self.config.solver_order - 1):
552 | self.model_outputs[i] = self.model_outputs[i + 1]
553 | self.model_outputs[-1] = model_output
554 |
555 | # Upcast to avoid precision issues when computing prev_sample
556 | sample = sample.to(torch.float32)
557 | if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
558 | noise = randn_tensor(
559 | model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
560 | )
561 | elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
562 | noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
563 | else:
564 | noise = None
565 |
566 | if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
567 | prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
568 | elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
569 | prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
570 | else:
571 | prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
572 |
573 | if self.lower_order_nums < self.config.solver_order:
574 | self.lower_order_nums += 1
575 |
576 | # Cast sample back to expected dtype
577 | prev_sample = prev_sample.to(model_output.dtype)
578 |
579 | # upon completion increase step index by one
580 | self._step_index += 1
581 |
582 | if not return_dict:
583 | return (prev_sample, model_output)
584 |
585 | return SchedulerOutput(prev_sample=prev_sample)
586 |
587 | return step
588 | if model.__class__.__name__ == 'DPMSolverMultistepScheduler':
589 | model.step = sd_schedule_step(model)
590 |
--------------------------------------------------------------------------------
/training/train_sd_lora.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
17 |
18 | import argparse
19 | import logging
20 | import math
21 | import os
22 | import random
23 | import shutil
24 | from contextlib import nullcontext
25 | from pathlib import Path
26 |
27 | import datasets
28 | import numpy as np
29 | import torch
30 | import torch.nn.functional as F
31 | import torch.utils.checkpoint
32 | import transformers
33 | from accelerate import Accelerator
34 | from accelerate.logging import get_logger
35 | from accelerate.utils import ProjectConfiguration, set_seed
36 | from datasets import load_dataset
37 | from huggingface_hub import create_repo, upload_folder
38 | from packaging import version
39 | from peft import LoraConfig
40 | from peft.utils import get_peft_model_state_dict
41 | from torchvision import transforms
42 | from tqdm.auto import tqdm
43 | from transformers import CLIPTextModel, CLIPTokenizer
44 |
45 | import diffusers
46 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
47 | from diffusers.optimization import get_scheduler
48 | from diffusers.training_utils import cast_training_params, compute_snr
49 | from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
50 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
51 | from diffusers.utils.import_utils import is_xformers_available
52 | from diffusers.utils.torch_utils import is_compiled_module
53 |
54 |
55 | if is_wandb_available():
56 | import wandb
57 |
58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59 | # check_min_version("0.31.0.dev0")
60 |
61 | logger = get_logger(__name__, log_level="INFO")
62 |
63 |
64 | def save_model_card(
65 | repo_id: str,
66 | images: list = None,
67 | base_model: str = None,
68 | dataset_name: str = None,
69 | repo_folder: str = None,
70 | ):
71 | img_str = ""
72 | if images is not None:
73 | for i, image in enumerate(images):
74 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
75 | img_str += f"\n"
76 |
77 | model_description = f"""
78 | # LoRA text2image fine-tuning - {repo_id}
79 | These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
80 | {img_str}
81 | """
82 |
83 | model_card = load_or_create_model_card(
84 | repo_id_or_path=repo_id,
85 | from_training=True,
86 | license="creativeml-openrail-m",
87 | base_model=base_model,
88 | model_description=model_description,
89 | inference=True,
90 | )
91 |
92 | tags = [
93 | "stable-diffusion",
94 | "stable-diffusion-diffusers",
95 | "text-to-image",
96 | "diffusers",
97 | "diffusers-training",
98 | "lora",
99 | ]
100 | model_card = populate_model_card(model_card, tags=tags)
101 |
102 | model_card.save(os.path.join(repo_folder, "README.md"))
103 |
104 |
105 | def log_validation(
106 | pipeline,
107 | args,
108 | accelerator,
109 | epoch,
110 | is_final_validation=False,
111 | ):
112 | logger.info(
113 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
114 | f" {args.validation_prompt}."
115 | )
116 | pipeline = pipeline.to(accelerator.device)
117 | pipeline.set_progress_bar_config(disable=True)
118 | generator = torch.Generator(device=accelerator.device)
119 | if args.seed is not None:
120 | generator = generator.manual_seed(args.seed)
121 | images = []
122 | if torch.backends.mps.is_available():
123 | autocast_ctx = nullcontext()
124 | else:
125 | autocast_ctx = torch.autocast(accelerator.device.type)
126 |
127 | with autocast_ctx:
128 | for _ in range(args.num_validation_images):
129 | images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
130 |
131 | for tracker in accelerator.trackers:
132 | phase_name = "test" if is_final_validation else "validation"
133 | if tracker.name == "tensorboard":
134 | np_images = np.stack([np.asarray(img) for img in images])
135 | tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
136 | if tracker.name == "wandb":
137 | tracker.log(
138 | {
139 | phase_name: [
140 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
141 | ]
142 | }
143 | )
144 | return images
145 |
146 |
147 | def parse_args():
148 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
149 | parser.add_argument(
150 | "--pretrained_model_name_or_path",
151 | type=str,
152 | default=None,
153 | required=True,
154 | help="Path to pretrained model or model identifier from huggingface.co/models.",
155 | )
156 | parser.add_argument(
157 | "--revision",
158 | type=str,
159 | default=None,
160 | required=False,
161 | help="Revision of pretrained model identifier from huggingface.co/models.",
162 | )
163 | parser.add_argument(
164 | "--variant",
165 | type=str,
166 | default=None,
167 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
168 | )
169 | parser.add_argument(
170 | "--dataset_name",
171 | type=str,
172 | default=None,
173 | help=(
174 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
175 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
176 | " or to a folder containing files that 🤗 Datasets can understand."
177 | ),
178 | )
179 | parser.add_argument(
180 | "--dataset_config_name",
181 | type=str,
182 | default=None,
183 | help="The config of the Dataset, leave as None if there's only one config.",
184 | )
185 | parser.add_argument(
186 | "--train_data_dir",
187 | type=str,
188 | default=None,
189 | help=(
190 | "A folder containing the training data. Folder contents must follow the structure described in"
191 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
192 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
193 | ),
194 | )
195 | parser.add_argument(
196 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
197 | )
198 | parser.add_argument(
199 | "--caption_column",
200 | type=str,
201 | default="text",
202 | help="The column of the dataset containing a caption or a list of captions.",
203 | )
204 | parser.add_argument(
205 | "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
206 | )
207 | parser.add_argument(
208 | "--num_validation_images",
209 | type=int,
210 | default=4,
211 | help="Number of images that should be generated during validation with `validation_prompt`.",
212 | )
213 | parser.add_argument(
214 | "--validation_epochs",
215 | type=int,
216 | default=1,
217 | help=(
218 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
219 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
220 | ),
221 | )
222 | parser.add_argument(
223 | "--max_train_samples",
224 | type=int,
225 | default=None,
226 | help=(
227 | "For debugging purposes or quicker training, truncate the number of training examples to this "
228 | "value if set."
229 | ),
230 | )
231 | parser.add_argument(
232 | "--output_dir",
233 | type=str,
234 | default="sd-model-finetuned-lora",
235 | help="The output directory where the model predictions and checkpoints will be written.",
236 | )
237 | parser.add_argument(
238 | "--cache_dir",
239 | type=str,
240 | default=None,
241 | help="The directory where the downloaded models and datasets will be stored.",
242 | )
243 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
244 | parser.add_argument(
245 | "--resolution",
246 | type=int,
247 | default=512,
248 | help=(
249 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
250 | " resolution"
251 | ),
252 | )
253 | parser.add_argument(
254 | "--center_crop",
255 | default=False,
256 | action="store_true",
257 | help=(
258 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
259 | " cropped. The images will be resized to the resolution first before cropping."
260 | ),
261 | )
262 | parser.add_argument(
263 | "--random_flip",
264 | action="store_true",
265 | help="whether to randomly flip images horizontally",
266 | )
267 | parser.add_argument(
268 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
269 | )
270 | parser.add_argument("--num_train_epochs", type=int, default=100)
271 | parser.add_argument(
272 | "--max_train_steps",
273 | type=int,
274 | default=None,
275 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
276 | )
277 | parser.add_argument(
278 | "--gradient_accumulation_steps",
279 | type=int,
280 | default=1,
281 | help="Number of updates steps to accumulate before performing a backward/update pass.",
282 | )
283 | parser.add_argument(
284 | "--gradient_checkpointing",
285 | action="store_true",
286 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
287 | )
288 | parser.add_argument(
289 | "--learning_rate",
290 | type=float,
291 | default=1e-4,
292 | help="Initial learning rate (after the potential warmup period) to use.",
293 | )
294 | parser.add_argument(
295 | "--scale_lr",
296 | action="store_true",
297 | default=False,
298 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
299 | )
300 | parser.add_argument(
301 | "--lr_scheduler",
302 | type=str,
303 | default="constant",
304 | help=(
305 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
306 | ' "constant", "constant_with_warmup"]'
307 | ),
308 | )
309 | parser.add_argument(
310 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
311 | )
312 | parser.add_argument(
313 | "--snr_gamma",
314 | type=float,
315 | default=None,
316 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
317 | "More details here: https://arxiv.org/abs/2303.09556.",
318 | )
319 | parser.add_argument(
320 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
321 | )
322 | parser.add_argument(
323 | "--allow_tf32",
324 | action="store_true",
325 | help=(
326 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
327 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
328 | ),
329 | )
330 | parser.add_argument(
331 | "--dataloader_num_workers",
332 | type=int,
333 | default=0,
334 | help=(
335 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
336 | ),
337 | )
338 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
339 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
340 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
341 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
342 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
343 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
344 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
345 | parser.add_argument(
346 | "--prediction_type",
347 | type=str,
348 | default=None,
349 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
350 | )
351 | parser.add_argument(
352 | "--hub_model_id",
353 | type=str,
354 | default=None,
355 | help="The name of the repository to keep in sync with the local `output_dir`.",
356 | )
357 | parser.add_argument(
358 | "--logging_dir",
359 | type=str,
360 | default="logs",
361 | help=(
362 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
363 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
364 | ),
365 | )
366 | parser.add_argument(
367 | "--mixed_precision",
368 | type=str,
369 | default=None,
370 | choices=["no", "fp16", "bf16"],
371 | help=(
372 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
373 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
374 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
375 | ),
376 | )
377 | parser.add_argument(
378 | "--report_to",
379 | type=str,
380 | default="tensorboard",
381 | help=(
382 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
383 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
384 | ),
385 | )
386 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
387 | parser.add_argument(
388 | "--checkpointing_steps",
389 | type=int,
390 | default=500,
391 | help=(
392 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
393 | " training using `--resume_from_checkpoint`."
394 | ),
395 | )
396 | parser.add_argument(
397 | "--checkpoints_total_limit",
398 | type=int,
399 | default=None,
400 | help=("Max number of checkpoints to store."),
401 | )
402 | parser.add_argument(
403 | "--resume_from_checkpoint",
404 | type=str,
405 | default=None,
406 | help=(
407 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
408 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
409 | ),
410 | )
411 | parser.add_argument(
412 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
413 | )
414 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
415 | parser.add_argument(
416 | "--rank",
417 | type=int,
418 | default=4,
419 | help=("The dimension of the LoRA update matrices."),
420 | )
421 |
422 | args = parser.parse_args()
423 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
424 | if env_local_rank != -1 and env_local_rank != args.local_rank:
425 | args.local_rank = env_local_rank
426 |
427 | # Sanity checks
428 | if args.dataset_name is None and args.train_data_dir is None:
429 | raise ValueError("Need either a dataset name or a training folder.")
430 |
431 | return args
432 |
433 |
434 | DATASET_NAME_MAPPING = {
435 | "lambdalabs/naruto-blip-captions": ("image", "text"),
436 | }
437 |
438 |
439 | def main():
440 | args = parse_args()
441 | if args.report_to == "wandb" and args.hub_token is not None:
442 | raise ValueError(
443 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
444 | " Please use `huggingface-cli login` to authenticate with the Hub."
445 | )
446 |
447 | logging_dir = Path(args.output_dir, args.logging_dir)
448 |
449 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
450 |
451 | accelerator = Accelerator(
452 | gradient_accumulation_steps=args.gradient_accumulation_steps,
453 | mixed_precision=args.mixed_precision,
454 | log_with=args.report_to,
455 | project_config=accelerator_project_config,
456 | )
457 |
458 | # Disable AMP for MPS.
459 | if torch.backends.mps.is_available():
460 | accelerator.native_amp = False
461 |
462 | # Make one log on every process with the configuration for debugging.
463 | logging.basicConfig(
464 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
465 | datefmt="%m/%d/%Y %H:%M:%S",
466 | level=logging.INFO,
467 | )
468 | logger.info(accelerator.state, main_process_only=False)
469 | if accelerator.is_local_main_process:
470 | datasets.utils.logging.set_verbosity_warning()
471 | transformers.utils.logging.set_verbosity_warning()
472 | diffusers.utils.logging.set_verbosity_info()
473 | else:
474 | datasets.utils.logging.set_verbosity_error()
475 | transformers.utils.logging.set_verbosity_error()
476 | diffusers.utils.logging.set_verbosity_error()
477 |
478 | # If passed along, set the training seed now.
479 | if args.seed is not None:
480 | set_seed(args.seed)
481 |
482 | # Handle the repository creation
483 | if accelerator.is_main_process:
484 | if args.output_dir is not None:
485 | os.makedirs(args.output_dir, exist_ok=True)
486 |
487 | if args.push_to_hub:
488 | repo_id = create_repo(
489 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
490 | ).repo_id
491 | # Load scheduler, tokenizer and models.
492 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
493 | tokenizer = CLIPTokenizer.from_pretrained(
494 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
495 | )
496 | text_encoder = CLIPTextModel.from_pretrained(
497 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
498 | )
499 | vae = AutoencoderKL.from_pretrained(
500 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
501 | )
502 | unet = UNet2DConditionModel.from_pretrained(
503 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
504 | )
505 | # freeze parameters of models to save more memory
506 | unet.requires_grad_(False)
507 | vae.requires_grad_(False)
508 | text_encoder.requires_grad_(False)
509 |
510 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
511 | # as these weights are only used for inference, keeping weights in full precision is not required.
512 | weight_dtype = torch.float32
513 | if accelerator.mixed_precision == "fp16":
514 | weight_dtype = torch.float16
515 | elif accelerator.mixed_precision == "bf16":
516 | weight_dtype = torch.bfloat16
517 |
518 | # Freeze the unet parameters before adding adapters
519 | for param in unet.parameters():
520 | param.requires_grad_(False)
521 |
522 | unet_lora_config = LoraConfig(
523 | r=256,
524 | lora_alpha=512,
525 | init_lora_weights="gaussian",
526 | target_modules=["to_k", "to_q", "to_v", "to_out.0"],
527 | )
528 |
529 | # Move unet, vae and text_encoder to device and cast to weight_dtype
530 | unet.to(accelerator.device, dtype=weight_dtype)
531 | vae.to(accelerator.device, dtype=weight_dtype)
532 | text_encoder.to(accelerator.device, dtype=weight_dtype)
533 |
534 | # Add adapter and make sure the trainable params are in float32.
535 | unet.add_adapter(unet_lora_config)
536 | if args.mixed_precision == "fp16":
537 | # only upcast trainable parameters (LoRA) into fp32
538 | cast_training_params(unet, dtype=torch.float32)
539 |
540 | if args.enable_xformers_memory_efficient_attention:
541 | if is_xformers_available():
542 | import xformers
543 |
544 | xformers_version = version.parse(xformers.__version__)
545 | if xformers_version == version.parse("0.0.16"):
546 | logger.warning(
547 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
548 | )
549 | unet.enable_xformers_memory_efficient_attention()
550 | else:
551 | raise ValueError("xformers is not available. Make sure it is installed correctly")
552 |
553 |
554 | for n, p in unet.named_parameters():
555 | if not p.requires_grad:
556 | continue # frozen weights
557 | p.data = p.data.to(torch.float32)
558 |
559 |
560 | lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
561 |
562 |
563 |
564 | # for i, param in enumerate(lora_layers):
565 | # print(f"Layer {i}: dtype = {param.dtype}")
566 |
567 | if args.gradient_checkpointing:
568 | unet.enable_gradient_checkpointing()
569 |
570 | # Enable TF32 for faster training on Ampere GPUs,
571 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
572 | if args.allow_tf32:
573 | torch.backends.cuda.matmul.allow_tf32 = True
574 |
575 | if args.scale_lr:
576 | args.learning_rate = (
577 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
578 | )
579 |
580 | # Initialize the optimizer
581 | if args.use_8bit_adam:
582 | try:
583 | import bitsandbytes as bnb
584 | except ImportError:
585 | raise ImportError(
586 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
587 | )
588 |
589 | optimizer_cls = bnb.optim.AdamW8bit
590 | else:
591 | optimizer_cls = torch.optim.AdamW
592 |
593 | optimizer = optimizer_cls(
594 | lora_layers,
595 | lr=args.learning_rate,
596 | betas=(args.adam_beta1, args.adam_beta2),
597 | weight_decay=args.adam_weight_decay,
598 | eps=args.adam_epsilon,
599 | )
600 |
601 | # Get the datasets: you can either provide your own training and evaluation files (see below)
602 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
603 |
604 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently
605 | # download the dataset.
606 | if args.dataset_name is not None:
607 | # Downloading and loading a dataset from the hub.
608 | dataset = load_dataset(
609 | args.dataset_name,
610 | args.dataset_config_name,
611 | cache_dir=args.cache_dir,
612 | data_dir=args.train_data_dir,
613 | )
614 | else:
615 | data_files = {}
616 | if args.train_data_dir is not None:
617 | data_files["train"] = os.path.join(args.train_data_dir, "**")
618 | dataset = load_dataset(
619 | "imagefolder",
620 | data_files=data_files,
621 | cache_dir=args.cache_dir,
622 | )
623 | # See more about loading custom images at
624 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
625 |
626 | # Preprocessing the datasets.
627 | # We need to tokenize inputs and targets.
628 | column_names = dataset["train"].column_names
629 |
630 | # 6. Get the column names for input/target.
631 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
632 | if args.image_column is None:
633 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
634 | else:
635 | image_column = args.image_column
636 | if image_column not in column_names:
637 | raise ValueError(
638 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
639 | )
640 | if args.caption_column is None:
641 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
642 | else:
643 | caption_column = args.caption_column
644 | if caption_column not in column_names:
645 | raise ValueError(
646 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
647 | )
648 |
649 | # Preprocessing the datasets.
650 | # We need to tokenize input captions and transform the images.
651 | def tokenize_captions(examples, is_train=True):
652 | captions = []
653 | for caption in examples[caption_column]:
654 | if isinstance(caption, str):
655 | captions.append(caption)
656 | elif isinstance(caption, (list, np.ndarray)):
657 | # take a random caption if there are multiple
658 | captions.append(random.choice(caption) if is_train else caption[0])
659 | else:
660 | raise ValueError(
661 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
662 | )
663 | inputs = tokenizer(
664 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
665 | )
666 | return inputs.input_ids
667 |
668 | # Preprocessing the datasets.
669 | train_transforms = transforms.Compose(
670 | [
671 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
672 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
673 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
674 | transforms.ToTensor(),
675 | transforms.Normalize([0.5], [0.5]),
676 | ]
677 | )
678 |
679 | def unwrap_model(model):
680 | model = accelerator.unwrap_model(model)
681 | model = model._orig_mod if is_compiled_module(model) else model
682 | return model
683 |
684 | def preprocess_train(examples):
685 | images = [image.convert("RGB") for image in examples[image_column]]
686 | examples["pixel_values"] = [train_transforms(image) for image in images]
687 | examples["input_ids"] = tokenize_captions(examples)
688 | return examples
689 |
690 | with accelerator.main_process_first():
691 | if args.max_train_samples is not None:
692 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
693 | # Set the training transforms
694 | train_dataset = dataset["train"].with_transform(preprocess_train)
695 |
696 | def collate_fn(examples):
697 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
698 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
699 | input_ids = torch.stack([example["input_ids"] for example in examples])
700 | return {"pixel_values": pixel_values, "input_ids": input_ids}
701 |
702 | # DataLoaders creation:
703 | train_dataloader = torch.utils.data.DataLoader(
704 | train_dataset,
705 | shuffle=True,
706 | collate_fn=collate_fn,
707 | batch_size=args.train_batch_size,
708 | num_workers=args.dataloader_num_workers,
709 | )
710 |
711 | # Scheduler and math around the number of training steps.
712 | # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
713 | num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
714 | if args.max_train_steps is None:
715 | len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
716 | num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
717 | num_training_steps_for_scheduler = (
718 | args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
719 | )
720 | else:
721 | num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
722 |
723 | lr_scheduler = get_scheduler(
724 | args.lr_scheduler,
725 | optimizer=optimizer,
726 | num_warmup_steps=num_warmup_steps_for_scheduler,
727 | num_training_steps=num_training_steps_for_scheduler,
728 | )
729 |
730 | # Prepare everything with our `accelerator`.
731 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
732 | unet, optimizer, train_dataloader, lr_scheduler
733 | )
734 |
735 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
736 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
737 | if args.max_train_steps is None:
738 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
739 | if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
740 | logger.warning(
741 | f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
742 | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
743 | f"This inconsistency may result in the learning rate scheduler not functioning properly."
744 | )
745 | # Afterwards we recalculate our number of training epochs
746 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
747 |
748 | # We need to initialize the trackers we use, and also store our configuration.
749 | # The trackers initializes automatically on the main process.
750 | if accelerator.is_main_process:
751 | accelerator.init_trackers("text2image-fine-tune", config=vars(args))
752 |
753 | # Train!
754 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
755 |
756 | logger.info("***** Running training *****")
757 | logger.info(f" Num examples = {len(train_dataset)}")
758 | logger.info(f" Num Epochs = {args.num_train_epochs}")
759 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
760 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
761 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
762 | logger.info(f" Total optimization steps = {args.max_train_steps}")
763 | global_step = 0
764 | first_epoch = 0
765 |
766 | # Potentially load in the weights and states from a previous save
767 | if args.resume_from_checkpoint:
768 | if args.resume_from_checkpoint != "latest":
769 | path = os.path.basename(args.resume_from_checkpoint)
770 | else:
771 | # Get the most recent checkpoint
772 | dirs = os.listdir(args.output_dir)
773 | dirs = [d for d in dirs if d.startswith("checkpoint")]
774 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
775 | path = dirs[-1] if len(dirs) > 0 else None
776 |
777 | if path is None:
778 | accelerator.print(
779 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
780 | )
781 | args.resume_from_checkpoint = None
782 | initial_global_step = 0
783 | else:
784 | accelerator.print(f"Resuming from checkpoint {path}")
785 | accelerator.load_state(os.path.join(args.output_dir, path))
786 | global_step = int(path.split("-")[1])
787 |
788 | initial_global_step = global_step
789 | first_epoch = global_step // num_update_steps_per_epoch
790 | else:
791 | initial_global_step = 0
792 |
793 | progress_bar = tqdm(
794 | range(0, args.max_train_steps),
795 | initial=initial_global_step,
796 | desc="Steps",
797 | # Only show the progress bar once on each machine.
798 | disable=not accelerator.is_local_main_process,
799 | )
800 |
801 | for epoch in range(first_epoch, args.num_train_epochs):
802 | unet.train()
803 | train_loss = 0.0
804 | for step, batch in enumerate(train_dataloader):
805 | with accelerator.accumulate(unet):
806 | # Convert images to latent space
807 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
808 | latents = latents * vae.config.scaling_factor
809 |
810 | # Sample noise that we'll add to the latents
811 | noise = torch.randn_like(latents)
812 | if args.noise_offset:
813 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
814 | noise += args.noise_offset * torch.randn(
815 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
816 | )
817 |
818 | bsz = latents.shape[0]
819 | # Sample a random timestep for each image
820 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
821 | timesteps = timesteps.long()
822 |
823 | # Add noise to the latents according to the noise magnitude at each timestep
824 | # (this is the forward diffusion process)
825 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
826 |
827 | # Get the text embedding for conditioning
828 | encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
829 |
830 | # Get the target for loss depending on the prediction type
831 | if args.prediction_type is not None:
832 | # set prediction_type of scheduler if defined
833 | noise_scheduler.register_to_config(prediction_type=args.prediction_type)
834 |
835 | if noise_scheduler.config.prediction_type == "epsilon":
836 | target = noise
837 | elif noise_scheduler.config.prediction_type == "v_prediction":
838 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
839 | else:
840 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
841 |
842 | # Predict the noise residual and compute loss
843 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
844 |
845 | if args.snr_gamma is None:
846 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
847 | else:
848 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
849 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
850 | # This is discussed in Section 4.2 of the same paper.
851 | snr = compute_snr(noise_scheduler, timesteps)
852 | mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
853 | dim=1
854 | )[0]
855 | if noise_scheduler.config.prediction_type == "epsilon":
856 | mse_loss_weights = mse_loss_weights / snr
857 | elif noise_scheduler.config.prediction_type == "v_prediction":
858 | mse_loss_weights = mse_loss_weights / (snr + 1)
859 |
860 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
861 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
862 | loss = loss.mean()
863 |
864 | # Gather the losses across all processes for logging (if we use distributed training).
865 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
866 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
867 |
868 | # Backpropagate
869 | accelerator.backward(loss)
870 | if accelerator.sync_gradients:
871 | params_to_clip = lora_layers
872 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
873 | optimizer.step()
874 | lr_scheduler.step()
875 | optimizer.zero_grad()
876 |
877 | # Checks if the accelerator has performed an optimization step behind the scenes
878 | if accelerator.sync_gradients:
879 | progress_bar.update(1)
880 | global_step += 1
881 | accelerator.log({"train_loss": train_loss}, step=global_step)
882 | train_loss = 0.0
883 |
884 | if global_step % args.checkpointing_steps == 0:
885 | if accelerator.is_main_process:
886 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
887 | if args.checkpoints_total_limit is not None:
888 | checkpoints = os.listdir(args.output_dir)
889 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
890 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
891 |
892 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
893 | if len(checkpoints) >= args.checkpoints_total_limit:
894 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
895 | removing_checkpoints = checkpoints[0:num_to_remove]
896 |
897 | logger.info(
898 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
899 | )
900 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
901 |
902 | for removing_checkpoint in removing_checkpoints:
903 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
904 | shutil.rmtree(removing_checkpoint)
905 |
906 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
907 | accelerator.save_state(save_path)
908 |
909 | unwrapped_unet = unwrap_model(unet)
910 | unet_lora_state_dict = convert_state_dict_to_diffusers(
911 | get_peft_model_state_dict(unwrapped_unet)
912 | )
913 |
914 | StableDiffusionPipeline.save_lora_weights(
915 | save_directory=save_path,
916 | unet_lora_layers=unet_lora_state_dict,
917 | safe_serialization=True,
918 | )
919 |
920 | logger.info(f"Saved state to {save_path}")
921 |
922 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
923 | progress_bar.set_postfix(**logs)
924 |
925 | if global_step >= args.max_train_steps:
926 | break
927 |
928 | if accelerator.is_main_process:
929 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
930 | # create pipeline
931 | pipeline = DiffusionPipeline.from_pretrained(
932 | args.pretrained_model_name_or_path,
933 | unet=unwrap_model(unet),
934 | revision=args.revision,
935 | variant=args.variant,
936 | torch_dtype=weight_dtype,
937 | )
938 | images = log_validation(pipeline, args, accelerator, epoch)
939 |
940 | del pipeline
941 | torch.cuda.empty_cache()
942 |
943 | # Save the lora layers
944 | accelerator.wait_for_everyone()
945 | if accelerator.is_main_process:
946 | unet = unet.to(torch.float32)
947 |
948 | unwrapped_unet = unwrap_model(unet)
949 | unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
950 | StableDiffusionPipeline.save_lora_weights(
951 | save_directory=args.output_dir,
952 | unet_lora_layers=unet_lora_state_dict,
953 | safe_serialization=True,
954 | )
955 |
956 | # Final inference
957 | # Load previous pipeline
958 | if args.validation_prompt is not None:
959 | pipeline = DiffusionPipeline.from_pretrained(
960 | args.pretrained_model_name_or_path,
961 | revision=args.revision,
962 | variant=args.variant,
963 | torch_dtype=weight_dtype,
964 | )
965 |
966 | # load attention processors
967 | pipeline.load_lora_weights(args.output_dir)
968 |
969 | # run inference
970 | images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
971 |
972 | if args.push_to_hub:
973 | save_model_card(
974 | repo_id,
975 | images=images,
976 | base_model=args.pretrained_model_name_or_path,
977 | dataset_name=args.dataset_name,
978 | repo_folder=args.output_dir,
979 | )
980 | upload_folder(
981 | repo_id=repo_id,
982 | folder_path=args.output_dir,
983 | commit_message="End of training",
984 | ignore_patterns=["step_*", "epoch_*"],
985 | )
986 |
987 | accelerator.end_training()
988 |
989 |
990 | if __name__ == "__main__":
991 | main()
--------------------------------------------------------------------------------
/training-free/sd_utils_x0_dpm_syn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import inspect
15 | from typing import Any, Callable, Dict, List, Optional, Union
16 |
17 | import torch
18 | from packaging import version
19 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20 |
21 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22 | from diffusers.configuration_utils import FrozenDict
23 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
24 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
25 | from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26 | from diffusers.models.lora import adjust_lora_scale_text_encoder
27 | from diffusers.schedulers import KarrasDiffusionSchedulers
28 | from diffusers.utils import (
29 | USE_PEFT_BACKEND,
30 | deprecate,
31 | logging,
32 | replace_example_docstring,
33 | scale_lora_layers,
34 | unscale_lora_layers,
35 | )
36 | from diffusers.utils.torch_utils import randn_tensor
37 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
38 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
39 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
40 | import torch.nn.functional as F
41 | import torch.optim as optim
42 | from diffusers.pipelines.stable_diffusion_attend_and_excite.pipeline_stable_diffusion_attend_and_excite import (
43 | AttentionStore,
44 | AttendExciteAttnProcessor
45 | )
46 | from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, align_wordpieces_indices, extract_attribution_indices, extract_attribution_indices_with_verbs, extract_attribution_indices_with_verb_root, extract_entities_only
47 |
48 |
49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50 |
51 | EXAMPLE_DOC_STRING = """
52 | Examples:
53 | ```py
54 | >>> import torch
55 | >>> from diffusers import StableDiffusionPipeline
56 |
57 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
58 | >>> pipe = pipe.to("cuda")
59 |
60 | >>> prompt = "a photo of an astronaut riding a horse on mars"
61 | >>> image = pipe(prompt).images[0]
62 | ```
63 | """
64 |
65 |
66 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
67 | """
68 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
69 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
70 | """
71 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
72 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
73 | # rescale the results from guidance (fixes overexposure)
74 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
75 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
76 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
77 | return noise_cfg
78 |
79 |
80 | def retrieve_timesteps(
81 | scheduler,
82 | num_inference_steps: Optional[int] = None,
83 | device: Optional[Union[str, torch.device]] = None,
84 | timesteps: Optional[List[int]] = None,
85 | sigmas: Optional[List[float]] = None,
86 | **kwargs,
87 | ):
88 | """
89 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
90 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
91 |
92 | Args:
93 | scheduler (`SchedulerMixin`):
94 | The scheduler to get timesteps from.
95 | num_inference_steps (`int`):
96 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
97 | must be `None`.
98 | device (`str` or `torch.device`, *optional*):
99 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
100 | timesteps (`List[int]`, *optional*):
101 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
102 | `num_inference_steps` and `sigmas` must be `None`.
103 | sigmas (`List[float]`, *optional*):
104 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
105 | `num_inference_steps` and `timesteps` must be `None`.
106 |
107 | Returns:
108 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
109 | second element is the number of inference steps.
110 | """
111 | if timesteps is not None and sigmas is not None:
112 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
113 | if timesteps is not None:
114 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
115 | if not accepts_timesteps:
116 | raise ValueError(
117 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118 | f" timestep schedules. Please check whether you are using the correct scheduler."
119 | )
120 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
121 | timesteps = scheduler.timesteps
122 | num_inference_steps = len(timesteps)
123 | elif sigmas is not None:
124 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125 | if not accept_sigmas:
126 | raise ValueError(
127 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128 | f" sigmas schedules. Please check whether you are using the correct scheduler."
129 | )
130 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131 | timesteps = scheduler.timesteps
132 | num_inference_steps = len(timesteps)
133 | else:
134 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135 | timesteps = scheduler.timesteps
136 | return timesteps, num_inference_steps
137 |
138 |
139 | import itertools
140 | from typing import Any, Callable, Dict, Optional, Union, List
141 |
142 | import spacy
143 | import torch
144 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
145 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
146 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
147 | EXAMPLE_DOC_STRING,
148 | rescale_noise_cfg
149 | )
150 | from diffusers.pipelines.stable_diffusion_attend_and_excite.pipeline_stable_diffusion_attend_and_excite import (
151 | AttentionStore,
152 | AttendExciteAttnProcessor
153 | )
154 | import numpy as np
155 | from diffusers.schedulers import KarrasDiffusionSchedulers
156 | from diffusers.utils import (
157 | logging,
158 | replace_example_docstring,
159 | )
160 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
161 |
162 | from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, align_wordpieces_indices, extract_attribution_indices, extract_attribution_indices_with_verbs, extract_attribution_indices_with_verb_root, extract_entities_only
163 |
164 |
165 | logger = logging.get_logger(__name__)
166 |
167 |
168 | class SynGenMaskDiffusionPipeline(StableDiffusionPipeline):
169 | def __init__(self,
170 | vae: AutoencoderKL,
171 | text_encoder: CLIPTextModel,
172 | tokenizer: CLIPTokenizer,
173 | unet: UNet2DConditionModel,
174 | scheduler: KarrasDiffusionSchedulers,
175 | safety_checker: StableDiffusionSafetyChecker,
176 | feature_extractor: CLIPImageProcessor,
177 | image_encoder: CLIPVisionModelWithProjection = None,
178 | requires_safety_checker: bool = True,
179 | include_entities: bool = False,
180 | ):
181 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor,image_encoder,
182 | requires_safety_checker)
183 |
184 | self.parser = spacy.load("en_core_web_trf")
185 | self.subtrees_indices = None
186 | self.doc = None
187 | self.include_entities = include_entities
188 |
189 | def _aggregate_and_get_attention_maps_per_token(self):
190 | attention_maps = self.attention_store.aggregate_attention(
191 | from_where=("up", "down", "mid"),
192 | )
193 | attention_maps_list = _get_attention_maps_list(
194 | attention_maps=attention_maps
195 | )
196 | return attention_maps_list
197 |
198 | @staticmethod
199 | def _update_latent(
200 | latents: torch.Tensor, loss: torch.Tensor, step_size: float
201 | ) -> torch.Tensor:
202 | """Update the latent according to the computed loss."""
203 | grad_cond = torch.autograd.grad(
204 | loss.requires_grad_(True), [latents], retain_graph=True
205 | )[0]
206 | latents = latents - step_size * grad_cond
207 | return latents
208 |
209 | def register_attention_control(self):
210 | attn_procs = {}
211 | cross_att_count = 0
212 | for name in self.unet.attn_processors.keys():
213 | if name.startswith("mid_block"):
214 | place_in_unet = "mid"
215 | elif name.startswith("up_blocks"):
216 | place_in_unet = "up"
217 | elif name.startswith("down_blocks"):
218 | place_in_unet = "down"
219 | else:
220 | continue
221 |
222 | cross_att_count += 1
223 | attn_procs[name] = AttendExciteAttnProcessor(
224 | attnstore=self.attention_store, place_in_unet=place_in_unet
225 | )
226 |
227 | self.unet.set_attn_processor(attn_procs)
228 | self.attention_store.num_att_layers = cross_att_count
229 |
230 | # @torch.no_grad()
231 | @replace_example_docstring(EXAMPLE_DOC_STRING)
232 | def __call__(
233 | self,
234 | prompt: Union[str, List[str]] = None,
235 | height: Optional[int] = None,
236 | width: Optional[int] = None,
237 | num_inference_steps: int = 50,
238 | timesteps: List[int] = None,
239 | sigmas: List[float] = None,
240 | guidance_scale: float = 7.5,
241 | negative_prompt: Optional[Union[str, List[str]]] = None,
242 | num_images_per_prompt: Optional[int] = 1,
243 | eta: float = 0.0,
244 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
245 | latents: Optional[torch.Tensor] = None,
246 | prompt_embeds: Optional[torch.Tensor] = None,
247 | negative_prompt_embeds: Optional[torch.Tensor] = None,
248 | ip_adapter_image: Optional[PipelineImageInput] = None,
249 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
250 | output_type: Optional[str] = "pil",
251 | return_dict: bool = True,
252 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
253 | guidance_rescale: float = 0.0,
254 | clip_skip: Optional[int] = None,
255 | callback_on_step_end: Optional[
256 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
257 | ] = None,
258 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
259 | attn_res=None,
260 | syngen_step_size: float = 20.0,
261 | parsed_prompt: str = None,
262 | num_intervention_steps: int = 25,
263 | **kwargs,
264 | ):
265 | r"""
266 | The call function to the pipeline for generation.
267 |
268 | Args:
269 | prompt (`str` or `List[str]`, *optional*):
270 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
271 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
272 | The height in pixels of the generated image.
273 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
274 | The width in pixels of the generated image.
275 | num_inference_steps (`int`, *optional*, defaults to 50):
276 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
277 | expense of slower inference.
278 | timesteps (`List[int]`, *optional*):
279 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
280 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
281 | passed will be used. Must be in descending order.
282 | sigmas (`List[float]`, *optional*):
283 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
284 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
285 | will be used.
286 | guidance_scale (`float`, *optional*, defaults to 7.5):
287 | A higher guidance scale value encourages the model to generate images closely linked to the text
288 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
289 | negative_prompt (`str` or `List[str]`, *optional*):
290 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to
291 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
292 | num_images_per_prompt (`int`, *optional*, defaults to 1):
293 | The number of images to generate per prompt.
294 | eta (`float`, *optional*, defaults to 0.0):
295 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
296 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
297 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
298 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
299 | generation deterministic.
300 | latents (`torch.Tensor`, *optional*):
301 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
302 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
303 | tensor is generated by sampling using the supplied random `generator`.
304 | prompt_embeds (`torch.Tensor`, *optional*):
305 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
306 | provided, text embeddings are generated from the `prompt` input argument.
307 | negative_prompt_embeds (`torch.Tensor`, *optional*):
308 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
309 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
310 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
311 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
312 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
313 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
314 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
315 | provided, embeddings are computed from the `ip_adapter_image` input argument.
316 | output_type (`str`, *optional*, defaults to `"pil"`):
317 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
318 | return_dict (`bool`, *optional*, defaults to `True`):
319 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
320 | plain tuple.
321 | cross_attention_kwargs (`dict`, *optional*):
322 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
323 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
324 | guidance_rescale (`float`, *optional*, defaults to 0.0):
325 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
326 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
327 | using zero terminal SNR.
328 | clip_skip (`int`, *optional*):
329 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
330 | the output of the pre-final layer will be used for computing the prompt embeddings.
331 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
332 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
333 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
334 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
335 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
336 | callback_on_step_end_tensor_inputs (`List`, *optional*):
337 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
338 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
339 | `._callback_tensor_inputs` attribute of your pipeline class.
340 |
341 | Examples:
342 |
343 | Returns:
344 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
345 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
346 | otherwise a `tuple` is returned where the first element is a list with the generated images and the
347 | second element is a list of `bool`s indicating whether the corresponding generated image contains
348 | "not-safe-for-work" (nsfw) content.
349 | """
350 |
351 | callback = kwargs.pop("callback", None)
352 | callback_steps = kwargs.pop("callback_steps", None)
353 |
354 | if callback is not None:
355 | deprecate(
356 | "callback",
357 | "1.0.0",
358 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
359 | )
360 | if callback_steps is not None:
361 | deprecate(
362 | "callback_steps",
363 | "1.0.0",
364 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
365 | )
366 |
367 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
368 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
369 | if parsed_prompt:
370 | self.doc = parsed_prompt
371 | else:
372 | self.doc = self.parser(prompt)
373 | # 0. Default height and width to unet
374 | height = height or self.unet.config.sample_size * self.vae_scale_factor
375 | width = width or self.unet.config.sample_size * self.vae_scale_factor
376 | # to deal with lora scaling and other possible forward hooks
377 |
378 | # 1. Check inputs. Raise error if not correct
379 | self.check_inputs(
380 | prompt,
381 | height,
382 | width,
383 | callback_steps,
384 | negative_prompt,
385 | prompt_embeds,
386 | negative_prompt_embeds,
387 | ip_adapter_image,
388 | ip_adapter_image_embeds,
389 | callback_on_step_end_tensor_inputs,
390 | )
391 |
392 | self._guidance_scale = guidance_scale
393 | self._guidance_rescale = guidance_rescale
394 | self._clip_skip = clip_skip
395 | self._cross_attention_kwargs = cross_attention_kwargs
396 | self._interrupt = False
397 |
398 | # 2. Define call parameters
399 | if prompt is not None and isinstance(prompt, str):
400 | batch_size = 1
401 | elif prompt is not None and isinstance(prompt, list):
402 | batch_size = len(prompt)
403 | else:
404 | batch_size = prompt_embeds.shape[0]
405 |
406 | device = self._execution_device
407 |
408 | # 3. Encode input prompt
409 | lora_scale = (
410 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
411 | )
412 |
413 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
414 | prompt,
415 | device,
416 | num_images_per_prompt,
417 | self.do_classifier_free_guidance,
418 | negative_prompt,
419 | prompt_embeds=prompt_embeds,
420 | negative_prompt_embeds=negative_prompt_embeds,
421 | lora_scale=lora_scale,
422 | clip_skip=self.clip_skip,
423 | )
424 | # print(prompt_embeds.size())
425 | # assert 2==1
426 |
427 | # For classifier free guidance, we need to do two forward passes.
428 | # Here we concatenate the unconditional and text embeddings into a single batch
429 | # to avoid doing two forward passes
430 | if self.do_classifier_free_guidance:
431 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
432 |
433 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
434 | image_embeds = self.prepare_ip_adapter_image_embeds(
435 | ip_adapter_image,
436 | ip_adapter_image_embeds,
437 | device,
438 | batch_size * num_images_per_prompt,
439 | self.do_classifier_free_guidance,
440 | )
441 |
442 | # 4. Prepare timesteps
443 | timesteps, num_inference_steps = retrieve_timesteps(
444 | self.scheduler, num_inference_steps, device, timesteps, sigmas
445 | )
446 |
447 | # 5. Prepare latent variables
448 | num_channels_latents = self.unet.config.in_channels
449 | latents = self.prepare_latents(
450 | batch_size * num_images_per_prompt,
451 | num_channels_latents,
452 | height,
453 | width,
454 | prompt_embeds.dtype,
455 | device,
456 | generator,
457 | latents,
458 | )
459 |
460 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
461 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
462 |
463 | if attn_res is None:
464 | attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
465 | self.attn_res = attn_res
466 | self.attention_store = AttentionStore(self.attn_res)
467 | self.register_attention_control()
468 |
469 | # 6.1 Add image embeds for IP-Adapter
470 | added_cond_kwargs = (
471 | {"image_embeds": image_embeds}
472 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
473 | else None
474 | )
475 |
476 | # 6.2 Optionally get Guidance Scale Embedding
477 | timestep_cond = None
478 | if self.unet.config.time_cond_proj_dim is not None:
479 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
480 | timestep_cond = self.get_guidance_scale_embedding(
481 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
482 | ).to(device=device, dtype=latents.dtype)
483 |
484 | text_embeddings = (
485 | prompt_embeds[batch_size * num_images_per_prompt:] if self.do_classifier_free_guidance else prompt_embeds
486 | )
487 | # 7. Denoising loop
488 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
489 | self._num_timesteps = len(timesteps)
490 | with self.progress_bar(total=num_inference_steps) as progress_bar:
491 | for i, t in enumerate(timesteps):
492 |
493 | if self.interrupt:
494 | continue
495 |
496 | if i < self.stop_step:
497 | continue
498 |
499 | if self.stop_step < num_intervention_steps:
500 | latents = self._syngen_step(
501 | latents,
502 | text_embeddings,
503 | t,
504 | self.stop_step,
505 | syngen_step_size,
506 | cross_attention_kwargs,
507 | prompt,
508 | num_intervention_steps=num_intervention_steps,
509 | )
510 | # expand the latents if we are doing classifier free guidance
511 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
512 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
513 | # print(self.unet)
514 | # assert 2==1
515 | # predict the noise residual
516 | noise_pred = self.unet(
517 | latent_model_input,
518 | t,
519 | encoder_hidden_states=prompt_embeds,
520 | timestep_cond=timestep_cond,
521 | cross_attention_kwargs=self.cross_attention_kwargs,
522 | added_cond_kwargs=added_cond_kwargs,
523 | return_dict=False,
524 | )[0]
525 |
526 | # perform guidance
527 | if self.do_classifier_free_guidance:
528 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
529 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
530 |
531 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
532 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
533 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
534 |
535 | # compute the previous noisy sample x_t -> x_t-1
536 | latents, pred_original_sample = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)
537 |
538 |
539 | if callback_on_step_end is not None:
540 | callback_kwargs = {}
541 | for k in callback_on_step_end_tensor_inputs:
542 | callback_kwargs[k] = locals()[k]
543 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
544 |
545 | latents = callback_outputs.pop("latents", latents)
546 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
547 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
548 |
549 | # call the callback, if provided
550 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
551 | progress_bar.update()
552 | if callback is not None and i % callback_steps == 0:
553 | step_idx = i // getattr(self.scheduler, "order", 1)
554 | callback(step_idx, t, latents)
555 |
556 | if i >= self.stop_step:
557 | # alpha_t = self.alphas[int(t.item())] ** 0.5
558 | # sigma_t = (1 - self.alphas[int(t.item())]) ** 0.5
559 | # latents = (latents - sigma_t * noise_pred) / alpha_t
560 | latents_cur = latents
561 | latents = pred_original_sample
562 | break
563 | if not output_type == "latent":
564 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
565 | 0
566 | ]
567 | # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
568 | else:
569 | has_nsfw_concept = None
570 |
571 |
572 | # if has_nsfw_concept is None:
573 | # do_denormalize = [True] * image.shape[0]
574 | # else:
575 | # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
576 |
577 | # image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
578 | # image = (image / 2 + 0.5).clamp(0, 1)
579 | image = (image / 2 + 0.5).clamp(0, 1)
580 | # optimizer = optim.Adam(self.unet.parameters(), lr=1e-1)
581 | # input_tensor = image # 模拟一个输入图像
582 | # target = torch.randn_like(input_tensor).to(noise_pred.device).half() # 目标与输入形状相同
583 | # loss = F.mse_loss(image, target)
584 | # optimizer.zero_grad()
585 | # loss.backward()
586 | # for name, param in self.unet.named_parameters():
587 | # if param.grad is not None:
588 | # print(f"Gradient for {name}: {param.grad}")
589 | # else:
590 | # print(f"No gradient for {name}")
591 |
592 | # assert 2==1
593 |
594 | # Offload all models
595 | self.maybe_free_model_hooks()
596 |
597 | if not return_dict:
598 | # return (image, has_nsfw_concept)
599 | return (image, latents_cur)
600 |
601 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), latents_cur
602 |
603 | def _syngen_step(
604 | self,
605 | latents,
606 | text_embeddings,
607 | t,
608 | i,
609 | step_size,
610 | cross_attention_kwargs,
611 | prompt,
612 | num_intervention_steps,
613 | ):
614 | with torch.enable_grad():
615 | latents = latents.clone().detach().requires_grad_(True)
616 | updated_latents = []
617 | for latent, text_embedding in zip(latents, text_embeddings):
618 | # Forward pass of denoising with text conditioning
619 | latent = latent.unsqueeze(0)
620 | text_embedding = text_embedding.unsqueeze(0)
621 |
622 | self.unet(
623 | latent,
624 | t,
625 | encoder_hidden_states=text_embedding,
626 | cross_attention_kwargs=cross_attention_kwargs,
627 | return_dict=False,
628 | )[0]
629 | self.unet.zero_grad()
630 | # Get attention maps
631 | attention_maps = self._aggregate_and_get_attention_maps_per_token()
632 | loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
633 | # Perform gradient update
634 | if i < num_intervention_steps:
635 | if loss != 0:
636 | latent = self._update_latent(
637 | latents=latent, loss=loss, step_size=step_size
638 | )
639 | logger.info(f"Iteration {i} | Loss: {loss:0.4f}")
640 |
641 | updated_latents.append(latent)
642 |
643 | latents = torch.cat(updated_latents, dim=0)
644 |
645 | return latents
646 |
647 | def _compute_loss(
648 | self, attention_maps: List[torch.Tensor], prompt: Union[str, List[str]]
649 | ) -> torch.Tensor:
650 | attn_map_idx_to_wp = get_attention_map_index_to_wordpiece(self.tokenizer, prompt)
651 | loss = self._attribution_loss(attention_maps, prompt, attn_map_idx_to_wp)
652 |
653 | return loss
654 |
655 | def _attribution_loss(
656 | self,
657 | attention_maps: List[torch.Tensor],
658 | prompt: Union[str, List[str]],
659 | attn_map_idx_to_wp,
660 | ) -> torch.Tensor:
661 | if not self.subtrees_indices:
662 | self.subtrees_indices = self._extract_attribution_indices(prompt)
663 | subtrees_indices = self.subtrees_indices
664 |
665 | loss = 0
666 |
667 | for subtree_indices in subtrees_indices:
668 | noun, modifier = split_indices(subtree_indices)
669 | all_subtree_pairs = list(itertools.product(noun, modifier))
670 | if noun and not modifier:
671 | if isinstance(noun, list) and len(noun) == 1:
672 | processed_noun = noun[0]
673 | else:
674 | processed_noun = noun
675 | loss += calculate_negative_loss(
676 | attention_maps, modifier, processed_noun, subtree_indices, attn_map_idx_to_wp
677 | )
678 | else:
679 | positive_loss, negative_loss = self._calculate_losses(
680 | attention_maps,
681 | all_subtree_pairs,
682 | subtree_indices,
683 | attn_map_idx_to_wp,
684 | )
685 |
686 | loss += positive_loss
687 | loss += negative_loss
688 |
689 | return loss
690 |
691 | def _calculate_losses(
692 | self,
693 | attention_maps,
694 | all_subtree_pairs,
695 | subtree_indices,
696 | attn_map_idx_to_wp,
697 | ):
698 | positive_loss = []
699 | negative_loss = []
700 | for pair in all_subtree_pairs:
701 | noun, modifier = pair
702 | positive_loss.append(
703 | calculate_positive_loss(attention_maps, modifier, noun)
704 | )
705 | negative_loss.append(
706 | calculate_negative_loss(
707 | attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
708 | )
709 | )
710 |
711 | positive_loss = sum(positive_loss)
712 | negative_loss = sum(negative_loss)
713 |
714 | return positive_loss, negative_loss
715 |
716 | def _align_indices(self, prompt, spacy_pairs):
717 | wordpieces2indices = get_indices(self.tokenizer, prompt)
718 | paired_indices = []
719 | collected_spacy_indices = (
720 | set()
721 | ) # helps track recurring nouns across different relations (i.e., cases where there is more than one instance of the same word)
722 |
723 | for pair in spacy_pairs:
724 | curr_collected_wp_indices = (
725 | []
726 | ) # helps track which nouns and amods were added to the current pair (this is useful in sentences with repeating amod on the same relation (e.g., "a red red red bear"))
727 | for member in pair:
728 | for idx, wp in wordpieces2indices.items():
729 | if wp in [start_token, end_token]:
730 | continue
731 |
732 | wp = wp.replace("", "")
733 | if member.text.lower() == wp.lower():
734 | if idx not in curr_collected_wp_indices and idx not in collected_spacy_indices:
735 | curr_collected_wp_indices.append(idx)
736 | break
737 | # take care of wordpieces that are split up
738 | elif member.text.lower().startswith(wp.lower()) and wp.lower() != member.text.lower(): # can maybe be while loop
739 | wp_indices = align_wordpieces_indices(
740 | wordpieces2indices, idx, member.text
741 | )
742 | # check if all wp_indices are not already in collected_spacy_indices
743 | if wp_indices and (wp_indices not in curr_collected_wp_indices) and all(
744 | [wp_idx not in collected_spacy_indices for wp_idx in wp_indices]):
745 | curr_collected_wp_indices.append(wp_indices)
746 | break
747 |
748 | for collected_idx in curr_collected_wp_indices:
749 | if isinstance(collected_idx, list):
750 | for idx in collected_idx:
751 | collected_spacy_indices.add(idx)
752 | else:
753 | collected_spacy_indices.add(collected_idx)
754 |
755 | if curr_collected_wp_indices:
756 | paired_indices.append(curr_collected_wp_indices)
757 | else:
758 | print(f"No wordpieces were aligned for {pair} in _align_indices")
759 |
760 | return paired_indices
761 |
762 | def _extract_attribution_indices(self, prompt):
763 | modifier_indices = []
764 | # extract standard attribution indices
765 | modifier_sets_1 = extract_attribution_indices(self.doc)
766 | modifier_indices_1 = self._align_indices(prompt, modifier_sets_1)
767 | if modifier_indices_1:
768 | modifier_indices.append(modifier_indices_1)
769 |
770 | # extract attribution indices with verbs in between
771 | modifier_sets_2 = extract_attribution_indices_with_verb_root(self.doc)
772 | modifier_indices_2 = self._align_indices(prompt, modifier_sets_2)
773 | if modifier_indices_2:
774 | modifier_indices.append(modifier_indices_2)
775 |
776 | modifier_sets_3 = extract_attribution_indices_with_verbs(self.doc)
777 | modifier_indices_3 = self._align_indices(prompt, modifier_sets_3)
778 | if modifier_indices_3:
779 | modifier_indices.append(modifier_indices_3)
780 |
781 | # entities only
782 | if self.include_entities:
783 | modifier_sets_4 = extract_entities_only(self.doc)
784 | modifier_indices_4 = self._align_indices(prompt, modifier_sets_4)
785 | modifier_indices.append(modifier_indices_4)
786 |
787 | # make sure there are no duplicates
788 | modifier_indices = unify_lists(modifier_indices)
789 | print(f"Final modifier indices collected:{modifier_indices}")
790 |
791 | return modifier_indices
792 |
793 |
794 | def _get_attention_maps_list(
795 | attention_maps: torch.Tensor
796 | ) -> List[torch.Tensor]:
797 | attention_maps *= 100
798 | attention_maps_list = [
799 | attention_maps[:, :, i] for i in range(attention_maps.shape[2])
800 | ]
801 |
802 | return attention_maps_list
803 |
804 |
805 | def unify_lists(list_of_lists):
806 | def flatten(lst):
807 | for elem in lst:
808 | if isinstance(elem, list):
809 | yield from flatten(elem)
810 | else:
811 | yield elem
812 |
813 | def have_common_element(lst1, lst2):
814 | flat_list1 = set(flatten(lst1))
815 | flat_list2 = set(flatten(lst2))
816 | return not flat_list1.isdisjoint(flat_list2)
817 |
818 | lst = []
819 | for l in list_of_lists:
820 | lst += l
821 | changed = True
822 | while changed:
823 | changed = False
824 | merged_list = []
825 | while lst:
826 | first = lst.pop(0)
827 | was_merged = False
828 | for index, other in enumerate(lst):
829 | if have_common_element(first, other):
830 | # If we merge, we should flatten the other list but not first
831 | new_merged = first + [item for item in other if item not in first]
832 | lst[index] = new_merged
833 | changed = True
834 | was_merged = True
835 | break
836 | if not was_merged:
837 | merged_list.append(first)
838 | lst = merged_list
839 |
840 | return lst
841 |
842 | from typing import Tuple
843 | from dataclasses import dataclass
844 | from diffusers.schedulers.scheduling_lcm import LCMSchedulerOutput
845 | from diffusers.utils.torch_utils import randn_tensor
846 | from diffusers.utils import BaseOutput
847 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
848 |
849 | def register_sdschedule_step(model):
850 | def sd_schedule_step(self):
851 | def step(
852 | model_output: torch.Tensor,
853 | timestep: Union[int, torch.Tensor],
854 | sample: torch.Tensor,
855 | generator=None,
856 | variance_noise: Optional[torch.Tensor] = None,
857 | return_dict: bool = True,
858 | ) -> Union[SchedulerOutput, Tuple]:
859 | """
860 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
861 | the multistep DPMSolver.
862 |
863 | Args:
864 | model_output (`torch.Tensor`):
865 | The direct output from learned diffusion model.
866 | timestep (`int`):
867 | The current discrete timestep in the diffusion chain.
868 | sample (`torch.Tensor`):
869 | A current instance of a sample created by the diffusion process.
870 | generator (`torch.Generator`, *optional*):
871 | A random number generator.
872 | variance_noise (`torch.Tensor`):
873 | Alternative to generating noise with `generator` by directly providing the noise for the variance
874 | itself. Useful for methods such as [`LEdits++`].
875 | return_dict (`bool`):
876 | Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
877 |
878 | Returns:
879 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
880 | If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
881 | tuple is returned where the first element is the sample tensor.
882 |
883 | """
884 | if self.num_inference_steps is None:
885 | raise ValueError(
886 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
887 | )
888 |
889 | if self.step_index is None:
890 | self._init_step_index(timestep)
891 |
892 | # Improve numerical stability for small number of steps
893 | lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
894 | self.config.euler_at_final
895 | or (self.config.lower_order_final and len(self.timesteps) < 15)
896 | or self.config.final_sigmas_type == "zero"
897 | )
898 | lower_order_second = (
899 | (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
900 | )
901 |
902 | model_output = self.convert_model_output(model_output, sample=sample)
903 | for i in range(self.config.solver_order - 1):
904 | self.model_outputs[i] = self.model_outputs[i + 1]
905 | self.model_outputs[-1] = model_output
906 |
907 | # Upcast to avoid precision issues when computing prev_sample
908 | sample = sample.to(torch.float32)
909 | if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
910 | noise = randn_tensor(
911 | model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
912 | )
913 | elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
914 | noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
915 | else:
916 | noise = None
917 |
918 | if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
919 | prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
920 | elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
921 | prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
922 | else:
923 | prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
924 |
925 | if self.lower_order_nums < self.config.solver_order:
926 | self.lower_order_nums += 1
927 |
928 | # Cast sample back to expected dtype
929 | prev_sample = prev_sample.to(model_output.dtype)
930 |
931 | # upon completion increase step index by one
932 | self._step_index += 1
933 |
934 | if not return_dict:
935 | return (prev_sample, model_output)
936 |
937 | return SchedulerOutput(prev_sample=prev_sample)
938 |
939 | return step
940 | if model.__class__.__name__ == 'DPMSolverMultistepScheduler':
941 | model.step = sd_schedule_step(model)
--------------------------------------------------------------------------------