├── my_model ├── __init__.py ├── unet_2d_condition.py ├── attention.py └── unet_2d_blocks.py ├── FreeMono.ttf ├── example_input └── text_inversion │ └── cat │ └── cat.png ├── requirements.txt ├── .gitignore ├── conf ├── base_config.yaml ├── unet │ └── config.json └── real_image_editing_config.yaml ├── README.md ├── utils.py ├── inference.py ├── text_inversion.py └── dreambooth.py /my_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FreeMono.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silent-chen/layout-guidance/HEAD/FreeMono.ttf -------------------------------------------------------------------------------- /example_input/text_inversion/cat/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silent-chen/layout-guidance/HEAD/example_input/text_inversion/cat/cat.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | omegaconf==2.2.3 4 | opencv-python 5 | imageio==2.9.0 6 | transformers==4.25.1 7 | diffusers==0.11.1 8 | accelerate==0.13.2 9 | scipy==1.9.1 10 | git+https://github.com/openai/CLIP.git 11 | hydra-core==1.2.0 12 | tqdm -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.pyc 4 | *.pyo 5 | *.pyd 6 | *.pyc 7 | .Python 8 | env/ 9 | venv/ 10 | ENV/ 11 | env.bak/ 12 | venv.bak/ 13 | *.pyc 14 | 15 | # PyCharm 16 | .idea/ 17 | *.iml 18 | 19 | # macOS 20 | .DS_Store 21 | .AppleDouble 22 | .LSOverride 23 | ._* 24 | .DocumentRevisions-V100 25 | .fseventsd 26 | .Spotlight-V100 27 | .TemporaryItems 28 | .Trashes 29 | .VolumeIcon.icns 30 | .com.apple.timemachine.donotpresent 31 | .AppleDB 32 | .AppleDesktop 33 | Network Trash Folder 34 | Temporary Items 35 | .apdisk -------------------------------------------------------------------------------- /conf/base_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | general: 3 | save_path: './example_output' 4 | model_path: 'runwayml/stable-diffusion-v1-5' 5 | unet_config: './conf/unet/config.json' 6 | real_image_editing: False 7 | 8 | inference: 9 | loss_scale: 30 10 | batch_size: 1 11 | loss_threshold: 0.2 12 | max_iter: 5 13 | max_index_step: 10 14 | timesteps: 51 15 | classifier_free_guidance: 7.5 16 | rand_seed: 445 17 | 18 | noise_schedule: 19 | beta_start: 0.00085 20 | beta_end: 0.012 21 | beta_schedule: "scaled_linear" 22 | num_train_timesteps: 1000 23 | 24 | real_image_editing: 25 | dreambooth_path: '' 26 | text_inversion_path: '' 27 | placeholder_token: '' 28 | -------------------------------------------------------------------------------- /conf/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "attention_head_dim": 8, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "center_input_sample": false, 13 | "cross_attention_dim": 768, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock2D", 16 | "CrossAttnDownBlock2D", 17 | "CrossAttnDownBlock2D", 18 | "DownBlock2D" 19 | ], 20 | "downsample_padding": 1, 21 | "flip_sin_to_cos": true, 22 | "freq_shift": 0, 23 | "in_channels": 4, 24 | "layers_per_block": 2, 25 | "mid_block_scale_factor": 1, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 64, 30 | "up_block_types": [ 31 | "UpBlock2D", 32 | "CrossAttnUpBlock2D", 33 | "CrossAttnUpBlock2D", 34 | "CrossAttnUpBlock2D" 35 | ] 36 | } -------------------------------------------------------------------------------- /conf/real_image_editing_config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | save_path: './example_output/real_image_editing' 3 | model_path: 'runwayml/stable-diffusion-v1-5' 4 | unet_config: './conf/unet/config.json' 5 | seed: 0 6 | 7 | text_inversion: 8 | use_ema: True 9 | batch_size: 4 10 | adam_beta1: 0.9 11 | adam_beta2: 0.999 12 | adam_weight_decay: 1e-2 13 | adam_epsilon: 1e-08 14 | lr_scheduler: constant 15 | lr_warmup_steps: 0 16 | max_train_steps: 500 17 | text_finetune_step: 50 18 | unet_finetune_step: 50 19 | alpha: 0.1 20 | min_lr: 1e-6 21 | warmup_epochs: 0 22 | num_train_epochs: 100 23 | gradient_accumulation_steps: 1 24 | lr: 5.0e-04 25 | placeholder_token: 26 | initial_token: pet 27 | scale_lr: True 28 | resolution: 512 29 | repeats: 100 30 | learnable_property: 'object' 31 | center_crop: False 32 | unet_pretrain: 'runwayml/stable-diffusion-v1-5' 33 | save_steps: 50 34 | randaug: False 35 | image_path: './example_input/text_inversion/cat/' 36 | inference: False 37 | embedding_ckp: '' 38 | example_prompt: 'a photo of a {}' 39 | new_prompt: '' 40 | inference_batch_size: 4 41 | 42 | dreambooth: 43 | with_prior_preservation: True 44 | class_data_dir: './dreambooth_class_preservation_dir/cat' 45 | num_class_images: 300 46 | pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5' 47 | class_prompt: 'a photo of a cat' 48 | scale_lr: True 49 | lr: 1e-6 50 | gradient_accumulation_steps: 1 51 | train_batch_size: 1 52 | train_text_encoder: True 53 | adam_beta1: 0.9 54 | adam_beta2: 0.999 55 | adam_weight_decay: 1e-2 56 | adam_epsilon: 1e-08 57 | instance_data_dir: './example_input/text_inversion/cat/' 58 | instance_prompt: 'a photo of a ' 59 | resolution: 512 60 | center_crop: False 61 | max_train_steps: 150 62 | num_train_epochs: 1000 63 | lr_scheduler: 'constant' 64 | lr_warmup_steps: 0 65 | max_grad_norm: 1.0 66 | prior_loss_weight: 1.0 67 | sample_batch_size: 4 68 | inference: False 69 | text_inversion_path: '' 70 | example_prompt: 'a photo of a {}' 71 | new_prompt: '' 72 | inference_batch_size: 4 73 | ckp_path: '' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training-Free Layout Control with Cross-Attention Guidance 2 | [Minghao Chen](https://silent-chen.github.io), [Iro Laina](), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/) 3 | 4 | [[Paper](https://arxiv.org/abs/2304.03373)] [[Project Page](https://silent-chen.github.io/layout-guidance/)] [[Demo](https://huggingface.co/spaces/silentchen/layout-guidance)] 5 | 6 | https://user-images.githubusercontent.com/30588507/229642269-57527ded-3189-4aa2-9590-1f3de4d51cad.mp4 7 | 8 | 9 |

