├── upscale_swin2sr ├── swin2sr │ ├── __init__.py │ ├── LICENSE │ └── models │ │ └── network_swin2sr.py ├── requirements.txt ├── README.md ├── generate_wds_shards.py └── upscale_images.py ├── README.md ├── requirements.txt ├── Makefile ├── scripts ├── temp_dataloader_validation.py ├── dataloader.py └── train_instruct_pix2pix_sdxl.py └── run.slurm /upscale_swin2sr/swin2sr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /upscale_swin2sr/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | timm 3 | webdataset 4 | ray -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # instructpix2pix-sdxl 2 | Training InstructPi2Pix with SDXL. 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/huggingface/diffusers 2 | accelerate 3 | transformers 4 | wandb 5 | webdataset 6 | xformers -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := . 2 | 3 | quality: 4 | black --check $(check_dirs) 5 | ruff $(check_dirs) 6 | 7 | style: 8 | black $(check_dirs) 9 | ruff $(check_dirs) --fix -------------------------------------------------------------------------------- /upscale_swin2sr/README.md: -------------------------------------------------------------------------------- 1 | Run `upscale_images.py` to upscale the images of [timbrooks/instructpix2pix-clip-filtered](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered). The upscaler is Swin2SR from [this repository](https://github.com/mv-lab/swin2sr). `swin2sr` is entirely from the original repository. `upscale_image.py` is a modified version of `predict.py` from [here](https://github.com/mv-lab/swin2sr/blob/main/predict.py). 2 | 3 | * `upscale_images.py` will produce a dataset (of 🤗 datasets) format locally. 4 | * Then use `generate_wds_shards.py` to get the dataset converted to the `webdataset` format. -------------------------------------------------------------------------------- /upscale_swin2sr/generate_wds_shards.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import ray 4 | import webdataset as wds 5 | from datasets import Dataset 6 | 7 | ray.init() 8 | 9 | 10 | def main(): 11 | dataset_path = "/scratch/suraj/instructpix2pix-clip-filtered-upscaled" 12 | wds_shards_path = "/scratch/suraj/instructpix2pix-clip-filtered-upscaled-wds" 13 | # get all .arrow files in the dataset path 14 | dataset_files = [ 15 | os.path.join(dataset_path, f) 16 | for f in os.listdir(dataset_path) 17 | if f.endswith(".arrow") 18 | ] 19 | 20 | @ray.remote 21 | def create_shard(path): 22 | # get basename of the file 23 | basename = os.path.basename(path) 24 | # get the shard number data-00123-of-01034.arrow -> 00123 25 | shard_num = basename.split("-")[1] 26 | dataset = Dataset.from_file(path) 27 | # create a webdataset shard 28 | shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar")) 29 | for i, example in enumerate(dataset): 30 | wds_example = { 31 | "__key__": str(i), 32 | "original_prompt.txt": example["original_prompt"], 33 | "original_image.jpg": example["original_image"].convert("RGB"), 34 | "edit_prompt.txt": example["edit_prompt"], 35 | "edited_prompt.txt": example["edited_prompt"], 36 | "edited_image.jpg": example["edited_image"].convert("RGB"), 37 | } 38 | shard.write(wds_example) 39 | shard.close() 40 | 41 | futures = [create_shard.remote(path) for path in dataset_files] 42 | ray.get(futures) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/temp_dataloader_validation.py: -------------------------------------------------------------------------------- 1 | from dataloader import get_dataloader 2 | from argparse import Namespace 3 | from PIL import Image 4 | from huggingface_hub import create_repo, upload_folder 5 | import os 6 | 7 | OUTPUT_DIR = "verify_samples" 8 | 9 | if __name__ == "__main__": 10 | args = Namespace( 11 | dataset_path="pipe:aws s3 cp s3://muse-datasets/instructpix2pix-clip-filtered-upscaled-wds/{00000..00519}.tar -", 12 | num_train_examples=313010, 13 | per_gpu_batch_size=8, 14 | global_batch_size=64, 15 | num_workers=4, 16 | center_crop=False, 17 | random_flip=True, 18 | resolution=256, 19 | original_image_column="original_image", 20 | edit_prompt_column="edit_prompt", 21 | edited_image_column="edited_image", 22 | ) 23 | dataloader = get_dataloader(args) 24 | os.makedirs(OUTPUT_DIR, exist_ok=True) 25 | 26 | for sample in dataloader: 27 | print(sample.keys()) 28 | print(sample["original_images"].shape) 29 | print(sample["edited_images"].shape) 30 | print(len(sample["edit_prompts"])) 31 | 32 | for i in range(len(sample["original_images"])): 33 | current_orig_sample = sample["original_images"][i].numpy().squeeze() 34 | current_orig_sample = current_orig_sample.transpose((1, 2, 0)) 35 | current_orig_sample = (current_orig_sample / 2 + 0.5).clip(0, 1) 36 | current_orig_sample *= 255.0 37 | current_orig_sample = current_orig_sample.round().astype("uint8") 38 | current_orig_sample = Image.fromarray(current_orig_sample) 39 | 40 | current_edited_sample = sample["edited_images"][i].numpy().squeeze() 41 | current_edited_sample = current_edited_sample.transpose((1, 2, 0)) 42 | current_edited_sample = (current_edited_sample / 2 + 0.5).clip(0, 1) 43 | current_edited_sample *= 255.0 44 | current_edited_sample = current_edited_sample.round().astype("uint8") 45 | current_edited_sample = Image.fromarray(current_edited_sample) 46 | 47 | current_orig_sample.save(os.path.join(OUTPUT_DIR, f"{i}_orig.png")) 48 | current_edited_sample.save(os.path.join(OUTPUT_DIR, f"{i}_edited.png")) 49 | with open(os.path.join(OUTPUT_DIR, f"{i}_edited_prompt.txt"), "w") as f: 50 | f.write(sample["edit_prompts"][i]) 51 | 52 | break 53 | 54 | repo_id = create_repo(repo_id="upscaled-validation-logging", exist_ok=True).repo_id 55 | upload_folder(repo_id=repo_id, folder_path=OUTPUT_DIR) 56 | -------------------------------------------------------------------------------- /run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=instructpix2pix-sdxl 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 5 | #SBATCH --cpus-per-task=96 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --exclusive 8 | #SBATCH --partition=production-cluster 9 | #SBATCH --output=/admin/home/suraj/logs/maskgit-imagenet/%x-%j.out 10 | 11 | set -x -e 12 | 13 | source /admin/home/suraj/.bashrc 14 | source /fsx/suraj/miniconda3/etc/profile.d/conda.sh 15 | conda activate muse 16 | 17 | echo "START TIME: $(date)" 18 | 19 | REPO=/admin/home/suraj/code/instructpix2pix-sdxl 20 | OUTPUT_DIR=/fsx/suraj/instructpix2pix-sdxl 21 | LOG_PATH=$OUTPUT_DIR/main_log.txt 22 | ACCELERATE_CONFIG_FILE="$OUTPUT_DIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated" 23 | 24 | mkdir -p $OUTPUT_DIR 25 | touch $LOG_PATH 26 | pushd $REPO 27 | 28 | 29 | GPUS_PER_NODE=8 30 | NNODES=$SLURM_NNODES 31 | NUM_GPUS=$((GPUS_PER_NODE*SLURM_NNODES)) 32 | 33 | # so processes know who to talk to 34 | MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 35 | MASTER_PORT=6000 36 | 37 | 38 | # Auto-generate the accelerate config 39 | cat << EOT > $ACCELERATE_CONFIG_FILE 40 | compute_environment: LOCAL_MACHINE 41 | deepspeed_config: {} 42 | distributed_type: MULTI_GPU 43 | fsdp_config: {} 44 | machine_rank: 0 45 | main_process_ip: $MASTER_ADDR 46 | main_process_port: $MASTER_PORT 47 | main_training_function: main 48 | num_machines: $SLURM_NNODES 49 | num_processes: $NUM_GPUS 50 | use_cpu: false 51 | EOT 52 | 53 | 54 | export MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0" 55 | 56 | PROGRAM="train_instruct_pix2pix_sdxl.py \ 57 | --pretrained_model_name_or_path=$MODEL_ID \ 58 | --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ 59 | --dataset_path='pipe:aws s3 cp s3://muse-datasets/instructpix2pix-clip-filtered-wds/{000000..000062}.tar -' \ 60 | --use_ema \ 61 | --enable_xformers_memory_efficient_attention \ 62 | --resolution=256 --random_flip \ 63 | --per_gpu_batch_size=16 --gradient_accumulation_steps=4 \ 64 | --num_workers=4 \ 65 | --max_train_steps=10000 \ 66 | --checkpointing_steps=2500 \ 67 | --learning_rate=1e-5 --lr_warmup_steps=0 \ 68 | --mixed_precision=fp16 \ 69 | --val_image_url='https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png' \ 70 | --validation_prompt='Turn sky into a cloudy one' \ 71 | --seed=42 \ 72 | --output_dir=$OUTPUT_DIR \ 73 | --report_to=wandb \ 74 | --push_to_hub 75 | " 76 | 77 | # Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable 78 | export LAUNCHER="accelerate launch \ 79 | --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,max_restarts=0,tee=3" \ 80 | --config_file $ACCELERATE_CONFIG_FILE \ 81 | --main_process_ip $MASTER_ADDR \ 82 | --main_process_port $MASTER_PORT \ 83 | --num_processes $NUM_GPUS \ 84 | --machine_rank \$SLURM_PROCID \ 85 | " 86 | 87 | 88 | export CMD="$LAUNCHER $PROGRAM" 89 | echo $CMD 90 | 91 | # hide duplicated errors using this hack - will be properly fixed in pt-1.12 92 | # export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json 93 | 94 | # force crashing on nccl issues like hanging broadcast 95 | export NCCL_ASYNC_ERROR_HANDLING=1 96 | # export NCCL_DEBUG=INFO 97 | # export NCCL_DEBUG_SUBSYS=COLL 98 | # export NCCL_SOCKET_NTHREADS=1 99 | # export NCCL_NSOCKS_PERTHREAD=1 100 | # export CUDA_LAUNCH_BLOCKING=1 101 | 102 | # AWS specific 103 | export NCCL_PROTO=simple 104 | export RDMAV_FORK_SAFE=1 105 | export FI_EFA_FORK_SAFE=1 106 | export FI_EFA_USE_DEVICE_RDMA=1 107 | export FI_PROVIDER=efa 108 | export FI_LOG_LEVEL=1 109 | export NCCL_IB_DISABLE=1 110 | export NCCL_SOCKET_IFNAME=ens 111 | 112 | 113 | # srun error handling: 114 | # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks 115 | # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code 116 | SRUN_ARGS=" \ 117 | --wait=60 \ 118 | --kill-on-bad-exit=1 \ 119 | " 120 | 121 | clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee $LOG_PATH 122 | 123 | echo "END TIME: $(date)" 124 | -------------------------------------------------------------------------------- /scripts/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import math 18 | from argparse import Namespace 19 | 20 | import torch 21 | import webdataset as wds 22 | from torchvision import transforms 23 | from torchvision.transforms.functional import crop 24 | import random 25 | 26 | 27 | def filter_keys(key_set): 28 | def _f(dictionary): 29 | return {k: v for k, v in dictionary.items() if k in key_set} 30 | 31 | return _f 32 | 33 | 34 | def get_dataloader(args): 35 | # num_train_examples: 313,010 36 | num_batches = math.ceil(args.num_train_examples / args.global_batch_size) 37 | num_worker_batches = math.ceil( 38 | args.num_train_examples / (args.global_batch_size * args.num_workers) 39 | ) # per dataloader worker 40 | num_batches = num_worker_batches * args.num_workers 41 | num_samples = num_batches * args.global_batch_size 42 | 43 | # Preprocessing the datasets. 44 | train_resize = transforms.Resize( 45 | args.resolution, interpolation=transforms.InterpolationMode.BILINEAR 46 | ) 47 | train_crop = ( 48 | transforms.CenterCrop(args.resolution) 49 | if args.center_crop 50 | else transforms.RandomCrop(args.resolution) 51 | ) 52 | train_flip = transforms.RandomHorizontalFlip(p=1.0) 53 | normalize = transforms.Normalize([0.5], [0.5]) 54 | 55 | def preprocess_images(sample): 56 | # We need to ensure that the original and the edited images undergo the same 57 | # augmentation transforms. 58 | # Some utilities have been taken from 59 | # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py 60 | orig_image = sample["original_image"] 61 | images = torch.stack( 62 | [ 63 | transforms.ToTensor()(sample["original_image"]), 64 | transforms.ToTensor()(sample["edited_image"]), 65 | ] 66 | ) 67 | images = train_resize(images) 68 | if args.center_crop: 69 | y1 = max(0, int(round((orig_image.height - args.resolution) / 2.0))) 70 | x1 = max(0, int(round((orig_image.width - args.resolution) / 2.0))) 71 | images = train_crop(images) 72 | else: 73 | y1, x1, h, w = train_crop.get_params( 74 | images, (args.resolution, args.resolution) 75 | ) 76 | images = crop(images, y1, x1, h, w) 77 | 78 | if args.random_flip and random.random() < 0.5: 79 | # flip 80 | x1 = orig_image.width - x1 81 | images = train_flip(images) 82 | crop_top_left = (y1, x1) 83 | 84 | transformed_images = normalize(images) 85 | 86 | # Separate the original and edited images and the edit prompt. 87 | original_image, edited_image = transformed_images.chunk(2) 88 | original_image = original_image.squeeze(0) 89 | edited_image = edited_image.squeeze(0) 90 | 91 | return { 92 | "original_image": original_image, 93 | "edited_image": edited_image, 94 | "edit_prompt": sample["edit_prompt"], 95 | "original_size": (orig_image.height, orig_image.width), 96 | "crop_top_left": crop_top_left, 97 | } 98 | 99 | def collate_fn(samples): 100 | original_images = torch.stack([sample["original_image"] for sample in samples]) 101 | original_images = original_images.to( 102 | memory_format=torch.contiguous_format 103 | ).float() 104 | 105 | edited_images = torch.stack([sample["edited_image"] for sample in samples]) 106 | edited_images = edited_images.to(memory_format=torch.contiguous_format).float() 107 | 108 | edit_prompts = [sample["edit_prompt"] for sample in samples] 109 | 110 | original_sizes = [sample["original_size"] for sample in samples] 111 | crop_top_lefts = [sample["crop_top_left"] for sample in samples] 112 | 113 | return { 114 | "original_images": original_images, 115 | "edited_images": edited_images, 116 | "original_sizes": original_sizes, 117 | "crop_top_lefts": crop_top_lefts, 118 | "edit_prompts": edit_prompts, 119 | } 120 | 121 | dataset = ( 122 | wds.WebDataset(args.dataset_path, resampled=True, handler=wds.warn_and_continue) 123 | .shuffle(690, handler=wds.warn_and_continue) 124 | .decode("pil", handler=wds.warn_and_continue) 125 | .rename( 126 | orig_prompt_ids="original_prompt.txt", 127 | original_image="original_image.jpg", 128 | edit_prompt="edit_prompt.txt", 129 | edited_image="edited_image.jpg", 130 | handler=wds.warn_and_continue, 131 | ) 132 | .map( 133 | filter_keys( 134 | { 135 | args.original_image_column, 136 | args.edit_prompt_column, 137 | args.edited_image_column, 138 | } 139 | ), 140 | handler=wds.warn_and_continue, 141 | ) 142 | .map(preprocess_images, handler=wds.warn_and_continue) 143 | .batched(args.per_gpu_batch_size, partial=False, collation_fn=collate_fn) 144 | .with_epoch(num_worker_batches) 145 | ) 146 | 147 | dataloader = wds.WebLoader( 148 | dataset, 149 | batch_size=None, 150 | shuffle=False, 151 | num_workers=args.num_workers, 152 | pin_memory=True, 153 | persistent_workers=True, 154 | ) 155 | 156 | # add meta-data to dataloader instance for convenience 157 | dataloader.num_batches = num_batches 158 | dataloader.num_samples = num_samples 159 | 160 | return dataloader 161 | 162 | 163 | if __name__ == "__main__": 164 | args = Namespace( 165 | dataset_path="pipe:aws s3 cp s3://muse-datasets/instructpix2pix-clip-filtered-wds/{000000..000062}.tar -", 166 | num_train_examples=313010, 167 | per_gpu_batch_size=8, 168 | global_batch_size=64, 169 | num_workers=4, 170 | center_crop=False, 171 | random_flip=True, 172 | resolution=256, 173 | original_image_column="original_image", 174 | edit_prompt_column="edit_prompt", 175 | edited_image_column="edited_image", 176 | ) 177 | dataloader = get_dataloader(args) 178 | for sample in dataloader: 179 | print(sample.keys()) 180 | print(sample["original_images"].shape) 181 | print(sample["edited_images"].shape) 182 | print(len(sample["edit_prompts"])) 183 | for s, c in zip(sample["original_sizes"], sample["crop_top_lefts"]): 184 | print(f"Original size: {s}, {type(s)}") 185 | print(f"Crop: {c}, {type(c)}") 186 | break 187 | -------------------------------------------------------------------------------- /upscale_swin2sr/upscale_images.py: -------------------------------------------------------------------------------- 1 | from swin2sr.models.network_swin2sr import Swin2SR as net 2 | 3 | from datasets import Dataset, Features 4 | from datasets import Image as ImageFeature 5 | from datasets import Value 6 | import numpy as np 7 | import PIL 8 | import os 9 | import requests 10 | import torch 11 | import datasets 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | MODEL_PATH = "model_zoo/swin2sr/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth" 16 | PARAM_KEY_G = "params_ema" 17 | SCALE = 4 18 | WINDOW_SIZE = 8 19 | DOWNSAMPLE_TO = 256 20 | BATCH_SIZE = 32 21 | 22 | NUM_WORKERS = 4 23 | DATASET_NAME = "timbrooks/instructpix2pix-clip-filtered" 24 | NEW_DATASET_NAME = "instructpix2pix-clip-filtered-upscaled" 25 | PROJECT_DIR = "/scratch" 26 | 27 | 28 | def download_model_weights() -> None: 29 | os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) 30 | url = "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/{}".format( 31 | os.path.basename(MODEL_PATH) 32 | ) 33 | r = requests.get(url, allow_redirects=True) 34 | with open(MODEL_PATH, "wb") as f: 35 | f.write(r.content) 36 | 37 | 38 | def load_model() -> torch.nn.Module: 39 | if not os.path.exists(MODEL_PATH): 40 | download_model_weights() 41 | model = net( 42 | upscale=SCALE, 43 | in_chans=3, 44 | img_size=64, 45 | window_size=8, 46 | img_range=1.0, 47 | depths=[6, 6, 6, 6, 6, 6], 48 | embed_dim=180, 49 | num_heads=[6, 6, 6, 6, 6, 6], 50 | mlp_ratio=2, 51 | upsampler="nearest+conv", 52 | resi_connection="1conv", 53 | ) 54 | pretrained_model = torch.load(MODEL_PATH) 55 | model.load_state_dict( 56 | pretrained_model[PARAM_KEY_G] 57 | if PARAM_KEY_G in pretrained_model.keys() 58 | else pretrained_model, 59 | strict=True, 60 | ) 61 | return model 62 | 63 | 64 | def preprocesss_image(image: PIL.Image.Image) -> torch.FloatTensor: 65 | image = image.resize((DOWNSAMPLE_TO, DOWNSAMPLE_TO)) 66 | image = np.array(image).astype("float32") / 255.0 67 | image = np.transpose(image, (2, 0, 1)) # HWC -> CHW 68 | img_lq = torch.from_numpy(image).float().unsqueeze(0) 69 | 70 | _, _, h_old, w_old = img_lq.size() 71 | h_pad = (h_old // WINDOW_SIZE + 1) * WINDOW_SIZE - h_old 72 | w_pad = (w_old // WINDOW_SIZE + 1) * WINDOW_SIZE - w_old 73 | img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, : h_old + h_pad, :] 74 | img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, : w_old + w_pad] 75 | return image 76 | 77 | 78 | def postprocess_image(output: torch.Tensor) -> PIL.Image.Image: 79 | output = output.data.float().cpu().clamp_(0, 1).numpy() 80 | output = (output * 255).round().astype("uint8") 81 | output = output.transpose(1, 2, 0) 82 | return PIL.Image.fromarray(output) 83 | 84 | 85 | def gen_examples( 86 | original_prompts, original_images, edit_prompts, edited_prompts, edited_images 87 | ): 88 | def fn(): 89 | for i in range(len(original_prompts)): 90 | yield { 91 | "original_prompt": original_prompts[i], 92 | "original_image": {"path": original_images[i]}, 93 | "edit_prompt": edit_prompts[i], 94 | "edited_prompt": edited_prompts[i], 95 | "edited_image": {"path": edited_images[i]}, 96 | } 97 | 98 | return fn 99 | 100 | 101 | if __name__ == "__main__": 102 | dataset = datasets.load_dataset( 103 | DATASET_NAME, split="train", num_proc=4, cache_dir=PROJECT_DIR 104 | ) 105 | print(f"Dataset has got {len(dataset)} samples.") 106 | 107 | model = load_model().eval().to("cuda:1") 108 | print("Model loaded.") 109 | 110 | folder_path = os.path.join(PROJECT_DIR, "sayak") 111 | os.makedirs(folder_path, exist_ok=True) 112 | 113 | def pp(examples): 114 | examples["original_image"] = [ 115 | preprocesss_image(image) for image in examples["original_image"] 116 | ] 117 | examples["edited_image"] = [ 118 | preprocesss_image(image) for image in examples["edited_image"] 119 | ] 120 | examples["original_prompt"] = [prompt for prompt in examples["original_prompt"]] 121 | examples["edit_prompt"] = [prompt for prompt in examples["edit_prompt"]] 122 | examples["edited_prompt"] = [prompt for prompt in examples["edited_prompt"]] 123 | return examples 124 | 125 | dataset = dataset.with_transform(pp) 126 | dataloader = DataLoader( 127 | dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True 128 | ) 129 | print("Dataloader prepared.") 130 | 131 | all_upscaled_original_paths = [] 132 | all_upscaled_edited_paths = [] 133 | all_original_prompts = [] 134 | all_edit_prompts = [] 135 | all_edited_prompts = [] 136 | 137 | with torch.no_grad(): 138 | for idx, batch in enumerate(tqdm(dataloader)): 139 | # Collate the original and edited images so that we do only a single 140 | # forward pass. 141 | images = [image for image in batch["original_image"]] 142 | images += [image for image in batch["edited_image"]] 143 | images = torch.stack(images).to( 144 | "cuda:1", memory_format=torch.contiguous_format 145 | ) 146 | 147 | # Inference. 148 | output_images = model(images) 149 | original_images, edited_images = output_images.chunk(2) 150 | 151 | # Postprocess. 152 | original_images = [postprocess_image(image) for image in original_images] 153 | edited_images = [postprocess_image(image) for image in edited_images] 154 | 155 | # Pack rest of the stuff. 156 | all_original_prompts += [prompt for prompt in batch["original_prompt"]] 157 | all_edit_prompts += [prompt for prompt in batch["edit_prompt"]] 158 | all_edited_prompts += [prompt for prompt in batch["edited_prompt"]] 159 | 160 | orig_img_paths = [ 161 | os.path.join(folder_path, f"{idx}_{i}_original_img.png") 162 | for i in range(len(original_images)) 163 | ] 164 | all_upscaled_original_paths += [path for path in orig_img_paths] 165 | edited_img_paths = [ 166 | os.path.join(folder_path, f"{idx}_{i}_edited_img.png") 167 | for i in range(len(edited_images)) 168 | ] 169 | all_upscaled_edited_paths += [path for path in edited_img_paths] 170 | 171 | for i in range(len(orig_img_paths)): 172 | original_images[i].save(orig_img_paths[i]) 173 | edited_images[i].save(edited_img_paths[i]) 174 | 175 | # Prep the dataset and get ready. 176 | generator_fn = gen_examples( 177 | original_prompts=all_original_prompts, 178 | original_images=all_upscaled_original_paths, 179 | edit_prompts=all_edit_prompts, 180 | edited_prompts=all_edited_prompts, 181 | edited_images=all_upscaled_edited_paths, 182 | ) 183 | ds = Dataset.from_generator( 184 | generator_fn, 185 | features=Features( 186 | original_prompt=Value("string"), 187 | original_image=ImageFeature(), 188 | edit_prompt=Value("string"), 189 | edited_prompt=Value("string"), 190 | edited_image=ImageFeature(), 191 | ), 192 | ) 193 | ds.save_to_disk(os.path.join(folder_path, NEW_DATASET_NAME), max_shard_size="1GB") 194 | -------------------------------------------------------------------------------- /upscale_swin2sr/swin2sr/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2021] [SwinIR Authors] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /upscale_swin2sr/swin2sr/models/network_swin2sr.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------------- 2 | # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345 3 | # Written by Conde and Choi et al. 4 | # ----------------------------------------------------------------------------------- 5 | 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__( 17 | self, 18 | in_features, 19 | hidden_features=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | drop=0.0, 23 | ): 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.fc1 = nn.Linear(in_features, hidden_features) 28 | self.act = act_layer() 29 | self.fc2 = nn.Linear(hidden_features, out_features) 30 | self.drop = nn.Dropout(drop) 31 | 32 | def forward(self, x): 33 | x = self.fc1(x) 34 | x = self.act(x) 35 | x = self.drop(x) 36 | x = self.fc2(x) 37 | x = self.drop(x) 38 | return x 39 | 40 | 41 | def window_partition(x, window_size): 42 | """ 43 | Args: 44 | x: (B, H, W, C) 45 | window_size (int): window size 46 | Returns: 47 | windows: (num_windows*B, window_size, window_size, C) 48 | """ 49 | B, H, W, C = x.shape 50 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 51 | windows = ( 52 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 53 | ) 54 | return windows 55 | 56 | 57 | def window_reverse(windows, window_size, H, W): 58 | """ 59 | Args: 60 | windows: (num_windows*B, window_size, window_size, C) 61 | window_size (int): Window size 62 | H (int): Height of image 63 | W (int): Width of image 64 | Returns: 65 | x: (B, H, W, C) 66 | """ 67 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 68 | x = windows.view( 69 | B, H // window_size, W // window_size, window_size, window_size, -1 70 | ) 71 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 72 | return x 73 | 74 | 75 | class WindowAttention(nn.Module): 76 | r"""Window based multi-head self attention (W-MSA) module with relative position bias. 77 | It supports both of shifted and non-shifted window. 78 | Args: 79 | dim (int): Number of input channels. 80 | window_size (tuple[int]): The height and width of the window. 81 | num_heads (int): Number of attention heads. 82 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 83 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 84 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 85 | pretrained_window_size (tuple[int]): The height and width of the window in pre-training. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | dim, 91 | window_size, 92 | num_heads, 93 | qkv_bias=True, 94 | attn_drop=0.0, 95 | proj_drop=0.0, 96 | pretrained_window_size=[0, 0], 97 | ): 98 | super().__init__() 99 | self.dim = dim 100 | self.window_size = window_size # Wh, Ww 101 | self.pretrained_window_size = pretrained_window_size 102 | self.num_heads = num_heads 103 | 104 | self.logit_scale = nn.Parameter( 105 | torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True 106 | ) 107 | 108 | # mlp to generate continuous relative position bias 109 | self.cpb_mlp = nn.Sequential( 110 | nn.Linear(2, 512, bias=True), 111 | nn.ReLU(inplace=True), 112 | nn.Linear(512, num_heads, bias=False), 113 | ) 114 | 115 | # get relative_coords_table 116 | relative_coords_h = torch.arange( 117 | -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32 118 | ) 119 | relative_coords_w = torch.arange( 120 | -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32 121 | ) 122 | relative_coords_table = ( 123 | torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) 124 | .permute(1, 2, 0) 125 | .contiguous() 126 | .unsqueeze(0) 127 | ) # 1, 2*Wh-1, 2*Ww-1, 2 128 | if pretrained_window_size[0] > 0: 129 | relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 130 | relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 131 | else: 132 | relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 133 | relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 134 | relative_coords_table *= 8 # normalize to -8, 8 135 | relative_coords_table = ( 136 | torch.sign(relative_coords_table) 137 | * torch.log2(torch.abs(relative_coords_table) + 1.0) 138 | / np.log2(8) 139 | ) 140 | 141 | self.register_buffer("relative_coords_table", relative_coords_table) 142 | 143 | # get pair-wise relative position index for each token inside the window 144 | coords_h = torch.arange(self.window_size[0]) 145 | coords_w = torch.arange(self.window_size[1]) 146 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 147 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 148 | relative_coords = ( 149 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 150 | ) # 2, Wh*Ww, Wh*Ww 151 | relative_coords = relative_coords.permute( 152 | 1, 2, 0 153 | ).contiguous() # Wh*Ww, Wh*Ww, 2 154 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 155 | relative_coords[:, :, 1] += self.window_size[1] - 1 156 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 157 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 158 | self.register_buffer("relative_position_index", relative_position_index) 159 | 160 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 161 | if qkv_bias: 162 | self.q_bias = nn.Parameter(torch.zeros(dim)) 163 | self.v_bias = nn.Parameter(torch.zeros(dim)) 164 | else: 165 | self.q_bias = None 166 | self.v_bias = None 167 | self.attn_drop = nn.Dropout(attn_drop) 168 | self.proj = nn.Linear(dim, dim) 169 | self.proj_drop = nn.Dropout(proj_drop) 170 | self.softmax = nn.Softmax(dim=-1) 171 | 172 | def forward(self, x, mask=None): 173 | """ 174 | Args: 175 | x: input features with shape of (num_windows*B, N, C) 176 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 177 | """ 178 | B_, N, C = x.shape 179 | qkv_bias = None 180 | if self.q_bias is not None: 181 | qkv_bias = torch.cat( 182 | ( 183 | self.q_bias, 184 | torch.zeros_like(self.v_bias, requires_grad=False), 185 | self.v_bias, 186 | ) 187 | ) 188 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 189 | qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 190 | q, k, v = ( 191 | qkv[0], 192 | qkv[1], 193 | qkv[2], 194 | ) # make torchscript happy (cannot use tensor as tuple) 195 | 196 | # cosine attention 197 | attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) 198 | logit_scale = torch.clamp( 199 | self.logit_scale, 200 | max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device), 201 | ).exp() 202 | attn = attn * logit_scale 203 | 204 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view( 205 | -1, self.num_heads 206 | ) 207 | relative_position_bias = relative_position_bias_table[ 208 | self.relative_position_index.view(-1) 209 | ].view( 210 | self.window_size[0] * self.window_size[1], 211 | self.window_size[0] * self.window_size[1], 212 | -1, 213 | ) # Wh*Ww,Wh*Ww,nH 214 | relative_position_bias = relative_position_bias.permute( 215 | 2, 0, 1 216 | ).contiguous() # nH, Wh*Ww, Wh*Ww 217 | relative_position_bias = 16 * torch.sigmoid(relative_position_bias) 218 | attn = attn + relative_position_bias.unsqueeze(0) 219 | 220 | if mask is not None: 221 | nW = mask.shape[0] 222 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 223 | 1 224 | ).unsqueeze(0) 225 | attn = attn.view(-1, self.num_heads, N, N) 226 | attn = self.softmax(attn) 227 | else: 228 | attn = self.softmax(attn) 229 | 230 | attn = self.attn_drop(attn) 231 | 232 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 233 | x = self.proj(x) 234 | x = self.proj_drop(x) 235 | return x 236 | 237 | def extra_repr(self) -> str: 238 | return ( 239 | f"dim={self.dim}, window_size={self.window_size}, " 240 | f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" 241 | ) 242 | 243 | def flops(self, N): 244 | # calculate flops for 1 window with token length of N 245 | flops = 0 246 | # qkv = self.qkv(x) 247 | flops += N * self.dim * 3 * self.dim 248 | # attn = (q @ k.transpose(-2, -1)) 249 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 250 | # x = (attn @ v) 251 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 252 | # x = self.proj(x) 253 | flops += N * self.dim * self.dim 254 | return flops 255 | 256 | 257 | class SwinTransformerBlock(nn.Module): 258 | r"""Swin Transformer Block. 259 | Args: 260 | dim (int): Number of input channels. 261 | input_resolution (tuple[int]): Input resulotion. 262 | num_heads (int): Number of attention heads. 263 | window_size (int): Window size. 264 | shift_size (int): Shift size for SW-MSA. 265 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 266 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 267 | drop (float, optional): Dropout rate. Default: 0.0 268 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 269 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 270 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 271 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 272 | pretrained_window_size (int): Window size in pre-training. 273 | """ 274 | 275 | def __init__( 276 | self, 277 | dim, 278 | input_resolution, 279 | num_heads, 280 | window_size=7, 281 | shift_size=0, 282 | mlp_ratio=4.0, 283 | qkv_bias=True, 284 | drop=0.0, 285 | attn_drop=0.0, 286 | drop_path=0.0, 287 | act_layer=nn.GELU, 288 | norm_layer=nn.LayerNorm, 289 | pretrained_window_size=0, 290 | ): 291 | super().__init__() 292 | self.dim = dim 293 | self.input_resolution = input_resolution 294 | self.num_heads = num_heads 295 | self.window_size = window_size 296 | self.shift_size = shift_size 297 | self.mlp_ratio = mlp_ratio 298 | if min(self.input_resolution) <= self.window_size: 299 | # if window size is larger than input resolution, we don't partition windows 300 | self.shift_size = 0 301 | self.window_size = min(self.input_resolution) 302 | assert ( 303 | 0 <= self.shift_size < self.window_size 304 | ), "shift_size must in 0-window_size" 305 | 306 | self.norm1 = norm_layer(dim) 307 | self.attn = WindowAttention( 308 | dim, 309 | window_size=to_2tuple(self.window_size), 310 | num_heads=num_heads, 311 | qkv_bias=qkv_bias, 312 | attn_drop=attn_drop, 313 | proj_drop=drop, 314 | pretrained_window_size=to_2tuple(pretrained_window_size), 315 | ) 316 | 317 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 318 | self.norm2 = norm_layer(dim) 319 | mlp_hidden_dim = int(dim * mlp_ratio) 320 | self.mlp = Mlp( 321 | in_features=dim, 322 | hidden_features=mlp_hidden_dim, 323 | act_layer=act_layer, 324 | drop=drop, 325 | ) 326 | 327 | if self.shift_size > 0: 328 | attn_mask = self.calculate_mask(self.input_resolution) 329 | else: 330 | attn_mask = None 331 | 332 | self.register_buffer("attn_mask", attn_mask) 333 | 334 | def calculate_mask(self, x_size): 335 | # calculate attention mask for SW-MSA 336 | H, W = x_size 337 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 338 | h_slices = ( 339 | slice(0, -self.window_size), 340 | slice(-self.window_size, -self.shift_size), 341 | slice(-self.shift_size, None), 342 | ) 343 | w_slices = ( 344 | slice(0, -self.window_size), 345 | slice(-self.window_size, -self.shift_size), 346 | slice(-self.shift_size, None), 347 | ) 348 | cnt = 0 349 | for h in h_slices: 350 | for w in w_slices: 351 | img_mask[:, h, w, :] = cnt 352 | cnt += 1 353 | 354 | mask_windows = window_partition( 355 | img_mask, self.window_size 356 | ) # nW, window_size, window_size, 1 357 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 358 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 359 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( 360 | attn_mask == 0, float(0.0) 361 | ) 362 | 363 | return attn_mask 364 | 365 | def forward(self, x, x_size): 366 | H, W = x_size 367 | B, L, C = x.shape 368 | # assert L == H * W, "input feature has wrong size" 369 | 370 | shortcut = x 371 | x = x.view(B, H, W, C) 372 | 373 | # cyclic shift 374 | if self.shift_size > 0: 375 | shifted_x = torch.roll( 376 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) 377 | ) 378 | else: 379 | shifted_x = x 380 | 381 | # partition windows 382 | x_windows = window_partition( 383 | shifted_x, self.window_size 384 | ) # nW*B, window_size, window_size, C 385 | x_windows = x_windows.view( 386 | -1, self.window_size * self.window_size, C 387 | ) # nW*B, window_size*window_size, C 388 | 389 | # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size 390 | if self.input_resolution == x_size: 391 | attn_windows = self.attn( 392 | x_windows, mask=self.attn_mask 393 | ) # nW*B, window_size*window_size, C 394 | else: 395 | attn_windows = self.attn( 396 | x_windows, mask=self.calculate_mask(x_size).to(x.device) 397 | ) 398 | 399 | # merge windows 400 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 401 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 402 | 403 | # reverse cyclic shift 404 | if self.shift_size > 0: 405 | x = torch.roll( 406 | shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) 407 | ) 408 | else: 409 | x = shifted_x 410 | x = x.view(B, H * W, C) 411 | x = shortcut + self.drop_path(self.norm1(x)) 412 | 413 | # FFN 414 | x = x + self.drop_path(self.norm2(self.mlp(x))) 415 | 416 | return x 417 | 418 | def extra_repr(self) -> str: 419 | return ( 420 | f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " 421 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 422 | ) 423 | 424 | def flops(self): 425 | flops = 0 426 | H, W = self.input_resolution 427 | # norm1 428 | flops += self.dim * H * W 429 | # W-MSA/SW-MSA 430 | nW = H * W / self.window_size / self.window_size 431 | flops += nW * self.attn.flops(self.window_size * self.window_size) 432 | # mlp 433 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 434 | # norm2 435 | flops += self.dim * H * W 436 | return flops 437 | 438 | 439 | class PatchMerging(nn.Module): 440 | r"""Patch Merging Layer. 441 | Args: 442 | input_resolution (tuple[int]): Resolution of input feature. 443 | dim (int): Number of input channels. 444 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 445 | """ 446 | 447 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 448 | super().__init__() 449 | self.input_resolution = input_resolution 450 | self.dim = dim 451 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 452 | self.norm = norm_layer(2 * dim) 453 | 454 | def forward(self, x): 455 | """ 456 | x: B, H*W, C 457 | """ 458 | H, W = self.input_resolution 459 | B, L, C = x.shape 460 | assert L == H * W, "input feature has wrong size" 461 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 462 | 463 | x = x.view(B, H, W, C) 464 | 465 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 466 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 467 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 468 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 469 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 470 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 471 | 472 | x = self.reduction(x) 473 | x = self.norm(x) 474 | 475 | return x 476 | 477 | def extra_repr(self) -> str: 478 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 479 | 480 | def flops(self): 481 | H, W = self.input_resolution 482 | flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 483 | flops += H * W * self.dim // 2 484 | return flops 485 | 486 | 487 | class BasicLayer(nn.Module): 488 | """A basic Swin Transformer layer for one stage. 489 | Args: 490 | dim (int): Number of input channels. 491 | input_resolution (tuple[int]): Input resolution. 492 | depth (int): Number of blocks. 493 | num_heads (int): Number of attention heads. 494 | window_size (int): Local window size. 495 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 496 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 497 | drop (float, optional): Dropout rate. Default: 0.0 498 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 499 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 500 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 501 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 502 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 503 | pretrained_window_size (int): Local window size in pre-training. 504 | """ 505 | 506 | def __init__( 507 | self, 508 | dim, 509 | input_resolution, 510 | depth, 511 | num_heads, 512 | window_size, 513 | mlp_ratio=4.0, 514 | qkv_bias=True, 515 | drop=0.0, 516 | attn_drop=0.0, 517 | drop_path=0.0, 518 | norm_layer=nn.LayerNorm, 519 | downsample=None, 520 | use_checkpoint=False, 521 | pretrained_window_size=0, 522 | ): 523 | super().__init__() 524 | self.dim = dim 525 | self.input_resolution = input_resolution 526 | self.depth = depth 527 | self.use_checkpoint = use_checkpoint 528 | 529 | # build blocks 530 | self.blocks = nn.ModuleList( 531 | [ 532 | SwinTransformerBlock( 533 | dim=dim, 534 | input_resolution=input_resolution, 535 | num_heads=num_heads, 536 | window_size=window_size, 537 | shift_size=0 if (i % 2 == 0) else window_size // 2, 538 | mlp_ratio=mlp_ratio, 539 | qkv_bias=qkv_bias, 540 | drop=drop, 541 | attn_drop=attn_drop, 542 | drop_path=drop_path[i] 543 | if isinstance(drop_path, list) 544 | else drop_path, 545 | norm_layer=norm_layer, 546 | pretrained_window_size=pretrained_window_size, 547 | ) 548 | for i in range(depth) 549 | ] 550 | ) 551 | 552 | # patch merging layer 553 | if downsample is not None: 554 | self.downsample = downsample( 555 | input_resolution, dim=dim, norm_layer=norm_layer 556 | ) 557 | else: 558 | self.downsample = None 559 | 560 | def forward(self, x, x_size): 561 | for blk in self.blocks: 562 | if self.use_checkpoint: 563 | x = checkpoint.checkpoint(blk, x, x_size) 564 | else: 565 | x = blk(x, x_size) 566 | if self.downsample is not None: 567 | x = self.downsample(x) 568 | return x 569 | 570 | def extra_repr(self) -> str: 571 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 572 | 573 | def flops(self): 574 | flops = 0 575 | for blk in self.blocks: 576 | flops += blk.flops() 577 | if self.downsample is not None: 578 | flops += self.downsample.flops() 579 | return flops 580 | 581 | def _init_respostnorm(self): 582 | for blk in self.blocks: 583 | nn.init.constant_(blk.norm1.bias, 0) 584 | nn.init.constant_(blk.norm1.weight, 0) 585 | nn.init.constant_(blk.norm2.bias, 0) 586 | nn.init.constant_(blk.norm2.weight, 0) 587 | 588 | 589 | class PatchEmbed(nn.Module): 590 | r"""Image to Patch Embedding 591 | Args: 592 | img_size (int): Image size. Default: 224. 593 | patch_size (int): Patch token size. Default: 4. 594 | in_chans (int): Number of input image channels. Default: 3. 595 | embed_dim (int): Number of linear projection output channels. Default: 96. 596 | norm_layer (nn.Module, optional): Normalization layer. Default: None 597 | """ 598 | 599 | def __init__( 600 | self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None 601 | ): 602 | super().__init__() 603 | img_size = to_2tuple(img_size) 604 | patch_size = to_2tuple(patch_size) 605 | patches_resolution = [ 606 | img_size[0] // patch_size[0], 607 | img_size[1] // patch_size[1], 608 | ] 609 | self.img_size = img_size 610 | self.patch_size = patch_size 611 | self.patches_resolution = patches_resolution 612 | self.num_patches = patches_resolution[0] * patches_resolution[1] 613 | 614 | self.in_chans = in_chans 615 | self.embed_dim = embed_dim 616 | 617 | self.proj = nn.Conv2d( 618 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 619 | ) 620 | if norm_layer is not None: 621 | self.norm = norm_layer(embed_dim) 622 | else: 623 | self.norm = None 624 | 625 | def forward(self, x): 626 | B, C, H, W = x.shape 627 | # FIXME look at relaxing size constraints 628 | # assert H == self.img_size[0] and W == self.img_size[1], 629 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 630 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 631 | if self.norm is not None: 632 | x = self.norm(x) 633 | return x 634 | 635 | def flops(self): 636 | Ho, Wo = self.patches_resolution 637 | flops = ( 638 | Ho 639 | * Wo 640 | * self.embed_dim 641 | * self.in_chans 642 | * (self.patch_size[0] * self.patch_size[1]) 643 | ) 644 | if self.norm is not None: 645 | flops += Ho * Wo * self.embed_dim 646 | return flops 647 | 648 | 649 | class RSTB(nn.Module): 650 | """Residual Swin Transformer Block (RSTB). 651 | 652 | Args: 653 | dim (int): Number of input channels. 654 | input_resolution (tuple[int]): Input resolution. 655 | depth (int): Number of blocks. 656 | num_heads (int): Number of attention heads. 657 | window_size (int): Local window size. 658 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 659 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 660 | drop (float, optional): Dropout rate. Default: 0.0 661 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 662 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 663 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 664 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 665 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 666 | img_size: Input image size. 667 | patch_size: Patch size. 668 | resi_connection: The convolutional block before residual connection. 669 | """ 670 | 671 | def __init__( 672 | self, 673 | dim, 674 | input_resolution, 675 | depth, 676 | num_heads, 677 | window_size, 678 | mlp_ratio=4.0, 679 | qkv_bias=True, 680 | drop=0.0, 681 | attn_drop=0.0, 682 | drop_path=0.0, 683 | norm_layer=nn.LayerNorm, 684 | downsample=None, 685 | use_checkpoint=False, 686 | img_size=224, 687 | patch_size=4, 688 | resi_connection="1conv", 689 | ): 690 | super(RSTB, self).__init__() 691 | 692 | self.dim = dim 693 | self.input_resolution = input_resolution 694 | 695 | self.residual_group = BasicLayer( 696 | dim=dim, 697 | input_resolution=input_resolution, 698 | depth=depth, 699 | num_heads=num_heads, 700 | window_size=window_size, 701 | mlp_ratio=mlp_ratio, 702 | qkv_bias=qkv_bias, 703 | drop=drop, 704 | attn_drop=attn_drop, 705 | drop_path=drop_path, 706 | norm_layer=norm_layer, 707 | downsample=downsample, 708 | use_checkpoint=use_checkpoint, 709 | ) 710 | 711 | if resi_connection == "1conv": 712 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 713 | elif resi_connection == "3conv": 714 | # to save parameters and memory 715 | self.conv = nn.Sequential( 716 | nn.Conv2d(dim, dim // 4, 3, 1, 1), 717 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 718 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), 719 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 720 | nn.Conv2d(dim // 4, dim, 3, 1, 1), 721 | ) 722 | 723 | self.patch_embed = PatchEmbed( 724 | img_size=img_size, 725 | patch_size=patch_size, 726 | in_chans=dim, 727 | embed_dim=dim, 728 | norm_layer=None, 729 | ) 730 | 731 | self.patch_unembed = PatchUnEmbed( 732 | img_size=img_size, 733 | patch_size=patch_size, 734 | in_chans=dim, 735 | embed_dim=dim, 736 | norm_layer=None, 737 | ) 738 | 739 | def forward(self, x, x_size): 740 | return ( 741 | self.patch_embed( 742 | self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) 743 | ) 744 | + x 745 | ) 746 | 747 | def flops(self): 748 | flops = 0 749 | flops += self.residual_group.flops() 750 | H, W = self.input_resolution 751 | flops += H * W * self.dim * self.dim * 9 752 | flops += self.patch_embed.flops() 753 | flops += self.patch_unembed.flops() 754 | 755 | return flops 756 | 757 | 758 | class PatchUnEmbed(nn.Module): 759 | r"""Image to Patch Unembedding 760 | 761 | Args: 762 | img_size (int): Image size. Default: 224. 763 | patch_size (int): Patch token size. Default: 4. 764 | in_chans (int): Number of input image channels. Default: 3. 765 | embed_dim (int): Number of linear projection output channels. Default: 96. 766 | norm_layer (nn.Module, optional): Normalization layer. Default: None 767 | """ 768 | 769 | def __init__( 770 | self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None 771 | ): 772 | super().__init__() 773 | img_size = to_2tuple(img_size) 774 | patch_size = to_2tuple(patch_size) 775 | patches_resolution = [ 776 | img_size[0] // patch_size[0], 777 | img_size[1] // patch_size[1], 778 | ] 779 | self.img_size = img_size 780 | self.patch_size = patch_size 781 | self.patches_resolution = patches_resolution 782 | self.num_patches = patches_resolution[0] * patches_resolution[1] 783 | 784 | self.in_chans = in_chans 785 | self.embed_dim = embed_dim 786 | 787 | def forward(self, x, x_size): 788 | B, HW, C = x.shape 789 | x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C 790 | return x 791 | 792 | def flops(self): 793 | flops = 0 794 | return flops 795 | 796 | 797 | class Upsample(nn.Sequential): 798 | """Upsample module. 799 | 800 | Args: 801 | scale (int): Scale factor. Supported scales: 2^n and 3. 802 | num_feat (int): Channel number of intermediate features. 803 | """ 804 | 805 | def __init__(self, scale, num_feat): 806 | m = [] 807 | if (scale & (scale - 1)) == 0: # scale = 2^n 808 | for _ in range(int(math.log(scale, 2))): 809 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 810 | m.append(nn.PixelShuffle(2)) 811 | elif scale == 3: 812 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 813 | m.append(nn.PixelShuffle(3)) 814 | else: 815 | raise ValueError( 816 | f"scale {scale} is not supported. " "Supported scales: 2^n and 3." 817 | ) 818 | super(Upsample, self).__init__(*m) 819 | 820 | 821 | class Upsample_hf(nn.Sequential): 822 | """Upsample module. 823 | 824 | Args: 825 | scale (int): Scale factor. Supported scales: 2^n and 3. 826 | num_feat (int): Channel number of intermediate features. 827 | """ 828 | 829 | def __init__(self, scale, num_feat): 830 | m = [] 831 | if (scale & (scale - 1)) == 0: # scale = 2^n 832 | for _ in range(int(math.log(scale, 2))): 833 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 834 | m.append(nn.PixelShuffle(2)) 835 | elif scale == 3: 836 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 837 | m.append(nn.PixelShuffle(3)) 838 | else: 839 | raise ValueError( 840 | f"scale {scale} is not supported. " "Supported scales: 2^n and 3." 841 | ) 842 | super(Upsample_hf, self).__init__(*m) 843 | 844 | 845 | class UpsampleOneStep(nn.Sequential): 846 | """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) 847 | Used in lightweight SR to save parameters. 848 | 849 | Args: 850 | scale (int): Scale factor. Supported scales: 2^n and 3. 851 | num_feat (int): Channel number of intermediate features. 852 | 853 | """ 854 | 855 | def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): 856 | self.num_feat = num_feat 857 | self.input_resolution = input_resolution 858 | m = [] 859 | m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) 860 | m.append(nn.PixelShuffle(scale)) 861 | super(UpsampleOneStep, self).__init__(*m) 862 | 863 | def flops(self): 864 | H, W = self.input_resolution 865 | flops = H * W * self.num_feat * 3 * 9 866 | return flops 867 | 868 | 869 | class Swin2SR(nn.Module): 870 | r"""Swin2SR 871 | A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. 872 | 873 | Args: 874 | img_size (int | tuple(int)): Input image size. Default 64 875 | patch_size (int | tuple(int)): Patch size. Default: 1 876 | in_chans (int): Number of input image channels. Default: 3 877 | embed_dim (int): Patch embedding dimension. Default: 96 878 | depths (tuple(int)): Depth of each Swin Transformer layer. 879 | num_heads (tuple(int)): Number of attention heads in different layers. 880 | window_size (int): Window size. Default: 7 881 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 882 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 883 | drop_rate (float): Dropout rate. Default: 0 884 | attn_drop_rate (float): Attention dropout rate. Default: 0 885 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 886 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 887 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 888 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 889 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 890 | upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction 891 | img_range: Image range. 1. or 255. 892 | upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None 893 | resi_connection: The convolutional block before residual connection. '1conv'/'3conv' 894 | """ 895 | 896 | def __init__( 897 | self, 898 | img_size=64, 899 | patch_size=1, 900 | in_chans=3, 901 | embed_dim=96, 902 | depths=[6, 6, 6, 6], 903 | num_heads=[6, 6, 6, 6], 904 | window_size=7, 905 | mlp_ratio=4.0, 906 | qkv_bias=True, 907 | drop_rate=0.0, 908 | attn_drop_rate=0.0, 909 | drop_path_rate=0.1, 910 | norm_layer=nn.LayerNorm, 911 | ape=False, 912 | patch_norm=True, 913 | use_checkpoint=False, 914 | upscale=2, 915 | img_range=1.0, 916 | upsampler="", 917 | resi_connection="1conv", 918 | **kwargs, 919 | ): 920 | super(Swin2SR, self).__init__() 921 | num_in_ch = in_chans 922 | num_out_ch = in_chans 923 | num_feat = 64 924 | self.img_range = img_range 925 | if in_chans == 3: 926 | rgb_mean = (0.4488, 0.4371, 0.4040) 927 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 928 | else: 929 | self.mean = torch.zeros(1, 1, 1, 1) 930 | self.upscale = upscale 931 | self.upsampler = upsampler 932 | self.window_size = window_size 933 | 934 | ##################################################################################################### 935 | ################################### 1, shallow feature extraction ################################### 936 | self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) 937 | 938 | ##################################################################################################### 939 | ################################### 2, deep feature extraction ###################################### 940 | self.num_layers = len(depths) 941 | self.embed_dim = embed_dim 942 | self.ape = ape 943 | self.patch_norm = patch_norm 944 | self.num_features = embed_dim 945 | self.mlp_ratio = mlp_ratio 946 | 947 | # split image into non-overlapping patches 948 | self.patch_embed = PatchEmbed( 949 | img_size=img_size, 950 | patch_size=patch_size, 951 | in_chans=embed_dim, 952 | embed_dim=embed_dim, 953 | norm_layer=norm_layer if self.patch_norm else None, 954 | ) 955 | num_patches = self.patch_embed.num_patches 956 | patches_resolution = self.patch_embed.patches_resolution 957 | self.patches_resolution = patches_resolution 958 | 959 | # merge non-overlapping patches into image 960 | self.patch_unembed = PatchUnEmbed( 961 | img_size=img_size, 962 | patch_size=patch_size, 963 | in_chans=embed_dim, 964 | embed_dim=embed_dim, 965 | norm_layer=norm_layer if self.patch_norm else None, 966 | ) 967 | 968 | # absolute position embedding 969 | if self.ape: 970 | self.absolute_pos_embed = nn.Parameter( 971 | torch.zeros(1, num_patches, embed_dim) 972 | ) 973 | trunc_normal_(self.absolute_pos_embed, std=0.02) 974 | 975 | self.pos_drop = nn.Dropout(p=drop_rate) 976 | 977 | # stochastic depth 978 | dpr = [ 979 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) 980 | ] # stochastic depth decay rule 981 | 982 | # build Residual Swin Transformer blocks (RSTB) 983 | self.layers = nn.ModuleList() 984 | for i_layer in range(self.num_layers): 985 | layer = RSTB( 986 | dim=embed_dim, 987 | input_resolution=(patches_resolution[0], patches_resolution[1]), 988 | depth=depths[i_layer], 989 | num_heads=num_heads[i_layer], 990 | window_size=window_size, 991 | mlp_ratio=self.mlp_ratio, 992 | qkv_bias=qkv_bias, 993 | drop=drop_rate, 994 | attn_drop=attn_drop_rate, 995 | drop_path=dpr[ 996 | sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) 997 | ], # no impact on SR results 998 | norm_layer=norm_layer, 999 | downsample=None, 1000 | use_checkpoint=use_checkpoint, 1001 | img_size=img_size, 1002 | patch_size=patch_size, 1003 | resi_connection=resi_connection, 1004 | ) 1005 | self.layers.append(layer) 1006 | 1007 | if self.upsampler == "pixelshuffle_hf": 1008 | self.layers_hf = nn.ModuleList() 1009 | for i_layer in range(self.num_layers): 1010 | layer = RSTB( 1011 | dim=embed_dim, 1012 | input_resolution=(patches_resolution[0], patches_resolution[1]), 1013 | depth=depths[i_layer], 1014 | num_heads=num_heads[i_layer], 1015 | window_size=window_size, 1016 | mlp_ratio=self.mlp_ratio, 1017 | qkv_bias=qkv_bias, 1018 | drop=drop_rate, 1019 | attn_drop=attn_drop_rate, 1020 | drop_path=dpr[ 1021 | sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) 1022 | ], # no impact on SR results 1023 | norm_layer=norm_layer, 1024 | downsample=None, 1025 | use_checkpoint=use_checkpoint, 1026 | img_size=img_size, 1027 | patch_size=patch_size, 1028 | resi_connection=resi_connection, 1029 | ) 1030 | self.layers_hf.append(layer) 1031 | 1032 | self.norm = norm_layer(self.num_features) 1033 | 1034 | # build the last conv layer in deep feature extraction 1035 | if resi_connection == "1conv": 1036 | self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) 1037 | elif resi_connection == "3conv": 1038 | # to save parameters and memory 1039 | self.conv_after_body = nn.Sequential( 1040 | nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), 1041 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 1042 | nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), 1043 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 1044 | nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1), 1045 | ) 1046 | 1047 | ##################################################################################################### 1048 | ################################ 3, high quality image reconstruction ################################ 1049 | if self.upsampler == "pixelshuffle": 1050 | # for classical SR 1051 | self.conv_before_upsample = nn.Sequential( 1052 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1053 | ) 1054 | self.upsample = Upsample(upscale, num_feat) 1055 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1056 | elif self.upsampler == "pixelshuffle_aux": 1057 | self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 1058 | self.conv_before_upsample = nn.Sequential( 1059 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1060 | ) 1061 | self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1062 | self.conv_after_aux = nn.Sequential( 1063 | nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1064 | ) 1065 | self.upsample = Upsample(upscale, num_feat) 1066 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1067 | 1068 | elif self.upsampler == "pixelshuffle_hf": 1069 | self.conv_before_upsample = nn.Sequential( 1070 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1071 | ) 1072 | self.upsample = Upsample(upscale, num_feat) 1073 | self.upsample_hf = Upsample_hf(upscale, num_feat) 1074 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1075 | self.conv_first_hf = nn.Sequential( 1076 | nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True) 1077 | ) 1078 | self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) 1079 | self.conv_before_upsample_hf = nn.Sequential( 1080 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1081 | ) 1082 | self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1083 | 1084 | elif self.upsampler == "pixelshuffledirect": 1085 | # for lightweight SR (to save parameters) 1086 | self.upsample = UpsampleOneStep( 1087 | upscale, 1088 | embed_dim, 1089 | num_out_ch, 1090 | (patches_resolution[0], patches_resolution[1]), 1091 | ) 1092 | elif self.upsampler == "nearest+conv": 1093 | # for real-world SR (less artifacts) 1094 | assert self.upscale == 4, "only support x4 now." 1095 | self.conv_before_upsample = nn.Sequential( 1096 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1097 | ) 1098 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 1099 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 1100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 1101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1102 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 1103 | else: 1104 | # for image denoising and JPEG compression artifact reduction 1105 | self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) 1106 | 1107 | self.apply(self._init_weights) 1108 | 1109 | def _init_weights(self, m): 1110 | if isinstance(m, nn.Linear): 1111 | trunc_normal_(m.weight, std=0.02) 1112 | if isinstance(m, nn.Linear) and m.bias is not None: 1113 | nn.init.constant_(m.bias, 0) 1114 | elif isinstance(m, nn.LayerNorm): 1115 | nn.init.constant_(m.bias, 0) 1116 | nn.init.constant_(m.weight, 1.0) 1117 | 1118 | @torch.jit.ignore 1119 | def no_weight_decay(self): 1120 | return {"absolute_pos_embed"} 1121 | 1122 | @torch.jit.ignore 1123 | def no_weight_decay_keywords(self): 1124 | return {"relative_position_bias_table"} 1125 | 1126 | def check_image_size(self, x): 1127 | _, _, h, w = x.size() 1128 | mod_pad_h = (self.window_size - h % self.window_size) % self.window_size 1129 | mod_pad_w = (self.window_size - w % self.window_size) % self.window_size 1130 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") 1131 | return x 1132 | 1133 | def forward_features(self, x): 1134 | x_size = (x.shape[2], x.shape[3]) 1135 | x = self.patch_embed(x) 1136 | if self.ape: 1137 | x = x + self.absolute_pos_embed 1138 | x = self.pos_drop(x) 1139 | 1140 | for layer in self.layers: 1141 | x = layer(x, x_size) 1142 | 1143 | x = self.norm(x) # B L C 1144 | x = self.patch_unembed(x, x_size) 1145 | 1146 | return x 1147 | 1148 | def forward_features_hf(self, x): 1149 | x_size = (x.shape[2], x.shape[3]) 1150 | x = self.patch_embed(x) 1151 | if self.ape: 1152 | x = x + self.absolute_pos_embed 1153 | x = self.pos_drop(x) 1154 | 1155 | for layer in self.layers_hf: 1156 | x = layer(x, x_size) 1157 | 1158 | x = self.norm(x) # B L C 1159 | x = self.patch_unembed(x, x_size) 1160 | 1161 | return x 1162 | 1163 | def forward(self, x): 1164 | H, W = x.shape[2:] 1165 | x = self.check_image_size(x) 1166 | 1167 | self.mean = self.mean.type_as(x) 1168 | x = (x - self.mean) * self.img_range 1169 | 1170 | if self.upsampler == "pixelshuffle": 1171 | # for classical SR 1172 | x = self.conv_first(x) 1173 | x = self.conv_after_body(self.forward_features(x)) + x 1174 | x = self.conv_before_upsample(x) 1175 | x = self.conv_last(self.upsample(x)) 1176 | elif self.upsampler == "pixelshuffle_aux": 1177 | bicubic = F.interpolate( 1178 | x, 1179 | size=(H * self.upscale, W * self.upscale), 1180 | mode="bicubic", 1181 | align_corners=False, 1182 | ) 1183 | bicubic = self.conv_bicubic(bicubic) 1184 | x = self.conv_first(x) 1185 | x = self.conv_after_body(self.forward_features(x)) + x 1186 | x = self.conv_before_upsample(x) 1187 | aux = self.conv_aux(x) # b, 3, LR_H, LR_W 1188 | x = self.conv_after_aux(aux) 1189 | x = ( 1190 | self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale] 1191 | + bicubic[:, :, : H * self.upscale, : W * self.upscale] 1192 | ) 1193 | x = self.conv_last(x) 1194 | aux = aux / self.img_range + self.mean 1195 | elif self.upsampler == "pixelshuffle_hf": 1196 | # for classical SR with HF 1197 | x = self.conv_first(x) 1198 | x = self.conv_after_body(self.forward_features(x)) + x 1199 | x_before = self.conv_before_upsample(x) 1200 | x_out = self.conv_last(self.upsample(x_before)) 1201 | 1202 | x_hf = self.conv_first_hf(x_before) 1203 | x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf 1204 | x_hf = self.conv_before_upsample_hf(x_hf) 1205 | x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) 1206 | x = x_out + x_hf 1207 | x_hf = x_hf / self.img_range + self.mean 1208 | 1209 | elif self.upsampler == "pixelshuffledirect": 1210 | # for lightweight SR 1211 | x = self.conv_first(x) 1212 | x = self.conv_after_body(self.forward_features(x)) + x 1213 | x = self.upsample(x) 1214 | elif self.upsampler == "nearest+conv": 1215 | # for real-world SR 1216 | x = self.conv_first(x) 1217 | x = self.conv_after_body(self.forward_features(x)) + x 1218 | x = self.conv_before_upsample(x) 1219 | x = self.lrelu( 1220 | self.conv_up1( 1221 | torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") 1222 | ) 1223 | ) 1224 | x = self.lrelu( 1225 | self.conv_up2( 1226 | torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") 1227 | ) 1228 | ) 1229 | x = self.conv_last(self.lrelu(self.conv_hr(x))) 1230 | else: 1231 | # for image denoising and JPEG compression artifact reduction 1232 | x_first = self.conv_first(x) 1233 | res = self.conv_after_body(self.forward_features(x_first)) + x_first 1234 | x = x + self.conv_last(res) 1235 | 1236 | x = x / self.img_range + self.mean 1237 | if self.upsampler == "pixelshuffle_aux": 1238 | return x[:, :, : H * self.upscale, : W * self.upscale], aux 1239 | 1240 | elif self.upsampler == "pixelshuffle_hf": 1241 | x_out = x_out / self.img_range + self.mean 1242 | return ( 1243 | x_out[:, :, : H * self.upscale, : W * self.upscale], 1244 | x[:, :, : H * self.upscale, : W * self.upscale], 1245 | x_hf[:, :, : H * self.upscale, : W * self.upscale], 1246 | ) 1247 | 1248 | else: 1249 | return x[:, :, : H * self.upscale, : W * self.upscale] 1250 | 1251 | def flops(self): 1252 | flops = 0 1253 | H, W = self.patches_resolution 1254 | flops += H * W * 3 * self.embed_dim * 9 1255 | flops += self.patch_embed.flops() 1256 | for i, layer in enumerate(self.layers): 1257 | flops += layer.flops() 1258 | flops += H * W * 3 * self.embed_dim * self.embed_dim 1259 | flops += self.upsample.flops() 1260 | return flops 1261 | 1262 | 1263 | if __name__ == "__main__": 1264 | upscale = 4 1265 | window_size = 8 1266 | height = (1024 // upscale // window_size + 1) * window_size 1267 | width = (720 // upscale // window_size + 1) * window_size 1268 | model = Swin2SR( 1269 | upscale=2, 1270 | img_size=(height, width), 1271 | window_size=window_size, 1272 | img_range=1.0, 1273 | depths=[6, 6, 6, 6], 1274 | embed_dim=60, 1275 | num_heads=[6, 6, 6, 6], 1276 | mlp_ratio=2, 1277 | upsampler="pixelshuffledirect", 1278 | ) 1279 | print(model) 1280 | print(height, width, model.flops() / 1e9) 1281 | 1282 | x = torch.randn((1, 3, height, width)) 1283 | x = model(x) 1284 | print(x.shape) 1285 | -------------------------------------------------------------------------------- /scripts/train_instruct_pix2pix_sdxl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import logging 19 | import math 20 | import os 21 | import shutil 22 | import warnings 23 | from pathlib import Path 24 | from urllib.parse import urlparse 25 | from dataloader import get_dataloader 26 | 27 | import accelerate 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from huggingface_hub import create_repo, upload_folder 37 | from packaging import version 38 | from PIL import Image 39 | from tqdm.auto import tqdm 40 | from transformers import AutoTokenizer, PretrainedConfig 41 | 42 | import diffusers 43 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 44 | from diffusers.optimization import get_scheduler 45 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import ( 46 | StableDiffusionXLInstructPix2PixPipeline, 47 | ) 48 | from diffusers.training_utils import EMAModel 49 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image 50 | from diffusers.utils.import_utils import is_xformers_available 51 | 52 | if is_wandb_available(): 53 | import wandb 54 | 55 | 56 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 57 | check_min_version("0.20.0.dev0") 58 | 59 | logger = get_logger(__name__, log_level="INFO") 60 | 61 | WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] 62 | 63 | 64 | def import_model_class_from_model_name_or_path( 65 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 66 | ): 67 | text_encoder_config = PretrainedConfig.from_pretrained( 68 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 69 | ) 70 | model_class = text_encoder_config.architectures[0] 71 | 72 | if model_class == "CLIPTextModel": 73 | from transformers import CLIPTextModel 74 | 75 | return CLIPTextModel 76 | elif model_class == "CLIPTextModelWithProjection": 77 | from transformers import CLIPTextModelWithProjection 78 | 79 | return CLIPTextModelWithProjection 80 | else: 81 | raise ValueError(f"{model_class} is not supported.") 82 | 83 | 84 | def parse_args(): 85 | parser = argparse.ArgumentParser( 86 | description="Script to train Stable Diffusion XL for InstructPix2Pix." 87 | ) 88 | parser.add_argument( 89 | "--pretrained_model_name_or_path", 90 | type=str, 91 | default=None, 92 | required=True, 93 | help="Path to pretrained model or model identifier from huggingface.co/models.", 94 | ) 95 | parser.add_argument( 96 | "--pretrained_vae_model_name_or_path", 97 | type=str, 98 | default=None, 99 | help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", 100 | ) 101 | parser.add_argument( 102 | "--revision", 103 | type=str, 104 | default=None, 105 | required=False, 106 | help="Revision of pretrained model identifier from huggingface.co/models.", 107 | ) 108 | parser.add_argument( 109 | "--dataset_path", 110 | type=str, 111 | default=None, 112 | help="Path to the dataset shards stored S3 in the webdataset format. " 113 | "Example: pipe:aws s3 cp s3://my-datasets/my-project/{000000..000010}.tar -", 114 | ) 115 | parser.add_argument( 116 | "--original_image_column", 117 | type=str, 118 | default="original_image", 119 | help="The column of the dataset containing the original image on which edits where made.", 120 | ) 121 | parser.add_argument( 122 | "--edited_image_column", 123 | type=str, 124 | default="edited_image", 125 | help="The column of the dataset containing the edited image.", 126 | ) 127 | parser.add_argument( 128 | "--edit_prompt_column", 129 | type=str, 130 | default="edit_prompt", 131 | help="The column of the dataset containing the edit instruction.", 132 | ) 133 | parser.add_argument( 134 | "--val_image_url_or_path", 135 | type=str, 136 | default=None, 137 | help="URL or path to the original image that you would like to edit (used during inference for debugging purposes).", 138 | ) 139 | parser.add_argument( 140 | "--validation_prompt", 141 | type=str, 142 | default=None, 143 | help="A prompt that is sampled during training for inference.", 144 | ) 145 | parser.add_argument( 146 | "--num_validation_images", 147 | type=int, 148 | default=4, 149 | help="Number of images that should be generated during validation with `validation_prompt`.", 150 | ) 151 | parser.add_argument( 152 | "--validation_steps", 153 | type=int, 154 | default=100, 155 | help=( 156 | "Run fine-tuning validation every X steps. The validation process consists of running the prompt" 157 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 158 | ), 159 | ) 160 | parser.add_argument( 161 | "--output_dir", 162 | type=str, 163 | default="instruct-pix2pix-sdxl-model", 164 | help="The output directory where the model predictions and checkpoints will be written.", 165 | ) 166 | parser.add_argument( 167 | "--cache_dir", 168 | type=str, 169 | default=None, 170 | help="The directory where the downloaded models and datasets will be stored.", 171 | ) 172 | parser.add_argument( 173 | "--seed", type=int, default=None, help="A seed for reproducible training." 174 | ) 175 | parser.add_argument( 176 | "--resolution", 177 | type=int, 178 | default=256, 179 | help=( 180 | "The resolution for input images, all the images in the train/validation dataset will be resized to this resolution." 181 | ), 182 | ) 183 | parser.add_argument( 184 | "--center_crop", 185 | default=False, 186 | action="store_true", 187 | help=( 188 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 189 | " cropped. The images will be resized to the resolution first before cropping." 190 | ), 191 | ) 192 | parser.add_argument( 193 | "--random_flip", 194 | action="store_true", 195 | help="whether to randomly flip images horizontally", 196 | ) 197 | parser.add_argument( 198 | "--per_gpu_batch_size", 199 | type=int, 200 | default=8, 201 | help="Batch size (per device) for the training dataloader.", 202 | ) 203 | parser.add_argument( 204 | "--num_train_examples", 205 | type=int, 206 | default=313010, 207 | help="Number of training examples.", 208 | ) 209 | parser.add_argument("--num_train_epochs", type=int, default=100) 210 | parser.add_argument( 211 | "--max_train_steps", 212 | type=int, 213 | default=None, 214 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 215 | ) 216 | parser.add_argument( 217 | "--gradient_accumulation_steps", 218 | type=int, 219 | default=1, 220 | help="Number of updates steps to accumulate before performing a backward/update pass.", 221 | ) 222 | parser.add_argument( 223 | "--gradient_checkpointing", 224 | action="store_true", 225 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 226 | ) 227 | parser.add_argument( 228 | "--learning_rate", 229 | type=float, 230 | default=1e-4, 231 | help="Initial learning rate (after the potential warmup period) to use.", 232 | ) 233 | parser.add_argument( 234 | "--scale_lr", 235 | action="store_true", 236 | default=False, 237 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 238 | ) 239 | parser.add_argument( 240 | "--lr_scheduler", 241 | type=str, 242 | default="constant", 243 | help=( 244 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 245 | ' "constant", "constant_with_warmup"]' 246 | ), 247 | ) 248 | parser.add_argument( 249 | "--lr_warmup_steps", 250 | type=int, 251 | default=500, 252 | help="Number of steps for the warmup in the lr scheduler.", 253 | ) 254 | parser.add_argument( 255 | "--conditioning_dropout_prob", 256 | type=float, 257 | default=None, 258 | help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", 259 | ) 260 | parser.add_argument( 261 | "--use_8bit_adam", 262 | action="store_true", 263 | help="Whether or not to use 8-bit Adam from bitsandbytes.", 264 | ) 265 | parser.add_argument( 266 | "--allow_tf32", 267 | action="store_true", 268 | help=( 269 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 270 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 271 | ), 272 | ) 273 | parser.add_argument( 274 | "--use_ema", action="store_true", help="Whether to use EMA model." 275 | ) 276 | parser.add_argument( 277 | "--non_ema_revision", 278 | type=str, 279 | default=None, 280 | required=False, 281 | help=( 282 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" 283 | " remote repository specified with --pretrained_model_name_or_path." 284 | ), 285 | ) 286 | parser.add_argument( 287 | "--num_workers", 288 | type=int, 289 | default=0, 290 | help=( 291 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 292 | ), 293 | ) 294 | parser.add_argument( 295 | "--adam_beta1", 296 | type=float, 297 | default=0.9, 298 | help="The beta1 parameter for the Adam optimizer.", 299 | ) 300 | parser.add_argument( 301 | "--adam_beta2", 302 | type=float, 303 | default=0.999, 304 | help="The beta2 parameter for the Adam optimizer.", 305 | ) 306 | parser.add_argument( 307 | "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." 308 | ) 309 | parser.add_argument( 310 | "--adam_epsilon", 311 | type=float, 312 | default=1e-08, 313 | help="Epsilon value for the Adam optimizer", 314 | ) 315 | parser.add_argument( 316 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 317 | ) 318 | parser.add_argument( 319 | "--push_to_hub", 320 | action="store_true", 321 | help="Whether or not to push the model to the Hub.", 322 | ) 323 | parser.add_argument( 324 | "--hub_token", 325 | type=str, 326 | default=None, 327 | help="The token to use to push to the Model Hub.", 328 | ) 329 | parser.add_argument( 330 | "--hub_model_id", 331 | type=str, 332 | default=None, 333 | help="The name of the repository to keep in sync with the local `output_dir`.", 334 | ) 335 | parser.add_argument( 336 | "--logging_dir", 337 | type=str, 338 | default="logs", 339 | help=( 340 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 341 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 342 | ), 343 | ) 344 | parser.add_argument( 345 | "--mixed_precision", 346 | type=str, 347 | default=None, 348 | choices=["no", "fp16", "bf16"], 349 | help=( 350 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 351 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 352 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 353 | ), 354 | ) 355 | parser.add_argument( 356 | "--report_to", 357 | type=str, 358 | default="tensorboard", 359 | help=( 360 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 361 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 362 | ), 363 | ) 364 | parser.add_argument( 365 | "--local_rank", 366 | type=int, 367 | default=-1, 368 | help="For distributed training: local_rank", 369 | ) 370 | parser.add_argument( 371 | "--checkpointing_steps", 372 | type=int, 373 | default=500, 374 | help=( 375 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 376 | " training using `--resume_from_checkpoint`." 377 | ), 378 | ) 379 | parser.add_argument( 380 | "--checkpoints_total_limit", 381 | type=int, 382 | default=None, 383 | help=("Max number of checkpoints to store."), 384 | ) 385 | parser.add_argument( 386 | "--resume_from_checkpoint", 387 | type=str, 388 | default=None, 389 | help=( 390 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 391 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 392 | ), 393 | ) 394 | parser.add_argument( 395 | "--enable_xformers_memory_efficient_attention", 396 | action="store_true", 397 | help="Whether or not to use xformers.", 398 | ) 399 | 400 | args = parser.parse_args() 401 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 402 | if env_local_rank != -1 and env_local_rank != args.local_rank: 403 | args.local_rank = env_local_rank 404 | 405 | # Sanity checks 406 | if args.dataset_path is None: 407 | raise ValueError("dataset_path cannot be None.") 408 | 409 | # default to using the same revision for the non-ema model if not specified 410 | if args.non_ema_revision is None: 411 | args.non_ema_revision = args.revision 412 | 413 | return args 414 | 415 | 416 | def log_validation( 417 | vae, 418 | unet, 419 | text_encoder_1, 420 | text_encoder_2, 421 | tokenizer_1, 422 | tokenizer_2, 423 | args, 424 | accelerator, 425 | weight_dtype, 426 | global_step, 427 | ): 428 | logger.info( 429 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 430 | f" {args.validation_prompt}." 431 | ) 432 | 433 | # The models need unwrapping because for compatibility in distributed training mode. 434 | pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( 435 | args.pretrained_model_name_or_path, 436 | unet=accelerator.unwrap_model(unet), 437 | text_encoder=text_encoder_1, 438 | text_encoder_2=text_encoder_2, 439 | tokenizer=tokenizer_1, 440 | tokenizer_2=tokenizer_2, 441 | vae=vae, 442 | revision=args.revision, 443 | torch_dtype=weight_dtype, 444 | ) 445 | pipeline = pipeline.to(accelerator.device) 446 | pipeline.set_progress_bar_config(disable=True) 447 | 448 | if args.enable_xformers_memory_efficient_attention: 449 | pipeline.enable_xformers_memory_efficient_attention() 450 | 451 | if args.seed is None: 452 | generator = None 453 | else: 454 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 455 | 456 | # run inference 457 | # Save validation images 458 | val_save_dir = os.path.join(args.output_dir, "validation_images") 459 | if not os.path.exists(val_save_dir): 460 | os.makedirs(val_save_dir) 461 | 462 | original_image = ( 463 | lambda image_url_or_path: load_image(image_url_or_path) 464 | if urlparse(image_url_or_path).scheme 465 | else Image.open(image_url_or_path).convert("RGB") 466 | )(args.val_image_url_or_path) 467 | with torch.autocast( 468 | accelerator.device.type, 469 | enabled=accelerator.mixed_precision == "fp16", 470 | ): 471 | edited_images = [] 472 | for val_img_idx in range(args.num_validation_images): 473 | a_val_img = pipeline( 474 | args.validation_prompt, 475 | image=original_image, 476 | num_inference_steps=20, 477 | image_guidance_scale=1.5, 478 | guidance_scale=7, 479 | generator=generator, 480 | ).images[0] 481 | edited_images.append(a_val_img) 482 | a_val_img.save( 483 | os.path.join( 484 | val_save_dir, 485 | f"step_{global_step}_val_img_{val_img_idx}.png", 486 | ) 487 | ) 488 | 489 | for tracker in accelerator.trackers: 490 | if tracker.name == "wandb": 491 | wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) 492 | for edited_image in edited_images: 493 | wandb_table.add_data( 494 | wandb.Image(original_image), 495 | wandb.Image(edited_image), 496 | args.validation_prompt, 497 | ) 498 | tracker.log({"validation": wandb_table}) 499 | 500 | del pipeline 501 | torch.cuda.empty_cache() 502 | 503 | 504 | def main(): 505 | args = parse_args() 506 | 507 | if args.non_ema_revision is not None: 508 | deprecate( 509 | "non_ema_revision!=None", 510 | "0.15.0", 511 | message=( 512 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" 513 | " use `--variant=non_ema` instead." 514 | ), 515 | ) 516 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 517 | accelerator_project_config = ProjectConfiguration( 518 | project_dir=args.output_dir, logging_dir=logging_dir 519 | ) 520 | accelerator = Accelerator( 521 | gradient_accumulation_steps=args.gradient_accumulation_steps, 522 | mixed_precision=args.mixed_precision, 523 | log_with=args.report_to, 524 | project_config=accelerator_project_config, 525 | ) 526 | 527 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 528 | 529 | if args.report_to == "wandb": 530 | if not is_wandb_available(): 531 | raise ImportError( 532 | "Make sure to install wandb if you want to use it for logging during training." 533 | ) 534 | import wandb 535 | 536 | # Make one log on every process with the configuration for debugging. 537 | logging.basicConfig( 538 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 539 | datefmt="%m/%d/%Y %H:%M:%S", 540 | level=logging.INFO, 541 | ) 542 | logger.info(accelerator.state, main_process_only=False) 543 | if accelerator.is_local_main_process: 544 | transformers.utils.logging.set_verbosity_warning() 545 | diffusers.utils.logging.set_verbosity_info() 546 | else: 547 | transformers.utils.logging.set_verbosity_error() 548 | diffusers.utils.logging.set_verbosity_error() 549 | 550 | # If passed along, set the training seed now. 551 | if args.seed is not None: 552 | set_seed(args.seed) 553 | 554 | # Handle the repository creation 555 | if accelerator.is_main_process: 556 | if args.output_dir is not None: 557 | os.makedirs(args.output_dir, exist_ok=True) 558 | 559 | if args.push_to_hub: 560 | repo_id = create_repo( 561 | repo_id=args.hub_model_id or Path(args.output_dir).name, 562 | exist_ok=True, 563 | token=args.hub_token, 564 | ).repo_id 565 | 566 | vae_path = ( 567 | args.pretrained_model_name_or_path 568 | if args.pretrained_vae_model_name_or_path is None 569 | else args.pretrained_vae_model_name_or_path 570 | ) 571 | vae = AutoencoderKL.from_pretrained( 572 | vae_path, 573 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 574 | revision=args.revision, 575 | ) 576 | unet = UNet2DConditionModel.from_pretrained( 577 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 578 | ) 579 | 580 | # InstructPix2Pix uses an additional image for conditioning. To accommodate that, 581 | # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is 582 | # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized 583 | # from the pre-trained checkpoints. For the extra channels added to the first layer, they are 584 | # initialized to zero. 585 | logger.info("Initializing the XL InstructPix2Pix UNet from the pretrained UNet.") 586 | in_channels = 8 587 | out_channels = unet.conv_in.out_channels 588 | unet.register_to_config(in_channels=in_channels) 589 | 590 | with torch.no_grad(): 591 | new_conv_in = nn.Conv2d( 592 | in_channels, 593 | out_channels, 594 | unet.conv_in.kernel_size, 595 | unet.conv_in.stride, 596 | unet.conv_in.padding, 597 | ) 598 | new_conv_in.weight.zero_() 599 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 600 | unet.conv_in = new_conv_in 601 | 602 | # Create EMA for the unet. 603 | if args.use_ema: 604 | ema_unet = EMAModel( 605 | unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config 606 | ) 607 | 608 | if args.enable_xformers_memory_efficient_attention: 609 | if is_xformers_available(): 610 | import xformers 611 | 612 | xformers_version = version.parse(xformers.__version__) 613 | if xformers_version == version.parse("0.0.16"): 614 | logger.warn( 615 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 616 | ) 617 | unet.enable_xformers_memory_efficient_attention() 618 | else: 619 | raise ValueError( 620 | "xformers is not available. Make sure it is installed correctly" 621 | ) 622 | 623 | # `accelerate` 0.16.0 will have better support for customized saving 624 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 625 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 626 | def save_model_hook(models, weights, output_dir): 627 | if args.use_ema: 628 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 629 | 630 | for i, model in enumerate(models): 631 | model.save_pretrained(os.path.join(output_dir, "unet")) 632 | 633 | # make sure to pop weight so that corresponding model is not saved again 634 | weights.pop() 635 | 636 | def load_model_hook(models, input_dir): 637 | if args.use_ema: 638 | load_model = EMAModel.from_pretrained( 639 | os.path.join(input_dir, "unet_ema"), UNet2DConditionModel 640 | ) 641 | ema_unet.load_state_dict(load_model.state_dict()) 642 | ema_unet.to(accelerator.device) 643 | del load_model 644 | 645 | for i in range(len(models)): 646 | # pop models so that they are not loaded again 647 | model = models.pop() 648 | 649 | # load diffusers style into model 650 | load_model = UNet2DConditionModel.from_pretrained( 651 | input_dir, subfolder="unet" 652 | ) 653 | model.register_to_config(**load_model.config) 654 | 655 | model.load_state_dict(load_model.state_dict()) 656 | del load_model 657 | 658 | accelerator.register_save_state_pre_hook(save_model_hook) 659 | accelerator.register_load_state_pre_hook(load_model_hook) 660 | 661 | if args.gradient_checkpointing: 662 | unet.enable_gradient_checkpointing() 663 | 664 | # Enable TF32 for faster training on Ampere GPUs, 665 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 666 | if args.allow_tf32: 667 | torch.backends.cuda.matmul.allow_tf32 = True 668 | 669 | if args.scale_lr: 670 | args.learning_rate = ( 671 | args.learning_rate 672 | * args.gradient_accumulation_steps 673 | * args.per_gpu_batch_size 674 | * accelerator.num_processes 675 | ) 676 | 677 | # Initialize the optimizer 678 | if args.use_8bit_adam: 679 | try: 680 | import bitsandbytes as bnb 681 | except ImportError: 682 | raise ImportError( 683 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 684 | ) 685 | 686 | optimizer_cls = bnb.optim.AdamW8bit 687 | else: 688 | optimizer_cls = torch.optim.AdamW 689 | 690 | optimizer = optimizer_cls( 691 | unet.parameters(), 692 | lr=args.learning_rate, 693 | betas=(args.adam_beta1, args.adam_beta2), 694 | weight_decay=args.adam_weight_decay, 695 | eps=args.adam_epsilon, 696 | ) 697 | 698 | # Get the datasets: you can either provide your own training and evaluation files (see below) 699 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 700 | args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes 701 | train_dataloader = get_dataloader(args) 702 | 703 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 704 | # as these models are only used for inference, keeping weights in full precision is not required. 705 | weight_dtype = torch.float32 706 | if accelerator.mixed_precision == "fp16": 707 | weight_dtype = torch.float16 708 | warnings.warn( 709 | f"weight_dtype {weight_dtype} may cause nan during vae encoding", 710 | UserWarning, 711 | ) 712 | 713 | elif accelerator.mixed_precision == "bf16": 714 | weight_dtype = torch.bfloat16 715 | warnings.warn( 716 | f"weight_dtype {weight_dtype} may cause nan during vae encoding", 717 | UserWarning, 718 | ) 719 | 720 | # Load scheduler, tokenizer and models. 721 | tokenizer_1 = AutoTokenizer.from_pretrained( 722 | args.pretrained_model_name_or_path, 723 | subfolder="tokenizer", 724 | revision=args.revision, 725 | use_fast=False, 726 | ) 727 | tokenizer_2 = AutoTokenizer.from_pretrained( 728 | args.pretrained_model_name_or_path, 729 | subfolder="tokenizer_2", 730 | revision=args.revision, 731 | use_fast=False, 732 | ) 733 | text_encoder_cls_1 = import_model_class_from_model_name_or_path( 734 | args.pretrained_model_name_or_path, args.revision 735 | ) 736 | text_encoder_cls_2 = import_model_class_from_model_name_or_path( 737 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 738 | ) 739 | 740 | # Load scheduler and models 741 | noise_scheduler = DDPMScheduler.from_pretrained( 742 | args.pretrained_model_name_or_path, subfolder="scheduler" 743 | ) 744 | text_encoder_1 = text_encoder_cls_1.from_pretrained( 745 | args.pretrained_model_name_or_path, 746 | subfolder="text_encoder", 747 | revision=args.revision, 748 | ) 749 | text_encoder_2 = text_encoder_cls_2.from_pretrained( 750 | args.pretrained_model_name_or_path, 751 | subfolder="text_encoder_2", 752 | revision=args.revision, 753 | ) 754 | 755 | # We ALWAYS pre-compute the additional condition embeddings needed for SDXL 756 | # UNet as the model is already big and it uses two text encoders. 757 | text_encoder_1.to(accelerator.device, dtype=weight_dtype) 758 | text_encoder_2.to(accelerator.device, dtype=weight_dtype) 759 | tokenizers = [tokenizer_1, tokenizer_2] 760 | text_encoders = [text_encoder_1, text_encoder_2] 761 | 762 | # Freeze vae and text_encoders 763 | vae.requires_grad_(False) 764 | text_encoder_1.requires_grad_(False) 765 | text_encoder_2.requires_grad_(False) 766 | 767 | # Adapted from diffusers.pipelines.StableDiffusionXLPipeline.encode_prompt 768 | def encode_prompt(prompts, text_encoders, tokenizers): 769 | prompt_embeds_list = [] 770 | 771 | with torch.no_grad(): 772 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 773 | text_inputs = tokenizer( 774 | prompts, 775 | padding="max_length", 776 | max_length=tokenizer.model_max_length, 777 | truncation=True, 778 | return_tensors="pt", 779 | ) 780 | text_input_ids = text_inputs.input_ids 781 | prompt_embeds = text_encoder( 782 | text_input_ids.to(text_encoder.device), 783 | output_hidden_states=True, 784 | ) 785 | 786 | # We are only ALWAYS interested in the pooled output of the final text encoder 787 | pooled_prompt_embeds = prompt_embeds[0] 788 | prompt_embeds = prompt_embeds.hidden_states[-2] 789 | bs_embed, seq_len, _ = prompt_embeds.shape 790 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 791 | prompt_embeds_list.append(prompt_embeds) 792 | 793 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 794 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 795 | return prompt_embeds, pooled_prompt_embeds 796 | 797 | def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers): 798 | prompt_embeds_all, pooled_prompt_embeds_all = encode_prompt( 799 | prompts, text_encoders, tokenizers 800 | ) 801 | prompt_embeds_all = prompt_embeds_all.to(accelerator.device) 802 | pooled_prompt_embeds_all = pooled_prompt_embeds_all.to(accelerator.device) 803 | return prompt_embeds_all, pooled_prompt_embeds_all 804 | 805 | def tokenize_captions(captions, tokenizer): 806 | inputs = tokenizer( 807 | captions, 808 | max_length=tokenizer.model_max_length, 809 | padding="max_length", 810 | truncation=True, 811 | return_tensors="pt", 812 | ) 813 | return inputs.input_ids 814 | 815 | # Get null conditioning. 816 | def compute_null_conditioning(): 817 | null_conditioning_list = [] 818 | for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders): 819 | null_conditioning_list.append( 820 | a_text_encoder( 821 | tokenize_captions([""], tokenizer=a_tokenizer).to( 822 | accelerator.device 823 | ), 824 | output_hidden_states=True, 825 | ).hidden_states[-2] 826 | ) 827 | return torch.concat(null_conditioning_list, dim=-1) 828 | 829 | null_conditioning = compute_null_conditioning() 830 | 831 | # Scheduler and math around the number of training steps. 832 | overrode_max_train_steps = False 833 | num_update_steps_per_epoch = math.ceil( 834 | train_dataloader.num_batches / args.gradient_accumulation_steps 835 | ) 836 | if args.max_train_steps is None: 837 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 838 | overrode_max_train_steps = True 839 | 840 | lr_scheduler = get_scheduler( 841 | args.lr_scheduler, 842 | optimizer=optimizer, 843 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 844 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 845 | ) 846 | 847 | # Prepare everything with our `accelerator` except the `train_dataloader` since it's already 848 | # prepared by webdataset. 849 | unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) 850 | 851 | if args.use_ema: 852 | ema_unet.to(accelerator.device) 853 | 854 | # Move vae, unet and text_encoder to device and cast to weight_dtype 855 | # The VAE is in float32 to avoid NaN losses. 856 | if args.pretrained_vae_model_name_or_path is not None: 857 | vae.to(accelerator.device, dtype=weight_dtype) 858 | else: 859 | vae.to(accelerator.device, dtype=torch.float32) 860 | 861 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 862 | num_update_steps_per_epoch = math.ceil( 863 | train_dataloader.num_batches / args.gradient_accumulation_steps 864 | ) 865 | if overrode_max_train_steps: 866 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 867 | # Afterwards we recalculate our number of training epochs 868 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 869 | 870 | # We need to initialize the trackers we use, and also store our configuration. 871 | # The trackers initializes automatically on the main process. 872 | if accelerator.is_main_process: 873 | logger.info("Preparing trackers") 874 | accelerator.init_trackers("instruct-pix2pix-sdxl", config=vars(args)) 875 | 876 | # Train! 877 | total_batch_size = ( 878 | args.per_gpu_batch_size 879 | * accelerator.num_processes 880 | * args.gradient_accumulation_steps 881 | ) 882 | 883 | logger.info("***** Running training *****") 884 | logger.info(f" Num examples = {train_dataloader.num_samples}") 885 | logger.info(f" Num Epochs = {args.num_train_epochs}") 886 | logger.info(f" Instantaneous batch size per device = {args.per_gpu_batch_size}") 887 | logger.info( 888 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 889 | ) 890 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 891 | logger.info(f" Total optimization steps = {args.max_train_steps}") 892 | global_step = 0 893 | first_epoch = 0 894 | 895 | # Potentially load in the weights and states from a previous save 896 | if args.resume_from_checkpoint: 897 | if args.resume_from_checkpoint != "latest": 898 | path = os.path.basename(args.resume_from_checkpoint) 899 | else: 900 | # Get the most recent checkpoint 901 | dirs = os.listdir(args.output_dir) 902 | dirs = [d for d in dirs if d.startswith("checkpoint")] 903 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 904 | path = dirs[-1] if len(dirs) > 0 else None 905 | 906 | if path is None: 907 | accelerator.print( 908 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 909 | ) 910 | args.resume_from_checkpoint = None 911 | else: 912 | accelerator.print(f"Resuming from checkpoint {path}") 913 | accelerator.load_state(os.path.join(args.output_dir, path)) 914 | global_step = int(path.split("-")[1]) 915 | 916 | first_epoch = global_step // num_update_steps_per_epoch 917 | 918 | # Only show the progress bar once on each machine. 919 | progress_bar = tqdm( 920 | range(global_step, args.max_train_steps), 921 | disable=not accelerator.is_local_main_process, 922 | ) 923 | progress_bar.set_description("Steps") 924 | 925 | for epoch in range(first_epoch, args.num_train_epochs): 926 | unet.train() 927 | for step, batch in enumerate(train_dataloader): 928 | with accelerator.accumulate(unet): 929 | # We want to learn the denoising process w.r.t the edited images which 930 | # are conditioned on the original image (which was edited) and the edit instruction. 931 | # So, first, convert images to latent space. 932 | if args.pretrained_vae_model_name_or_path is not None: 933 | edited_pixel_values = batch["edited_images"].to(dtype=weight_dtype) 934 | if vae.dtype != weight_dtype: 935 | vae.to(dtype=weight_dtype) 936 | else: 937 | edited_pixel_values = batch["edited_images"] 938 | edited_pixel_values = edited_pixel_values.to( 939 | accelerator.device, non_blocking=True 940 | ) 941 | latents = vae.encode(edited_pixel_values).latent_dist.sample() 942 | latents = latents * vae.config.scaling_factor 943 | if args.pretrained_vae_model_name_or_path is None: 944 | latents = latents.to(weight_dtype) 945 | 946 | # Sample noise that we'll add to the latents 947 | noise = torch.randn_like(latents) 948 | bsz = latents.shape[0] 949 | # Sample a random timestep for each image 950 | timesteps = torch.randint( 951 | 0, 952 | noise_scheduler.config.num_train_timesteps, 953 | (bsz,), 954 | device=latents.device, 955 | ) 956 | timesteps = timesteps.long() 957 | 958 | # Add noise to the latents according to the noise magnitude at each timestep 959 | # (this is the forward diffusion process) 960 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 961 | 962 | # time ids 963 | def compute_time_ids(original_size, crops_coords_top_left): 964 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 965 | target_size = (args.resolution, args.resolution) 966 | if not isinstance(original_size, tuple): 967 | original_size = tuple(original_size) 968 | if not isinstance(crops_coords_top_left, tuple): 969 | crops_coords_top_left = tuple(crops_coords_top_left) 970 | add_time_ids = list( 971 | original_size + crops_coords_top_left + target_size 972 | ) 973 | add_time_ids = torch.tensor([add_time_ids]) 974 | add_time_ids = add_time_ids.to( 975 | accelerator.device, dtype=weight_dtype 976 | ) 977 | return add_time_ids 978 | 979 | # Pack SDXL conditions. 980 | add_time_ids = torch.cat( 981 | [ 982 | compute_time_ids(s, c) 983 | for s, c in zip( 984 | batch["original_sizes"], batch["crop_top_lefts"] 985 | ) 986 | ] 987 | ) 988 | prompt_embeds, pooled_prompt_embeds = compute_embeddings_for_prompts( 989 | batch["edit_prompts"], text_encoders, tokenizers 990 | ) 991 | added_cond_kwargs = { 992 | "text_embeds": pooled_prompt_embeds, 993 | "time_ids": add_time_ids, 994 | } 995 | 996 | # Get the additional image embedding for conditioning. 997 | # Instead of getting a diagonal Gaussian here, we simply take the mode. 998 | if args.pretrained_vae_model_name_or_path is not None: 999 | original_pixel_values = batch["original_images"].to( 1000 | dtype=weight_dtype 1001 | ) 1002 | else: 1003 | original_pixel_values = batch["original_images"] 1004 | original_pixel_values = original_pixel_values.to( 1005 | accelerator.device, non_blocking=True 1006 | ) 1007 | original_image_embeds = vae.encode( 1008 | original_pixel_values 1009 | ).latent_dist.sample() 1010 | if args.pretrained_vae_model_name_or_path is None: 1011 | original_image_embeds = original_image_embeds.to(weight_dtype) 1012 | 1013 | # Conditioning dropout to support classifier-free guidance during inference. For more details 1014 | # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. 1015 | if args.conditioning_dropout_prob is not None: 1016 | random_p = torch.rand( 1017 | bsz, device=latents.device, generator=generator 1018 | ) 1019 | # Sample masks for the edit prompts. 1020 | prompt_mask = random_p < 2 * args.conditioning_dropout_prob 1021 | prompt_mask = prompt_mask.reshape(bsz, 1, 1) 1022 | # Final text conditioning. 1023 | prompt_embeds = torch.where( 1024 | prompt_mask, null_conditioning, prompt_embeds 1025 | ) 1026 | 1027 | # Sample masks for the original images. 1028 | image_mask_dtype = original_image_embeds.dtype 1029 | image_mask = 1 - ( 1030 | (random_p >= args.conditioning_dropout_prob).to( 1031 | image_mask_dtype 1032 | ) 1033 | * (random_p < 3 * args.conditioning_dropout_prob).to( 1034 | image_mask_dtype 1035 | ) 1036 | ) 1037 | image_mask = image_mask.reshape(bsz, 1, 1, 1) 1038 | # Final image conditioning. 1039 | original_image_embeds = image_mask * original_image_embeds 1040 | 1041 | # Concatenate the `original_image_embeds` with the `noisy_latents`. 1042 | concatenated_noisy_latents = torch.cat( 1043 | [noisy_latents, original_image_embeds], dim=1 1044 | ) 1045 | 1046 | # Get the target for loss depending on the prediction type 1047 | if noise_scheduler.config.prediction_type == "epsilon": 1048 | target = noise 1049 | elif noise_scheduler.config.prediction_type == "v_prediction": 1050 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 1051 | else: 1052 | raise ValueError( 1053 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" 1054 | ) 1055 | 1056 | # Predict the noise residual and compute loss 1057 | model_pred = unet( 1058 | concatenated_noisy_latents, 1059 | timesteps, 1060 | encoder_hidden_states=prompt_embeds, 1061 | added_cond_kwargs=added_cond_kwargs, 1062 | ).sample 1063 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1064 | 1065 | # Backpropagate 1066 | accelerator.backward(loss) 1067 | if accelerator.sync_gradients: 1068 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 1069 | optimizer.step() 1070 | lr_scheduler.step() 1071 | optimizer.zero_grad() 1072 | 1073 | # Checks if the accelerator has performed an optimization step behind the scenes 1074 | if accelerator.sync_gradients: 1075 | if args.use_ema: 1076 | ema_unet.step(unet.parameters()) 1077 | progress_bar.update(1) 1078 | global_step += 1 1079 | 1080 | if accelerator.is_main_process: 1081 | if global_step % args.checkpointing_steps == 0: 1082 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1083 | if args.checkpoints_total_limit is not None: 1084 | checkpoints = os.listdir(args.output_dir) 1085 | checkpoints = [ 1086 | d for d in checkpoints if d.startswith("checkpoint") 1087 | ] 1088 | checkpoints = sorted( 1089 | checkpoints, key=lambda x: int(x.split("-")[1]) 1090 | ) 1091 | 1092 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1093 | if len(checkpoints) >= args.checkpoints_total_limit: 1094 | num_to_remove = ( 1095 | len(checkpoints) - args.checkpoints_total_limit + 1 1096 | ) 1097 | removing_checkpoints = checkpoints[0:num_to_remove] 1098 | 1099 | logger.info( 1100 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1101 | ) 1102 | logger.info( 1103 | f"removing checkpoints: {', '.join(removing_checkpoints)}" 1104 | ) 1105 | 1106 | for removing_checkpoint in removing_checkpoints: 1107 | removing_checkpoint = os.path.join( 1108 | args.output_dir, removing_checkpoint 1109 | ) 1110 | shutil.rmtree(removing_checkpoint) 1111 | 1112 | save_path = os.path.join( 1113 | args.output_dir, f"checkpoint-{global_step}" 1114 | ) 1115 | accelerator.save_state(save_path) 1116 | logger.info(f"Saved state to {save_path}") 1117 | 1118 | if global_step % args.validation_steps == 0: 1119 | if (args.val_image_url_or_path is not None) and ( 1120 | args.validation_prompt is not None 1121 | ): 1122 | # create pipeline 1123 | if args.use_ema: 1124 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 1125 | ema_unet.store(unet.parameters()) 1126 | ema_unet.copy_to(unet.parameters()) 1127 | 1128 | log_validation( 1129 | vae=vae, 1130 | unet=unet, 1131 | text_encoder_1=text_encoder_1, 1132 | text_encoder_2=text_encoder_2, 1133 | tokenizer_1=tokenizer_1, 1134 | tokenizer_2=tokenizer_2, 1135 | args=args, 1136 | accelerator=accelerator, 1137 | weight_dtype=weight_dtype, 1138 | global_step=global_step, 1139 | ) 1140 | 1141 | if args.use_ema: 1142 | # Switch back to the original UNet parameters. 1143 | ema_unet.restore(unet.parameters()) 1144 | 1145 | logs = { 1146 | "step_loss": loss.detach().item(), 1147 | "lr": lr_scheduler.get_last_lr()[0], 1148 | } 1149 | progress_bar.set_postfix(**logs) 1150 | accelerator.log(logs, step=global_step) 1151 | 1152 | if global_step >= args.max_train_steps: 1153 | break 1154 | 1155 | # Create the pipeline using the trained modules and save it. 1156 | accelerator.wait_for_everyone() 1157 | if accelerator.is_main_process: 1158 | unet = accelerator.unwrap_model(unet) 1159 | if args.use_ema: 1160 | ema_unet.copy_to(unet.parameters()) 1161 | 1162 | pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( 1163 | args.pretrained_model_name_or_path, 1164 | text_encoder=text_encoder_1, 1165 | text_encoder_2=text_encoder_2, 1166 | tokenizer=tokenizer_1, 1167 | tokenizer_2=tokenizer_2, 1168 | vae=vae, 1169 | unet=unet, 1170 | revision=args.revision, 1171 | ) 1172 | pipeline.save_pretrained(args.output_dir) 1173 | 1174 | if args.push_to_hub: 1175 | upload_folder( 1176 | repo_id=repo_id, 1177 | folder_path=args.output_dir, 1178 | commit_message="End of training", 1179 | ignore_patterns=["step_*", "epoch_*", "checkpoint-*"], 1180 | ) 1181 | 1182 | if args.validation_prompt is not None: 1183 | edited_images = [] 1184 | pipeline = pipeline.to(accelerator.device) 1185 | original_image = ( 1186 | lambda image_url_or_path: load_image(image_url_or_path) 1187 | if urlparse(image_url_or_path).scheme 1188 | else Image.open(image_url_or_path).convert("RGB") 1189 | )(args.val_image_url_or_path) 1190 | with torch.autocast( 1191 | accelerator.device.type, 1192 | enabled=accelerator.mixed_precision == "fp16", 1193 | ): 1194 | for _ in range(args.num_validation_images): 1195 | edited_images.append( 1196 | pipeline( 1197 | args.validation_prompt, 1198 | image=original_image, 1199 | num_inference_steps=20, 1200 | image_guidance_scale=1.5, 1201 | guidance_scale=7, 1202 | generator=generator, 1203 | ).images[0] 1204 | ) 1205 | 1206 | for tracker in accelerator.trackers: 1207 | if tracker.name == "wandb": 1208 | wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) 1209 | for edited_image in edited_images: 1210 | wandb_table.add_data( 1211 | wandb.Image(original_image), 1212 | wandb.Image(edited_image), 1213 | args.validation_prompt, 1214 | ) 1215 | tracker.log({"test": wandb_table}) 1216 | 1217 | accelerator.end_training() 1218 | 1219 | 1220 | if __name__ == "__main__": 1221 | main() 1222 | --------------------------------------------------------------------------------