10 | 11 |
12 | teaser 13 |
14 | 15 | Our method manage to control of layout of images generated by large pretrained Text-to-Image diffusion models **without training** through the layout guidance performed on the cross-attention maps. 16 | 17 | ## Abstract 18 | Recent diffusion-based generators can produce high-quality images based only on textual prompts. However, they do not correctly interpret instructions that specify the spatial layout of the composition. We propose a simple approach that can achieve robust layout control without requiring training or fine-tuning the image generator. Our technique, which we call layout guidance, manipulates the cross-attention layers that the model uses to interface textual and visual information and steers the reconstruction in the desired direction given, e.g., a user-specified layout. In order to determine how to best guide attention, we study the role of different attention maps when generating images and experiment with two alternative strategies, forward and backward guidance. We evaluate our method quantitatively and qualitatively with several experiments, validating its effectiveness. We further demonstrate its versatility by extending layout guidance to the task of editing the layout and context of a given real image. 19 | 20 | ## Environment Setup 21 | 22 | To set up the enviroment you can easily run the following command: 23 | ```buildoutcfg 24 | conda create -n layout-guidance python=3.8 25 | conda activate layout-guidance 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Inference 30 | 31 | We provide an example inference script. The example outputs, including log file, generated images, config file, are saved to the specified path `./example_output`. Detail configuration can be found in the `./conf/base_config.yaml` and `inference.py`. 32 | ```buildoutcfg 33 | python inference.py general.save_path=./example_output 34 | ``` 35 | 36 | ## Applications 37 | 38 | ### Real Image Editing 39 | 40 | We achieve real image editing based on Dreambooth and Text Inversion. Specifically, we can change the context, location and size of the objects in the original image. 41 | 42 | 43 | There are 3 steps to achieve real image editing based on layout guidance. Please check the config file in `./conf/real_image_editing.yaml` for more detailed configuration. 44 | 45 | Step 1: Use text inversion to generate a special token that describes the desired object. 46 | ```buildoutcfg 47 | python text_inversion.py \ 48 | general.save_path=./example_output/real_image_editing \ 49 | text_inversion.image_path=./example_input/text_inversion/cat/ \ 50 | text_inversion.initial_token='pet' 51 | ``` 52 | Step 2: Use Dreambooth to finetune the U-Net and text-encoder. 53 | 54 | ```buildoutcfg 55 | python dreambooth.py dreambooth.text_inversion_path=./example_output/real_image_editing/text_inversion/learned_embeds_iteration_500.bin 56 | ``` 57 | 58 | Step 3: Perform layout guidance on the fine-tuned text encoder and U-Net. 59 | 60 | ```buildoutcfg 61 | python inference.py \ 62 | general.save_path=./example_output/real_image_editing/ \ 63 | general.real_image_editing=True \ 64 | real_image_editing.dreambooth_path=./example_output/real_image_editing/dreambooth/dreambooth_150.ckp \ 65 | real_image_editing.text_inversion_path=./example_output/real_image_editing/text_inversion/learned_embeds_iteration_500.bin 66 | ``` 67 | 68 | Here are some example outputs of real image editing. 69 | 70 |
71 | teaser 72 |
73 | 74 | 75 | ## Citation 76 | 77 | If this repo is helpful for you, please consider to cite it. Thank you! :) 78 | 79 | ```bibtex 80 | @article{chen2023trainingfree, 81 | title={Training-Free Layout Control with Cross-Attention Guidance}, 82 | author={Minghao Chen and Iro Laina and Andrea Vedaldi}, 83 | journal={arXiv preprint arXiv:2304.03373}, 84 | year={2023} 85 | } 86 | 87 | ``` 88 | 89 | ## To Do List 90 | 91 | - [x] Basic Backward Guidance 92 | - [ ] Support Different Layer of Backward Guidance 93 | - [ ] Forward Guidance 94 | - [x] Real Image Editting Example 95 | 96 | ## Acknowledgements 97 | 98 | This research is supported by ERC-CoG UNION 101001212. 99 | The codes are inspired by [Diffuser](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion). -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from PIL import Image, ImageDraw, ImageFont 4 | import logging 5 | import os 6 | 7 | 8 | def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): 9 | loss = 0 10 | object_number = len(bboxes) 11 | if object_number == 0: 12 | return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() 13 | for attn_map_integrated in attn_maps_mid: 14 | attn_map = attn_map_integrated 15 | 16 | # 17 | b, i, j = attn_map.shape 18 | H = W = int(math.sqrt(i)) 19 | for obj_idx in range(object_number): 20 | obj_loss = 0 21 | mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) 22 | for obj_box in bboxes[obj_idx]: 23 | 24 | x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ 25 | int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) 26 | mask[y_min: y_max, x_min: x_max] = 1 27 | 28 | for obj_position in object_positions[obj_idx]: 29 | ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) 30 | 31 | activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) 32 | 33 | obj_loss += torch.mean((1 - activation_value) ** 2) 34 | loss += (obj_loss/len(object_positions[obj_idx])) 35 | 36 | for attn_map_integrated in attn_maps_up[0]: 37 | attn_map = attn_map_integrated 38 | # 39 | b, i, j = attn_map.shape 40 | H = W = int(math.sqrt(i)) 41 | 42 | for obj_idx in range(object_number): 43 | obj_loss = 0 44 | mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) 45 | for obj_box in bboxes[obj_idx]: 46 | x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ 47 | int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) 48 | mask[y_min: y_max, x_min: x_max] = 1 49 | 50 | for obj_position in object_positions[obj_idx]: 51 | ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) 52 | # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) 53 | 54 | activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( 55 | dim=-1) 56 | 57 | obj_loss += torch.mean((1 - activation_value) ** 2) 58 | loss += (obj_loss / len(object_positions[obj_idx])) 59 | loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) 60 | return loss 61 | 62 | def Pharse2idx(prompt, phrases): 63 | phrases = [x.strip() for x in phrases.split(';')] 64 | prompt_list = prompt.strip('.').split(' ') 65 | object_positions = [] 66 | for obj in phrases: 67 | obj_position = [] 68 | for word in obj.split(' '): 69 | obj_first_index = prompt_list.index(word) + 1 70 | obj_position.append(obj_first_index) 71 | object_positions.append(obj_position) 72 | 73 | return object_positions 74 | 75 | def draw_box(pil_img, bboxes, phrases, save_path): 76 | draw = ImageDraw.Draw(pil_img) 77 | font = ImageFont.truetype('./FreeMono.ttf', 25) 78 | phrases = [x.strip() for x in phrases.split(';')] 79 | for obj_bboxes, phrase in zip(bboxes, phrases): 80 | for obj_bbox in obj_bboxes: 81 | x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3] 82 | draw.rectangle([int(x_0 * 512), int(y_0 * 512), int(x_1 * 512), int(y_1 * 512)], outline='red', width=5) 83 | draw.text((int(x_0 * 512) + 5, int(y_0 * 512) + 5), phrase, font=font, fill=(255, 0, 0)) 84 | pil_img.save(save_path) 85 | 86 | 87 | 88 | def setup_logger(save_path, logger_name): 89 | logger = logging.getLogger(logger_name) 90 | logger.setLevel(logging.INFO) 91 | 92 | # Create a file handler to write logs to a file 93 | file_handler = logging.FileHandler(os.path.join(save_path, f"{logger_name}.log")) 94 | file_handler.setLevel(logging.INFO) 95 | 96 | # Create a formatter to format log messages 97 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 98 | 99 | # Set the formatter for the file handler 100 | file_handler.setFormatter(formatter) 101 | 102 | # Add the file handler to the logger 103 | logger.addHandler(file_handler) 104 | 105 | return logger 106 | 107 | def load_text_inversion(text_encoder, tokenizer, placeholder_token, embedding_ckp_path): 108 | num_added_tokens = tokenizer.add_tokens(placeholder_token) 109 | if num_added_tokens == 0: 110 | raise ValueError( 111 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" 112 | " `placeholder_token` that is not already in the tokenizer." 113 | ) 114 | 115 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) 116 | 117 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 118 | text_encoder.resize_token_embeddings(len(tokenizer)) 119 | 120 | # Initialise the newly added placeholder token with the embeddings of the initializer token 121 | token_embeds = text_encoder.get_input_embeddings().weight.data 122 | learned_embedding = torch.load(embedding_ckp_path) 123 | token_embeds[placeholder_token_id] = learned_embedding[placeholder_token] 124 | return text_encoder, tokenizer 125 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from diffusers import AutoencoderKL, LMSDiscreteScheduler 5 | from my_model import unet_2d_condition 6 | import json 7 | from PIL import Image 8 | from utils import compute_ca_loss, Pharse2idx, draw_box, setup_logger 9 | import hydra 10 | import os 11 | from tqdm import tqdm 12 | from utils import load_text_inversion 13 | def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, phrases, cfg, logger): 14 | 15 | 16 | logger.info("Inference") 17 | logger.info(f"Prompt: {prompt}") 18 | logger.info(f"Phrases: {phrases}") 19 | 20 | # Get Object Positions 21 | 22 | logger.info("Convert Phrases to Object Positions") 23 | object_positions = Pharse2idx(prompt, phrases) 24 | 25 | # Encode Classifier Embeddings 26 | uncond_input = tokenizer( 27 | [""] * cfg.inference.batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" 28 | ) 29 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] 30 | 31 | # Encode Prompt 32 | input_ids = tokenizer( 33 | [prompt] * cfg.inference.batch_size, 34 | padding="max_length", 35 | truncation=True, 36 | max_length=tokenizer.model_max_length, 37 | return_tensors="pt", 38 | ) 39 | 40 | cond_embeddings = text_encoder(input_ids.input_ids.to(device))[0] 41 | text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) 42 | generator = torch.manual_seed(cfg.inference.rand_seed) # Seed generator to create the initial latent noise 43 | 44 | noise_scheduler = LMSDiscreteScheduler(beta_start=cfg.noise_schedule.beta_start, beta_end=cfg.noise_schedule.beta_end, 45 | beta_schedule=cfg.noise_schedule.beta_schedule, num_train_timesteps=cfg.noise_schedule.num_train_timesteps) 46 | 47 | latents = torch.randn( 48 | (cfg.inference.batch_size, 4, 64, 64), 49 | generator=generator, 50 | ).to(device) 51 | 52 | noise_scheduler.set_timesteps(cfg.inference.timesteps) 53 | 54 | latents = latents * noise_scheduler.init_noise_sigma 55 | 56 | loss = torch.tensor(10000) 57 | 58 | for index, t in enumerate(tqdm(noise_scheduler.timesteps)): 59 | iteration = 0 60 | 61 | while loss.item() / cfg.inference.loss_scale > cfg.inference.loss_threshold and iteration < cfg.inference.max_iter and index < cfg.inference.max_index_step: 62 | latents = latents.requires_grad_(True) 63 | latent_model_input = latents 64 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 65 | noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \ 66 | unet(latent_model_input, t, encoder_hidden_states=cond_embeddings) 67 | 68 | # update latents with guidance 69 | loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes, 70 | object_positions=object_positions) * cfg.inference.loss_scale 71 | 72 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0] 73 | 74 | latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2 75 | iteration += 1 76 | torch.cuda.empty_cache() 77 | 78 | with torch.no_grad(): 79 | latent_model_input = torch.cat([latents] * 2) 80 | 81 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 82 | noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \ 83 | unet(latent_model_input, t, encoder_hidden_states=text_embeddings) 84 | 85 | noise_pred = noise_pred.sample 86 | 87 | # perform guidance 88 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 89 | noise_pred = noise_pred_uncond + cfg.inference.classifier_free_guidance * (noise_pred_text - noise_pred_uncond) 90 | 91 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 92 | torch.cuda.empty_cache() 93 | 94 | with torch.no_grad(): 95 | logger.info("Decode Image...") 96 | latents = 1 / 0.18215 * latents 97 | image = vae.decode(latents).sample 98 | image = (image / 2 + 0.5).clamp(0, 1) 99 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 100 | images = (image * 255).round().astype("uint8") 101 | pil_images = [Image.fromarray(image) for image in images] 102 | return pil_images 103 | 104 | 105 | @hydra.main(version_base=None, config_path="conf", config_name="base_config") 106 | def main(cfg): 107 | 108 | # build and load model 109 | with open(cfg.general.unet_config) as f: 110 | unet_config = json.load(f) 111 | unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained(cfg.general.model_path, subfolder="unet") 112 | tokenizer = CLIPTokenizer.from_pretrained(cfg.general.model_path, subfolder="tokenizer") 113 | text_encoder = CLIPTextModel.from_pretrained(cfg.general.model_path, subfolder="text_encoder") 114 | vae = AutoencoderKL.from_pretrained(cfg.general.model_path, subfolder="vae") 115 | 116 | if cfg.general.real_image_editing: 117 | text_encoder, tokenizer = load_text_inversion(text_encoder, tokenizer, cfg.real_image_editing.placeholder_token, cfg.real_image_editing.text_inversion_path) 118 | unet.load_state_dict(torch.load(cfg.real_image_editing.dreambooth_path)['unet']) 119 | text_encoder.load_state_dict(torch.load(cfg.real_image_editing.dreambooth_path)['encoder']) 120 | 121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 122 | 123 | unet.to(device) 124 | text_encoder.to(device) 125 | vae.to(device) 126 | 127 | 128 | 129 | # ------------------ example input ------------------ 130 | examples = {"prompt": "A hello kitty toy is playing with a purple ball.", 131 | "phrases": "hello kitty; ball", 132 | "bboxes": [[[0.1, 0.2, 0.5, 0.8]], [[0.75, 0.6, 0.95, 0.8]]], 133 | 'save_path': cfg.general.save_path 134 | } 135 | 136 | # ------------------ real image editing example input ------------------ 137 | if cfg.general.real_image_editing: 138 | examples = {"prompt": "A {} is standing on grass.".format(cfg.real_image_editing.placeholder_token), 139 | "phrases": "{}".format(cfg.real_image_editing.placeholder_token), 140 | "bboxes": [[[0.4, 0.2, 0.9, 0.9]]], 141 | 'save_path': cfg.general.save_path 142 | } 143 | # --------------------------------------------------- 144 | # Prepare the save path 145 | if not os.path.exists(cfg.general.save_path): 146 | os.makedirs(cfg.general.save_path) 147 | logger = setup_logger(cfg.general.save_path, __name__) 148 | 149 | logger.info(cfg) 150 | # Save cfg 151 | logger.info("save config to {}".format(os.path.join(cfg.general.save_path, 'config.yaml'))) 152 | OmegaConf.save(cfg, os.path.join(cfg.general.save_path, 'config.yaml')) 153 | 154 | # Inference 155 | pil_images = inference(device, unet, vae, tokenizer, text_encoder, examples['prompt'], examples['bboxes'], examples['phrases'], cfg, logger) 156 | 157 | # Save example images 158 | for index, pil_image in enumerate(pil_images): 159 | image_path = os.path.join(cfg.general.save_path, 'example_{}.png'.format(index)) 160 | logger.info('save example image to {}'.format(image_path)) 161 | draw_box(pil_image, examples['bboxes'], examples['phrases'], image_path) 162 | 163 | if __name__ == "__main__": 164 | main() -------------------------------------------------------------------------------- /my_model/unet_2d_condition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pdb 15 | from dataclasses import dataclass 16 | from typing import Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.utils.checkpoint 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.modeling_utils import ModelMixin 24 | from diffusers.utils import BaseOutput, logging 25 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 26 | from .unet_2d_blocks import ( 27 | CrossAttnDownBlock2D, 28 | CrossAttnUpBlock2D, 29 | DownBlock2D, 30 | UNetMidBlock2DCrossAttn, 31 | UpBlock2D, 32 | get_down_block, 33 | get_up_block, 34 | ) 35 | 36 | 37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 38 | 39 | 40 | @dataclass 41 | class UNet2DConditionOutput(BaseOutput): 42 | """ 43 | Args: 44 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 45 | Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. 46 | """ 47 | 48 | sample: torch.FloatTensor 49 | 50 | 51 | class UNet2DConditionModel(ModelMixin, ConfigMixin): 52 | r""" 53 | UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep 54 | and returns sample shaped output. 55 | 56 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library 57 | implements for all the models (such as downloading or saving, etc.) 58 | 59 | Parameters: 60 | sample_size (`int`, *optional*): The size of the input sample. 61 | in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. 62 | out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. 63 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 64 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 65 | Whether to flip the sin to cos in the time embedding. 66 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 67 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 68 | The tuple of downsample blocks to use. 69 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): 70 | The tuple of upsample blocks to use. 71 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 72 | The tuple of output channels for each block. 73 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 74 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 75 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 76 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 77 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 78 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 79 | cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. 80 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 81 | """ 82 | 83 | _supports_gradient_checkpointing = True 84 | 85 | @register_to_config 86 | def __init__( 87 | self, 88 | sample_size: Optional[int] = None, 89 | in_channels: int = 4, 90 | out_channels: int = 4, 91 | center_input_sample: bool = False, 92 | flip_sin_to_cos: bool = True, 93 | freq_shift: int = 0, 94 | down_block_types: Tuple[str] = ( 95 | "CrossAttnDownBlock2D", 96 | "CrossAttnDownBlock2D", 97 | "CrossAttnDownBlock2D", 98 | "DownBlock2D", 99 | ), 100 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 101 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 102 | layers_per_block: int = 2, 103 | downsample_padding: int = 1, 104 | mid_block_scale_factor: float = 1, 105 | act_fn: str = "silu", 106 | norm_num_groups: int = 32, 107 | norm_eps: float = 1e-5, 108 | cross_attention_dim: int = 1280, 109 | attention_head_dim: int = 8, 110 | ): 111 | super().__init__() 112 | 113 | self.sample_size = sample_size 114 | time_embed_dim = block_out_channels[0] * 4 115 | 116 | # input 117 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 118 | 119 | # time 120 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 121 | timestep_input_dim = block_out_channels[0] 122 | 123 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 124 | 125 | self.down_blocks = nn.ModuleList([]) 126 | self.mid_block = None 127 | self.up_blocks = nn.ModuleList([]) 128 | 129 | # down 130 | output_channel = block_out_channels[0] 131 | for i, down_block_type in enumerate(down_block_types): 132 | input_channel = output_channel 133 | output_channel = block_out_channels[i] 134 | is_final_block = i == len(block_out_channels) - 1 135 | 136 | down_block = get_down_block( 137 | down_block_type, 138 | num_layers=layers_per_block, 139 | in_channels=input_channel, 140 | out_channels=output_channel, 141 | temb_channels=time_embed_dim, 142 | add_downsample=not is_final_block, 143 | resnet_eps=norm_eps, 144 | resnet_act_fn=act_fn, 145 | resnet_groups=norm_num_groups, 146 | cross_attention_dim=cross_attention_dim, 147 | attn_num_head_channels=attention_head_dim, 148 | downsample_padding=downsample_padding, 149 | ) 150 | self.down_blocks.append(down_block) 151 | 152 | # mid 153 | self.mid_block = UNetMidBlock2DCrossAttn( 154 | in_channels=block_out_channels[-1], 155 | temb_channels=time_embed_dim, 156 | resnet_eps=norm_eps, 157 | resnet_act_fn=act_fn, 158 | output_scale_factor=mid_block_scale_factor, 159 | resnet_time_scale_shift="default", 160 | cross_attention_dim=cross_attention_dim, 161 | attn_num_head_channels=attention_head_dim, 162 | resnet_groups=norm_num_groups, 163 | ) 164 | 165 | # count how many layers upsample the images 166 | self.num_upsamplers = 0 167 | 168 | # up 169 | reversed_block_out_channels = list(reversed(block_out_channels)) 170 | output_channel = reversed_block_out_channels[0] 171 | for i, up_block_type in enumerate(up_block_types): 172 | is_final_block = i == len(block_out_channels) - 1 173 | 174 | prev_output_channel = output_channel 175 | output_channel = reversed_block_out_channels[i] 176 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 177 | 178 | # add upsample block for all BUT final layer 179 | if not is_final_block: 180 | add_upsample = True 181 | self.num_upsamplers += 1 182 | else: 183 | add_upsample = False 184 | 185 | up_block = get_up_block( 186 | up_block_type, 187 | num_layers=layers_per_block + 1, 188 | in_channels=input_channel, 189 | out_channels=output_channel, 190 | prev_output_channel=prev_output_channel, 191 | temb_channels=time_embed_dim, 192 | add_upsample=add_upsample, 193 | resnet_eps=norm_eps, 194 | resnet_act_fn=act_fn, 195 | resnet_groups=norm_num_groups, 196 | cross_attention_dim=cross_attention_dim, 197 | attn_num_head_channels=attention_head_dim, 198 | ) 199 | self.up_blocks.append(up_block) 200 | prev_output_channel = output_channel 201 | 202 | # out 203 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 204 | self.conv_act = nn.SiLU() 205 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 206 | 207 | def set_attention_slice(self, slice_size): 208 | if slice_size is not None and self.config.attention_head_dim % slice_size != 0: 209 | raise ValueError( 210 | f"Make sure slice_size {slice_size} is a divisor of " 211 | f"the number of heads used in cross_attention {self.config.attention_head_dim}" 212 | ) 213 | if slice_size is not None and slice_size > self.config.attention_head_dim: 214 | raise ValueError( 215 | f"Chunk_size {slice_size} has to be smaller or equal to " 216 | f"the number of heads used in cross_attention {self.config.attention_head_dim}" 217 | ) 218 | 219 | for block in self.down_blocks: 220 | if hasattr(block, "attentions") and block.attentions is not None: 221 | block.set_attention_slice(slice_size) 222 | 223 | self.mid_block.set_attention_slice(slice_size) 224 | 225 | for block in self.up_blocks: 226 | if hasattr(block, "attentions") and block.attentions is not None: 227 | block.set_attention_slice(slice_size) 228 | 229 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 230 | for block in self.down_blocks: 231 | if hasattr(block, "attentions") and block.attentions is not None: 232 | block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 233 | 234 | self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 235 | 236 | for block in self.up_blocks: 237 | if hasattr(block, "attentions") and block.attentions is not None: 238 | block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 239 | 240 | def _set_gradient_checkpointing(self, module, value=False): 241 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): 242 | module.gradient_checkpointing = value 243 | 244 | def forward( 245 | self, 246 | sample: torch.FloatTensor, 247 | timestep: Union[torch.Tensor, float, int], 248 | encoder_hidden_states: torch.Tensor, 249 | return_dict: bool = True, 250 | ) -> Union[UNet2DConditionOutput, Tuple]: 251 | r""" 252 | Args: 253 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs_coarse tensor 254 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 255 | encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states 256 | return_dict (`bool`, *optional*, defaults to `True`): 257 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 258 | 259 | Returns: 260 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 261 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 262 | returning a tuple, the first element is the sample tensor. 263 | """ 264 | # By default samples have to be AT least a multiple of the overall upsampling factor. 265 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 266 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 267 | # on the fly if necessary. 268 | default_overall_up_factor = 2**self.num_upsamplers 269 | 270 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 271 | forward_upsample_size = False 272 | upsample_size = None 273 | 274 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 275 | logger.info("Forward upsample size to force interpolation output size.") 276 | forward_upsample_size = True 277 | 278 | # 0. center input if necessary 279 | if self.config.center_input_sample: 280 | sample = 2 * sample - 1.0 281 | 282 | # 1. time 283 | timesteps = timestep 284 | if not torch.is_tensor(timesteps): 285 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 286 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 287 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 288 | timesteps = timesteps[None].to(sample.device) 289 | 290 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 291 | timesteps = timesteps.expand(sample.shape[0]) 292 | 293 | t_emb = self.time_proj(timesteps) 294 | 295 | # timesteps does not contain any weights and will always return f32 tensors 296 | # but time_embedding might actually be running in fp16. so we need to cast here. 297 | # there might be better ways to encapsulate this. 298 | t_emb = t_emb.to(dtype=self.dtype) 299 | emb = self.time_embedding(t_emb) 300 | # 2. pre-process 301 | sample = self.conv_in(sample) 302 | # 3. down 303 | attn_down = [] 304 | down_block_res_samples = (sample,) 305 | for block_idx, downsample_block in enumerate(self.down_blocks): 306 | if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: 307 | sample, res_samples, cross_atten_prob = downsample_block( 308 | hidden_states=sample, 309 | temb=emb, 310 | encoder_hidden_states=encoder_hidden_states 311 | ) 312 | attn_down.append(cross_atten_prob) 313 | else: 314 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 315 | 316 | down_block_res_samples += res_samples 317 | 318 | # 4. mid 319 | sample, attn_mid = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 320 | 321 | # 5. up 322 | attn_up = [] 323 | for i, upsample_block in enumerate(self.up_blocks): 324 | is_final_block = i == len(self.up_blocks) - 1 325 | 326 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 327 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 328 | 329 | # if we have not reached the final block and need to forward the 330 | # upsample size, we do it here 331 | if not is_final_block and forward_upsample_size: 332 | upsample_size = down_block_res_samples[-1].shape[2:] 333 | 334 | if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: 335 | sample, cross_atten_prob = upsample_block( 336 | hidden_states=sample, 337 | temb=emb, 338 | res_hidden_states_tuple=res_samples, 339 | encoder_hidden_states=encoder_hidden_states, 340 | upsample_size=upsample_size, 341 | ) 342 | attn_up.append(cross_atten_prob) 343 | else: 344 | sample = upsample_block( 345 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 346 | ) 347 | # 6. post-process 348 | sample = self.conv_norm_out(sample) 349 | sample = self.conv_act(sample) 350 | sample = self.conv_out(sample) 351 | 352 | if not return_dict: 353 | return (sample,) 354 | 355 | return UNet2DConditionOutput(sample=sample), attn_up, attn_mid, attn_down 356 | -------------------------------------------------------------------------------- /text_inversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import math 4 | import os 5 | import random 6 | from pathlib import Path 7 | from typing import Optional 8 | import hydra 9 | import json 10 | from omegaconf import DictConfig 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | from torch.utils.data import Dataset 17 | from omegaconf import OmegaConf 18 | 19 | import PIL 20 | from accelerate import Accelerator 21 | from accelerate.utils import set_seed 22 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler 23 | from my_model import unet_2d_condition 24 | from diffusers.optimization import get_scheduler 25 | from PIL import Image 26 | from torchvision import transforms 27 | from tqdm.auto import tqdm 28 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 29 | from utils import setup_logger, load_text_inversion 30 | def save_progress(text_encoder, placeholder_token_id, accelerator, iteration_idx, cfg, logger): 31 | logger.info("Saving embeddings to {}".format(os.path.join(cfg.general.save_path, "learned_embeds_iteration_{}.bin"))) 32 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] 33 | learned_embeds_dict = {cfg.text_inversion.placeholder_token: learned_embeds.detach().cpu()} 34 | torch.save(learned_embeds_dict, os.path.join(cfg.general.save_path, "learned_embeds_iteration_{}.bin".format(iteration_idx))) 35 | 36 | imagenet_templates_small = [ 37 | "a photo of a {}", 38 | "a rendering of a {}", 39 | "a cropped photo of the {}", 40 | "the photo of a {}", 41 | "a photo of a clean {}", 42 | "a photo of a dirty {}", 43 | "a dark photo of the {}", 44 | "a photo of my {}", 45 | "a photo of the cool {}", 46 | "a close-up photo of a {}", 47 | "a bright photo of the {}", 48 | "a cropped photo of a {}", 49 | "a photo of the {}", 50 | "a good photo of the {}", 51 | "a photo of one {}", 52 | "a close-up photo of the {}", 53 | "a rendition of the {}", 54 | "a photo of the clean {}", 55 | "a rendition of a {}", 56 | "a photo of a nice {}", 57 | "a good photo of a {}", 58 | "a photo of the nice {}", 59 | "a photo of the small {}", 60 | "a photo of the weird {}", 61 | "a photo of the large {}", 62 | "a photo of a cool {}", 63 | "a photo of a small {}", 64 | ] 65 | 66 | imagenet_style_templates_small = [ 67 | "a painting in the style of {}", 68 | "a rendering in the style of {}", 69 | "a cropped painting in the style of {}", 70 | "the painting in the style of {}", 71 | "a clean painting in the style of {}", 72 | "a dirty painting in the style of {}", 73 | "a dark painting in the style of {}", 74 | "a picture in the style of {}", 75 | "a cool painting in the style of {}", 76 | "a close-up painting in the style of {}", 77 | "a bright painting in the style of {}", 78 | "a cropped painting in the style of {}", 79 | "a good painting in the style of {}", 80 | "a close-up painting in the style of {}", 81 | "a rendition in the style of {}", 82 | "a nice painting in the style of {}", 83 | "a small painting in the style of {}", 84 | "a weird painting in the style of {}", 85 | "a large painting in the style of {}", 86 | ] 87 | 88 | 89 | class TextualInversionDataset(Dataset): 90 | def __init__( 91 | self, 92 | data_root, 93 | tokenizer, 94 | learnable_property="object", # [object, style] 95 | size=512, 96 | repeats=100, 97 | interpolation="bicubic", 98 | flip_p=0.5, 99 | set="train", 100 | placeholder_token="*", 101 | center_crop=False, 102 | randaug=False 103 | ): 104 | self.data_root = data_root 105 | self.tokenizer = tokenizer 106 | self.learnable_property = learnable_property 107 | self.size = size 108 | self.placeholder_token = placeholder_token 109 | self.center_crop = center_crop 110 | self.flip_p = flip_p 111 | self.randaug = randaug 112 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 113 | 114 | self.num_images = len(self.image_paths) 115 | self._length = self.num_images 116 | 117 | if set == "train": 118 | self._length = self.num_images * repeats 119 | 120 | self.interpolation = { 121 | "linear": PIL.Image.LINEAR, 122 | "bilinear": PIL.Image.BILINEAR, 123 | "bicubic": PIL.Image.BICUBIC, 124 | "lanczos": PIL.Image.LANCZOS, 125 | }[interpolation] 126 | 127 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 128 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 129 | self.transform = transforms.RandAugment() 130 | # self.transform = transforms.Compose([ 131 | # transforms.Resize(int(size * 5/4)), 132 | # transforms.CenterCrop(int(size * 5/4)), 133 | # transforms.RandomApply([ 134 | # transforms.RandomRotation(degrees=10, fill=255), 135 | # transforms.CenterCrop(int(size * 5/6)), 136 | # transforms.Resize(size), 137 | # ], p=0.75), 138 | # transforms.RandomResizedCrop(size, scale=(0.85, 1.15)), 139 | # # transforms.RandomApply([transforms.ColorJitter(0.04, 0.04, 0.04, 0.04)], p=0.75), 140 | # # transforms.RandomGrayscale(p=0.10), 141 | # transforms.RandomApply([transforms.GaussianBlur(5, (0.1, 2))], p=0.10), 142 | # ]) 143 | def __len__(self): 144 | return self._length 145 | 146 | def __getitem__(self, i): 147 | example = {} 148 | image = Image.open(self.image_paths[i % self.num_images]) 149 | 150 | if not image.mode == "RGB": 151 | image = image.convert("RGB") 152 | 153 | placeholder_string = self.placeholder_token 154 | text = random.choice(self.templates).format(placeholder_string) 155 | 156 | example["input_ids"] = self.tokenizer( 157 | text, 158 | padding="max_length", 159 | truncation=True, 160 | max_length=self.tokenizer.model_max_length, 161 | return_tensors="pt", 162 | ).input_ids[0] 163 | 164 | # default to score-sde preprocessing 165 | img = np.array(image).astype(np.uint8) 166 | 167 | if self.center_crop: 168 | crop = min(img.shape[0], img.shape[1]) 169 | h, w, = ( 170 | img.shape[0], 171 | img.shape[1], 172 | ) 173 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 174 | 175 | image = Image.fromarray(img) 176 | image = image.resize((self.size, self.size), resample=self.interpolation) 177 | if self.randaug: 178 | print("using randaug") 179 | image = self.transform(image) 180 | image = self.flip_transform(image) 181 | image = np.array(image).astype(np.uint8) 182 | image = (image / 127.5 - 1.0).astype(np.float32) 183 | 184 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 185 | return example 186 | 187 | 188 | def unfreeze_params(params): 189 | for param in params: 190 | param.requires_grad = True 191 | 192 | def freeze_params(params): 193 | for param in params: 194 | param.requires_grad = False 195 | 196 | 197 | def text_inversion(device, unet, vae, tokenizer, text_encoder, cfg, accelerator, logger): 198 | num_added_tokens = tokenizer.add_tokens(cfg.text_inversion.placeholder_token) 199 | if num_added_tokens == 0: 200 | raise ValueError( 201 | f"The tokenizer already contains the token {cfg.text_inversion.placeholder_token}. Please pass a different" 202 | " `placeholder_token` that is not already in the tokenizer." 203 | ) 204 | 205 | # Convert the initializer_token, placeholder_token to ids 206 | token_ids = tokenizer.encode(cfg.text_inversion.initial_token, add_special_tokens=False) 207 | # Check if initializer_token is a single token or a sequence of tokens 208 | if len(token_ids) > 1: 209 | raise ValueError("The initializer token must be a single token.") 210 | 211 | initializer_token_id = token_ids[0] 212 | placeholder_token_id = tokenizer.convert_tokens_to_ids(cfg.text_inversion.placeholder_token) 213 | 214 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 215 | text_encoder.resize_token_embeddings(len(tokenizer)) 216 | 217 | # Initialise the newly added placeholder token with the embeddings of the initializer token 218 | token_embeds = text_encoder.get_input_embeddings().weight.data 219 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] 220 | 221 | # Freeze vae and unet 222 | freeze_params(vae.parameters()) 223 | freeze_params(unet.parameters()) 224 | # Freeze all parameters except for the token embeddings in text encoder 225 | params_to_freeze = itertools.chain( 226 | text_encoder.text_model.encoder.parameters(), 227 | text_encoder.text_model.final_layer_norm.parameters(), 228 | text_encoder.text_model.embeddings.position_embedding.parameters(), 229 | ) 230 | freeze_params(params_to_freeze) 231 | 232 | if cfg.text_inversion.scale_lr: 233 | cfg.text_inversion.lr = ( 234 | cfg.text_inversion.lr * cfg.text_inversion.gradient_accumulation_steps * cfg.text_inversion.batch_size * accelerator.num_processes 235 | ) 236 | 237 | # Initialize the optimizer 238 | optimizer = torch.optim.AdamW( 239 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 240 | lr=cfg.text_inversion.lr, 241 | betas=(cfg.text_inversion.adam_beta1, cfg.text_inversion.adam_beta2), 242 | weight_decay=cfg.text_inversion.adam_weight_decay, 243 | eps=cfg.text_inversion.adam_epsilon, 244 | ) 245 | 246 | noise_scheduler = DDPMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") 247 | 248 | lr_scheduler = get_scheduler( 249 | cfg.text_inversion.lr_scheduler, 250 | optimizer=optimizer, 251 | num_warmup_steps=cfg.text_inversion.lr_warmup_steps, 252 | num_training_steps=cfg.text_inversion.max_train_steps 253 | ) 254 | inverted_image_path = cfg.general.save_path if cfg.text_inversion.image_path == ' ' else cfg.text_inversion.image_path 255 | logger.info('load image at {}'.format(inverted_image_path)) 256 | train_dataset = TextualInversionDataset( 257 | data_root=inverted_image_path, 258 | tokenizer=tokenizer, 259 | size=cfg.text_inversion.resolution, 260 | placeholder_token=cfg.text_inversion.placeholder_token, 261 | repeats=cfg.text_inversion.repeats, 262 | learnable_property=cfg.text_inversion.learnable_property, 263 | center_crop=cfg.text_inversion.center_crop, 264 | set="train", 265 | randaug=cfg.text_inversion.randaug, 266 | ) 267 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.text_inversion.batch_size, shuffle=True) 268 | 269 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 270 | text_encoder, optimizer, train_dataloader, lr_scheduler 271 | ) 272 | 273 | 274 | 275 | # Keep vae and unet in eval model as we don't train these 276 | vae.eval() 277 | unet.eval() 278 | 279 | 280 | total_batch_size = cfg.text_inversion.batch_size * accelerator.num_processes * cfg.text_inversion.gradient_accumulation_steps 281 | 282 | logger.info("***** Running training *****") 283 | logger.info(f" Num examples = {len(train_dataset)}") 284 | logger.info(f" Num Epochs = {cfg.text_inversion.num_train_epochs}") 285 | logger.info(f" Instantaneous batch size per device = {cfg.text_inversion.batch_size}") 286 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 287 | logger.info(f" Gradient Accumulation steps = {cfg.text_inversion.gradient_accumulation_steps}") 288 | logger.info(f" Total optimization steps = {cfg.text_inversion.max_train_steps}") 289 | # Only show the progress bar once on each machine. 290 | progress_bar = tqdm(range(cfg.text_inversion.max_train_steps), disable=not accelerator.is_local_main_process) 291 | progress_bar.set_description("Steps") 292 | global_step = 0 293 | 294 | for epoch in range(cfg.text_inversion.num_train_epochs): 295 | text_encoder.train() 296 | for step, batch in enumerate(train_dataloader): 297 | 298 | with accelerator.accumulate(text_encoder): 299 | # Convert images to latent space 300 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 301 | latents = latents * 0.18215 302 | 303 | # Sample noise that we'll add to the latents 304 | noise = torch.randn(latents.shape).to(latents.device) 305 | bsz = latents.shape[0] 306 | # Sample a random timestep for each image 307 | timesteps = torch.randint( 308 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device 309 | ).long() 310 | 311 | # Add noise to the latents according to the noise magnitude at each timestep 312 | # (this is the forward diffusion process) 313 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 314 | 315 | # Get the text embedding for conditioning 316 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 317 | 318 | # Predict the noise residual 319 | noise_pred, _, _, _ = unet(noisy_latents, timesteps, encoder_hidden_states) 320 | noise_pred = noise_pred.sample 321 | 322 | # import pdb; pdb.set_trace() 323 | 324 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 325 | accelerator.backward(loss) 326 | 327 | # Zero out the gradients for all token embeddings except the newly added 328 | # embeddings for the concept, as we only want to optimize the concept embeddings 329 | if accelerator.num_processes > 1: 330 | grads = text_encoder.module.get_input_embeddings().weight.grad 331 | else: 332 | grads = text_encoder.get_input_embeddings().weight.grad 333 | # Get the index for tokens that we want to zero the grads for 334 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id 335 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 336 | 337 | optimizer.step() 338 | lr_scheduler.step() 339 | optimizer.zero_grad() 340 | 341 | # Checks if the accelerator has performed an optimization step behind the scenes 342 | if accelerator.sync_gradients: 343 | progress_bar.update(1) 344 | global_step += 1 345 | if global_step % cfg.text_inversion.save_steps == 0: 346 | save_progress(text_encoder, placeholder_token_id, accelerator, global_step, cfg, logger) 347 | 348 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 349 | progress_bar.set_postfix(**logs) 350 | # accelerator.log(logs, step=global_step) 351 | 352 | if global_step >= cfg.text_inversion.max_train_steps: 353 | logger.info("reach the maximum iteration") 354 | return 355 | 356 | accelerator.wait_for_everyone() 357 | def inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg, logger): 358 | vae.eval() 359 | unet.eval() 360 | text_encoder.eval() 361 | 362 | uncond_input = tokenizer( 363 | [""] * cfg.dreambooth.inference_batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" 364 | ) 365 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] 366 | input_ids = tokenizer( 367 | [prompt] * cfg.dreambooth.inference_batch_size, 368 | padding="max_length", 369 | truncation=True, 370 | max_length=tokenizer.model_max_length, 371 | return_tensors="pt", 372 | ).input_ids.to(device) 373 | text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]]) 374 | 375 | latents = torch.randn( 376 | (cfg.dreambooth.inference_batch_size, 4, 64, 64), 377 | ).to(device) 378 | 379 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 380 | noise_scheduler.set_timesteps(51) 381 | guidance_scale = 7.5 382 | latents = latents * noise_scheduler.init_noise_sigma 383 | 384 | for index, t in enumerate(tqdm(noise_scheduler.timesteps)): 385 | with torch.no_grad(): 386 | latent_model_input = torch.cat([latents] * 2) 387 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 388 | 389 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)[0] 390 | noise_pred = noise_pred.sample 391 | 392 | # perform guidance 393 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 394 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 395 | 396 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 397 | 398 | logger.info("save images to {}".format(cfg.general.save_path)) 399 | 400 | vae.to(latents.device, dtype=latents.dtype) 401 | with torch.no_grad(): 402 | latents = 1 / 0.18215 * latents 403 | image = vae.decode(latents).sample 404 | image = (image / 2 + 0.5).clamp(0, 1) 405 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 406 | images = (image * 255).round().astype("uint8") 407 | pil_images = [Image.fromarray(image) for image in images] 408 | for idx, pil_image in enumerate(pil_images): 409 | pil_image.save(os.path.join(cfg.general.save_path, "{}_{}.png".format('_'.join(prompt.split(' ')), idx))) 410 | 411 | @hydra.main(version_base=None, config_path="conf", config_name="real_image_editing_config") 412 | def main(cfg: DictConfig): 413 | 414 | cfg.general.save_path = os.path.join(cfg.general.save_path, 'text_inversion') 415 | 416 | if cfg.general.seed is not None: 417 | set_seed(cfg.general.seed) 418 | 419 | with open(cfg.general.unet_config) as f: 420 | unet_config = json.load(f) 421 | # load pretrained models and schedular 422 | unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained(cfg.general.model_path, subfolder="unet") 423 | tokenizer = CLIPTokenizer.from_pretrained(cfg.general.model_path, subfolder="tokenizer") 424 | text_encoder = CLIPTextModel.from_pretrained(cfg.general.model_path, subfolder="text_encoder") 425 | vae = AutoencoderKL.from_pretrained(cfg.general.model_path, subfolder="vae") 426 | 427 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 428 | mixed_precision = 'fp16' if torch.cuda.is_available() else 'no' 429 | accelerator = Accelerator( 430 | gradient_accumulation_steps=cfg.text_inversion.gradient_accumulation_steps, 431 | mixed_precision=mixed_precision 432 | ) 433 | 434 | if not os.path.exists(cfg.general.save_path) and accelerator.is_main_process: 435 | os.makedirs(cfg.general.save_path) 436 | 437 | logger = setup_logger(cfg.general.save_path, __name__) 438 | 439 | logger.info(cfg) 440 | # Save cfg 441 | logger.info("save config to {}".format(os.path.join(cfg.general.save_path, 'config.yaml'))) 442 | OmegaConf.save(cfg, os.path.join(cfg.general.save_path, 'config.yaml')) 443 | 444 | 445 | # Move models to device 446 | vae.to(accelerator.device) 447 | unet.to(accelerator.device) 448 | text_encoder.to(accelerator.device) 449 | 450 | if not cfg.text_inversion.inference: 451 | text_inversion(device, unet, vae, tokenizer, text_encoder, cfg, accelerator, logger) 452 | else: 453 | text_encoder, tokenizer = load_text_inversion(text_encoder, tokenizer, cfg.text_inversion.placeholder_token, cfg.text_inversion.embedding_ckp) 454 | 455 | if cfg.text_inversion.new_prompt != '': 456 | prompt = cfg.text_inversion.new_prompt.format(cfg.text_inversion.placeholder_token) 457 | else: 458 | prompt = cfg.text_inversion.example_prompt.format(cfg.text_inversion.placeholder_token) 459 | inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg, logger) 460 | 461 | 462 | 463 | 464 | 465 | if __name__ == "__main__": 466 | main() -------------------------------------------------------------------------------- /dreambooth.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 4 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline 5 | import torch.nn.functional as F 6 | from PIL import Image, ImageDraw, ImageFont 7 | from pathlib import Path 8 | from accelerate import Accelerator 9 | from omegaconf import DictConfig, OmegaConf 10 | from datetime import datetime 11 | import itertools 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | from diffusers import LMSDiscreteScheduler 15 | from diffusers.optimization import get_scheduler 16 | import math 17 | from my_model import unet_2d_condition 18 | import os 19 | import json 20 | from accelerate.logging import get_logger 21 | import hashlib 22 | from torch.utils.data import Dataset 23 | from torchvision import transforms 24 | from utils import setup_logger, load_text_inversion 25 | from accelerate.utils import set_seed 26 | 27 | class DreamBoothDataset(Dataset): 28 | """ 29 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 30 | It pre-processes the images and the tokenizes prompts. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | instance_data_root, 36 | instance_prompt, 37 | tokenizer, 38 | class_data_root=None, 39 | class_prompt=None, 40 | size=512, 41 | center_crop=False, 42 | ): 43 | self.size = size 44 | self.center_crop = center_crop 45 | self.tokenizer = tokenizer 46 | 47 | self.instance_data_root = Path(instance_data_root) 48 | if not self.instance_data_root.exists(): 49 | raise ValueError("Instance images root doesn't exists.") 50 | 51 | self.instance_images_path = list(Path(instance_data_root).iterdir()) 52 | self.num_instance_images = len(self.instance_images_path) 53 | self.instance_prompt = instance_prompt 54 | self._length = self.num_instance_images 55 | 56 | if class_data_root is not None: 57 | self.class_data_root = Path(class_data_root) 58 | self.class_data_root.mkdir(parents=True, exist_ok=True) 59 | self.class_images_path = list(self.class_data_root.iterdir()) 60 | self.num_class_images = len(self.class_images_path) 61 | self._length = max(self.num_class_images, self.num_instance_images) 62 | self.class_prompt = class_prompt 63 | else: 64 | self.class_data_root = None 65 | 66 | self.image_transforms = transforms.Compose( 67 | [ 68 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 69 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 70 | transforms.ToTensor(), 71 | transforms.Normalize([0.5], [0.5]), 72 | ] 73 | ) 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, index): 79 | example = {} 80 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 81 | if not instance_image.mode == "RGB": 82 | instance_image = instance_image.convert("RGB") 83 | example["instance_images"] = self.image_transforms(instance_image) 84 | example["instance_prompt_ids"] = self.tokenizer( 85 | self.instance_prompt, 86 | padding="do_not_pad", 87 | truncation=True, 88 | max_length=self.tokenizer.model_max_length, 89 | ).input_ids 90 | 91 | if self.class_data_root: 92 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 93 | if not class_image.mode == "RGB": 94 | class_image = class_image.convert("RGB") 95 | example["class_images"] = self.image_transforms(class_image) 96 | example["class_prompt_ids"] = self.tokenizer( 97 | self.class_prompt, 98 | padding="do_not_pad", 99 | truncation=True, 100 | max_length=self.tokenizer.model_max_length, 101 | ).input_ids 102 | 103 | return example 104 | 105 | class PromptDataset(Dataset): 106 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 107 | 108 | def __init__(self, prompt, num_samples): 109 | self.prompt = prompt 110 | self.num_samples = num_samples 111 | 112 | def __len__(self): 113 | return self.num_samples 114 | 115 | def __getitem__(self, index): 116 | example = {} 117 | example["prompt"] = self.prompt 118 | example["index"] = index 119 | return example 120 | 121 | def train_dreambooth(device, unet, vae, tokenizer, text_encoder, cfg, accelerator, logger): 122 | if cfg.dreambooth.with_prior_preservation: 123 | class_images_dir = Path(cfg.dreambooth.class_data_dir) 124 | if not class_images_dir.exists(): 125 | class_images_dir.mkdir(parents=True) 126 | cur_class_images = len(list(class_images_dir.iterdir())) 127 | 128 | if cur_class_images < cfg.dreambooth.num_class_images: 129 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 130 | pipeline = StableDiffusionPipeline.from_pretrained( 131 | cfg.dreambooth.pretrained_model_name_or_path, 132 | torch_dtype=torch_dtype, 133 | safety_checker=None, 134 | ) 135 | pipeline.set_progress_bar_config(disable=True) 136 | 137 | num_new_images = cfg.dreambooth.num_class_images - cur_class_images 138 | logger.info(f"Number of class images to sample: {num_new_images}.") 139 | 140 | sample_dataset = PromptDataset(cfg.dreambooth.class_prompt, num_new_images) 141 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=cfg.dreambooth.sample_batch_size) 142 | 143 | sample_dataloader = accelerator.prepare(sample_dataloader) 144 | pipeline.to(accelerator.device) 145 | 146 | for example in tqdm( 147 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 148 | ): 149 | images = pipeline(example["prompt"]).images 150 | 151 | for i, image in enumerate(images): 152 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 153 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 154 | image.save(image_filename) 155 | 156 | del pipeline 157 | if torch.cuda.is_available(): 158 | torch.cuda.empty_cache() 159 | 160 | vae.requires_grad_(False) 161 | if not cfg.dreambooth.train_text_encoder: 162 | text_encoder.requires_grad_(False) 163 | 164 | if cfg.dreambooth.scale_lr: 165 | cfg.dreambooth.lr = ( 166 | cfg.dreambooth.lr * cfg.dreambooth.gradient_accumulation_steps * cfg.dreambooth.train_batch_size * accelerator.num_processes 167 | ) 168 | optimizer_class = torch.optim.AdamW 169 | params_to_optimize = ( 170 | itertools.chain(unet.parameters(), text_encoder.parameters()) if cfg.dreambooth.train_text_encoder else unet.parameters() 171 | ) 172 | optimizer = optimizer_class( 173 | params_to_optimize, 174 | lr=cfg.dreambooth.lr, 175 | betas=(cfg.dreambooth.adam_beta1, cfg.dreambooth.adam_beta2), 176 | weight_decay=cfg.dreambooth.adam_weight_decay, 177 | eps=cfg.dreambooth.adam_epsilon, 178 | ) 179 | noise_scheduler = DDPMScheduler.from_config(cfg.dreambooth.pretrained_model_name_or_path, subfolder="scheduler") 180 | train_dataset = DreamBoothDataset( 181 | instance_data_root=cfg.dreambooth.instance_data_dir, 182 | instance_prompt=cfg.dreambooth.instance_prompt, 183 | class_data_root=cfg.dreambooth.class_data_dir if cfg.dreambooth.with_prior_preservation else None, 184 | class_prompt=cfg.dreambooth.class_prompt, 185 | tokenizer=tokenizer, 186 | size=cfg.dreambooth.resolution, 187 | center_crop=cfg.dreambooth.center_crop, 188 | ) 189 | def collate_fn(examples): 190 | input_ids = [example["instance_prompt_ids"] for example in examples] 191 | pixel_values = [example["instance_images"] for example in examples] 192 | 193 | # Concat class and instance examples for prior preservation. 194 | # We do this to avoid doing two forward passes. 195 | if cfg.dreambooth.with_prior_preservation: 196 | input_ids += [example["class_prompt_ids"] for example in examples] 197 | pixel_values += [example["class_images"] for example in examples] 198 | 199 | pixel_values = torch.stack(pixel_values) 200 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 201 | 202 | input_ids = tokenizer.pad( 203 | {"input_ids": input_ids}, 204 | padding="max_length", 205 | max_length=tokenizer.model_max_length, 206 | return_tensors="pt", 207 | ).input_ids 208 | 209 | batch = { 210 | "input_ids": input_ids, 211 | "pixel_values": pixel_values, 212 | } 213 | return batch 214 | 215 | train_dataloader = torch.utils.data.DataLoader( 216 | train_dataset, batch_size=cfg.dreambooth.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 217 | ) 218 | 219 | # Scheduler and math around the number of training steps. 220 | overrode_max_train_steps = False 221 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.dreambooth.gradient_accumulation_steps) 222 | if cfg.dreambooth.max_train_steps is None: 223 | cfg.dreambooth.max_train_steps = cfg.dreambooth.num_train_epochs * num_update_steps_per_epoch 224 | overrode_max_train_steps = True 225 | 226 | lr_scheduler = get_scheduler( 227 | cfg.dreambooth.lr_scheduler, 228 | optimizer=optimizer, 229 | num_warmup_steps=cfg.dreambooth.lr_warmup_steps * cfg.dreambooth.gradient_accumulation_steps, 230 | num_training_steps=cfg.dreambooth.max_train_steps * cfg.dreambooth.gradient_accumulation_steps, 231 | ) 232 | 233 | if cfg.dreambooth.train_text_encoder: 234 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 235 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 236 | ) 237 | else: 238 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 239 | unet, optimizer, train_dataloader, lr_scheduler 240 | ) 241 | 242 | weight_dtype = torch.float32 243 | if accelerator.mixed_precision == "fp16": 244 | weight_dtype = torch.float16 245 | elif accelerator.mixed_precision == "bf16": 246 | weight_dtype = torch.bfloat16 247 | 248 | # Move text_encode and vae to gpu. 249 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 250 | # as these models are only used for inference, keeping weights in full precision is not required. 251 | vae.to(accelerator.device, dtype=weight_dtype) 252 | if not cfg.dreambooth.train_text_encoder: 253 | text_encoder.to(accelerator.device, dtype=weight_dtype) 254 | 255 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 256 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.dreambooth.gradient_accumulation_steps) 257 | if overrode_max_train_steps: 258 | cfg.dreambooth.max_train_steps = cfg.dreambooth.num_train_epochs * num_update_steps_per_epoch 259 | # Afterwards we recalculate our number of training epochs 260 | cfg.dreambooth.num_train_epochs = math.ceil(cfg.dreambooth.max_train_steps / num_update_steps_per_epoch) 261 | 262 | 263 | # Train! 264 | total_batch_size = cfg.dreambooth.train_batch_size * accelerator.num_processes * cfg.dreambooth.gradient_accumulation_steps 265 | 266 | logger.info("***** Running Dreambooth training *****") 267 | logger.info(f" Num examples = {len(train_dataset)}") 268 | logger.info(f" Instantaneous batch size per device = {cfg.dreambooth.train_batch_size}") 269 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 270 | logger.info(f" Gradient Accumulation steps = {cfg.dreambooth.gradient_accumulation_steps}") 271 | logger.info(f" Total optimization steps = {cfg.dreambooth.max_train_steps}") 272 | # Only show the progress bar once on each machine. 273 | progress_bar = tqdm(range(cfg.dreambooth.max_train_steps), disable=not accelerator.is_local_main_process) 274 | progress_bar.set_description("Steps") 275 | global_step = 0 276 | 277 | for epoch in range(cfg.dreambooth.num_train_epochs): 278 | unet.train() 279 | if cfg.dreambooth.train_text_encoder: 280 | text_encoder.train() 281 | for step, batch in enumerate(train_dataloader): 282 | with accelerator.accumulate(unet): 283 | # Convert images to latent space 284 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 285 | latents = latents * 0.18215 286 | 287 | # Sample noise that we'll add to the latents 288 | noise = torch.randn_like(latents) 289 | bsz = latents.shape[0] 290 | # Sample a random timestep for each image 291 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 292 | timesteps = timesteps.long() 293 | 294 | # Add noise to the latents according to the noise magnitude at each timestep 295 | # (this is the forward diffusion process) 296 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 297 | 298 | # Get the text embedding for conditioning 299 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 300 | 301 | # Predict the noise residual 302 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states)[0] 303 | noise_pred = noise_pred.sample 304 | 305 | if cfg.dreambooth.with_prior_preservation: 306 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 307 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 308 | noise, noise_prior = torch.chunk(noise, 2, dim=0) 309 | 310 | # Compute instance loss 311 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 312 | 313 | # Compute prior loss 314 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 315 | # Add the prior loss to the instance loss. 316 | loss = loss + cfg.dreambooth.prior_loss_weight * prior_loss 317 | else: 318 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 319 | 320 | accelerator.backward(loss) 321 | if accelerator.sync_gradients: 322 | params_to_clip = ( 323 | itertools.chain(unet.parameters(), text_encoder.parameters()) 324 | if cfg.dreambooth.train_text_encoder 325 | else unet.parameters() 326 | ) 327 | accelerator.clip_grad_norm_(params_to_clip, cfg.dreambooth.max_grad_norm) 328 | optimizer.step() 329 | lr_scheduler.step() 330 | optimizer.zero_grad() 331 | 332 | # Checks if the accelerator has performed an optimization step behind the scenes 333 | if accelerator.sync_gradients: 334 | progress_bar.update(1) 335 | global_step += 1 336 | 337 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 338 | progress_bar.set_postfix(**logs) 339 | 340 | if accelerator.is_main_process and global_step % 50 == 0: 341 | logger.info("Ready to save dreambooth model!!!!") 342 | save_state = { 343 | 'unet': accelerator.unwrap_model(unet).state_dict(), 344 | 'encoder': accelerator.unwrap_model(text_encoder).state_dict(), 345 | } 346 | logger.info('saving model at {}'.format( 347 | os.path.join(cfg.general.save_path, 'dreambooth_{}.ckp'.format(global_step)))) 348 | torch.save(save_state, os.path.join(cfg.general.save_path, 'dreambooth_{}.ckp'.format(global_step))) 349 | 350 | torch.cuda.empty_cache() 351 | if global_step > cfg.dreambooth.max_train_steps: 352 | return 353 | accelerator.wait_for_everyone() 354 | 355 | def inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg, logger): 356 | vae.eval() 357 | unet.eval() 358 | text_encoder.eval() 359 | 360 | uncond_input = tokenizer( 361 | [""] * cfg.dreambooth.inference_batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" 362 | ) 363 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] 364 | input_ids = tokenizer( 365 | [prompt] * cfg.dreambooth.inference_batch_size, 366 | padding="max_length", 367 | truncation=True, 368 | max_length=tokenizer.model_max_length, 369 | return_tensors="pt", 370 | ).input_ids.to(device) 371 | text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]]) 372 | 373 | latents = torch.randn( 374 | (cfg.dreambooth.inference_batch_size, 4, 64, 64), 375 | ).to(device) 376 | 377 | noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 378 | noise_scheduler.set_timesteps(51) 379 | guidance_scale = 7.5 380 | latents = latents * noise_scheduler.init_noise_sigma 381 | 382 | for index, t in enumerate(tqdm(noise_scheduler.timesteps)): 383 | with torch.no_grad(): 384 | latent_model_input = torch.cat([latents] * 2) 385 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 386 | 387 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)[0] 388 | noise_pred = noise_pred.sample 389 | 390 | # perform guidance 391 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 392 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 393 | 394 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 395 | 396 | logger.info("save images to {}".format(cfg.general.save_path)) 397 | 398 | vae.to(latents.device, dtype=latents.dtype) 399 | with torch.no_grad(): 400 | latents = 1 / 0.18215 * latents 401 | image = vae.decode(latents).sample 402 | image = (image / 2 + 0.5).clamp(0, 1) 403 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 404 | images = (image * 255).round().astype("uint8") 405 | pil_images = [Image.fromarray(image) for image in images] 406 | for idx, pil_image in enumerate(pil_images): 407 | pil_image.save(os.path.join(cfg.general.save_path, "{}_{}.png".format('_'.join(prompt.split(' ')), idx))) 408 | 409 | @hydra.main(version_base=None, config_path="conf", config_name="real_image_editing_config") 410 | def main(cfg: DictConfig): 411 | 412 | cfg.general.save_path = os.path.join(cfg.general.save_path, 'dreambooth') 413 | 414 | if cfg.general.seed is not None: 415 | set_seed(cfg.general.seed) 416 | 417 | with open(cfg.general.unet_config) as f: 418 | unet_config = json.load(f) 419 | # load pretrained models and schedular 420 | unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained(cfg.general.model_path, subfolder="unet") 421 | tokenizer = CLIPTokenizer.from_pretrained(cfg.general.model_path, subfolder="tokenizer") 422 | text_encoder = CLIPTextModel.from_pretrained(cfg.general.model_path, subfolder="text_encoder") 423 | vae = AutoencoderKL.from_pretrained(cfg.general.model_path, subfolder="vae") 424 | 425 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 426 | mixed_precision = 'fp16' if torch.cuda.is_available() else 'no' 427 | accelerator = Accelerator( 428 | gradient_accumulation_steps=cfg.text_inversion.gradient_accumulation_steps, 429 | mixed_precision=mixed_precision 430 | ) 431 | 432 | if not os.path.exists(cfg.general.save_path) and accelerator.is_main_process: 433 | os.makedirs(cfg.general.save_path) 434 | 435 | logger = setup_logger(cfg.general.save_path, __name__) 436 | 437 | logger.info(cfg) 438 | # Save cfg 439 | logger.info("save config to {}".format(os.path.join(cfg.general.save_path, 'config.yaml'))) 440 | OmegaConf.save(cfg, os.path.join(cfg.general.save_path, 'config.yaml')) 441 | 442 | # Move vae and unet to device 443 | vae.to(device) 444 | unet.to(device) 445 | text_encoder.to(device) 446 | 447 | if cfg.dreambooth.text_inversion_path != '': 448 | logger.info("load text inversion ckp from {}".format(cfg.dreambooth.text_inversion_path)) 449 | text_encoder, tokenizer = load_text_inversion(text_encoder, tokenizer, cfg.text_inversion.placeholder_token, cfg.dreambooth.text_inversion_path) 450 | 451 | if cfg.dreambooth.inference: 452 | ckp = torch.load(cfg.dreambooth.ckp_path) 453 | unet.load_state_dict(ckp['unet']) 454 | text_encoder.load_state_dict(ckp['encoder']) 455 | 456 | 457 | if not cfg.dreambooth.inference: 458 | train_dreambooth(device, unet, vae, tokenizer, text_encoder, cfg, accelerator, logger) 459 | 460 | if cfg.dreambooth.new_prompt != '': 461 | prompt = cfg.dreambooth.new_prompt.format(cfg.text_inversion.placeholder_token) 462 | else: 463 | prompt = cfg.dreambooth.example_prompt.format(cfg.text_inversion.placeholder_token) 464 | inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg, logger) 465 | 466 | if __name__ == "__main__": 467 | main() 468 | 469 | 470 | -------------------------------------------------------------------------------- /my_model/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from torch import nn 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.modeling_utils import ModelMixin 24 | from diffusers.models.embeddings import ImagePositionalEmbeddings 25 | from diffusers.utils import BaseOutput 26 | from diffusers.utils.import_utils import is_xformers_available 27 | 28 | @dataclass 29 | class Transformer2DModelOutput(BaseOutput): 30 | """ 31 | Args: 32 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 33 | Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions 34 | for the unnoised latent pixels. 35 | """ 36 | 37 | sample: torch.FloatTensor 38 | 39 | 40 | if is_xformers_available(): 41 | import xformers 42 | import xformers.ops 43 | else: 44 | xformers = None 45 | 46 | 47 | class Transformer2DModel(ModelMixin, ConfigMixin): 48 | """ 49 | Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual 50 | embeddings) inputs_coarse. 51 | 52 | When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard 53 | transformer action. Finally, reshape to image. 54 | 55 | When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional 56 | embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict 57 | classes of unnoised image. 58 | 59 | Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised 60 | image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. 61 | 62 | Parameters: 63 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 64 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 65 | in_channels (`int`, *optional*): 66 | Pass if the input is continuous. The number of channels in the input and output. 67 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 68 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. 69 | cross_attention_dim (`int`, *optional*): The number of context dimensions to use. 70 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 71 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 72 | `ImagePositionalEmbeddings`. 73 | num_vector_embeds (`int`, *optional*): 74 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 75 | Includes the class for the masked latent pixel. 76 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 77 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 78 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 79 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 80 | up to but not more than steps than `num_embeds_ada_norm`. 81 | attention_bias (`bool`, *optional*): 82 | Configure if the TransformerBlocks' attention should contain a bias parameter. 83 | """ 84 | 85 | @register_to_config 86 | def __init__( 87 | self, 88 | num_attention_heads: int = 16, 89 | attention_head_dim: int = 88, 90 | in_channels: Optional[int] = None, 91 | num_layers: int = 1, 92 | dropout: float = 0.0, 93 | norm_num_groups: int = 32, 94 | cross_attention_dim: Optional[int] = None, 95 | attention_bias: bool = False, 96 | sample_size: Optional[int] = None, 97 | num_vector_embeds: Optional[int] = None, 98 | activation_fn: str = "geglu", 99 | num_embeds_ada_norm: Optional[int] = None, 100 | ): 101 | super().__init__() 102 | self.num_attention_heads = num_attention_heads 103 | self.attention_head_dim = attention_head_dim 104 | inner_dim = num_attention_heads * attention_head_dim 105 | 106 | # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 107 | # Define whether input is continuous or discrete depending on configuration 108 | self.is_input_continuous = in_channels is not None 109 | self.is_input_vectorized = num_vector_embeds is not None 110 | 111 | if self.is_input_continuous and self.is_input_vectorized: 112 | raise ValueError( 113 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 114 | " sure that either `in_channels` or `num_vector_embeds` is None." 115 | ) 116 | elif not self.is_input_continuous and not self.is_input_vectorized: 117 | raise ValueError( 118 | f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" 119 | " sure that either `in_channels` or `num_vector_embeds` is not None." 120 | ) 121 | 122 | # 2. Define input layers 123 | if self.is_input_continuous: 124 | self.in_channels = in_channels 125 | 126 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 127 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 128 | elif self.is_input_vectorized: 129 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 130 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 131 | 132 | self.height = sample_size 133 | self.width = sample_size 134 | self.num_vector_embeds = num_vector_embeds 135 | self.num_latent_pixels = self.height * self.width 136 | 137 | self.latent_image_embedding = ImagePositionalEmbeddings( 138 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 139 | ) 140 | 141 | # 3. Define transformers blocks 142 | self.transformer_blocks = nn.ModuleList( 143 | [ 144 | BasicTransformerBlock( 145 | inner_dim, 146 | num_attention_heads, 147 | attention_head_dim, 148 | dropout=dropout, 149 | cross_attention_dim=cross_attention_dim, 150 | activation_fn=activation_fn, 151 | num_embeds_ada_norm=num_embeds_ada_norm, 152 | attention_bias=attention_bias, 153 | ) 154 | for d in range(num_layers) 155 | ] 156 | ) 157 | 158 | # 4. Define output layers 159 | if self.is_input_continuous: 160 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 161 | elif self.is_input_vectorized: 162 | self.norm_out = nn.LayerNorm(inner_dim) 163 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 164 | 165 | def _set_attention_slice(self, slice_size): 166 | for block in self.transformer_blocks: 167 | block._set_attention_slice(slice_size) 168 | 169 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attn_map=None, attn_shift=False, obj_ids=None, relationship=None, return_dict: bool = True): 170 | """ 171 | Args: 172 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 173 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 174 | hidden_states 175 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): 176 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 177 | self-attention. 178 | timestep ( `torch.long`, *optional*): 179 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 180 | return_dict (`bool`, *optional*, defaults to `True`): 181 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 182 | 183 | Returns: 184 | [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] 185 | if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample 186 | tensor. 187 | """ 188 | # 1. Input 189 | if self.is_input_continuous: 190 | batch, channel, height, weight = hidden_states.shape 191 | residual = hidden_states 192 | hidden_states = self.norm(hidden_states) 193 | hidden_states = self.proj_in(hidden_states) 194 | inner_dim = hidden_states.shape[1] 195 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 196 | elif self.is_input_vectorized: 197 | hidden_states = self.latent_image_embedding(hidden_states) 198 | 199 | # 2. Blocks 200 | for block in self.transformer_blocks: 201 | hidden_states, cross_attn_prob = block(hidden_states, context=encoder_hidden_states, timestep=timestep) 202 | 203 | # 3. Output 204 | if self.is_input_continuous: 205 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) 206 | hidden_states = self.proj_out(hidden_states) 207 | output = hidden_states + residual 208 | elif self.is_input_vectorized: 209 | hidden_states = self.norm_out(hidden_states) 210 | logits = self.out(hidden_states) 211 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 212 | logits = logits.permute(0, 2, 1) 213 | 214 | # log(p(x_0)) 215 | output = F.log_softmax(logits.double(), dim=1).float() 216 | 217 | if not return_dict: 218 | return (output,) 219 | 220 | return Transformer2DModelOutput(sample=output), cross_attn_prob 221 | 222 | def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 223 | for block in self.transformer_blocks: 224 | block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 225 | 226 | 227 | class AttentionBlock(nn.Module): 228 | """ 229 | An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted 230 | to the N-d case. 231 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 232 | Uses three q, k, v linear layers to compute attention. 233 | 234 | Parameters: 235 | channels (`int`): The number of channels in the input and output. 236 | num_head_channels (`int`, *optional*): 237 | The number of channels in each head. If None, then `num_heads` = 1. 238 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. 239 | rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. 240 | eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. 241 | """ 242 | 243 | def __init__( 244 | self, 245 | channels: int, 246 | num_head_channels: Optional[int] = None, 247 | norm_num_groups: int = 32, 248 | rescale_output_factor: float = 1.0, 249 | eps: float = 1e-5, 250 | ): 251 | super().__init__() 252 | self.channels = channels 253 | 254 | self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 255 | self.num_head_size = num_head_channels 256 | self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) 257 | 258 | # define q,k,v as linear layers 259 | self.query = nn.Linear(channels, channels) 260 | self.key = nn.Linear(channels, channels) 261 | self.value = nn.Linear(channels, channels) 262 | 263 | self.rescale_output_factor = rescale_output_factor 264 | self.proj_attn = nn.Linear(channels, channels, 1) 265 | 266 | def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: 267 | new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) 268 | # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) 269 | new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) 270 | return new_projection 271 | 272 | def forward(self, hidden_states): 273 | residual = hidden_states 274 | batch, channel, height, width = hidden_states.shape 275 | 276 | # norm 277 | hidden_states = self.group_norm(hidden_states) 278 | 279 | hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) 280 | 281 | # proj to q, k, v 282 | query_proj = self.query(hidden_states) 283 | key_proj = self.key(hidden_states) 284 | value_proj = self.value(hidden_states) 285 | 286 | # transpose 287 | query_states = self.transpose_for_scores(query_proj) 288 | key_states = self.transpose_for_scores(key_proj) 289 | value_states = self.transpose_for_scores(value_proj) 290 | 291 | # get scores 292 | scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) 293 | attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm 294 | attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) 295 | 296 | # compute attention output 297 | hidden_states = torch.matmul(attention_probs, value_states) 298 | 299 | hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() 300 | new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) 301 | hidden_states = hidden_states.view(new_hidden_states_shape) 302 | 303 | # compute next hidden_states 304 | hidden_states = self.proj_attn(hidden_states) 305 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) 306 | 307 | # res connect and rescale 308 | hidden_states = (hidden_states + residual) / self.rescale_output_factor 309 | return hidden_states 310 | 311 | 312 | class BasicTransformerBlock(nn.Module): 313 | r""" 314 | A basic Transformer block. 315 | 316 | Parameters: 317 | dim (`int`): The number of channels in the input and output. 318 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 319 | attention_head_dim (`int`): The number of channels in each head. 320 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 321 | cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention. 322 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 323 | num_embeds_ada_norm (: 324 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 325 | attention_bias (: 326 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 327 | """ 328 | 329 | def __init__( 330 | self, 331 | dim: int, 332 | num_attention_heads: int, 333 | attention_head_dim: int, 334 | dropout=0.0, 335 | cross_attention_dim: Optional[int] = None, 336 | activation_fn: str = "geglu", 337 | num_embeds_ada_norm: Optional[int] = None, 338 | attention_bias: bool = False, 339 | ): 340 | super().__init__() 341 | self.attn1 = CrossAttention( 342 | query_dim=dim, 343 | heads=num_attention_heads, 344 | dim_head=attention_head_dim, 345 | dropout=dropout, 346 | bias=attention_bias, 347 | ) # is a self-attention 348 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 349 | self.attn2 = CrossAttention( 350 | query_dim=dim, 351 | cross_attention_dim=cross_attention_dim, 352 | heads=num_attention_heads, 353 | dim_head=attention_head_dim, 354 | dropout=dropout, 355 | bias=attention_bias, 356 | ) # is self-attn if context is none 357 | 358 | # layer norms 359 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 360 | if self.use_ada_layer_norm: 361 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 362 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 363 | else: 364 | self.norm1 = nn.LayerNorm(dim) 365 | self.norm2 = nn.LayerNorm(dim) 366 | self.norm3 = nn.LayerNorm(dim) 367 | 368 | def _set_attention_slice(self, slice_size): 369 | self.attn1._slice_size = slice_size 370 | self.attn2._slice_size = slice_size 371 | 372 | def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 373 | if not is_xformers_available(): 374 | print("Here is how to install it") 375 | raise ModuleNotFoundError( 376 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 377 | " xformers", 378 | name="xformers", 379 | ) 380 | elif not torch.cuda.is_available(): 381 | raise ValueError( 382 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 383 | " available for GPU " 384 | ) 385 | else: 386 | try: 387 | # Make sure we can run the memory efficient attention 388 | _ = xformers.ops.memory_efficient_attention( 389 | torch.randn((1, 2, 40), device="cuda"), 390 | torch.randn((1, 2, 40), device="cuda"), 391 | torch.randn((1, 2, 40), device="cuda"), 392 | ) 393 | except Exception as e: 394 | raise e 395 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 396 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 397 | 398 | def forward(self, hidden_states, context=None, timestep=None): 399 | # 1. Self-Attention 400 | norm_hidden_states = ( 401 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 402 | ) 403 | tmp_hidden_states, cross_attn_prob = self.attn1(norm_hidden_states) 404 | hidden_states = tmp_hidden_states + hidden_states 405 | 406 | # 2. Cross-Attention 407 | norm_hidden_states = ( 408 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 409 | ) 410 | tmp_hidden_states, cross_attn_prob = self.attn2(norm_hidden_states, context=context) 411 | hidden_states = tmp_hidden_states + hidden_states 412 | 413 | # 3. Feed-forward 414 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 415 | 416 | return hidden_states, cross_attn_prob 417 | 418 | 419 | class CrossAttention(nn.Module): 420 | r""" 421 | A cross attention layer. 422 | 423 | Parameters: 424 | query_dim (`int`): The number of channels in the query. 425 | cross_attention_dim (`int`, *optional*): 426 | The number of channels in the context. If not given, defaults to `query_dim`. 427 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. 428 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. 429 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 430 | bias (`bool`, *optional*, defaults to False): 431 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 432 | """ 433 | 434 | def __init__( 435 | self, 436 | query_dim: int, 437 | cross_attention_dim: Optional[int] = None, 438 | heads: int = 8, 439 | dim_head: int = 64, 440 | dropout: float = 0.0, 441 | bias=False, 442 | ): 443 | super().__init__() 444 | inner_dim = dim_head * heads 445 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 446 | 447 | self.scale = dim_head**-0.5 448 | self.heads = heads 449 | # for slice_size > 0 the attention score computation 450 | # is split across the batch axis to save memory 451 | # You can set slice_size with `set_attention_slice` 452 | self._slice_size = None 453 | self._use_memory_efficient_attention_xformers = False 454 | 455 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) 456 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 457 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 458 | 459 | self.to_out = nn.ModuleList([]) 460 | self.to_out.append(nn.Linear(inner_dim, query_dim)) 461 | self.to_out.append(nn.Dropout(dropout)) 462 | 463 | def reshape_heads_to_batch_dim(self, tensor): 464 | batch_size, seq_len, dim = tensor.shape 465 | head_size = self.heads 466 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 467 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 468 | return tensor 469 | 470 | def reshape_batch_dim_to_heads(self, tensor): 471 | batch_size, seq_len, dim = tensor.shape 472 | head_size = self.heads 473 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 474 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 475 | return tensor 476 | 477 | def forward(self, hidden_states, context=None, mask=None): 478 | batch_size, sequence_length, _ = hidden_states.shape 479 | 480 | query = self.to_q(hidden_states) 481 | context = context if context is not None else hidden_states 482 | key = self.to_k(context) 483 | value = self.to_v(context) 484 | 485 | dim = query.shape[-1] 486 | 487 | query = self.reshape_heads_to_batch_dim(query) 488 | key = self.reshape_heads_to_batch_dim(key) 489 | value = self.reshape_heads_to_batch_dim(value) 490 | 491 | # TODO(PVP) - mask is currently never used. Remember to re-implement when used 492 | 493 | # attention, what we cannot get enough of 494 | if self._use_memory_efficient_attention_xformers: 495 | hidden_states = self._memory_efficient_attention_xformers(query, key, value) 496 | else: 497 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 498 | hidden_states, attention_probs = self._attention(query, key, value) 499 | else: 500 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) 501 | 502 | # linear proj 503 | hidden_states = self.to_out[0](hidden_states) 504 | # dropout 505 | hidden_states = self.to_out[1](hidden_states) 506 | return hidden_states, attention_probs 507 | 508 | def _attention(self, query, key, value): 509 | # TODO: use baddbmm for better performance 510 | if query.device.type == "mps": 511 | # Better performance on mps (~20-25%) 512 | attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale 513 | else: 514 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale 515 | attention_probs = attention_scores.softmax(dim=-1) 516 | # compute attention output 517 | 518 | if query.device.type == "mps": 519 | hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value) 520 | else: 521 | hidden_states = torch.matmul(attention_probs, value) 522 | 523 | # reshape hidden_states 524 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 525 | return hidden_states, attention_probs 526 | 527 | def _sliced_attention(self, query, key, value, sequence_length, dim): 528 | batch_size_attention = query.shape[0] 529 | hidden_states = torch.zeros( 530 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype 531 | ) 532 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] 533 | for i in range(hidden_states.shape[0] // slice_size): 534 | start_idx = i * slice_size 535 | end_idx = (i + 1) * slice_size 536 | if query.device.type == "mps": 537 | # Better performance on mps (~20-25%) 538 | attn_slice = ( 539 | torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) 540 | * self.scale 541 | ) 542 | else: 543 | attn_slice = ( 544 | torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale 545 | ) # TODO: use baddbmm for better performance 546 | attn_slice = attn_slice.softmax(dim=-1) 547 | if query.device.type == "mps": 548 | attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) 549 | else: 550 | attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) 551 | 552 | hidden_states[start_idx:end_idx] = attn_slice 553 | 554 | # reshape hidden_states 555 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 556 | return hidden_states 557 | 558 | def _memory_efficient_attention_xformers(self, query, key, value): 559 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) 560 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 561 | return hidden_states 562 | 563 | 564 | class FeedForward(nn.Module): 565 | r""" 566 | A feed-forward layer. 567 | 568 | Parameters: 569 | dim (`int`): The number of channels in the input. 570 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 571 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 572 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 573 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 574 | """ 575 | 576 | def __init__( 577 | self, 578 | dim: int, 579 | dim_out: Optional[int] = None, 580 | mult: int = 4, 581 | dropout: float = 0.0, 582 | activation_fn: str = "geglu", 583 | ): 584 | super().__init__() 585 | inner_dim = int(dim * mult) 586 | dim_out = dim_out if dim_out is not None else dim 587 | 588 | if activation_fn == "geglu": 589 | geglu = GEGLU(dim, inner_dim) 590 | elif activation_fn == "geglu-approximate": 591 | geglu = ApproximateGELU(dim, inner_dim) 592 | 593 | self.net = nn.ModuleList([]) 594 | # project in 595 | self.net.append(geglu) 596 | # project dropout 597 | self.net.append(nn.Dropout(dropout)) 598 | # project out 599 | self.net.append(nn.Linear(inner_dim, dim_out)) 600 | 601 | def forward(self, hidden_states): 602 | for module in self.net: 603 | hidden_states = module(hidden_states) 604 | return hidden_states 605 | 606 | 607 | # feedforward 608 | class GEGLU(nn.Module): 609 | r""" 610 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 611 | 612 | Parameters: 613 | dim_in (`int`): The number of channels in the input. 614 | dim_out (`int`): The number of channels in the output. 615 | """ 616 | 617 | def __init__(self, dim_in: int, dim_out: int): 618 | super().__init__() 619 | self.proj = nn.Linear(dim_in, dim_out * 2) 620 | 621 | def gelu(self, gate): 622 | if gate.device.type != "mps": 623 | return F.gelu(gate) 624 | # mps: gelu is not implemented for float16 625 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 626 | 627 | def forward(self, hidden_states): 628 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 629 | return hidden_states * self.gelu(gate) 630 | 631 | 632 | class ApproximateGELU(nn.Module): 633 | """ 634 | The approximate form of Gaussian Error Linear Unit (GELU) 635 | 636 | For more details, see section 2: https://arxiv.org/abs/1606.08415 637 | """ 638 | 639 | def __init__(self, dim_in: int, dim_out: int): 640 | super().__init__() 641 | self.proj = nn.Linear(dim_in, dim_out) 642 | 643 | def forward(self, x): 644 | x = self.proj(x) 645 | return x * torch.sigmoid(1.702 * x) 646 | 647 | 648 | class AdaLayerNorm(nn.Module): 649 | """ 650 | Norm layer modified to incorporate timestep embeddings. 651 | """ 652 | 653 | def __init__(self, embedding_dim, num_embeddings): 654 | super().__init__() 655 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 656 | self.silu = nn.SiLU() 657 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2) 658 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) 659 | 660 | def forward(self, x, timestep): 661 | emb = self.linear(self.silu(self.emb(timestep))) 662 | scale, shift = torch.chunk(emb, 2) 663 | x = self.norm(x) * (1 + scale) + shift 664 | return x -------------------------------------------------------------------------------- /my_model/unet_2d_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import torch 16 | from torch import nn 17 | 18 | from .attention import AttentionBlock, Transformer2DModel 19 | from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D 20 | 21 | 22 | def get_down_block( 23 | down_block_type, 24 | num_layers, 25 | in_channels, 26 | out_channels, 27 | temb_channels, 28 | add_downsample, 29 | resnet_eps, 30 | resnet_act_fn, 31 | attn_num_head_channels, 32 | resnet_groups=None, 33 | cross_attention_dim=None, 34 | downsample_padding=None, 35 | ): 36 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 37 | if down_block_type == "DownBlock2D": 38 | return DownBlock2D( 39 | num_layers=num_layers, 40 | in_channels=in_channels, 41 | out_channels=out_channels, 42 | temb_channels=temb_channels, 43 | add_downsample=add_downsample, 44 | resnet_eps=resnet_eps, 45 | resnet_act_fn=resnet_act_fn, 46 | resnet_groups=resnet_groups, 47 | downsample_padding=downsample_padding, 48 | ) 49 | elif down_block_type == "AttnDownBlock2D": 50 | return AttnDownBlock2D( 51 | num_layers=num_layers, 52 | in_channels=in_channels, 53 | out_channels=out_channels, 54 | temb_channels=temb_channels, 55 | add_downsample=add_downsample, 56 | resnet_eps=resnet_eps, 57 | resnet_act_fn=resnet_act_fn, 58 | resnet_groups=resnet_groups, 59 | downsample_padding=downsample_padding, 60 | attn_num_head_channels=attn_num_head_channels, 61 | ) 62 | elif down_block_type == "CrossAttnDownBlock2D": 63 | if cross_attention_dim is None: 64 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") 65 | return CrossAttnDownBlock2D( 66 | num_layers=num_layers, 67 | in_channels=in_channels, 68 | out_channels=out_channels, 69 | temb_channels=temb_channels, 70 | add_downsample=add_downsample, 71 | resnet_eps=resnet_eps, 72 | resnet_act_fn=resnet_act_fn, 73 | resnet_groups=resnet_groups, 74 | downsample_padding=downsample_padding, 75 | cross_attention_dim=cross_attention_dim, 76 | attn_num_head_channels=attn_num_head_channels, 77 | ) 78 | elif down_block_type == "SkipDownBlock2D": 79 | return SkipDownBlock2D( 80 | num_layers=num_layers, 81 | in_channels=in_channels, 82 | out_channels=out_channels, 83 | temb_channels=temb_channels, 84 | add_downsample=add_downsample, 85 | resnet_eps=resnet_eps, 86 | resnet_act_fn=resnet_act_fn, 87 | downsample_padding=downsample_padding, 88 | ) 89 | elif down_block_type == "AttnSkipDownBlock2D": 90 | return AttnSkipDownBlock2D( 91 | num_layers=num_layers, 92 | in_channels=in_channels, 93 | out_channels=out_channels, 94 | temb_channels=temb_channels, 95 | add_downsample=add_downsample, 96 | resnet_eps=resnet_eps, 97 | resnet_act_fn=resnet_act_fn, 98 | downsample_padding=downsample_padding, 99 | attn_num_head_channels=attn_num_head_channels, 100 | ) 101 | elif down_block_type == "DownEncoderBlock2D": 102 | return DownEncoderBlock2D( 103 | num_layers=num_layers, 104 | in_channels=in_channels, 105 | out_channels=out_channels, 106 | add_downsample=add_downsample, 107 | resnet_eps=resnet_eps, 108 | resnet_act_fn=resnet_act_fn, 109 | resnet_groups=resnet_groups, 110 | downsample_padding=downsample_padding, 111 | ) 112 | elif down_block_type == "AttnDownEncoderBlock2D": 113 | return AttnDownEncoderBlock2D( 114 | num_layers=num_layers, 115 | in_channels=in_channels, 116 | out_channels=out_channels, 117 | add_downsample=add_downsample, 118 | resnet_eps=resnet_eps, 119 | resnet_act_fn=resnet_act_fn, 120 | resnet_groups=resnet_groups, 121 | downsample_padding=downsample_padding, 122 | attn_num_head_channels=attn_num_head_channels, 123 | ) 124 | raise ValueError(f"{down_block_type} does not exist.") 125 | 126 | 127 | def get_up_block( 128 | up_block_type, 129 | num_layers, 130 | in_channels, 131 | out_channels, 132 | prev_output_channel, 133 | temb_channels, 134 | add_upsample, 135 | resnet_eps, 136 | resnet_act_fn, 137 | attn_num_head_channels, 138 | resnet_groups=None, 139 | cross_attention_dim=None, 140 | ): 141 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 142 | if up_block_type == "UpBlock2D": 143 | return UpBlock2D( 144 | num_layers=num_layers, 145 | in_channels=in_channels, 146 | out_channels=out_channels, 147 | prev_output_channel=prev_output_channel, 148 | temb_channels=temb_channels, 149 | add_upsample=add_upsample, 150 | resnet_eps=resnet_eps, 151 | resnet_act_fn=resnet_act_fn, 152 | resnet_groups=resnet_groups, 153 | ) 154 | elif up_block_type == "CrossAttnUpBlock2D": 155 | if cross_attention_dim is None: 156 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") 157 | return CrossAttnUpBlock2D( 158 | num_layers=num_layers, 159 | in_channels=in_channels, 160 | out_channels=out_channels, 161 | prev_output_channel=prev_output_channel, 162 | temb_channels=temb_channels, 163 | add_upsample=add_upsample, 164 | resnet_eps=resnet_eps, 165 | resnet_act_fn=resnet_act_fn, 166 | resnet_groups=resnet_groups, 167 | cross_attention_dim=cross_attention_dim, 168 | attn_num_head_channels=attn_num_head_channels, 169 | ) 170 | elif up_block_type == "AttnUpBlock2D": 171 | return AttnUpBlock2D( 172 | num_layers=num_layers, 173 | in_channels=in_channels, 174 | out_channels=out_channels, 175 | prev_output_channel=prev_output_channel, 176 | temb_channels=temb_channels, 177 | add_upsample=add_upsample, 178 | resnet_eps=resnet_eps, 179 | resnet_act_fn=resnet_act_fn, 180 | resnet_groups=resnet_groups, 181 | attn_num_head_channels=attn_num_head_channels, 182 | ) 183 | elif up_block_type == "SkipUpBlock2D": 184 | return SkipUpBlock2D( 185 | num_layers=num_layers, 186 | in_channels=in_channels, 187 | out_channels=out_channels, 188 | prev_output_channel=prev_output_channel, 189 | temb_channels=temb_channels, 190 | add_upsample=add_upsample, 191 | resnet_eps=resnet_eps, 192 | resnet_act_fn=resnet_act_fn, 193 | ) 194 | elif up_block_type == "AttnSkipUpBlock2D": 195 | return AttnSkipUpBlock2D( 196 | num_layers=num_layers, 197 | in_channels=in_channels, 198 | out_channels=out_channels, 199 | prev_output_channel=prev_output_channel, 200 | temb_channels=temb_channels, 201 | add_upsample=add_upsample, 202 | resnet_eps=resnet_eps, 203 | resnet_act_fn=resnet_act_fn, 204 | attn_num_head_channels=attn_num_head_channels, 205 | ) 206 | elif up_block_type == "UpDecoderBlock2D": 207 | return UpDecoderBlock2D( 208 | num_layers=num_layers, 209 | in_channels=in_channels, 210 | out_channels=out_channels, 211 | add_upsample=add_upsample, 212 | resnet_eps=resnet_eps, 213 | resnet_act_fn=resnet_act_fn, 214 | resnet_groups=resnet_groups, 215 | ) 216 | elif up_block_type == "AttnUpDecoderBlock2D": 217 | return AttnUpDecoderBlock2D( 218 | num_layers=num_layers, 219 | in_channels=in_channels, 220 | out_channels=out_channels, 221 | add_upsample=add_upsample, 222 | resnet_eps=resnet_eps, 223 | resnet_act_fn=resnet_act_fn, 224 | resnet_groups=resnet_groups, 225 | attn_num_head_channels=attn_num_head_channels, 226 | ) 227 | raise ValueError(f"{up_block_type} does not exist.") 228 | 229 | 230 | class UNetMidBlock2D(nn.Module): 231 | def __init__( 232 | self, 233 | in_channels: int, 234 | temb_channels: int, 235 | dropout: float = 0.0, 236 | num_layers: int = 1, 237 | resnet_eps: float = 1e-6, 238 | resnet_time_scale_shift: str = "default", 239 | resnet_act_fn: str = "swish", 240 | resnet_groups: int = 32, 241 | resnet_pre_norm: bool = True, 242 | attn_num_head_channels=1, 243 | attention_type="default", 244 | output_scale_factor=1.0, 245 | **kwargs, 246 | ): 247 | super().__init__() 248 | 249 | self.attention_type = attention_type 250 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 251 | 252 | # there is always at least one resnet 253 | resnets = [ 254 | ResnetBlock2D( 255 | in_channels=in_channels, 256 | out_channels=in_channels, 257 | temb_channels=temb_channels, 258 | eps=resnet_eps, 259 | groups=resnet_groups, 260 | dropout=dropout, 261 | time_embedding_norm=resnet_time_scale_shift, 262 | non_linearity=resnet_act_fn, 263 | output_scale_factor=output_scale_factor, 264 | pre_norm=resnet_pre_norm, 265 | ) 266 | ] 267 | attentions = [] 268 | 269 | for _ in range(num_layers): 270 | attentions.append( 271 | AttentionBlock( 272 | in_channels, 273 | num_head_channels=attn_num_head_channels, 274 | rescale_output_factor=output_scale_factor, 275 | eps=resnet_eps, 276 | norm_num_groups=resnet_groups, 277 | ) 278 | ) 279 | resnets.append( 280 | ResnetBlock2D( 281 | in_channels=in_channels, 282 | out_channels=in_channels, 283 | temb_channels=temb_channels, 284 | eps=resnet_eps, 285 | groups=resnet_groups, 286 | dropout=dropout, 287 | time_embedding_norm=resnet_time_scale_shift, 288 | non_linearity=resnet_act_fn, 289 | output_scale_factor=output_scale_factor, 290 | pre_norm=resnet_pre_norm, 291 | ) 292 | ) 293 | 294 | self.attentions = nn.ModuleList(attentions) 295 | self.resnets = nn.ModuleList(resnets) 296 | 297 | def forward(self, hidden_states, temb=None, encoder_states=None): 298 | hidden_states = self.resnets[0](hidden_states, temb) 299 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 300 | if self.attention_type == "default": 301 | hidden_states = attn(hidden_states) 302 | else: 303 | hidden_states = attn(hidden_states, encoder_states) 304 | hidden_states = resnet(hidden_states, temb) 305 | 306 | return hidden_states 307 | 308 | 309 | class UNetMidBlock2DCrossAttn(nn.Module): 310 | def __init__( 311 | self, 312 | in_channels: int, 313 | temb_channels: int, 314 | dropout: float = 0.0, 315 | num_layers: int = 1, 316 | resnet_eps: float = 1e-6, 317 | resnet_time_scale_shift: str = "default", 318 | resnet_act_fn: str = "swish", 319 | resnet_groups: int = 32, 320 | resnet_pre_norm: bool = True, 321 | attn_num_head_channels=1, 322 | attention_type="default", 323 | output_scale_factor=1.0, 324 | cross_attention_dim=1280, 325 | **kwargs, 326 | ): 327 | super().__init__() 328 | 329 | self.attention_type = attention_type 330 | self.attn_num_head_channels = attn_num_head_channels 331 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 332 | 333 | # there is always at least one resnet 334 | resnets = [ 335 | ResnetBlock2D( 336 | in_channels=in_channels, 337 | out_channels=in_channels, 338 | temb_channels=temb_channels, 339 | eps=resnet_eps, 340 | groups=resnet_groups, 341 | dropout=dropout, 342 | time_embedding_norm=resnet_time_scale_shift, 343 | non_linearity=resnet_act_fn, 344 | output_scale_factor=output_scale_factor, 345 | pre_norm=resnet_pre_norm, 346 | ) 347 | ] 348 | attentions = [] 349 | 350 | for _ in range(num_layers): 351 | attentions.append( 352 | Transformer2DModel( 353 | attn_num_head_channels, 354 | in_channels // attn_num_head_channels, 355 | in_channels=in_channels, 356 | num_layers=1, 357 | cross_attention_dim=cross_attention_dim, 358 | norm_num_groups=resnet_groups, 359 | ) 360 | ) 361 | resnets.append( 362 | ResnetBlock2D( 363 | in_channels=in_channels, 364 | out_channels=in_channels, 365 | temb_channels=temb_channels, 366 | eps=resnet_eps, 367 | groups=resnet_groups, 368 | dropout=dropout, 369 | time_embedding_norm=resnet_time_scale_shift, 370 | non_linearity=resnet_act_fn, 371 | output_scale_factor=output_scale_factor, 372 | pre_norm=resnet_pre_norm, 373 | ) 374 | ) 375 | 376 | self.attentions = nn.ModuleList(attentions) 377 | self.resnets = nn.ModuleList(resnets) 378 | 379 | def set_attention_slice(self, slice_size): 380 | if slice_size is not None and self.attn_num_head_channels % slice_size != 0: 381 | raise ValueError( 382 | f"Make sure slice_size {slice_size} is a divisor of " 383 | f"the number of heads used in cross_attention {self.attn_num_head_channels}" 384 | ) 385 | if slice_size is not None and slice_size > self.attn_num_head_channels: 386 | raise ValueError( 387 | f"Chunk_size {slice_size} has to be smaller or equal to " 388 | f"the number of heads used in cross_attention {self.attn_num_head_channels}" 389 | ) 390 | 391 | for attn in self.attentions: 392 | attn._set_attention_slice(slice_size) 393 | 394 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 395 | for attn in self.attentions: 396 | attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 397 | 398 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 399 | hidden_states = self.resnets[0](hidden_states, temb) 400 | mid_attn = [] 401 | for layer_idx, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])): 402 | hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states) 403 | hidden_states = hidden_states.sample 404 | hidden_states = resnet(hidden_states, temb) 405 | mid_attn.append(cross_attn_prob) 406 | return hidden_states, mid_attn 407 | 408 | 409 | class AttnDownBlock2D(nn.Module): 410 | def __init__( 411 | self, 412 | in_channels: int, 413 | out_channels: int, 414 | temb_channels: int, 415 | dropout: float = 0.0, 416 | num_layers: int = 1, 417 | resnet_eps: float = 1e-6, 418 | resnet_time_scale_shift: str = "default", 419 | resnet_act_fn: str = "swish", 420 | resnet_groups: int = 32, 421 | resnet_pre_norm: bool = True, 422 | attn_num_head_channels=1, 423 | attention_type="default", 424 | output_scale_factor=1.0, 425 | downsample_padding=1, 426 | add_downsample=True, 427 | ): 428 | super().__init__() 429 | resnets = [] 430 | attentions = [] 431 | 432 | self.attention_type = attention_type 433 | 434 | for i in range(num_layers): 435 | in_channels = in_channels if i == 0 else out_channels 436 | resnets.append( 437 | ResnetBlock2D( 438 | in_channels=in_channels, 439 | out_channels=out_channels, 440 | temb_channels=temb_channels, 441 | eps=resnet_eps, 442 | groups=resnet_groups, 443 | dropout=dropout, 444 | time_embedding_norm=resnet_time_scale_shift, 445 | non_linearity=resnet_act_fn, 446 | output_scale_factor=output_scale_factor, 447 | pre_norm=resnet_pre_norm, 448 | ) 449 | ) 450 | attentions.append( 451 | AttentionBlock( 452 | out_channels, 453 | num_head_channels=attn_num_head_channels, 454 | rescale_output_factor=output_scale_factor, 455 | eps=resnet_eps, 456 | norm_num_groups=resnet_groups, 457 | ) 458 | ) 459 | 460 | self.attentions = nn.ModuleList(attentions) 461 | self.resnets = nn.ModuleList(resnets) 462 | 463 | if add_downsample: 464 | self.downsamplers = nn.ModuleList( 465 | [ 466 | Downsample2D( 467 | in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 468 | ) 469 | ] 470 | ) 471 | else: 472 | self.downsamplers = None 473 | 474 | def forward(self, hidden_states, temb=None): 475 | output_states = () 476 | 477 | for resnet, attn in zip(self.resnets, self.attentions): 478 | hidden_states = resnet(hidden_states, temb) 479 | hidden_states = attn(hidden_states) 480 | output_states += (hidden_states,) 481 | 482 | if self.downsamplers is not None: 483 | for downsampler in self.downsamplers: 484 | hidden_states = downsampler(hidden_states) 485 | 486 | output_states += (hidden_states,) 487 | 488 | return hidden_states, output_states 489 | 490 | 491 | class CrossAttnDownBlock2D(nn.Module): 492 | def __init__( 493 | self, 494 | in_channels: int, 495 | out_channels: int, 496 | temb_channels: int, 497 | dropout: float = 0.0, 498 | num_layers: int = 1, 499 | resnet_eps: float = 1e-6, 500 | resnet_time_scale_shift: str = "default", 501 | resnet_act_fn: str = "swish", 502 | resnet_groups: int = 32, 503 | resnet_pre_norm: bool = True, 504 | attn_num_head_channels=1, 505 | cross_attention_dim=1280, 506 | attention_type="default", 507 | output_scale_factor=1.0, 508 | downsample_padding=1, 509 | add_downsample=True, 510 | ): 511 | super().__init__() 512 | resnets = [] 513 | attentions = [] 514 | 515 | self.attention_type = attention_type 516 | self.attn_num_head_channels = attn_num_head_channels 517 | 518 | for i in range(num_layers): 519 | in_channels = in_channels if i == 0 else out_channels 520 | resnets.append( 521 | ResnetBlock2D( 522 | in_channels=in_channels, 523 | out_channels=out_channels, 524 | temb_channels=temb_channels, 525 | eps=resnet_eps, 526 | groups=resnet_groups, 527 | dropout=dropout, 528 | time_embedding_norm=resnet_time_scale_shift, 529 | non_linearity=resnet_act_fn, 530 | output_scale_factor=output_scale_factor, 531 | pre_norm=resnet_pre_norm, 532 | ) 533 | ) 534 | attentions.append( 535 | Transformer2DModel( 536 | attn_num_head_channels, 537 | out_channels // attn_num_head_channels, 538 | in_channels=out_channels, 539 | num_layers=1, 540 | cross_attention_dim=cross_attention_dim, 541 | norm_num_groups=resnet_groups, 542 | ) 543 | ) 544 | self.attentions = nn.ModuleList(attentions) 545 | self.resnets = nn.ModuleList(resnets) 546 | 547 | if add_downsample: 548 | self.downsamplers = nn.ModuleList( 549 | [ 550 | Downsample2D( 551 | in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 552 | ) 553 | ] 554 | ) 555 | else: 556 | self.downsamplers = None 557 | 558 | self.gradient_checkpointing = False 559 | 560 | def set_attention_slice(self, slice_size): 561 | if slice_size is not None and self.attn_num_head_channels % slice_size != 0: 562 | raise ValueError( 563 | f"Make sure slice_size {slice_size} is a divisor of " 564 | f"the number of heads used in cross_attention {self.attn_num_head_channels}" 565 | ) 566 | if slice_size is not None and slice_size > self.attn_num_head_channels: 567 | raise ValueError( 568 | f"Chunk_size {slice_size} has to be smaller or equal to " 569 | f"the number of heads used in cross_attention {self.attn_num_head_channels}" 570 | ) 571 | 572 | for attn in self.attentions: 573 | attn._set_attention_slice(slice_size) 574 | 575 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 576 | for attn in self.attentions: 577 | attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 578 | 579 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 580 | output_states = () 581 | cross_attn_prob_list = [] 582 | for layer_idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 583 | if self.training and self.gradient_checkpointing: 584 | 585 | def create_custom_forward(module, return_dict=None): 586 | def custom_forward(*inputs): 587 | if return_dict is not None: 588 | return module(*inputs, return_dict=return_dict) 589 | else: 590 | return module(*inputs) 591 | 592 | return custom_forward 593 | 594 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 595 | hidden_states = torch.utils.checkpoint.checkpoint( 596 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states 597 | )[0] 598 | else: 599 | hidden_states = resnet(hidden_states, temb) 600 | tmp_hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) 601 | hidden_states = tmp_hidden_states.sample 602 | 603 | output_states += (hidden_states,) 604 | cross_attn_prob_list.append(cross_attn_prob) 605 | if self.downsamplers is not None: 606 | for downsampler in self.downsamplers: 607 | hidden_states = downsampler(hidden_states) 608 | 609 | output_states += (hidden_states,) 610 | 611 | return hidden_states, output_states, cross_attn_prob_list 612 | 613 | 614 | class DownBlock2D(nn.Module): 615 | def __init__( 616 | self, 617 | in_channels: int, 618 | out_channels: int, 619 | temb_channels: int, 620 | dropout: float = 0.0, 621 | num_layers: int = 1, 622 | resnet_eps: float = 1e-6, 623 | resnet_time_scale_shift: str = "default", 624 | resnet_act_fn: str = "swish", 625 | resnet_groups: int = 32, 626 | resnet_pre_norm: bool = True, 627 | output_scale_factor=1.0, 628 | add_downsample=True, 629 | downsample_padding=1, 630 | ): 631 | super().__init__() 632 | resnets = [] 633 | 634 | for i in range(num_layers): 635 | in_channels = in_channels if i == 0 else out_channels 636 | resnets.append( 637 | ResnetBlock2D( 638 | in_channels=in_channels, 639 | out_channels=out_channels, 640 | temb_channels=temb_channels, 641 | eps=resnet_eps, 642 | groups=resnet_groups, 643 | dropout=dropout, 644 | time_embedding_norm=resnet_time_scale_shift, 645 | non_linearity=resnet_act_fn, 646 | output_scale_factor=output_scale_factor, 647 | pre_norm=resnet_pre_norm, 648 | ) 649 | ) 650 | 651 | self.resnets = nn.ModuleList(resnets) 652 | 653 | if add_downsample: 654 | self.downsamplers = nn.ModuleList( 655 | [ 656 | Downsample2D( 657 | in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 658 | ) 659 | ] 660 | ) 661 | else: 662 | self.downsamplers = None 663 | 664 | self.gradient_checkpointing = False 665 | 666 | def forward(self, hidden_states, temb=None): 667 | output_states = () 668 | 669 | for resnet in self.resnets: 670 | if self.training and self.gradient_checkpointing: 671 | 672 | def create_custom_forward(module): 673 | def custom_forward(*inputs): 674 | return module(*inputs) 675 | 676 | return custom_forward 677 | 678 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 679 | else: 680 | hidden_states = resnet(hidden_states, temb) 681 | 682 | output_states += (hidden_states,) 683 | 684 | if self.downsamplers is not None: 685 | for downsampler in self.downsamplers: 686 | hidden_states = downsampler(hidden_states) 687 | 688 | output_states += (hidden_states,) 689 | 690 | return hidden_states, output_states 691 | 692 | 693 | class DownEncoderBlock2D(nn.Module): 694 | def __init__( 695 | self, 696 | in_channels: int, 697 | out_channels: int, 698 | dropout: float = 0.0, 699 | num_layers: int = 1, 700 | resnet_eps: float = 1e-6, 701 | resnet_time_scale_shift: str = "default", 702 | resnet_act_fn: str = "swish", 703 | resnet_groups: int = 32, 704 | resnet_pre_norm: bool = True, 705 | output_scale_factor=1.0, 706 | add_downsample=True, 707 | downsample_padding=1, 708 | ): 709 | super().__init__() 710 | resnets = [] 711 | 712 | for i in range(num_layers): 713 | in_channels = in_channels if i == 0 else out_channels 714 | resnets.append( 715 | ResnetBlock2D( 716 | in_channels=in_channels, 717 | out_channels=out_channels, 718 | temb_channels=None, 719 | eps=resnet_eps, 720 | groups=resnet_groups, 721 | dropout=dropout, 722 | time_embedding_norm=resnet_time_scale_shift, 723 | non_linearity=resnet_act_fn, 724 | output_scale_factor=output_scale_factor, 725 | pre_norm=resnet_pre_norm, 726 | ) 727 | ) 728 | 729 | self.resnets = nn.ModuleList(resnets) 730 | 731 | if add_downsample: 732 | self.downsamplers = nn.ModuleList( 733 | [ 734 | Downsample2D( 735 | in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 736 | ) 737 | ] 738 | ) 739 | else: 740 | self.downsamplers = None 741 | 742 | def forward(self, hidden_states): 743 | for resnet in self.resnets: 744 | hidden_states = resnet(hidden_states, temb=None) 745 | 746 | if self.downsamplers is not None: 747 | for downsampler in self.downsamplers: 748 | hidden_states = downsampler(hidden_states) 749 | 750 | return hidden_states 751 | 752 | 753 | class AttnDownEncoderBlock2D(nn.Module): 754 | def __init__( 755 | self, 756 | in_channels: int, 757 | out_channels: int, 758 | dropout: float = 0.0, 759 | num_layers: int = 1, 760 | resnet_eps: float = 1e-6, 761 | resnet_time_scale_shift: str = "default", 762 | resnet_act_fn: str = "swish", 763 | resnet_groups: int = 32, 764 | resnet_pre_norm: bool = True, 765 | attn_num_head_channels=1, 766 | output_scale_factor=1.0, 767 | add_downsample=True, 768 | downsample_padding=1, 769 | ): 770 | super().__init__() 771 | resnets = [] 772 | attentions = [] 773 | 774 | for i in range(num_layers): 775 | in_channels = in_channels if i == 0 else out_channels 776 | resnets.append( 777 | ResnetBlock2D( 778 | in_channels=in_channels, 779 | out_channels=out_channels, 780 | temb_channels=None, 781 | eps=resnet_eps, 782 | groups=resnet_groups, 783 | dropout=dropout, 784 | time_embedding_norm=resnet_time_scale_shift, 785 | non_linearity=resnet_act_fn, 786 | output_scale_factor=output_scale_factor, 787 | pre_norm=resnet_pre_norm, 788 | ) 789 | ) 790 | attentions.append( 791 | AttentionBlock( 792 | out_channels, 793 | num_head_channels=attn_num_head_channels, 794 | rescale_output_factor=output_scale_factor, 795 | eps=resnet_eps, 796 | norm_num_groups=resnet_groups, 797 | ) 798 | ) 799 | 800 | self.attentions = nn.ModuleList(attentions) 801 | self.resnets = nn.ModuleList(resnets) 802 | 803 | if add_downsample: 804 | self.downsamplers = nn.ModuleList( 805 | [ 806 | Downsample2D( 807 | in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 808 | ) 809 | ] 810 | ) 811 | else: 812 | self.downsamplers = None 813 | 814 | def forward(self, hidden_states): 815 | for resnet, attn in zip(self.resnets, self.attentions): 816 | hidden_states = resnet(hidden_states, temb=None) 817 | hidden_states = attn(hidden_states) 818 | 819 | if self.downsamplers is not None: 820 | for downsampler in self.downsamplers: 821 | hidden_states = downsampler(hidden_states) 822 | 823 | return hidden_states 824 | 825 | 826 | class AttnSkipDownBlock2D(nn.Module): 827 | def __init__( 828 | self, 829 | in_channels: int, 830 | out_channels: int, 831 | temb_channels: int, 832 | dropout: float = 0.0, 833 | num_layers: int = 1, 834 | resnet_eps: float = 1e-6, 835 | resnet_time_scale_shift: str = "default", 836 | resnet_act_fn: str = "swish", 837 | resnet_pre_norm: bool = True, 838 | attn_num_head_channels=1, 839 | attention_type="default", 840 | output_scale_factor=np.sqrt(2.0), 841 | downsample_padding=1, 842 | add_downsample=True, 843 | ): 844 | super().__init__() 845 | self.attentions = nn.ModuleList([]) 846 | self.resnets = nn.ModuleList([]) 847 | 848 | self.attention_type = attention_type 849 | 850 | for i in range(num_layers): 851 | in_channels = in_channels if i == 0 else out_channels 852 | self.resnets.append( 853 | ResnetBlock2D( 854 | in_channels=in_channels, 855 | out_channels=out_channels, 856 | temb_channels=temb_channels, 857 | eps=resnet_eps, 858 | groups=min(in_channels // 4, 32), 859 | groups_out=min(out_channels // 4, 32), 860 | dropout=dropout, 861 | time_embedding_norm=resnet_time_scale_shift, 862 | non_linearity=resnet_act_fn, 863 | output_scale_factor=output_scale_factor, 864 | pre_norm=resnet_pre_norm, 865 | ) 866 | ) 867 | self.attentions.append( 868 | AttentionBlock( 869 | out_channels, 870 | num_head_channels=attn_num_head_channels, 871 | rescale_output_factor=output_scale_factor, 872 | eps=resnet_eps, 873 | ) 874 | ) 875 | 876 | if add_downsample: 877 | self.resnet_down = ResnetBlock2D( 878 | in_channels=out_channels, 879 | out_channels=out_channels, 880 | temb_channels=temb_channels, 881 | eps=resnet_eps, 882 | groups=min(out_channels // 4, 32), 883 | dropout=dropout, 884 | time_embedding_norm=resnet_time_scale_shift, 885 | non_linearity=resnet_act_fn, 886 | output_scale_factor=output_scale_factor, 887 | pre_norm=resnet_pre_norm, 888 | use_in_shortcut=True, 889 | down=True, 890 | kernel="fir", 891 | ) 892 | self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) 893 | self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) 894 | else: 895 | self.resnet_down = None 896 | self.downsamplers = None 897 | self.skip_conv = None 898 | 899 | def forward(self, hidden_states, temb=None, skip_sample=None): 900 | output_states = () 901 | 902 | for resnet, attn in zip(self.resnets, self.attentions): 903 | hidden_states = resnet(hidden_states, temb) 904 | hidden_states = attn(hidden_states) 905 | output_states += (hidden_states,) 906 | 907 | if self.downsamplers is not None: 908 | hidden_states = self.resnet_down(hidden_states, temb) 909 | for downsampler in self.downsamplers: 910 | skip_sample = downsampler(skip_sample) 911 | 912 | hidden_states = self.skip_conv(skip_sample) + hidden_states 913 | 914 | output_states += (hidden_states,) 915 | 916 | return hidden_states, output_states, skip_sample 917 | 918 | 919 | class SkipDownBlock2D(nn.Module): 920 | def __init__( 921 | self, 922 | in_channels: int, 923 | out_channels: int, 924 | temb_channels: int, 925 | dropout: float = 0.0, 926 | num_layers: int = 1, 927 | resnet_eps: float = 1e-6, 928 | resnet_time_scale_shift: str = "default", 929 | resnet_act_fn: str = "swish", 930 | resnet_pre_norm: bool = True, 931 | output_scale_factor=np.sqrt(2.0), 932 | add_downsample=True, 933 | downsample_padding=1, 934 | ): 935 | super().__init__() 936 | self.resnets = nn.ModuleList([]) 937 | 938 | for i in range(num_layers): 939 | in_channels = in_channels if i == 0 else out_channels 940 | self.resnets.append( 941 | ResnetBlock2D( 942 | in_channels=in_channels, 943 | out_channels=out_channels, 944 | temb_channels=temb_channels, 945 | eps=resnet_eps, 946 | groups=min(in_channels // 4, 32), 947 | groups_out=min(out_channels // 4, 32), 948 | dropout=dropout, 949 | time_embedding_norm=resnet_time_scale_shift, 950 | non_linearity=resnet_act_fn, 951 | output_scale_factor=output_scale_factor, 952 | pre_norm=resnet_pre_norm, 953 | ) 954 | ) 955 | 956 | if add_downsample: 957 | self.resnet_down = ResnetBlock2D( 958 | in_channels=out_channels, 959 | out_channels=out_channels, 960 | temb_channels=temb_channels, 961 | eps=resnet_eps, 962 | groups=min(out_channels // 4, 32), 963 | dropout=dropout, 964 | time_embedding_norm=resnet_time_scale_shift, 965 | non_linearity=resnet_act_fn, 966 | output_scale_factor=output_scale_factor, 967 | pre_norm=resnet_pre_norm, 968 | use_in_shortcut=True, 969 | down=True, 970 | kernel="fir", 971 | ) 972 | self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) 973 | self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) 974 | else: 975 | self.resnet_down = None 976 | self.downsamplers = None 977 | self.skip_conv = None 978 | 979 | def forward(self, hidden_states, temb=None, skip_sample=None): 980 | output_states = () 981 | 982 | for resnet in self.resnets: 983 | hidden_states = resnet(hidden_states, temb) 984 | output_states += (hidden_states,) 985 | 986 | if self.downsamplers is not None: 987 | hidden_states = self.resnet_down(hidden_states, temb) 988 | for downsampler in self.downsamplers: 989 | skip_sample = downsampler(skip_sample) 990 | 991 | hidden_states = self.skip_conv(skip_sample) + hidden_states 992 | 993 | output_states += (hidden_states,) 994 | 995 | return hidden_states, output_states, skip_sample 996 | 997 | 998 | class AttnUpBlock2D(nn.Module): 999 | def __init__( 1000 | self, 1001 | in_channels: int, 1002 | prev_output_channel: int, 1003 | out_channels: int, 1004 | temb_channels: int, 1005 | dropout: float = 0.0, 1006 | num_layers: int = 1, 1007 | resnet_eps: float = 1e-6, 1008 | resnet_time_scale_shift: str = "default", 1009 | resnet_act_fn: str = "swish", 1010 | resnet_groups: int = 32, 1011 | resnet_pre_norm: bool = True, 1012 | attention_type="default", 1013 | attn_num_head_channels=1, 1014 | output_scale_factor=1.0, 1015 | add_upsample=True, 1016 | ): 1017 | super().__init__() 1018 | resnets = [] 1019 | attentions = [] 1020 | 1021 | self.attention_type = attention_type 1022 | 1023 | for i in range(num_layers): 1024 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 1025 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 1026 | 1027 | resnets.append( 1028 | ResnetBlock2D( 1029 | in_channels=resnet_in_channels + res_skip_channels, 1030 | out_channels=out_channels, 1031 | temb_channels=temb_channels, 1032 | eps=resnet_eps, 1033 | groups=resnet_groups, 1034 | dropout=dropout, 1035 | time_embedding_norm=resnet_time_scale_shift, 1036 | non_linearity=resnet_act_fn, 1037 | output_scale_factor=output_scale_factor, 1038 | pre_norm=resnet_pre_norm, 1039 | ) 1040 | ) 1041 | attentions.append( 1042 | AttentionBlock( 1043 | out_channels, 1044 | num_head_channels=attn_num_head_channels, 1045 | rescale_output_factor=output_scale_factor, 1046 | eps=resnet_eps, 1047 | norm_num_groups=resnet_groups, 1048 | ) 1049 | ) 1050 | 1051 | self.attentions = nn.ModuleList(attentions) 1052 | self.resnets = nn.ModuleList(resnets) 1053 | 1054 | if add_upsample: 1055 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 1056 | else: 1057 | self.upsamplers = None 1058 | 1059 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None): 1060 | for resnet, attn in zip(self.resnets, self.attentions): 1061 | # pop res hidden states 1062 | res_hidden_states = res_hidden_states_tuple[-1] 1063 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1064 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1065 | 1066 | hidden_states = resnet(hidden_states, temb) 1067 | hidden_states = attn(hidden_states) 1068 | 1069 | if self.upsamplers is not None: 1070 | for upsampler in self.upsamplers: 1071 | hidden_states = upsampler(hidden_states) 1072 | 1073 | return hidden_states 1074 | 1075 | 1076 | class CrossAttnUpBlock2D(nn.Module): 1077 | def __init__( 1078 | self, 1079 | in_channels: int, 1080 | out_channels: int, 1081 | prev_output_channel: int, 1082 | temb_channels: int, 1083 | dropout: float = 0.0, 1084 | num_layers: int = 1, 1085 | resnet_eps: float = 1e-6, 1086 | resnet_time_scale_shift: str = "default", 1087 | resnet_act_fn: str = "swish", 1088 | resnet_groups: int = 32, 1089 | resnet_pre_norm: bool = True, 1090 | attn_num_head_channels=1, 1091 | cross_attention_dim=1280, 1092 | attention_type="default", 1093 | output_scale_factor=1.0, 1094 | add_upsample=True, 1095 | ): 1096 | super().__init__() 1097 | resnets = [] 1098 | attentions = [] 1099 | 1100 | self.attention_type = attention_type 1101 | self.attn_num_head_channels = attn_num_head_channels 1102 | 1103 | for i in range(num_layers): 1104 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 1105 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 1106 | 1107 | resnets.append( 1108 | ResnetBlock2D( 1109 | in_channels=resnet_in_channels + res_skip_channels, 1110 | out_channels=out_channels, 1111 | temb_channels=temb_channels, 1112 | eps=resnet_eps, 1113 | groups=resnet_groups, 1114 | dropout=dropout, 1115 | time_embedding_norm=resnet_time_scale_shift, 1116 | non_linearity=resnet_act_fn, 1117 | output_scale_factor=output_scale_factor, 1118 | pre_norm=resnet_pre_norm, 1119 | ) 1120 | ) 1121 | attentions.append( 1122 | Transformer2DModel( 1123 | attn_num_head_channels, 1124 | out_channels // attn_num_head_channels, 1125 | in_channels=out_channels, 1126 | num_layers=1, 1127 | cross_attention_dim=cross_attention_dim, 1128 | norm_num_groups=resnet_groups, 1129 | ) 1130 | ) 1131 | self.attentions = nn.ModuleList(attentions) 1132 | self.resnets = nn.ModuleList(resnets) 1133 | 1134 | if add_upsample: 1135 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 1136 | else: 1137 | self.upsamplers = None 1138 | 1139 | self.gradient_checkpointing = False 1140 | 1141 | def set_attention_slice(self, slice_size): 1142 | if slice_size is not None and self.attn_num_head_channels % slice_size != 0: 1143 | raise ValueError( 1144 | f"Make sure slice_size {slice_size} is a divisor of " 1145 | f"the number of heads used in cross_attention {self.attn_num_head_channels}" 1146 | ) 1147 | if slice_size is not None and slice_size > self.attn_num_head_channels: 1148 | raise ValueError( 1149 | f"Chunk_size {slice_size} has to be smaller or equal to " 1150 | f"the number of heads used in cross_attention {self.attn_num_head_channels}" 1151 | ) 1152 | 1153 | for attn in self.attentions: 1154 | attn._set_attention_slice(slice_size) 1155 | 1156 | self.gradient_checkpointing = False 1157 | 1158 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 1159 | for attn in self.attentions: 1160 | attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) 1161 | 1162 | def forward( 1163 | self, 1164 | hidden_states, 1165 | res_hidden_states_tuple, 1166 | temb=None, 1167 | encoder_hidden_states=None, 1168 | upsample_size=None, 1169 | ): 1170 | cross_attn_prob_list = list() 1171 | for layer_idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 1172 | # pop res hidden states 1173 | res_hidden_states = res_hidden_states_tuple[-1] 1174 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1175 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1176 | 1177 | if self.training and self.gradient_checkpointing: 1178 | 1179 | def create_custom_forward(module, return_dict=None): 1180 | def custom_forward(*inputs): 1181 | if return_dict is not None: 1182 | return module(*inputs, return_dict=return_dict) 1183 | else: 1184 | return module(*inputs) 1185 | 1186 | return custom_forward 1187 | 1188 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 1189 | hidden_states = torch.utils.checkpoint.checkpoint( 1190 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states 1191 | )[0] 1192 | else: 1193 | hidden_states = resnet(hidden_states, temb) 1194 | tmp_hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) 1195 | hidden_states = tmp_hidden_states.sample 1196 | cross_attn_prob_list.append(cross_attn_prob) 1197 | if self.upsamplers is not None: 1198 | for upsampler in self.upsamplers: 1199 | hidden_states = upsampler(hidden_states, upsample_size) 1200 | 1201 | return hidden_states, cross_attn_prob_list 1202 | 1203 | 1204 | class UpBlock2D(nn.Module): 1205 | def __init__( 1206 | self, 1207 | in_channels: int, 1208 | prev_output_channel: int, 1209 | out_channels: int, 1210 | temb_channels: int, 1211 | dropout: float = 0.0, 1212 | num_layers: int = 1, 1213 | resnet_eps: float = 1e-6, 1214 | resnet_time_scale_shift: str = "default", 1215 | resnet_act_fn: str = "swish", 1216 | resnet_groups: int = 32, 1217 | resnet_pre_norm: bool = True, 1218 | output_scale_factor=1.0, 1219 | add_upsample=True, 1220 | ): 1221 | super().__init__() 1222 | resnets = [] 1223 | 1224 | for i in range(num_layers): 1225 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 1226 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 1227 | 1228 | resnets.append( 1229 | ResnetBlock2D( 1230 | in_channels=resnet_in_channels + res_skip_channels, 1231 | out_channels=out_channels, 1232 | temb_channels=temb_channels, 1233 | eps=resnet_eps, 1234 | groups=resnet_groups, 1235 | dropout=dropout, 1236 | time_embedding_norm=resnet_time_scale_shift, 1237 | non_linearity=resnet_act_fn, 1238 | output_scale_factor=output_scale_factor, 1239 | pre_norm=resnet_pre_norm, 1240 | ) 1241 | ) 1242 | 1243 | self.resnets = nn.ModuleList(resnets) 1244 | 1245 | if add_upsample: 1246 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 1247 | else: 1248 | self.upsamplers = None 1249 | 1250 | self.gradient_checkpointing = False 1251 | 1252 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 1253 | for resnet in self.resnets: 1254 | # pop res hidden states 1255 | res_hidden_states = res_hidden_states_tuple[-1] 1256 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1257 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1258 | 1259 | if self.training and self.gradient_checkpointing: 1260 | 1261 | def create_custom_forward(module): 1262 | def custom_forward(*inputs): 1263 | return module(*inputs) 1264 | 1265 | return custom_forward 1266 | 1267 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 1268 | else: 1269 | hidden_states = resnet(hidden_states, temb) 1270 | 1271 | if self.upsamplers is not None: 1272 | for upsampler in self.upsamplers: 1273 | hidden_states = upsampler(hidden_states, upsample_size) 1274 | 1275 | return hidden_states 1276 | 1277 | 1278 | class UpDecoderBlock2D(nn.Module): 1279 | def __init__( 1280 | self, 1281 | in_channels: int, 1282 | out_channels: int, 1283 | dropout: float = 0.0, 1284 | num_layers: int = 1, 1285 | resnet_eps: float = 1e-6, 1286 | resnet_time_scale_shift: str = "default", 1287 | resnet_act_fn: str = "swish", 1288 | resnet_groups: int = 32, 1289 | resnet_pre_norm: bool = True, 1290 | output_scale_factor=1.0, 1291 | add_upsample=True, 1292 | ): 1293 | super().__init__() 1294 | resnets = [] 1295 | 1296 | for i in range(num_layers): 1297 | input_channels = in_channels if i == 0 else out_channels 1298 | 1299 | resnets.append( 1300 | ResnetBlock2D( 1301 | in_channels=input_channels, 1302 | out_channels=out_channels, 1303 | temb_channels=None, 1304 | eps=resnet_eps, 1305 | groups=resnet_groups, 1306 | dropout=dropout, 1307 | time_embedding_norm=resnet_time_scale_shift, 1308 | non_linearity=resnet_act_fn, 1309 | output_scale_factor=output_scale_factor, 1310 | pre_norm=resnet_pre_norm, 1311 | ) 1312 | ) 1313 | 1314 | self.resnets = nn.ModuleList(resnets) 1315 | 1316 | if add_upsample: 1317 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 1318 | else: 1319 | self.upsamplers = None 1320 | 1321 | def forward(self, hidden_states): 1322 | for resnet in self.resnets: 1323 | hidden_states = resnet(hidden_states, temb=None) 1324 | 1325 | if self.upsamplers is not None: 1326 | for upsampler in self.upsamplers: 1327 | hidden_states = upsampler(hidden_states) 1328 | 1329 | return hidden_states 1330 | 1331 | 1332 | class AttnUpDecoderBlock2D(nn.Module): 1333 | def __init__( 1334 | self, 1335 | in_channels: int, 1336 | out_channels: int, 1337 | dropout: float = 0.0, 1338 | num_layers: int = 1, 1339 | resnet_eps: float = 1e-6, 1340 | resnet_time_scale_shift: str = "default", 1341 | resnet_act_fn: str = "swish", 1342 | resnet_groups: int = 32, 1343 | resnet_pre_norm: bool = True, 1344 | attn_num_head_channels=1, 1345 | output_scale_factor=1.0, 1346 | add_upsample=True, 1347 | ): 1348 | super().__init__() 1349 | resnets = [] 1350 | attentions = [] 1351 | 1352 | for i in range(num_layers): 1353 | input_channels = in_channels if i == 0 else out_channels 1354 | 1355 | resnets.append( 1356 | ResnetBlock2D( 1357 | in_channels=input_channels, 1358 | out_channels=out_channels, 1359 | temb_channels=None, 1360 | eps=resnet_eps, 1361 | groups=resnet_groups, 1362 | dropout=dropout, 1363 | time_embedding_norm=resnet_time_scale_shift, 1364 | non_linearity=resnet_act_fn, 1365 | output_scale_factor=output_scale_factor, 1366 | pre_norm=resnet_pre_norm, 1367 | ) 1368 | ) 1369 | attentions.append( 1370 | AttentionBlock( 1371 | out_channels, 1372 | num_head_channels=attn_num_head_channels, 1373 | rescale_output_factor=output_scale_factor, 1374 | eps=resnet_eps, 1375 | norm_num_groups=resnet_groups, 1376 | ) 1377 | ) 1378 | 1379 | self.attentions = nn.ModuleList(attentions) 1380 | self.resnets = nn.ModuleList(resnets) 1381 | 1382 | if add_upsample: 1383 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 1384 | else: 1385 | self.upsamplers = None 1386 | 1387 | def forward(self, hidden_states): 1388 | for resnet, attn in zip(self.resnets, self.attentions): 1389 | hidden_states = resnet(hidden_states, temb=None) 1390 | hidden_states = attn(hidden_states) 1391 | 1392 | if self.upsamplers is not None: 1393 | for upsampler in self.upsamplers: 1394 | hidden_states = upsampler(hidden_states) 1395 | 1396 | return hidden_states 1397 | 1398 | 1399 | class AttnSkipUpBlock2D(nn.Module): 1400 | def __init__( 1401 | self, 1402 | in_channels: int, 1403 | prev_output_channel: int, 1404 | out_channels: int, 1405 | temb_channels: int, 1406 | dropout: float = 0.0, 1407 | num_layers: int = 1, 1408 | resnet_eps: float = 1e-6, 1409 | resnet_time_scale_shift: str = "default", 1410 | resnet_act_fn: str = "swish", 1411 | resnet_pre_norm: bool = True, 1412 | attn_num_head_channels=1, 1413 | attention_type="default", 1414 | output_scale_factor=np.sqrt(2.0), 1415 | upsample_padding=1, 1416 | add_upsample=True, 1417 | ): 1418 | super().__init__() 1419 | self.attentions = nn.ModuleList([]) 1420 | self.resnets = nn.ModuleList([]) 1421 | 1422 | self.attention_type = attention_type 1423 | 1424 | for i in range(num_layers): 1425 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 1426 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 1427 | 1428 | self.resnets.append( 1429 | ResnetBlock2D( 1430 | in_channels=resnet_in_channels + res_skip_channels, 1431 | out_channels=out_channels, 1432 | temb_channels=temb_channels, 1433 | eps=resnet_eps, 1434 | groups=min(resnet_in_channels + res_skip_channels // 4, 32), 1435 | groups_out=min(out_channels // 4, 32), 1436 | dropout=dropout, 1437 | time_embedding_norm=resnet_time_scale_shift, 1438 | non_linearity=resnet_act_fn, 1439 | output_scale_factor=output_scale_factor, 1440 | pre_norm=resnet_pre_norm, 1441 | ) 1442 | ) 1443 | 1444 | self.attentions.append( 1445 | AttentionBlock( 1446 | out_channels, 1447 | num_head_channels=attn_num_head_channels, 1448 | rescale_output_factor=output_scale_factor, 1449 | eps=resnet_eps, 1450 | ) 1451 | ) 1452 | 1453 | self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) 1454 | if add_upsample: 1455 | self.resnet_up = ResnetBlock2D( 1456 | in_channels=out_channels, 1457 | out_channels=out_channels, 1458 | temb_channels=temb_channels, 1459 | eps=resnet_eps, 1460 | groups=min(out_channels // 4, 32), 1461 | groups_out=min(out_channels // 4, 32), 1462 | dropout=dropout, 1463 | time_embedding_norm=resnet_time_scale_shift, 1464 | non_linearity=resnet_act_fn, 1465 | output_scale_factor=output_scale_factor, 1466 | pre_norm=resnet_pre_norm, 1467 | use_in_shortcut=True, 1468 | up=True, 1469 | kernel="fir", 1470 | ) 1471 | self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 1472 | self.skip_norm = torch.nn.GroupNorm( 1473 | num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True 1474 | ) 1475 | self.act = nn.SiLU() 1476 | else: 1477 | self.resnet_up = None 1478 | self.skip_conv = None 1479 | self.skip_norm = None 1480 | self.act = None 1481 | 1482 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): 1483 | for resnet in self.resnets: 1484 | # pop res hidden states 1485 | res_hidden_states = res_hidden_states_tuple[-1] 1486 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1487 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1488 | 1489 | hidden_states = resnet(hidden_states, temb) 1490 | 1491 | hidden_states = self.attentions[0](hidden_states) 1492 | 1493 | if skip_sample is not None: 1494 | skip_sample = self.upsampler(skip_sample) 1495 | else: 1496 | skip_sample = 0 1497 | 1498 | if self.resnet_up is not None: 1499 | skip_sample_states = self.skip_norm(hidden_states) 1500 | skip_sample_states = self.act(skip_sample_states) 1501 | skip_sample_states = self.skip_conv(skip_sample_states) 1502 | 1503 | skip_sample = skip_sample + skip_sample_states 1504 | 1505 | hidden_states = self.resnet_up(hidden_states, temb) 1506 | 1507 | return hidden_states, skip_sample 1508 | 1509 | 1510 | class SkipUpBlock2D(nn.Module): 1511 | def __init__( 1512 | self, 1513 | in_channels: int, 1514 | prev_output_channel: int, 1515 | out_channels: int, 1516 | temb_channels: int, 1517 | dropout: float = 0.0, 1518 | num_layers: int = 1, 1519 | resnet_eps: float = 1e-6, 1520 | resnet_time_scale_shift: str = "default", 1521 | resnet_act_fn: str = "swish", 1522 | resnet_pre_norm: bool = True, 1523 | output_scale_factor=np.sqrt(2.0), 1524 | add_upsample=True, 1525 | upsample_padding=1, 1526 | ): 1527 | super().__init__() 1528 | self.resnets = nn.ModuleList([]) 1529 | 1530 | for i in range(num_layers): 1531 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 1532 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 1533 | 1534 | self.resnets.append( 1535 | ResnetBlock2D( 1536 | in_channels=resnet_in_channels + res_skip_channels, 1537 | out_channels=out_channels, 1538 | temb_channels=temb_channels, 1539 | eps=resnet_eps, 1540 | groups=min((resnet_in_channels + res_skip_channels) // 4, 32), 1541 | groups_out=min(out_channels // 4, 32), 1542 | dropout=dropout, 1543 | time_embedding_norm=resnet_time_scale_shift, 1544 | non_linearity=resnet_act_fn, 1545 | output_scale_factor=output_scale_factor, 1546 | pre_norm=resnet_pre_norm, 1547 | ) 1548 | ) 1549 | 1550 | self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) 1551 | if add_upsample: 1552 | self.resnet_up = ResnetBlock2D( 1553 | in_channels=out_channels, 1554 | out_channels=out_channels, 1555 | temb_channels=temb_channels, 1556 | eps=resnet_eps, 1557 | groups=min(out_channels // 4, 32), 1558 | groups_out=min(out_channels // 4, 32), 1559 | dropout=dropout, 1560 | time_embedding_norm=resnet_time_scale_shift, 1561 | non_linearity=resnet_act_fn, 1562 | output_scale_factor=output_scale_factor, 1563 | pre_norm=resnet_pre_norm, 1564 | use_in_shortcut=True, 1565 | up=True, 1566 | kernel="fir", 1567 | ) 1568 | self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 1569 | self.skip_norm = torch.nn.GroupNorm( 1570 | num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True 1571 | ) 1572 | self.act = nn.SiLU() 1573 | else: 1574 | self.resnet_up = None 1575 | self.skip_conv = None 1576 | self.skip_norm = None 1577 | self.act = None 1578 | 1579 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): 1580 | for resnet in self.resnets: 1581 | # pop res hidden states 1582 | res_hidden_states = res_hidden_states_tuple[-1] 1583 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1584 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1585 | 1586 | hidden_states = resnet(hidden_states, temb) 1587 | 1588 | if skip_sample is not None: 1589 | skip_sample = self.upsampler(skip_sample) 1590 | else: 1591 | skip_sample = 0 1592 | 1593 | if self.resnet_up is not None: 1594 | skip_sample_states = self.skip_norm(hidden_states) 1595 | skip_sample_states = self.act(skip_sample_states) 1596 | skip_sample_states = self.skip_conv(skip_sample_states) 1597 | 1598 | skip_sample = skip_sample + skip_sample_states 1599 | 1600 | hidden_states = self.resnet_up(hidden_states, temb) 1601 | 1602 | return hidden_states, skip_sample 1603 | --------------------------------------------------------------------------------