├── .gitignore ├── README.md ├── configs ├── data │ └── local.yaml ├── experiment │ ├── sample_b-lora_sdxl.yaml │ ├── sample_struct.yaml │ ├── sample_struct_attn.yaml │ ├── sample_style.yaml │ ├── sample_two.yaml │ ├── train_struct_sd15.yaml │ └── train_style_sd15.yaml ├── lora │ ├── encoder │ │ ├── ViT-B-32.yaml │ │ ├── ViT-H-14-proj.yaml │ │ ├── ViT-L-14-336-proj.yaml │ │ ├── ViT-L-14-proj.yaml │ │ ├── ViT-L-14.yaml │ │ ├── canny.yaml │ │ ├── hed.yaml │ │ ├── identity.yaml │ │ ├── midas.yaml │ │ └── openpose.yaml │ ├── mapper_network │ │ ├── asm15.yaml │ │ ├── fsm15.yaml │ │ ├── fsmXL.yaml │ │ └── simple.yaml │ ├── struct.yaml │ ├── struct_attn.yaml │ └── style.yaml ├── model │ ├── sd15.yaml │ └── sdxl.yaml ├── sample.yaml └── train.yaml ├── data ├── basin.jpg ├── beach.png ├── bedroom.jpg ├── coffe-cup.jpg ├── deer.jpg ├── dog.jpg ├── door.jpg ├── eagle.png ├── elephant.jpg ├── flowers.jpg ├── giraffe.jpg ├── girl.png ├── house.jpg ├── lake.jpg ├── living-room.jpg ├── motor-bike.jpg ├── pizza.jpg ├── shoes.jpg ├── statue.png ├── stormstrooper.jpg ├── vases.jpg ├── volleyball.jpg ├── woman-warrior.png └── woman.jpg ├── docs ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── images │ ├── 20.jpg │ ├── architecture.png │ ├── main.jpg │ ├── qual_struct.png │ ├── qual_style.png │ ├── quan_depth.png │ └── quan_style.png │ └── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js ├── download_weights.sh ├── environment.yaml ├── sample.py ├── sample_two.py ├── src ├── __init__.py ├── annotators │ ├── __init__.py │ ├── canny.py │ ├── hed.py │ ├── midas.py │ ├── openclip.py │ ├── style.py │ └── util.py ├── data │ ├── __init__.py │ ├── local.py │ └── transforms.py ├── lora.py ├── mapper_network.py ├── model.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | outputs/ 3 | checkpoints/ 4 | *.pyc 5 | ignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2 | 3 |

4 |

Conditional LoRAdapter for Efficient 0-Shot Control & Altering of T2I Models

5 |

6 | Nick Stracke1 · Stefan A. Baumann1 · Josh Susskind2 · Miguel A. Bautista2 · Björn Ommer1 7 |

8 |

9 | 1 CompVis Group @ LMU Munich
10 | 2 Apple 11 |

12 |

13 |

14 | ECCV 2024 15 |

16 | 17 | 18 | [![Project Page](https://img.shields.io/badge/Project-Page-blue)](https://compvis.github.io/LoRAdapter/) 19 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2405.07913) 20 | 21 | This repository contains an implementation of the paper "CTRLorALTer: Conditional LoRAdapter for Efficient 0-Shot Control & Altering of T2I Models". 22 | 23 | We present LoRAdapter, an approach that unifies both style and structure conditioning under the same formulation using a novel conditional LoRA block that enables zero-shot control. 24 | LoRAdapter is an efficient, powerful, and architecture-agnostic approach to condition text-to-image diffusion models, which enables fine-grained control conditioning during generation and outperforms recent state-of-the-art approaches. 25 | 26 | ![teaser](./docs/static/images/main.jpg) 27 | 28 | ## 🔥 Updates 29 | - Implemented [B-LoRA](https://b-lora.github.io/B-LoRA/) implicit content and style disentangle using LoRAdapter 30 | - Released Code and Weights for inference 31 | 32 | ## 💪 TODO 33 | - [x] Add training Code 34 | - [ ] Add more structure conditioning checkpoints (including SDXL) 35 | - [ ] Experiment with SD3 36 | 37 | 38 | ## Setup 39 | 40 | Create the conda environment 41 | 42 | `conda env create -f environment.yaml ` 43 | 44 | Activate the conda environment 45 | 46 | `conda activate loradapter` 47 | 48 | ## Weights 49 | 50 | All weights are available on [HuggingFace](https://huggingface.co/kliyer/LoRAdapter/tree/main). 51 | 52 | For ease of you, you can also use the provided bash script `download_weights.sh` to automatically download all available weights and place them in the the right directory. 53 | 54 | 55 | ## Usage 56 | Sampling works according to the following schema: 57 | ``` 58 | python sample.py experiment= 59 | ``` 60 | All currently available experiments are listed in `/config/experiments`. Feel free to adjust the configs according to you own needs. 61 | 62 | ### B-LoRA 63 | Sampling using the [B-LoRA](https://b-lora.github.io/B-LoRA/) LoRAdapter is possible using the config `sample_b-lora_sdxl.yaml`. By default this will condition on both content and style of the image. For conditioning on _only_ content or _only_ style, change the `adaption_mode` to either `b-lora_content` or `b-lora_style`. Also set `ignore_check` to true as we are only loading the checkpoint partially. 64 | 65 | For best results provide information about the missing modality via the text prompt or using another LoRAdapter. 66 | 67 | 68 | 69 | ## 🎓 Citation 70 | 71 | If you use this codebase or otherwise found our work valuable, please cite our paper: 72 | 73 | ```bibtex 74 | @misc{stracke2024loradapter, 75 | title={CTRLorALTer: Conditional LoRAdapter for Efficient 0-Shot Control & Altering of T2I Models}, 76 | author={Nick Stracke and Stefan Andreas Baumann and Joshua Susskind and Miguel Angel Bautista and Björn Ommer}, 77 | year={2024}, 78 | eprint={2405.07913}, 79 | archivePrefix={arXiv}, 80 | primaryClass={cs.CV} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /configs/data/local.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.data.local.ImageDataModule 5 | directories: 6 | - /p/project/degeai/stracke1/condition_lora/ip_data 7 | transform: 8 | - _target_: torchvision.transforms.Resize 9 | size: 512 10 | - _target_: torchvision.transforms.CenterCrop 11 | size: 512 12 | - _target_: torchvision.transforms.ToTensor 13 | - _target_: torchvision.transforms.Normalize 14 | mean: 15 | - 0.5 16 | - 0.5 17 | - 0.5 18 | std: 19 | - 0.5 20 | - 0.5 21 | - 0.5 22 | batch_size: 1 23 | 24 | -------------------------------------------------------------------------------- /configs/experiment/sample_b-lora_sdxl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /lora@lora.style: style 5 | - override /model: sdxl 6 | - override /data: local 7 | - _self_ 8 | 9 | size: 1024 10 | 11 | 12 | 13 | data: 14 | caption_from_name: true 15 | caption_prefix: "a picture of " 16 | directories: 17 | - data 18 | 19 | 20 | prompt: null 21 | 22 | n_samples: 4 23 | 24 | seed: 87651331119 25 | 26 | model: 27 | guidance_scale: 10 28 | 29 | 30 | save_grid: false 31 | 32 | tag: b-lora 33 | 34 | bf16: true 35 | 36 | # set true if using b-lora_content or b-lora_style 37 | ignore_check: true 38 | 39 | 40 | lora: 41 | style: 42 | ckpt_path: checkpoints/sdxl_b-lora_256 43 | config: 44 | adaption_mode: b-lora # b-lora_content or b-lora_style 45 | rank: 256 46 | c_dim: 1024 47 | lora_scale: 1 48 | 49 | local_files_only: false -------------------------------------------------------------------------------- /configs/experiment/sample_struct.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /lora@lora.struct: struct 5 | - override /lora/encoder@lora.struct.encoder: midas # hed 6 | - override /model: sd15 7 | - override /data: local 8 | - _self_ 9 | 10 | 11 | size: 512 12 | n_samples: 4 13 | 14 | save_grid: true 15 | log_cond: true 16 | 17 | data: 18 | caption_from_name: true 19 | caption_prefix: "a picture of " 20 | directories: 21 | - data 22 | 23 | model: 24 | guidance_scale: 7.5 25 | 26 | 27 | lora: 28 | struct: 29 | cfg: false 30 | # ckpt_path: checkpoints/sd15-hed-128-only-res 31 | ckpt_path: checkpoints/sd15-depth-128-only-res 32 | config: 33 | c_dim: 128 34 | rank: 128 35 | adaption_mode: only_res_conv 36 | 37 | tag: struct -------------------------------------------------------------------------------- /configs/experiment/sample_struct_attn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /lora@lora.struct: struct_attn 5 | - override /lora/encoder@lora.struct.encoder: midas # hed 6 | - override /model: sd15 7 | - override /data: local 8 | - _self_ 9 | 10 | 11 | size: 512 12 | n_samples: 4 13 | 14 | save_grid: true 15 | log_cond: true 16 | 17 | data: 18 | caption_from_name: true 19 | caption_prefix: "a picture of " 20 | directories: 21 | - data 22 | 23 | model: 24 | guidance_scale: 7.5 25 | 26 | prompt: '' 27 | 28 | lora: 29 | struct: 30 | cfg: false 31 | # ckpt_path: checkpoints/sd15-hed-128-only-res 32 | ckpt_path: checkpoints/sd15-depth-02-self 33 | config: 34 | 35 | c_dim: 128 36 | rank: 0.2 37 | adaption_mode: only_self 38 | 39 | tag: struct 40 | 41 | bf16: true -------------------------------------------------------------------------------- /configs/experiment/sample_style.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /lora@lora.style: style 5 | - override /model: sd15 6 | - override /data: local 7 | - _self_ 8 | 9 | size: 512 10 | 11 | data: 12 | caption_from_name: true 13 | caption_prefix: "a picture of " 14 | directories: 15 | - data 16 | 17 | # replace for fixed prompt 18 | prompt: null 19 | 20 | n_samples: 4 21 | 22 | seed: 7534 23 | 24 | model: 25 | guidance_scale: 7.5 26 | 27 | lora: 28 | style: 29 | ckpt_path: checkpoints/sd15-style-cross-160-h 30 | config: 31 | lora_scale: 1 32 | 33 | 34 | tag: style -------------------------------------------------------------------------------- /configs/experiment/sample_two.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # USE WITH sample_two 4 | 5 | defaults: 6 | - /lora@lora.style: style 7 | - /lora@lora.struct: struct 8 | - /data@data2: local 9 | - override /model: sd15 10 | - override /data: local 11 | - override /lora/encoder@lora.struct.encoder: midas 12 | - _self_ 13 | 14 | size: 512 15 | prompt: '' 16 | 17 | save_grid: true 18 | log_cond: true 19 | 20 | data: 21 | caption_from_name: true 22 | caption_prefix: "a picture of " 23 | directories: 24 | - data 25 | 26 | data2: 27 | caption_from_name: true 28 | caption_prefix: "a picture of " 29 | directories: 30 | - data 31 | 32 | model: 33 | guidance_scale: 7.5 34 | 35 | 36 | n_samples: 4 37 | 38 | seed: 7534 39 | 40 | lora: 41 | style: 42 | ckpt_path: checkpoints/sd15-style-cross-160-h 43 | config: 44 | lora_scale: 1 45 | struct: 46 | cfg: false 47 | # ckpt_path: checkpoints/sd15-hed-128-only-res 48 | ckpt_path: checkpoints/sd15-depth-128-only-res 49 | config: 50 | c_dim: 128 51 | rank: 128 52 | adaption_mode: only_res_conv 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/experiment/train_struct_sd15.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /lora@lora.struct: struct 5 | - override /lora/encoder@lora.struct.encoder: midas 6 | - override /model: sd15 7 | - override /data: local 8 | - _self_ 9 | 10 | data: 11 | batch_size: 8 12 | caption_from_name: true 13 | caption_prefix: "a picture of " 14 | directories: 15 | - data 16 | 17 | lora: 18 | struct: 19 | optimize: true 20 | 21 | 22 | size: 512 23 | 24 | log_c: true 25 | 26 | val_batches: 4 27 | 28 | learning_rate: 1e-4 29 | 30 | ckpt_steps: 3000 31 | val_steps: 3000 32 | 33 | epochs: 10 34 | 35 | prompt: null 36 | 37 | # model: 38 | # guidance_scale: 1.5 -------------------------------------------------------------------------------- /configs/experiment/train_style_sd15.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /lora@lora.style: style 5 | - override /model: sd15 6 | - override /data: local 7 | - _self_ 8 | 9 | data: 10 | batch_size: 8 11 | caption_from_name: true 12 | caption_prefix: "a picture of " 13 | directories: 14 | - data 15 | 16 | val_batches: 1 17 | 18 | lora: 19 | style: 20 | # rank: 208 21 | # rank: 16 22 | adaption_mode: only_cross 23 | optimize: true 24 | 25 | size: 512 26 | 27 | learning_rate: 1e-4 28 | 29 | ckpt_steps: 1000 30 | val_steps: 1000 31 | 32 | epochs: 100 33 | 34 | prompt: null 35 | 36 | # model: 37 | # guidance_scale: 1.5 -------------------------------------------------------------------------------- /configs/lora/encoder/ViT-B-32.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.style.VisionModel 5 | clip_model: openai/clip-vit-base-patch32 6 | -------------------------------------------------------------------------------- /configs/lora/encoder/ViT-H-14-proj.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.openclip.VisionModel 5 | clip_model: 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K' 6 | local_files_only: ${local_files_only} -------------------------------------------------------------------------------- /configs/lora/encoder/ViT-L-14-336-proj.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.style.VisionModel 5 | clip_model: openai/clip-vit-large-patch14-336 6 | with_projection: true -------------------------------------------------------------------------------- /configs/lora/encoder/ViT-L-14-proj.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.style.VisionModel 5 | clip_model: openai/clip-vit-large-patch14 6 | with_projection: true -------------------------------------------------------------------------------- /configs/lora/encoder/ViT-L-14.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.style.VisionModel 5 | clip_model: openai/clip-vit-large-patch14 6 | -------------------------------------------------------------------------------- /configs/lora/encoder/canny.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.canny.CannyDetector -------------------------------------------------------------------------------- /configs/lora/encoder/hed.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.hed.TorchHEDdetector 5 | size: ${size} -------------------------------------------------------------------------------- /configs/lora/encoder/identity.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: torch.nn.Identity -------------------------------------------------------------------------------- /configs/lora/encoder/midas.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.midas.DepthEstimator 5 | model: Intel/dpt-hybrid-midas 6 | size: ${size} 7 | local_files_only: ${local_files_only} -------------------------------------------------------------------------------- /configs/lora/encoder/openpose.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.annotators.openpose.OpenposeDetector -------------------------------------------------------------------------------- /configs/lora/mapper_network/asm15.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.mapper_network.AttentionStructureMapper15 5 | c_dim: ${c_dim} -------------------------------------------------------------------------------- /configs/lora/mapper_network/fsm15.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.mapper_network.FixedStructureMapper15 5 | c_dim: ${c_dim} -------------------------------------------------------------------------------- /configs/lora/mapper_network/fsmXL.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.mapper_network.FixedStructureMapperXL 5 | c_dim: ${c_dim} -------------------------------------------------------------------------------- /configs/lora/mapper_network/simple.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.mapper_network.SimpleMapper 5 | d_model: ??? 6 | c_dim: ??? -------------------------------------------------------------------------------- /configs/lora/struct.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mapper_network: fsm15 3 | - encoder: ??? 4 | - _self_ 5 | 6 | cfg: false 7 | 8 | config: 9 | c_dim: 128 10 | rank: 128 11 | adaption_mode: only_res_conv 12 | lora_cls: NewStructLoRAConv 13 | 14 | 15 | mapper_network: 16 | c_dim: ${..config.c_dim} -------------------------------------------------------------------------------- /configs/lora/struct_attn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mapper_network: asm15 3 | - encoder: ??? 4 | - _self_ 5 | 6 | cfg: false 7 | 8 | config: 9 | use_depth: true 10 | c_dim: 128 11 | rank: 128 12 | adaption_mode: only_self 13 | lora_cls: SimpleLoraLinear 14 | shift: true 15 | broadcast_tokens: false # important!! 16 | 17 | mapper_network: 18 | c_dim: ${..config.c_dim} -------------------------------------------------------------------------------- /configs/lora/style.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mapper_network: simple 3 | - encoder: ViT-H-14-proj 4 | - _self_ 5 | 6 | cfg: true 7 | 8 | config: 9 | c_dim: 1024 10 | rank: 160 11 | adaption_mode: only_cross 12 | lora_cls: SimpleLoraLinear 13 | broadcast_tokens: true 14 | 15 | 16 | mapper_network: 17 | c_dim: ${..config.c_dim} 18 | d_model: ${..config.c_dim} 19 | -------------------------------------------------------------------------------- /configs/model/sd15.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.model.SD15 5 | pipeline_type: diffusers.StableDiffusionPipeline 6 | model_name: runwayml/stable-diffusion-v1-5 7 | local_files_only: ${local_files_only} -------------------------------------------------------------------------------- /configs/model/sdxl.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: src.model.SDXL 5 | pipeline_type: diffusers.StableDiffusionXLPipeline 6 | model_name: stabilityai/stable-diffusion-xl-base-1.0 7 | local_files_only: ${local_files_only} -------------------------------------------------------------------------------- /configs/sample.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: ??? 3 | - model: ??? 4 | - _self_ 5 | - experiment: null 6 | 7 | size: ??? 8 | prompt: null 9 | 10 | val_images: 9 11 | seed: 42 12 | n_samples: 3 13 | 14 | tag: '' 15 | 16 | hydra: 17 | run: 18 | dir: outputs/sample/${tag}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 19 | sweep: 20 | dir: outputs/sample/${tag}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 21 | subdir: ${hydra.job.num} 22 | job: 23 | chdir: true 24 | 25 | local_files_only: false -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: ??? 3 | - model: ??? 4 | - _self_ 5 | - experiment: null 6 | 7 | size: ??? 8 | max_train_steps: null 9 | epochs: 20 10 | learning_rate: 1e-4 11 | 12 | lr_warmup_steps: 0 13 | lr_scheduler: constant 14 | 15 | prompt: null 16 | gradient_accumulation_steps: 1 17 | 18 | ckpt_steps: 1000 19 | val_steps: 1000 20 | val_images: 4 21 | seed: 42 22 | n_samples: 4 23 | 24 | 25 | 26 | tag: '' 27 | 28 | local_files_only: false 29 | 30 | hydra: 31 | run: 32 | dir: outputs/train/${tag}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 33 | sweep: 34 | dir: outputs/train/${tag}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 35 | subdir: ${hydra.job.num} 36 | job: 37 | chdir: true -------------------------------------------------------------------------------- /data/basin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/basin.jpg -------------------------------------------------------------------------------- /data/beach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/beach.png -------------------------------------------------------------------------------- /data/bedroom.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/bedroom.jpg -------------------------------------------------------------------------------- /data/coffe-cup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/coffe-cup.jpg -------------------------------------------------------------------------------- /data/deer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/deer.jpg -------------------------------------------------------------------------------- /data/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/dog.jpg -------------------------------------------------------------------------------- /data/door.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/door.jpg -------------------------------------------------------------------------------- /data/eagle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/eagle.png -------------------------------------------------------------------------------- /data/elephant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/elephant.jpg -------------------------------------------------------------------------------- /data/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/flowers.jpg -------------------------------------------------------------------------------- /data/giraffe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/giraffe.jpg -------------------------------------------------------------------------------- /data/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/girl.png -------------------------------------------------------------------------------- /data/house.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/house.jpg -------------------------------------------------------------------------------- /data/lake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/lake.jpg -------------------------------------------------------------------------------- /data/living-room.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/living-room.jpg -------------------------------------------------------------------------------- /data/motor-bike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/motor-bike.jpg -------------------------------------------------------------------------------- /data/pizza.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/pizza.jpg -------------------------------------------------------------------------------- /data/shoes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/shoes.jpg -------------------------------------------------------------------------------- /data/statue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/statue.png -------------------------------------------------------------------------------- /data/stormstrooper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/stormstrooper.jpg -------------------------------------------------------------------------------- /data/vases.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/vases.jpg -------------------------------------------------------------------------------- /data/volleyball.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/volleyball.jpg -------------------------------------------------------------------------------- /data/woman-warrior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/woman-warrior.png -------------------------------------------------------------------------------- /data/woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/data/woman.jpg -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 10 | 11 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 24 | 25 | 27 | 28 | 29 | 30 | 31 | 32 | Academic Project Page 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 |
55 |
56 |
57 |
58 |
59 |

CTRLorALTer: Conditional LoRAdapter for Efficient 0-Shot Control & 60 | Altering of T2I Models

61 |
62 | 63 | 64 | Nick Stracke1, 65 | 66 | 67 | Stefan Andreas 68 | Baumann1, 69 | 70 | 71 | Joshua Susskind2, 72 | 73 | 74 | Miguel Angel Bautista2, 75 | 76 | 77 | Björn Ommer1 78 | 79 |
80 | 81 |
82 | CompVis @ LMU Munich, MCML1
83 | Apple2 84 | 85 |
86 | 87 |
88 | ECCV 2024 89 |
90 | 91 |
92 | 137 |
138 |
139 |
140 |
141 |
142 |
143 | 144 | 145 | 146 | 159 | 160 | 161 |
162 |
163 |
164 | 165 |
166 |
167 | 168 | 169 |
170 |
171 |
172 |
173 |
174 | 175 | 176 |
177 |
178 |
179 |
180 |

Overview

181 |
182 |

183 | Text-to-image generative models have become a prominent and powerful tool that excels at generating 184 | high-resolution realistic images. However, guiding the generative process of these models to consider 185 | detailed forms of conditioning reflecting style and/or structure information remains an open problem. In 186 | this paper, we present LoRAdapter, an approach that unifies both style and structure 187 | conditioning under the same formulation using a novel conditional LoRA block that enables zero-shot 188 | control. LoRAdapter is an efficient, powerful, and architecture-agnostic approach to condition 189 | text-to-image diffusion models, which enables fine-grained control conditioning during generation and 190 | outperforms recent state-of-the-art approaches. 191 |

192 |
193 |
194 |
195 |
196 |
197 | 198 | 199 | 200 |
201 |
202 |
203 |
204 |

How it works

205 |
206 | 207 |

208 | Following the standard LoRA method, we keep the original weight matrix W_0 frozen and add two new 209 | trainable 210 | weight matrices A and B for each layer (i) that we want to adapt. 211 | Usually, we would train A and B on a small dataset to capture a specific style or subject, resulting in an 212 | adapter that is fixed at inference time. 213 | However, we propose to dynamically apply a transformation φ on the embedding of the first LoRA matrix A. 214 | In 215 | practice, we 216 | implement φ as an affine transformation with scale and shift parameter γ and β, respectively. These are 217 | predicted by a mapping network that depend on the conditioning c. 218 |

219 |
220 | 232 |
233 |
234 |
235 |
236 | 237 |
238 |
239 |
240 |
241 |

Qualitative Comparison

242 |
243 |

Style

244 | 245 |

246 | Samples from our method with style conditioning compared against other methods. We used an empty prompt 247 | and only conditioned on the image. We generally perform on par with IP-Adapter and outperform it on some 248 | samples. Note that the third image from the left is less degraded, and the third image from the right 249 | captures the mane of the horse better. 250 |

251 |

Structure

252 | 253 | 254 |

255 | Samples from our method with structural conditioning compared against other methods. Note that for our 256 | method, especially compared with T2I Adapter, the details of the images are substantially more closely 257 | aligned with the depth prompt (see e.g. the lamp in the background of the living room scene and the side 258 | table's legs, or the salad on the pizza) 259 |

260 |
261 |
262 |
263 |
264 |
265 | 266 | 267 |
268 |
269 |
270 |
271 |

Quantitative Comparison

272 |
273 |

Style

274 | 275 |

276 | Best results are in bold. LoRAdapter needs the fewest parameters and is able to achieve 277 | state-of-the-art 278 | performance while also enabling direct structure control. 279 |

280 |

Structure

281 | 282 | 283 |

284 | Best results are in bold. We evaluate cycle consistency (MSE-d), FID and LPIPS. The difference 285 | between configuration A and B is the number of layers that are adapted resulting in a different number of 286 | parameters. LoRAdapter outperforms all other methods in all metrics. 287 |

288 |
289 |
290 |
291 |
292 |
293 | 294 | 295 | 296 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 352 | 353 | 354 | 355 | 356 |
357 |
358 |

BibTeX

359 |

360 | @misc{stracke2024loradapter,
361 |   title={CTRLorALTer: Conditional LoRAdapter for Efficient 0-Shot Control & Altering of T2I Models}, 
362 |   author={Nick Stracke and Stefan Andreas Baumann and Joshua Susskind and Miguel Angel Bautista and Björn Ommer},
363 |   year={2024},
364 |   eprint={2405.07913},
365 |   archivePrefix={arXiv},
366 |   primaryClass={cs.CV}
367 | }
368 |       
369 |
370 |
371 | 372 | 373 | 374 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | padding: 20px; 121 | font-size: 0; 122 | } 123 | 124 | .results-carousel video { 125 | margin: 0; 126 | } 127 | 128 | .slider-pagination .slider-page { 129 | background: #000000; 130 | } 131 | 132 | .eql-cntrb { 133 | font-size: smaller; 134 | } 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /docs/static/images/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/20.jpg -------------------------------------------------------------------------------- /docs/static/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/architecture.png -------------------------------------------------------------------------------- /docs/static/images/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/main.jpg -------------------------------------------------------------------------------- /docs/static/images/qual_struct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/qual_struct.png -------------------------------------------------------------------------------- /docs/static/images/qual_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/qual_style.png -------------------------------------------------------------------------------- /docs/static/images/quan_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/quan_depth.png -------------------------------------------------------------------------------- /docs/static/images/quan_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/docs/static/images/quan_style.png -------------------------------------------------------------------------------- /docs/static/js/bulma-carousel.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaCarousel=e():t.bulmaCarousel=e()}("undefined"!=typeof self?self:this,function(){return function(i){var n={};function s(t){if(n[t])return n[t].exports;var e=n[t]={i:t,l:!1,exports:{}};return i[t].call(e.exports,e,e.exports,s),e.l=!0,e.exports}return s.m=i,s.c=n,s.d=function(t,e,i){s.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:i})},s.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return s.d(e,"a",e),e},s.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},s.p="",s(s.s=5)}([function(t,e,i){"use strict";i.d(e,"d",function(){return s}),i.d(e,"e",function(){return r}),i.d(e,"b",function(){return o}),i.d(e,"c",function(){return a}),i.d(e,"a",function(){return l});var n=i(2),s=function(e,t){(t=Array.isArray(t)?t:t.split(" ")).forEach(function(t){e.classList.remove(t)})},r=function(t){return t.getBoundingClientRect().width||t.offsetWidth},o=function(t){return t.getBoundingClientRect().height||t.offsetHeight},a=function(t){var e=1=t._x&&this._x<=e._x&&this._y>=t._y&&this._y<=e._y}},{key:"constrain",value:function(t,e){if(t._x>e._x||t._y>e._y)return this;var i=this._x,n=this._y;return null!==t._x&&(i=Math.max(i,t._x)),null!==e._x&&(i=Math.min(i,e._x)),null!==t._y&&(n=Math.max(n,t._y)),null!==e._y&&(n=Math.min(n,e._y)),new s(i,n)}},{key:"reposition",value:function(t){t.style.top=this._y+"px",t.style.left=this._x+"px"}},{key:"toString",value:function(){return"("+this._x+","+this._y+")"}},{key:"x",get:function(){return this._x},set:function(){var t=0this.state.length-this.slidesToShow&&!this.options.centerMode?this.state.next=this.state.index:this.state.next=this.state.index+this.slidesToScroll,this.show()}},{key:"previous",value:function(){this.options.loop||this.options.infinite||0!==this.state.index?this.state.next=this.state.index-this.slidesToScroll:this.state.next=this.state.index,this.show()}},{key:"start",value:function(){this._autoplay.start()}},{key:"pause",value:function(){this._autoplay.pause()}},{key:"stop",value:function(){this._autoplay.stop()}},{key:"show",value:function(t){var e=1this.options.slidesToShow&&(this.options.slidesToScroll=this.slidesToShow),this._breakpoint.init(),this.state.index>=this.state.length&&0!==this.state.index&&(this.state.index=this.state.index-this.slidesToScroll),this.state.length<=this.slidesToShow&&(this.state.index=0),this._ui.wrapper.appendChild(this._navigation.init().render()),this._ui.wrapper.appendChild(this._pagination.init().render()),this.options.navigationSwipe?this._swipe.bindEvents():this._swipe._bindEvents(),this._breakpoint.apply(),this._slides.forEach(function(t){return e._ui.container.appendChild(t)}),this._transitioner.init().apply(!0,this._setHeight.bind(this)),this.options.autoplay&&this._autoplay.init().start()}},{key:"destroy",value:function(){var e=this;this._unbindEvents(),this._items.forEach(function(t){e.element.appendChild(t)}),this.node.remove()}},{key:"id",get:function(){return this._id}},{key:"index",set:function(t){this._index=t},get:function(){return this._index}},{key:"length",set:function(t){this._length=t},get:function(){return this._length}},{key:"slides",get:function(){return this._slides},set:function(t){this._slides=t}},{key:"slidesToScroll",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToScroll():1}},{key:"slidesToShow",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToShow():1}},{key:"direction",get:function(){return"rtl"===this.element.dir.toLowerCase()||"rtl"===this.element.style.direction?"rtl":"ltr"}},{key:"wrapper",get:function(){return this._ui.wrapper}},{key:"wrapperWidth",get:function(){return this._wrapperWidth||0}},{key:"container",get:function(){return this._ui.container}},{key:"containerWidth",get:function(){return this._containerWidth||0}},{key:"slideWidth",get:function(){return this._slideWidth||0}},{key:"transitioner",get:function(){return this._transitioner}}],[{key:"attach",value:function(){var i=this,t=0>t/4).toString(16)})}},function(t,e,i){"use strict";var n=i(3),s=i(8),r=function(){function n(t,e){for(var i=0;i=t.slider.state.length-t.slider.slidesToShow&&!t.slider.options.loop&&!t.slider.options.infinite?t.stop():t.slider.next())},this.slider.options.autoplaySpeed))}},{key:"stop",value:function(){this._interval=clearInterval(this._interval),this.emit("stop",this)}},{key:"pause",value:function(){var t=this,e=0parseInt(e.changePoint,10)}),this._currentBreakpoint=this._getActiveBreakpoint(),this}},{key:"destroy",value:function(){this._unbindEvents()}},{key:"_bindEvents",value:function(){window.addEventListener("resize",this[s]),window.addEventListener("orientationchange",this[s])}},{key:"_unbindEvents",value:function(){window.removeEventListener("resize",this[s]),window.removeEventListener("orientationchange",this[s])}},{key:"_getActiveBreakpoint",value:function(){var t=!0,e=!1,i=void 0;try{for(var n,s=this.options.breakpoints[Symbol.iterator]();!(t=(n=s.next()).done);t=!0){var r=n.value;if(r.changePoint>=window.innerWidth)return r}}catch(t){e=!0,i=t}finally{try{!t&&s.return&&s.return()}finally{if(e)throw i}}return this._defaultBreakpoint}},{key:"getSlidesToShow",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToShow:this._defaultBreakpoint.slidesToShow}},{key:"getSlidesToScroll",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToScroll:this._defaultBreakpoint.slidesToScroll}},{key:"apply",value:function(){this.slider.state.index>=this.slider.state.length&&0!==this.slider.state.index&&(this.slider.state.index=this.slider.state.index-this._currentBreakpoint.slidesToScroll),this.slider.state.length<=this._currentBreakpoint.slidesToShow&&(this.slider.state.index=0),this.options.loop&&this.slider._loop.init().apply(),this.options.infinite&&this.slider._infinite.init().apply(),this.slider._setDimensions(),this.slider._transitioner.init().apply(!0,this.slider._setHeight.bind(this.slider)),this.slider._setClasses(),this.slider._navigation.refresh(),this.slider._pagination.refresh()}},{key:s,value:function(t){var e=this._getActiveBreakpoint();e.slidesToShow!==this._currentBreakpoint.slidesToShow&&(this._currentBreakpoint=e,this.apply())}}]),e}();e.a=r},function(t,e,i){"use strict";var n=function(){function n(t,e){for(var i=0;ithis.slider.state.length-1-this._infiniteCount;i-=1)e=i-1,t.unshift(this._cloneSlide(this.slider.slides[e],e-this.slider.state.length));for(var n=[],s=0;s=this.slider.state.length?(this.slider.state.index=this.slider.state.next=this.slider.state.next-this.slider.state.length,this.slider.transitioner.apply(!0)):this.slider.state.next<0&&(this.slider.state.index=this.slider.state.next=this.slider.state.length+this.slider.state.next,this.slider.transitioner.apply(!0)))}},{key:"_cloneSlide",value:function(t,e){var i=t.cloneNode(!0);return i.dataset.sliderIndex=e,i.dataset.cloned=!0,(i.querySelectorAll("[id]")||[]).forEach(function(t){t.setAttribute("id","")}),i}}]),e}();e.a=s},function(t,e,i){"use strict";var n=i(12),s=function(){function n(t,e){for(var i=0;ithis.slider.state.length-this.slider.slidesToShow&&Object(n.a)(this.slider._slides[this.slider.state.length-1],this.slider.wrapper)?this.slider.state.next=0:this.slider.state.next=Math.min(Math.max(this.slider.state.next,0),this.slider.state.length-this.slider.slidesToShow):this.slider.state.next=0:this.slider.state.next<=0-this.slider.slidesToScroll?this.slider.state.next=this.slider.state.length-this.slider.slidesToShow:this.slider.state.next=0)}}]),e}();e.a=r},function(t,e,i){"use strict";i.d(e,"a",function(){return n});var n=function(t,e){var i=t.getBoundingClientRect();return e=e||document.documentElement,0<=i.top&&0<=i.left&&i.bottom<=(window.innerHeight||e.clientHeight)&&i.right<=(window.innerWidth||e.clientWidth)}},function(t,e,i){"use strict";var n=i(14),s=i(1),r=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.remove("is-hidden"),0===this.slider.state.next?(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.remove("is-hidden")):this.slider.state.next>=this.slider.state.length-this.slider.slidesToShow&&!this.slider.options.centerMode?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden")):this.slider.state.next>=this.slider.state.length-1&&this.slider.options.centerMode&&(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden"))):(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.add("is-hidden")))}},{key:"render",value:function(){return this.node}}]),e}();e.a=o},function(t,e,i){"use strict";e.a=function(t){return'
'+t.previous+'
\n
'+t.next+"
"}},function(t,e,i){"use strict";var n=i(16),s=i(17),r=i(1),o=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow){for(var t=0;t<=this._count;t++){var e=document.createRange().createContextualFragment(Object(s.a)()).firstChild;e.dataset.index=t*this.slider.slidesToScroll,this._pages.push(e),this._ui.container.appendChild(e)}this._bindEvents()}}},{key:"onPageClick",value:function(t){this._supportsPassive||t.preventDefault(),this.slider.state.next=t.currentTarget.dataset.index,this.slider.show()}},{key:"onResize",value:function(){this._draw()}},{key:"refresh",value:function(){var e=this,t=void 0;(t=this.slider.options.infinite?Math.ceil(this.slider.state.length-1/this.slider.slidesToScroll):Math.ceil((this.slider.state.length-this.slider.slidesToShow)/this.slider.slidesToScroll))!==this._count&&(this._count=t,this._draw()),this._pages.forEach(function(t){t.classList.remove("is-active"),parseInt(t.dataset.index,10)===e.slider.state.next%e.slider.state.length&&t.classList.add("is-active")})}},{key:"render",value:function(){return this.node}}]),e}();e.a=a},function(t,e,i){"use strict";e.a=function(){return'
'}},function(t,e,i){"use strict";e.a=function(){return'
'}},function(t,e,i){"use strict";var n=i(4),s=i(1),r=function(){function n(t,e){for(var i=0;iMath.abs(this._lastTranslate.y)&&(this._supportsPassive||t.preventDefault(),t.stopPropagation())}}},{key:"onStopDrag",value:function(t){this._origin&&this._lastTranslate&&(Math.abs(this._lastTranslate.x)>.2*this.width?this._lastTranslate.x<0?this.slider.next():this.slider.previous():this.slider.show(!0)),this._origin=null,this._lastTranslate=null}}]),e}();e.a=o},function(t,e,i){"use strict";var n=i(20),s=i(21),r=function(){function n(t,e){for(var i=0;it.x?(s.x=0,this.slider.state.next=0):this.options.vertical&&Math.abs(this._position.y)>t.y&&(s.y=0,this.slider.state.next=0)),this._position.x=s.x,this._position.y=s.y,this.options.centerMode&&(this._position.x=this._position.x+this.slider.wrapperWidth/2-Object(o.e)(i)/2),"rtl"===this.slider.direction&&(this._position.x=-this._position.x,this._position.y=-this._position.y),this.slider.container.style.transform="translate3d("+this._position.x+"px, "+this._position.y+"px, 0)",n.x>t.x&&this.slider.transitioner.end()}}},{key:"onTransitionEnd",value:function(t){"translate"===this.options.effect&&(this.transitioner.isAnimating()&&t.target==this.slider.container&&this.options.infinite&&this.slider._infinite.onTransitionEnd(t),this.transitioner.end())}}]),n}();e.a=n},function(t,e,i){"use strict";e.a={initialSlide:0,slidesToScroll:1,slidesToShow:1,navigation:!0,navigationKeys:!0,navigationSwipe:!0,pagination:!0,loop:!1,infinite:!1,effect:"translate",duration:300,timing:"ease",autoplay:!1,autoplaySpeed:3e3,pauseOnHover:!0,breakpoints:[{changePoint:480,slidesToShow:1,slidesToScroll:1},{changePoint:640,slidesToShow:2,slidesToScroll:2},{changePoint:768,slidesToShow:3,slidesToScroll:3}],onReady:null,icons:{previous:'\n \n ',next:'\n \n '}}},function(t,e,i){"use strict";e.a=function(t){return'
\n
\n
'}},function(t,e,i){"use strict";e.a=function(){return'
'}}]).default}); -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define([], factory); 6 | else if(typeof exports === 'object') 7 | exports["bulmaSlider"] = factory(); 8 | else 9 | root["bulmaSlider"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { 50 | /******/ configurable: false, 51 | /******/ enumerable: true, 52 | /******/ get: getter 53 | /******/ }); 54 | /******/ } 55 | /******/ }; 56 | /******/ 57 | /******/ // getDefaultExport function for compatibility with non-harmony modules 58 | /******/ __webpack_require__.n = function(module) { 59 | /******/ var getter = module && module.__esModule ? 60 | /******/ function getDefault() { return module['default']; } : 61 | /******/ function getModuleExports() { return module; }; 62 | /******/ __webpack_require__.d(getter, 'a', getter); 63 | /******/ return getter; 64 | /******/ }; 65 | /******/ 66 | /******/ // Object.prototype.hasOwnProperty.call 67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 68 | /******/ 69 | /******/ // __webpack_public_path__ 70 | /******/ __webpack_require__.p = ""; 71 | /******/ 72 | /******/ // Load entry module and return exports 73 | /******/ return __webpack_require__(__webpack_require__.s = 0); 74 | /******/ }) 75 | /************************************************************************/ 76 | /******/ ([ 77 | /* 0 */ 78 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 79 | 80 | "use strict"; 81 | Object.defineProperty(__webpack_exports__, "__esModule", { value: true }); 82 | /* harmony export (binding) */ __webpack_require__.d(__webpack_exports__, "isString", function() { return isString; }); 83 | /* harmony import */ var __WEBPACK_IMPORTED_MODULE_0__events__ = __webpack_require__(1); 84 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; 85 | 86 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 87 | 88 | var _typeof = typeof Symbol === "function" && typeof Symbol.iterator === "symbol" ? function (obj) { return typeof obj; } : function (obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; }; 89 | 90 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 91 | 92 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } 93 | 94 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } 95 | 96 | 97 | 98 | var isString = function isString(unknown) { 99 | return typeof unknown === 'string' || !!unknown && (typeof unknown === 'undefined' ? 'undefined' : _typeof(unknown)) === 'object' && Object.prototype.toString.call(unknown) === '[object String]'; 100 | }; 101 | 102 | var bulmaSlider = function (_EventEmitter) { 103 | _inherits(bulmaSlider, _EventEmitter); 104 | 105 | function bulmaSlider(selector) { 106 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 107 | 108 | _classCallCheck(this, bulmaSlider); 109 | 110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this)); 111 | 112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector; 113 | // An invalid selector or non-DOM node has been provided. 114 | if (!_this.element) { 115 | throw new Error('An invalid selector or non-DOM node has been provided.'); 116 | } 117 | 118 | _this._clickEvents = ['click']; 119 | /// Set default options and merge with instance defined 120 | _this.options = _extends({}, options); 121 | 122 | _this.onSliderInput = _this.onSliderInput.bind(_this); 123 | 124 | _this.init(); 125 | return _this; 126 | } 127 | 128 | /** 129 | * Initiate all DOM element containing selector 130 | * @method 131 | * @return {Array} Array of all slider instances 132 | */ 133 | 134 | 135 | _createClass(bulmaSlider, [{ 136 | key: 'init', 137 | 138 | 139 | /** 140 | * Initiate plugin 141 | * @method init 142 | * @return {void} 143 | */ 144 | value: function init() { 145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999)); 146 | this.output = this._findOutputForSlider(); 147 | 148 | this._bindEvents(); 149 | 150 | if (this.output) { 151 | if (this.element.classList.contains('has-output-tooltip')) { 152 | // Get new output position 153 | var newPosition = this._getSliderOutputPosition(); 154 | 155 | // Set output position 156 | this.output.style['left'] = newPosition.position; 157 | } 158 | } 159 | 160 | this.emit('bulmaslider:ready', this.element.value); 161 | } 162 | }, { 163 | key: '_findOutputForSlider', 164 | value: function _findOutputForSlider() { 165 | var _this2 = this; 166 | 167 | var result = null; 168 | var outputs = document.getElementsByTagName('output') || []; 169 | 170 | Array.from(outputs).forEach(function (output) { 171 | if (output.htmlFor == _this2.element.getAttribute('id')) { 172 | result = output; 173 | return true; 174 | } 175 | }); 176 | return result; 177 | } 178 | }, { 179 | key: '_getSliderOutputPosition', 180 | value: function _getSliderOutputPosition() { 181 | // Update output position 182 | var newPlace, minValue; 183 | 184 | var style = window.getComputedStyle(this.element, null); 185 | // Measure width of range input 186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10); 187 | 188 | // Figure out placement percentage between left and right of input 189 | if (!this.element.getAttribute('min')) { 190 | minValue = 0; 191 | } else { 192 | minValue = this.element.getAttribute('min'); 193 | } 194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue); 195 | 196 | // Prevent bubble from going beyond left or right (unsupported browsers) 197 | if (newPoint < 0) { 198 | newPlace = 0; 199 | } else if (newPoint > 1) { 200 | newPlace = sliderWidth; 201 | } else { 202 | newPlace = sliderWidth * newPoint; 203 | } 204 | 205 | return { 206 | 'position': newPlace + 'px' 207 | }; 208 | } 209 | 210 | /** 211 | * Bind all events 212 | * @method _bindEvents 213 | * @return {void} 214 | */ 215 | 216 | }, { 217 | key: '_bindEvents', 218 | value: function _bindEvents() { 219 | if (this.output) { 220 | // Add event listener to update output when slider value change 221 | this.element.addEventListener('input', this.onSliderInput, false); 222 | } 223 | } 224 | }, { 225 | key: 'onSliderInput', 226 | value: function onSliderInput(e) { 227 | e.preventDefault(); 228 | 229 | if (this.element.classList.contains('has-output-tooltip')) { 230 | // Get new output position 231 | var newPosition = this._getSliderOutputPosition(); 232 | 233 | // Set output position 234 | this.output.style['left'] = newPosition.position; 235 | } 236 | 237 | // Check for prefix and postfix 238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : ''; 239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : ''; 240 | 241 | // Update output with slider value 242 | this.output.value = prefix + this.element.value + postfix; 243 | 244 | this.emit('bulmaslider:ready', this.element.value); 245 | } 246 | }], [{ 247 | key: 'attach', 248 | value: function attach() { 249 | var _this3 = this; 250 | 251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider'; 252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 253 | 254 | var instances = new Array(); 255 | 256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector]; 257 | elements.forEach(function (element) { 258 | if (typeof element[_this3.constructor.name] === 'undefined') { 259 | var instance = new bulmaSlider(element, options); 260 | element[_this3.constructor.name] = instance; 261 | instances.push(instance); 262 | } else { 263 | instances.push(element[_this3.constructor.name]); 264 | } 265 | }); 266 | 267 | return instances; 268 | } 269 | }]); 270 | 271 | return bulmaSlider; 272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]); 273 | 274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider); 275 | 276 | /***/ }), 277 | /* 1 */ 278 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 279 | 280 | "use strict"; 281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 282 | 283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 284 | 285 | var EventEmitter = function () { 286 | function EventEmitter() { 287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : []; 288 | 289 | _classCallCheck(this, EventEmitter); 290 | 291 | this._listeners = new Map(listeners); 292 | this._middlewares = new Map(); 293 | } 294 | 295 | _createClass(EventEmitter, [{ 296 | key: "listenerCount", 297 | value: function listenerCount(eventName) { 298 | if (!this._listeners.has(eventName)) { 299 | return 0; 300 | } 301 | 302 | var eventListeners = this._listeners.get(eventName); 303 | return eventListeners.length; 304 | } 305 | }, { 306 | key: "removeListeners", 307 | value: function removeListeners() { 308 | var _this = this; 309 | 310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false; 312 | 313 | if (eventName !== null) { 314 | if (Array.isArray(eventName)) { 315 | name.forEach(function (e) { 316 | return _this.removeListeners(e, middleware); 317 | }); 318 | } else { 319 | this._listeners.delete(eventName); 320 | 321 | if (middleware) { 322 | this.removeMiddleware(eventName); 323 | } 324 | } 325 | } else { 326 | this._listeners = new Map(); 327 | } 328 | } 329 | }, { 330 | key: "middleware", 331 | value: function middleware(eventName, fn) { 332 | var _this2 = this; 333 | 334 | if (Array.isArray(eventName)) { 335 | name.forEach(function (e) { 336 | return _this2.middleware(e, fn); 337 | }); 338 | } else { 339 | if (!Array.isArray(this._middlewares.get(eventName))) { 340 | this._middlewares.set(eventName, []); 341 | } 342 | 343 | this._middlewares.get(eventName).push(fn); 344 | } 345 | } 346 | }, { 347 | key: "removeMiddleware", 348 | value: function removeMiddleware() { 349 | var _this3 = this; 350 | 351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 352 | 353 | if (eventName !== null) { 354 | if (Array.isArray(eventName)) { 355 | name.forEach(function (e) { 356 | return _this3.removeMiddleware(e); 357 | }); 358 | } else { 359 | this._middlewares.delete(eventName); 360 | } 361 | } else { 362 | this._middlewares = new Map(); 363 | } 364 | } 365 | }, { 366 | key: "on", 367 | value: function on(name, callback) { 368 | var _this4 = this; 369 | 370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 371 | 372 | if (Array.isArray(name)) { 373 | name.forEach(function (e) { 374 | return _this4.on(e, callback); 375 | }); 376 | } else { 377 | name = name.toString(); 378 | var split = name.split(/,|, | /); 379 | 380 | if (split.length > 1) { 381 | split.forEach(function (e) { 382 | return _this4.on(e, callback); 383 | }); 384 | } else { 385 | if (!Array.isArray(this._listeners.get(name))) { 386 | this._listeners.set(name, []); 387 | } 388 | 389 | this._listeners.get(name).push({ once: once, callback: callback }); 390 | } 391 | } 392 | } 393 | }, { 394 | key: "once", 395 | value: function once(name, callback) { 396 | this.on(name, callback, true); 397 | } 398 | }, { 399 | key: "emit", 400 | value: function emit(name, data) { 401 | var _this5 = this; 402 | 403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 404 | 405 | name = name.toString(); 406 | var listeners = this._listeners.get(name); 407 | var middlewares = null; 408 | var doneCount = 0; 409 | var execute = silent; 410 | 411 | if (Array.isArray(listeners)) { 412 | listeners.forEach(function (listener, index) { 413 | // Start Middleware checks unless we're doing a silent emit 414 | if (!silent) { 415 | middlewares = _this5._middlewares.get(name); 416 | // Check and execute Middleware 417 | if (Array.isArray(middlewares)) { 418 | middlewares.forEach(function (middleware) { 419 | middleware(data, function () { 420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 421 | 422 | if (newData !== null) { 423 | data = newData; 424 | } 425 | doneCount++; 426 | }, name); 427 | }); 428 | 429 | if (doneCount >= middlewares.length) { 430 | execute = true; 431 | } 432 | } else { 433 | execute = true; 434 | } 435 | } 436 | 437 | // If Middleware checks have been passed, execute 438 | if (execute) { 439 | if (listener.once) { 440 | listeners[index] = null; 441 | } 442 | listener.callback(data); 443 | } 444 | }); 445 | 446 | // Dirty way of removing used Events 447 | while (listeners.indexOf(null) !== -1) { 448 | listeners.splice(listeners.indexOf(null), 1); 449 | } 450 | } 451 | } 452 | }]); 453 | 454 | return EventEmitter; 455 | }(); 456 | 457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter); 458 | 459 | /***/ }) 460 | /******/ ])["default"]; 461 | }); -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: true, 13 | autoplaySpeed: 5000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | bulmaSlider.attach(); 20 | 21 | }) 22 | -------------------------------------------------------------------------------- /download_weights.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=(sd15-depth-128-only-res/struct sd15-hed-128-only-res/struct sd15-style-cross-160-h/style sdxl_b-lora_256/style) 3 | 4 | for p in "${paths[@]}" 5 | do 6 | mkdir -p checkpoints/$p 7 | wget -O checkpoints/$p/lora-checkpoint.pt https://huggingface.co/kliyer/LoRAdapter/resolve/main/$p/lora-checkpoint.pt?download=true 8 | wget -O checkpoints/$p/mapper-checkpoint.pt https://huggingface.co/kliyer/LoRAdapter/resolve/main/$p/mapper-checkpoint.pt?download=true 9 | done -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: loradapter 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python=3.11 8 | - pip=23.3 9 | - pytorch=2.1.2 10 | - pytorch-cuda=12.1 11 | - pip: 12 | - diffusers==0.25.0 13 | - accelerate 14 | - torchvision 15 | - transformers 16 | - tensorboard 17 | - Pillow 18 | - hydra-core 19 | - jaxtyping 20 | - einops 21 | - numpy 22 | - open-clip-torch 23 | - torch-fidelity 24 | - basicsr 25 | - tqdm -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import math 3 | from src.utils import DataProvider 4 | from src.model import ModelBase 5 | from diffusers.optimization import get_scheduler 6 | import torch 7 | from accelerate import Accelerator 8 | from tqdm.auto import tqdm 9 | from torch import nn 10 | from pathlib import Path 11 | import numpy as np 12 | import torchvision.transforms.functional as TF 13 | from accelerate.logging import get_logger 14 | from PIL import Image 15 | from functools import reduce 16 | 17 | from src.utils import add_lora_from_config 18 | 19 | 20 | torch.set_float32_matmul_precision("high") 21 | 22 | 23 | def get_imgs_from_batch(batch: dict[str, torch.Tensor], is_video=False) -> torch.Tensor: 24 | if is_video: 25 | B, C, T, H, W = batch["sequence"].shape 26 | 27 | batch_selector = torch.linspace(0, B - 1, B, dtype=torch.int) 28 | frame_selector = torch.randint(0, T, (B,)) 29 | 30 | # imgs in [-1, 1] 31 | imgs = batch["sequence"] 32 | imgs = imgs[batch_selector, :, frame_selector] 33 | return imgs 34 | 35 | B, C, H, W = batch["jpg"].shape 36 | imgs = batch["jpg"] 37 | 38 | return imgs 39 | 40 | 41 | @hydra.main(config_path="configs", config_name="sample") 42 | def main(cfg): 43 | output_path = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) 44 | 45 | accelerator = Accelerator( 46 | project_dir=output_path / "logs", 47 | ) 48 | 49 | str_cfg = cfg 50 | print(str_cfg) 51 | cfg = hydra.utils.instantiate(cfg) 52 | model: ModelBase = cfg.model 53 | 54 | model = model.to(accelerator.device) 55 | model.pipe.to(accelerator.device) 56 | 57 | weight_type = torch.float32 58 | if cfg.get("bf16", False): 59 | weight_type = torch.bfloat16 60 | 61 | cfg_mask = add_lora_from_config(model, cfg, accelerator.device, weight_type) 62 | 63 | model.unet.to(accelerator.device, weight_type) 64 | model = model.to(accelerator.device, weight_type) 65 | model.pipe.to(accelerator.device, weight_type) 66 | 67 | print(cfg_mask) 68 | 69 | dm = cfg.data 70 | val_dataloader = dm.val_dataloader() 71 | print(val_dataloader) 72 | 73 | logger = get_logger(__name__) 74 | 75 | logger.info("==================================") 76 | logger.info(str_cfg) 77 | logger.info(output_path) 78 | 79 | logger.info("prepare network") 80 | val_dataloader = accelerator.prepare(val_dataloader) 81 | unet = model.unet 82 | # model.unet = unet 83 | 84 | unet.requires_grad_(False) 85 | unet.eval() 86 | 87 | cn = cfg.get("use_controlnet", False) 88 | if cn: 89 | encoder = cfg.cn_encoder 90 | encoder.to(accelerator.device) 91 | 92 | images = [] 93 | val_prompts = [] 94 | with torch.no_grad(): 95 | for i, val_batch in enumerate(tqdm(val_dataloader)): 96 | 97 | generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed) 98 | 99 | if cfg.get("prompt", None) is not None: 100 | if len(cfg.prompt) > 1: 101 | prompts = cfg.prompt 102 | else: 103 | prompts = [cfg.prompt] 104 | else: 105 | prompts = val_batch["caption"] 106 | 107 | print(prompts) 108 | val_prompts.append(prompts) 109 | 110 | imgs = get_imgs_from_batch(val_batch, cfg.get("is_video", False)) 111 | imgs = imgs.to(accelerator.device, weight_type) 112 | imgs = imgs.clip(-1.0, 1.0) 113 | 114 | cs = [imgs for _ in model.mappers] 115 | 116 | pipeline_args = { 117 | "prompt": prompts, 118 | "num_images_per_prompt": cfg.n_samples, 119 | "cs": cs, 120 | "generator": generator, 121 | "cfg_mask": cfg_mask, 122 | # "prompt_offset_step": cfg.get("prompt_offset_step", 0), 123 | } 124 | 125 | if cn: 126 | s_imgs = encoder(imgs) 127 | pipeline_args["image"] = s_imgs 128 | 129 | preds = model.sample(**pipeline_args) 130 | 131 | for j, pred in enumerate(preds): 132 | pred.save(f"{accelerator.process_index}-img_{i}_{j}_sample.jpg") 133 | 134 | if cfg.get("log_cond", False): 135 | # depth is in [0, 1] 136 | cond = model.encoders[-1](imgs) 137 | print(cond.shape) 138 | log_pils = [TF.to_pil_image(c.float().cpu()) for c in cond] 139 | else: 140 | log_pils = [TF.to_pil_image((img.float().cpu() + 1) / 2) for img in imgs] 141 | 142 | for j, log_pil in enumerate(log_pils): 143 | log_pil.save(f"{accelerator.process_index}-img_{i}_{j}_prompt.jpg") 144 | 145 | if cfg.get("save_grid", False): 146 | images.append( 147 | np.concatenate( 148 | [np.asarray(img.resize((512, 512))) for img in [*log_pils, *preds]], 149 | axis=1, 150 | ) 151 | ) 152 | 153 | if cfg.get("save_grid", False): 154 | np_images = np.concatenate(images, axis=0) 155 | Image.fromarray(np_images).save("test.jpg") 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /sample_two.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import math 3 | 4 | from src.utils import DataProvider 5 | from src.model import ModelBase 6 | from diffusers.optimization import get_scheduler 7 | import torch 8 | from accelerate import Accelerator 9 | from tqdm.auto import tqdm 10 | from torch import nn 11 | from pathlib import Path 12 | import numpy as np 13 | import torchvision.transforms.functional as TF 14 | from accelerate.logging import get_logger 15 | from PIL import Image 16 | from functools import reduce 17 | 18 | from src.utils import add_lora_from_config 19 | 20 | 21 | # only used this for HED so far 22 | torch.set_float32_matmul_precision("high") 23 | 24 | 25 | def get_imgs_from_batch(batch: dict[str, torch.Tensor], is_video=False) -> torch.Tensor: 26 | if is_video: 27 | B, C, T, H, W = batch["sequence"].shape 28 | 29 | batch_selector = torch.linspace(0, B - 1, B, dtype=torch.int) 30 | frame_selector = torch.randint(0, T, (B,)) 31 | 32 | # imgs in [-1, 1] 33 | imgs = batch["sequence"] 34 | imgs = imgs[batch_selector, :, frame_selector] 35 | return imgs 36 | 37 | imgs = batch["jpg"] 38 | 39 | return imgs 40 | 41 | 42 | @hydra.main(config_path="configs", config_name="sample") 43 | def main(cfg): 44 | output_path = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) 45 | 46 | accelerator = Accelerator( 47 | project_dir=output_path / "logs", 48 | ) 49 | 50 | str_cfg = cfg 51 | print(str_cfg) 52 | cfg = hydra.utils.instantiate(cfg) 53 | model: ModelBase = cfg.model 54 | 55 | model = model.to(accelerator.device) 56 | model.pipe.to(accelerator.device) 57 | 58 | weight_type = torch.float32 59 | if cfg.get("bf16", False): 60 | weight_type = torch.bfloat16 61 | 62 | cfg_mask = add_lora_from_config(model, cfg, accelerator.device, weight_type) 63 | 64 | model.unet.to(accelerator.device, weight_type) 65 | model = model.to(accelerator.device, weight_type) 66 | model.pipe.to(accelerator.device, weight_type) 67 | 68 | print(cfg_mask) 69 | 70 | dm1 = cfg.data 71 | val_dataloader1 = dm1.val_dataloader() 72 | print(val_dataloader1) 73 | 74 | val_dataloader2 = dm1.val_dataloader() 75 | try: 76 | dm2 = cfg.data2 77 | val_dataloader2 = dm2.val_dataloader() 78 | except: 79 | print("no second dataloader provided") 80 | 81 | logger = get_logger(__name__) 82 | 83 | logger.info("==================================") 84 | logger.info(str_cfg) 85 | logger.info(output_path) 86 | 87 | logger.info("prepare network") 88 | val_dataloader1 = accelerator.prepare(val_dataloader1) 89 | unet = model.unet 90 | # model.unet = unet 91 | 92 | unet.requires_grad_(False) 93 | unet.eval() 94 | 95 | images = [] 96 | val_prompts = [] 97 | for it, val_batch in enumerate(tqdm(val_dataloader1)): 98 | 99 | if it < cfg.get("skip", 0): 100 | continue 101 | 102 | for ib, val_batch2 in enumerate(tqdm(val_dataloader2)): 103 | 104 | generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed) 105 | 106 | i = max(it, ib) 107 | 108 | if cfg.get("prompt", None) is not None: 109 | if len(cfg.prompt) > 1: 110 | prompts = cfg.prompt 111 | else: 112 | prompts = [cfg.prompt] 113 | else: 114 | prompts = val_batch["caption"] 115 | 116 | print(prompts) 117 | val_prompts.append(prompts) 118 | 119 | # B, C, T, H, W = batch["sequence"].shape 120 | # imgs = get_imgs_from_batch(val_batch, cfg.get("is_video", False)) 121 | 122 | imgs = val_batch["jpg"] 123 | imgs = imgs.to(accelerator.device, weight_type) 124 | imgs = imgs.clip(-1.0, 1.0) 125 | 126 | imgs2 = val_batch2["jpg"] 127 | imgs2 = imgs2.to(accelerator.device, weight_type) 128 | imgs2 = imgs2.clip(-1.0, 1.0) 129 | 130 | cs = [imgs, imgs2] 131 | 132 | pipeline_args = { 133 | "prompt": prompts, 134 | "num_images_per_prompt": cfg.n_samples, 135 | "cs": cs, 136 | "generator": generator, 137 | "cfg_mask": cfg_mask, 138 | # "prompt_offset_step": cfg.get("prompt_offset_step", 0), 139 | } 140 | 141 | preds = model.sample(**pipeline_args) 142 | 143 | for j, pred in enumerate(preds): 144 | pred.save(f"{accelerator.process_index}-img_{it}_{ib}_{j}_sample.jpg") 145 | 146 | if cfg.get("save_grid", False): 147 | 148 | if cfg.get("log_cond", False): 149 | # depth is in [0, 1] 150 | cond1 = (imgs + 1) / 2 151 | cond2 = model.encoders[-1](imgs2) 152 | log_pils = [TF.to_pil_image((torch.cat([c1, c2], dim=2)).float().cpu()) for c1, c2 in zip(cond1, cond2)] 153 | else: 154 | log_pils = [TF.to_pil_image((img.float().cpu() + 1) / 2) for img in imgs] 155 | 156 | for j, log_pil in enumerate(log_pils): 157 | log_pil.save(f"{accelerator.process_index}-img_{i}_{j}_prompt.jpg") 158 | 159 | images.append( 160 | np.concatenate( 161 | # we know height is constant 162 | [np.asarray(img.resize((int(cfg.size * img.width / img.height), cfg.size))) for img in [*log_pils, *preds]], 163 | axis=1, 164 | ) 165 | ) 166 | 167 | if cfg.get("save_grid", False): 168 | np_images = np.concatenate(images, axis=0) 169 | Image.fromarray(np_images).save("test.jpg") 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/src/__init__.py -------------------------------------------------------------------------------- /src/annotators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/src/annotators/__init__.py -------------------------------------------------------------------------------- /src/annotators/canny.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import torchvision.transforms.functional as TF 4 | import numpy as np 5 | 6 | 7 | class CannyDetector: 8 | 9 | def to(self, *args, **kwargs): 10 | return self 11 | 12 | def __call__(self, imgs, low_threshold=100, high_threshold=250): 13 | assert isinstance(imgs, torch.Tensor) 14 | assert imgs.ndim == 4 15 | assert imgs.shape[1] == 3 16 | assert imgs.dtype == torch.float32 17 | assert imgs.max() <= 1.0 18 | assert imgs.min() >= -1.0 19 | 20 | imgs = (imgs + 1.0) / 2.0 21 | edges = [] 22 | for img in imgs: 23 | img = TF.to_pil_image(img) 24 | img = np.array(img) 25 | edge = cv2.Canny(img, low_threshold, high_threshold) 26 | edge = TF.to_tensor(edge) 27 | edge = edge.repeat_interleave(3, dim=0) 28 | edges.append(edge) 29 | 30 | return torch.stack(edges).to(imgs.device) 31 | -------------------------------------------------------------------------------- /src/annotators/hed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torchvision import transforms as T 7 | from src.annotators.util import annotator_ckpts_path 8 | 9 | # Taken from https://github.com/lllyasviel/ControlNet/blob/main/annotator/hed/__init__.py 10 | # Thanks 11 | 12 | 13 | class DoubleConvBlock(torch.nn.Module): 14 | def __init__(self, input_channel, output_channel, layer_number): 15 | super().__init__() 16 | self.convs = torch.nn.Sequential() 17 | self.convs.append( 18 | torch.nn.Conv2d( 19 | in_channels=input_channel, 20 | out_channels=output_channel, 21 | kernel_size=(3, 3), 22 | stride=(1, 1), 23 | padding=1, 24 | ) 25 | ) 26 | for i in range(1, layer_number): 27 | self.convs.append( 28 | torch.nn.Conv2d( 29 | in_channels=output_channel, 30 | out_channels=output_channel, 31 | kernel_size=(3, 3), 32 | stride=(1, 1), 33 | padding=1, 34 | ) 35 | ) 36 | self.projection = torch.nn.Conv2d( 37 | in_channels=output_channel, 38 | out_channels=1, 39 | kernel_size=(1, 1), 40 | stride=(1, 1), 41 | padding=0, 42 | ) 43 | 44 | def __call__(self, x, down_sampling=False): 45 | h = x 46 | if down_sampling: 47 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) 48 | for conv in self.convs: 49 | h = conv(h) 50 | h = torch.nn.functional.relu(h) 51 | return h, self.projection(h) 52 | 53 | 54 | class ControlNetHED_Apache2(torch.nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) 58 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) 59 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) 60 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) 61 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) 62 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) 63 | 64 | def forward(self, x): 65 | h = x - self.norm 66 | h, projection1 = self.block1(h) 67 | h, projection2 = self.block2(h, down_sampling=True) 68 | h, projection3 = self.block3(h, down_sampling=True) 69 | h, projection4 = self.block4(h, down_sampling=True) 70 | h, projection5 = self.block5(h, down_sampling=True) 71 | return projection1, projection2, projection3, projection4, projection5 72 | 73 | 74 | class TorchHEDdetector(nn.Module): 75 | def __init__(self, size, return_without_channels: bool = False, local_files_only: bool = False): 76 | super().__init__() 77 | 78 | self.size = size 79 | self.return_without_channels = return_without_channels 80 | 81 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth" 82 | modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") 83 | if not os.path.exists(modelpath): 84 | from basicsr.utils.download_util import load_file_from_url 85 | 86 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 87 | 88 | self.netNetwork = ControlNetHED_Apache2().float().eval() 89 | self.netNetwork.load_state_dict(torch.load(modelpath)) 90 | 91 | self.netNetwork = self.netNetwork.eval() 92 | self.netNetwork.requires_grad_(False) 93 | 94 | # converts imgs from [-1, 1] to [0, 255] 95 | # returns back in [0, 1] 96 | # @torch.no_grad() 97 | def forward(self, image_hed, return_without_channels: bool = False): 98 | assert isinstance(image_hed, torch.Tensor) 99 | assert image_hed.ndim == 4 100 | assert image_hed.shape[1] == 3 101 | assert image_hed.dtype == torch.float32 102 | assert image_hed.max() <= 1.0 103 | assert image_hed.min() >= -1.0 104 | 105 | resize = T.Resize((self.size, self.size), T.InterpolationMode.BICUBIC) 106 | 107 | # yes it's supposed to be in [0, 255] as float32 108 | image_hed = (image_hed + 1.0) * 127.5 109 | 110 | edges = self.netNetwork(image_hed) 111 | edges = [e[:, 0] for e in edges] 112 | edges = [resize(e) for e in edges] 113 | edges = torch.stack(edges, dim=3) 114 | edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=3))) 115 | 116 | if return_without_channels or self.return_without_channels: 117 | return edge 118 | 119 | edge = edge[:, None, :, :].repeat_interleave(3, 1) 120 | return edge 121 | -------------------------------------------------------------------------------- /src/annotators/midas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from jaxtyping import Float 4 | from torchvision.transforms.functional import resize 5 | 6 | from transformers import ( 7 | DPTImageProcessor, 8 | DPTForDepthEstimation, 9 | ) 10 | 11 | from .util import better_resize 12 | 13 | 14 | class DepthEstimator(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | size: int, 19 | model: str = "Intel/dpt-hybrid-midas", 20 | local_files_only: bool = True, 21 | ) -> None: 22 | super().__init__() 23 | self.model = model 24 | self.size = size 25 | self.model_size = 384 26 | 27 | self.depth_estimator = DPTForDepthEstimation.from_pretrained(model, local_files_only=local_files_only) 28 | # self.feature_extractor = DPTImageProcessor.from_pretrained(model, local_files_only=local_files_only) 29 | 30 | self.depth_estimator.requires_grad_(False) 31 | self.depth_estimator.eval() 32 | 33 | @torch.no_grad() 34 | def forward( 35 | self, 36 | imgs: Float[torch.Tensor, "B C H W"], 37 | ) -> Float[torch.Tensor, "B C H W"]: 38 | assert imgs.min() >= -1.0 39 | assert imgs.max() <= 1.0 40 | assert len(imgs.shape) == 4 41 | 42 | imgs = (imgs + 1.0) / 2.0 43 | imgs = better_resize(imgs, self.model_size) 44 | # depth_dict = self.feature_extractor(imgs, do_rescale=False, return_tensors="pt") 45 | 46 | # for k, v in depth_dict.items(): 47 | # if isinstance(v, torch.Tensor): 48 | # depth_dict[k] = v.to(device=imgs.device) 49 | 50 | depth_map = self.depth_estimator(pixel_values=imgs).predicted_depth 51 | 52 | depth_map = torch.nn.functional.interpolate( 53 | depth_map.unsqueeze(1), 54 | size=(self.size, self.size), 55 | mode="bicubic", 56 | align_corners=False, 57 | ) 58 | depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) 59 | depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) 60 | depth_map = (depth_map - depth_min) / (depth_max - depth_min + 1e-6) 61 | depth_map = torch.cat([depth_map] * 3, dim=1) 62 | 63 | # in [0.0, 1.0] 64 | return depth_map 65 | -------------------------------------------------------------------------------- /src/annotators/openclip.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import open_clip 3 | import torch 4 | from transformers import ( 5 | CLIPImageProcessor, 6 | ) 7 | from torchvision.transforms.functional import normalize 8 | from .util import better_resize 9 | 10 | 11 | class VisionModel(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | clip_model: str, 16 | local_files_only: bool = True, 17 | ) -> None: 18 | super().__init__() 19 | self.clip_model = clip_model 20 | self.clip_vision_model = open_clip.create_model_from_pretrained("hf-hub:" + clip_model, return_transform=False) 21 | self.image_size = 224 22 | 23 | self.clip_vision_model.requires_grad_(False) 24 | self.clip_vision_model.eval() 25 | 26 | # @torch.no_grad() 27 | def forward(self, imgs: torch.Tensor) -> torch.Tensor: 28 | assert imgs.min() >= -1.0 29 | assert imgs.max() <= 1.0 30 | assert len(imgs.shape) == 4 31 | 32 | imgs = (imgs + 1.0) / 2.0 33 | imgs = better_resize(imgs, self.image_size) 34 | imgs = normalize( 35 | imgs, 36 | mean=[0.48145466, 0.4578275, 0.40821073], 37 | std=[0.26862954, 0.26130258, 0.27577711], 38 | ) 39 | 40 | image_features = self.clip_vision_model.encode_image(imgs) 41 | 42 | # image_features /= image_features.norm(dim=-1, keepdim=True) 43 | 44 | return image_features 45 | -------------------------------------------------------------------------------- /src/annotators/style.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import ( 4 | CLIPImageProcessor, 5 | CLIPVisionModel, 6 | CLIPVisionModelWithProjection, 7 | ) 8 | 9 | 10 | class VisionModel(nn.Module): 11 | 12 | def __init__( 13 | self, 14 | clip_model: str, 15 | with_projection: bool = False, 16 | local_files_only: bool = True, 17 | ) -> None: 18 | super().__init__() 19 | self.clip_model = clip_model 20 | self.with_projection = with_projection 21 | if with_projection: 22 | self.clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(clip_model, local_files_only=local_files_only) 23 | else: 24 | self.clip_vision_model = CLIPVisionModel.from_pretrained(clip_model, local_files_only=local_files_only) 25 | self.clip_vision_processor = CLIPImageProcessor.from_pretrained(clip_model, local_files_only=local_files_only) 26 | 27 | self.clip_vision_model.requires_grad_(False) 28 | self.clip_vision_model.eval() 29 | 30 | @torch.no_grad() 31 | def forward(self, imgs: torch.Tensor) -> torch.Tensor: 32 | assert imgs.min() >= -1.0 33 | assert imgs.max() <= 1.0 34 | assert len(imgs.shape) == 4 35 | 36 | imgs = (imgs + 1) * 127.5 37 | imgs = imgs.to(dtype=torch.uint8) 38 | clip_vision_inputs = self.clip_vision_processor(images=imgs, return_tensors="pt") 39 | 40 | for k, v in clip_vision_inputs.items(): 41 | if isinstance(v, torch.Tensor): 42 | clip_vision_inputs[k] = v.to(device=imgs.device) 43 | 44 | vision_outputs = self.clip_vision_model(**clip_vision_inputs) 45 | last_hidden_state = vision_outputs.last_hidden_state 46 | if self.with_projection: 47 | out = vision_outputs.image_embeds 48 | else: 49 | out = vision_outputs.pooler_output 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /src/annotators/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.nn.functional import avg_pool2d, interpolate 3 | from torchvision.transforms.functional import center_crop 4 | import torch 5 | 6 | annotator_ckpts_path = os.path.join("~/.cache/custom", "ckpts") 7 | 8 | 9 | def better_resize(imgs: torch.Tensor, image_size: int) -> torch.Tensor: 10 | ss = imgs.shape 11 | assert ss[-3] == 3 12 | 13 | H, W = ss[-2:] 14 | 15 | if len(ss) == 3: 16 | imgs = imgs.unsqueeze(0) 17 | 18 | side = min(H, W) 19 | factor = side // image_size 20 | 21 | imgs = center_crop(imgs, [side, side]) 22 | if factor > 1: 23 | imgs = avg_pool2d(imgs, factor) 24 | imgs = interpolate(imgs, [image_size, image_size], mode="bilinear") 25 | 26 | if len(ss) == 3: 27 | imgs = imgs[0] 28 | return imgs 29 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/LoRAdapter/3aa731684f8ace1b9976c61a90b7295b663dd62a/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/local.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision import transforms 4 | from PIL import Image 5 | from pathlib import Path 6 | 7 | 8 | def sort_key(p: Path): 9 | try: 10 | return int(p.stem) 11 | except: 12 | return p.stem 13 | 14 | 15 | class ImageFolderDataset(Dataset): 16 | def __init__(self, directory: Path, transform, caption_from_name: bool, caption_prefix: str): 17 | """ 18 | Args: 19 | directory (string): Directory with all the images. 20 | transform (callable, optional): Optional transform to be applied on a sample. 21 | """ 22 | self.directory = directory 23 | self.transform = transform 24 | self.image_paths = [directory / file for file in os.listdir(directory) if file.endswith(("jpg", "jpeg", "png"))] 25 | self.image_paths.sort(key=sort_key) 26 | self.caption_from_name = caption_from_name 27 | self.caption_prefix = caption_prefix 28 | 29 | def __len__(self): 30 | return len(self.image_paths) 31 | 32 | def __getitem__(self, idx: int): 33 | image_path = self.image_paths[idx] 34 | txt_path = image_path.with_suffix(".txt") 35 | 36 | if self.caption_from_name: 37 | label = self.caption_prefix + image_path.stem.split("_")[0].replace("-", " ") 38 | else: 39 | try: 40 | with open(txt_path, "r") as f: 41 | label = f.read() 42 | except: 43 | label = "" 44 | 45 | image = Image.open(image_path).convert("RGB") 46 | if self.transform: 47 | image = self.transform(image) 48 | return {"jpg": image, "caption": label} 49 | 50 | 51 | class ZipDataset(Dataset): 52 | def __init__(self, datasets: list[ImageFolderDataset]): 53 | # Ensure all datasets have the same length 54 | assert all(len(datasets[0]) == len(d) for d in datasets), "Datasets must all be the same length!" 55 | self.datasets = datasets 56 | 57 | def __len__(self): 58 | return len(self.datasets[0]) 59 | 60 | def __getitem__(self, idx: int): 61 | # Return a tuple containing elements from each dataset at the given index 62 | if len(self.datasets) == 1: 63 | return self.datasets[0][idx] 64 | 65 | return tuple(d[idx] for d in self.datasets) 66 | 67 | 68 | class ImageDataModule: 69 | def __init__( 70 | self, 71 | directories: list[str], 72 | transform: list, 73 | val_directories: list[str] = [], 74 | batch_size: int = 32, 75 | val_batch_size: int = 1, 76 | workers: int = 4, 77 | val_workers: int = 1, 78 | caption_from_name: bool = False, 79 | caption_prefix: str = "", 80 | ): 81 | super().__init__() 82 | 83 | self.batch_size = batch_size 84 | self.val_batch_size = val_batch_size 85 | self.workers = workers 86 | self.val_workers = val_workers 87 | 88 | project_root = Path(os.path.abspath(__file__)).parent.parent.parent 89 | 90 | self.train_dataset = ZipDataset( 91 | [ 92 | ImageFolderDataset( 93 | directory=Path(project_root, d), 94 | transform=transforms.Compose(transform), 95 | caption_from_name=caption_from_name, 96 | caption_prefix=caption_prefix, 97 | ) 98 | for d in directories 99 | ] 100 | ) 101 | 102 | self.val_dataset = ZipDataset( 103 | [ 104 | ImageFolderDataset( 105 | directory=Path(project_root, d), 106 | transform=transforms.Compose(transform), 107 | caption_from_name=caption_from_name, 108 | caption_prefix=caption_prefix, 109 | ) 110 | for d in val_directories 111 | ] 112 | ) 113 | 114 | def train_dataloader(self): 115 | return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers) 116 | 117 | def val_dataloader(self): 118 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_workers) 119 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms.v2 import Transform 2 | import torchvision.transforms.v2.functional as F 3 | 4 | 5 | class SquarePad(Transform): 6 | # use standard pad transform of v2 7 | # but always pads it to be a square 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | # self.fill = fill 13 | # self.padding_mode = padding_mode 14 | 15 | def _transform(self, inpt, params): 16 | h, w = inpt.shape[-2], inpt.shape[-1] 17 | 18 | if h > w: 19 | padding = [h - w // 2, 0, h - w // 2, 0] 20 | else: 21 | padding = [0, w - h // 2, 0, w - h // 2] 22 | 23 | return F.pad(inpt, padding, fill=255) 24 | 25 | 26 | class TopCrop(Transform): 27 | # use standard crop transform of v2 28 | # but always crops from the top 29 | 30 | def __init__(self, size): 31 | super().__init__() 32 | self.size = size 33 | 34 | def _transform(self, inpt, params): 35 | return F.crop(inpt, 0, 0, self.size, self.size) 36 | -------------------------------------------------------------------------------- /src/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Union, Tuple 4 | from src.utils import DataProvider 5 | from jaxtyping import Float 6 | from einops import rearrange 7 | 8 | 9 | class SimpleLoraLinear(torch.nn.Module): 10 | def __init__( 11 | self, 12 | out_features: int, 13 | in_features: int, 14 | c_dim: int, 15 | rank: int | float, 16 | data_provider: DataProvider, 17 | alpha: float = 1.0, 18 | lora_scale: float = 1.0, 19 | broadcast_tokens: bool = True, 20 | depth: int | None = None, 21 | use_depth: bool = False, 22 | n_transformations: int = 1, 23 | **kwargs, 24 | ): 25 | super().__init__() 26 | 27 | self.data_provider = data_provider 28 | self.lora_scale = lora_scale 29 | self.broadcast_tokens = broadcast_tokens 30 | self.depth = depth 31 | self.use_depth = use_depth 32 | self.n_transformations = n_transformations 33 | self.rank = rank 34 | 35 | # original weight of the matrix 36 | self.W = nn.Linear(in_features, out_features, bias=False) 37 | for p in self.W.parameters(): 38 | p.requires_grad_(False) 39 | 40 | if type(rank) == float: 41 | self.rank = int(in_features * self.rank) 42 | 43 | self.A = nn.Linear(in_features, self.rank, bias=False) 44 | self.B = nn.Linear(self.rank, out_features, bias=False) 45 | 46 | nn.init.zeros_(self.B.weight) 47 | nn.init.kaiming_normal_(self.A.weight, a=1) 48 | 49 | self.emb_gamma = nn.Linear(c_dim, self.rank * n_transformations, bias=False) 50 | self.emb_beta = nn.Linear(c_dim, self.rank * n_transformations, bias=False) 51 | 52 | def forward(self, x, *args, **kwargs): 53 | w_out = self.W(x) 54 | 55 | if self.lora_scale == 0.0: 56 | return w_out 57 | 58 | c = self.data_provider.get_batch() 59 | if self.use_depth: 60 | assert self.depth is not None, "block depth has to be provided" 61 | c = c[self.depth] 62 | 63 | scale = self.emb_gamma(c) + 1.0 64 | shift = self.emb_beta(c) 65 | 66 | # we need to do that when we only get a single embedding vector 67 | # e.g pooled clip img embedding 68 | # out is [B, 1, rank] 69 | if self.broadcast_tokens: 70 | scale = scale.unsqueeze(1) 71 | shift = shift.unsqueeze(1) 72 | 73 | if self.n_transformations > 1: 74 | # out is [B, 1, trans, rank] 75 | scale = scale.reshape(-1, 1, self.n_transformations, self.rank) 76 | shift = shift.reshape(-1, 1, self.n_transformations, self.rank) 77 | 78 | a_out = self.A(x) # [B, N, D] 79 | if self.n_transformations > 1: 80 | a_out = a_out.unsqueeze(-2).expand(-1, -1, self.n_transformations, -1) # [B, N, trans, rank] 81 | a_cond = scale * a_out 82 | 83 | a_cond = a_cond + shift 84 | 85 | if self.n_transformations > 1: 86 | a_cond = a_cond.mean(dim=-2) 87 | 88 | b_out = self.B(a_cond) 89 | 90 | return w_out + b_out * self.lora_scale 91 | 92 | 93 | # FiLM style LoRA conditioning 94 | class LoRAConv(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels: int, 98 | out_channels: int, 99 | kernel_size: Union[int, Tuple[int, int]], 100 | stride: Union[int, Tuple[int, int]], 101 | padding: Union[int, Tuple[int, int]], 102 | c_dim: int, 103 | rank: int, 104 | depth: int, 105 | data_provider: DataProvider, 106 | lora_scale: float = 1.0, 107 | *args, 108 | **kwargs, 109 | ): 110 | super().__init__() 111 | 112 | self.lora_scale = lora_scale 113 | self.data_provider = data_provider 114 | 115 | self.W = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 116 | for p in self.W.parameters(): 117 | p.requires_grad_(False) 118 | 119 | self.A = nn.Conv2d(in_channels, rank, kernel_size, stride, padding, bias=False) 120 | self.B = nn.Conv2d(rank, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 121 | 122 | nn.init.zeros_(self.B.weight) 123 | nn.init.kaiming_normal_(self.A.weight, a=1) 124 | 125 | self.emb_gamma = nn.Linear(c_dim, rank, bias=False) 126 | self.emb_beta = nn.Linear(c_dim, rank, bias=False) 127 | 128 | def forward(self, x: Float[torch.Tensor, "B C H W"], *args, **kwargs) -> Float[torch.Tensor, "B C H W"]: 129 | w_out = self.W(x) 130 | 131 | c = self.data_provider.get_batch() 132 | scale = self.emb_gamma(c) + 1.0 133 | shift = self.emb_beta(c) 134 | 135 | a_out = self.A(x) 136 | a_cond = scale[..., None, None] * a_out + shift[..., None, None] 137 | 138 | b_out = self.B(a_cond) 139 | 140 | return w_out + b_out * self.lora_scale 141 | 142 | 143 | class NewStructLoRAConv(nn.Module): 144 | def __init__( 145 | self, 146 | in_channels: int, 147 | out_channels: int, 148 | kernel_size: Union[int, Tuple[int, int]], 149 | stride: Union[int, Tuple[int, int]], 150 | padding: Union[int, Tuple[int, int]], 151 | c_dim: int, 152 | rank: int, 153 | depth: int, 154 | data_provider: DataProvider, 155 | lora_scale: float = 1.0, 156 | ): 157 | super().__init__() 158 | 159 | self.lora_scale = lora_scale 160 | self.depth = depth 161 | 162 | self.data_provider = data_provider 163 | 164 | self.W = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 165 | for p in self.W.parameters(): 166 | p.requires_grad_(False) 167 | 168 | self.A = nn.Conv2d(in_channels, rank, kernel_size, stride, padding, bias=False) 169 | self.B = nn.Conv2d(rank, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 170 | 171 | nn.init.zeros_(self.B.weight) 172 | nn.init.kaiming_normal_(self.A.weight, a=1) 173 | 174 | self.beta = nn.Conv2d(c_dim, rank, 1, bias=False) 175 | self.gamma = nn.Conv2d(c_dim, rank, 1, bias=False) 176 | 177 | def forward(self, x: Float[torch.Tensor, "B C H W"], *args, **kwargs) -> Float[torch.Tensor, "B C H W"]: 178 | w_out = self.W(x) 179 | 180 | if self.lora_scale == 0.0: 181 | return w_out 182 | 183 | cs = self.data_provider.get_batch() # tuple 184 | c = cs[self.depth] 185 | 186 | element_shift = self.beta(c) 187 | element_scale = self.gamma(c) + 1.0 188 | 189 | # check if norm is actually needed 190 | # if doesn't work add norm on a_out 191 | a_out = self.A(x) 192 | 193 | a_cond = a_out * element_scale + element_shift 194 | 195 | b_out = self.B(a_cond) 196 | 197 | return w_out + b_out * self.lora_scale 198 | -------------------------------------------------------------------------------- /src/mapper_network.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from functools import reduce 5 | from einops import rearrange 6 | 7 | 8 | class SimpleMapper(nn.Module): 9 | def __init__(self, d_model, c_dim): 10 | super().__init__() 11 | 12 | self.ls = nn.Sequential(nn.Linear(d_model, c_dim), nn.LayerNorm(c_dim)) # just [b, d] (no n as it's a single vector) 13 | 14 | def forward(self, x): 15 | return self.ls(x) 16 | 17 | 18 | class FixedStructureMapper15(nn.Module): 19 | def __init__(self, c_dim: int): 20 | super().__init__() 21 | self.c_dim = c_dim 22 | 23 | self.down = nn.Sequential( 24 | nn.Conv2d(3, 16, 3, padding=1), 25 | nn.SiLU(), 26 | nn.Conv2d(16, 16, 3, padding=1), 27 | nn.SiLU(), 28 | nn.Conv2d(16, 32, 3, padding=1, stride=2), # 256 29 | nn.SiLU(), 30 | nn.Conv2d(32, 32, 3, padding=1), 31 | nn.SiLU(), 32 | nn.Conv2d(32, 64, 3, padding=1, stride=2), # 128 33 | nn.SiLU(), 34 | nn.Conv2d(64, 64, 3, padding=1), 35 | nn.SiLU(), 36 | nn.Conv2d(64, 128, 3, padding=1, stride=2), # 64 37 | nn.SiLU(), 38 | nn.Conv2d(128, 128, 3, padding=1), 39 | nn.SiLU(), 40 | # nn.Conv2d(128, 128, 3, padding=1), 41 | ) 42 | 43 | self.block0 = nn.Identity() 44 | 45 | self.block1 = nn.Sequential( 46 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 32 47 | nn.SiLU(), 48 | nn.Conv2d(128, 128, 3, padding=1), 49 | nn.SiLU(), 50 | ) 51 | self.block2 = nn.Sequential( 52 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 16 53 | nn.SiLU(), 54 | nn.Conv2d(128, 128, 3, padding=1), 55 | nn.SiLU(), 56 | ) 57 | self.block3 = nn.Sequential( 58 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 8 59 | nn.SiLU(), 60 | nn.Conv2d(128, 128, 3, padding=1), 61 | nn.SiLU(), 62 | ) 63 | 64 | self.out0 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 65 | self.out1 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 66 | self.out2 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 67 | self.out3 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 68 | 69 | def forward(self, x, *args, **kwargs): 70 | base = self.down(x) 71 | 72 | b0 = self.block0(base) 73 | b1 = self.block1(b0) 74 | b2 = self.block2(b1) 75 | b3 = self.block3(b2) 76 | 77 | out0 = self.out0(b0) 78 | out1 = self.out1(b1) 79 | out2 = self.out2(b2) 80 | out3 = self.out3(b3) 81 | 82 | return out0, out1, out2, out3 83 | 84 | 85 | class FixedStructureMapperXL(nn.Module): 86 | def __init__(self, c_dim: int): 87 | super().__init__() 88 | self.c_dim = c_dim 89 | 90 | self.down = nn.Sequential( 91 | nn.Conv2d(3, 32, 3, padding=1), 92 | nn.SiLU(), 93 | nn.Conv2d(32, 32, 3, padding=1), 94 | nn.SiLU(), 95 | nn.Conv2d(32, 128, 3, padding=1, stride=2), # 256 96 | nn.SiLU(), 97 | nn.Conv2d(128, 128, 3, padding=1), 98 | nn.SiLU(), 99 | ) 100 | 101 | self.block0 = nn.Identity() 102 | 103 | self.block1 = nn.Sequential( 104 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 32 105 | nn.SiLU(), 106 | nn.Conv2d(128, 128, 3, padding=1), 107 | nn.SiLU(), 108 | ) 109 | self.block2 = nn.Sequential( 110 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 16 111 | nn.SiLU(), 112 | nn.Conv2d(128, 128, 3, padding=1), 113 | nn.SiLU(), 114 | ) 115 | 116 | self.out0 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 117 | self.out1 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 118 | self.out2 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 119 | 120 | def forward(self, x, *args, **kwargs): 121 | base = self.down(x) 122 | 123 | b0 = self.block0(base) 124 | b1 = self.block1(b0) 125 | b2 = self.block2(b1) 126 | 127 | out0 = self.out0(b0) 128 | out1 = self.out1(b1) 129 | out2 = self.out2(b2) 130 | 131 | return out0, out1, out2 132 | 133 | 134 | # we don't have attention in the deepest blocks 135 | # so we only have three outputs here for SD15 136 | class AttentionStructureMapper15(nn.Module): 137 | def __init__(self, c_dim: int): 138 | super().__init__() 139 | self.c_dim = c_dim 140 | 141 | self.down = nn.Sequential( 142 | nn.Conv2d(3, 16, 3, padding=1), 143 | nn.SiLU(), 144 | nn.Conv2d(16, 16, 3, padding=1), 145 | nn.SiLU(), 146 | nn.Conv2d(16, 32, 3, padding=1, stride=2), # 256 147 | nn.SiLU(), 148 | nn.Conv2d(32, 32, 3, padding=1), 149 | nn.SiLU(), 150 | nn.Conv2d(32, 64, 3, padding=1, stride=2), # 128 151 | nn.SiLU(), 152 | nn.Conv2d(64, 64, 3, padding=1), 153 | nn.SiLU(), 154 | nn.Conv2d(64, 128, 3, padding=1, stride=2), # 64 155 | nn.SiLU(), 156 | nn.Conv2d(128, 128, 3, padding=1), 157 | nn.SiLU(), 158 | # nn.Conv2d(128, 128, 3, padding=1), 159 | ) 160 | 161 | # the output channels correspond to the token dim 162 | self.block0 = nn.Identity() 163 | 164 | self.block1 = nn.Sequential( 165 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 32 166 | nn.SiLU(), 167 | nn.Conv2d(128, 128, 3, padding=1), 168 | nn.SiLU(), 169 | ) 170 | self.block2 = nn.Sequential( 171 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 16 172 | nn.SiLU(), 173 | nn.Conv2d(128, 128, 3, padding=1), 174 | nn.SiLU(), 175 | ) 176 | 177 | self.block3 = nn.Sequential( 178 | nn.Conv2d(128, 128, 3, padding=1, stride=2), # 8 179 | nn.SiLU(), 180 | nn.Conv2d(128, 128, 3, padding=1), 181 | nn.SiLU(), 182 | ) 183 | 184 | # here we project them down again 185 | self.out0 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 186 | self.out1 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 187 | self.out2 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 188 | self.out3 = nn.Sequential(nn.Conv2d(128, c_dim, 1)) 189 | 190 | def forward(self, x, *args, **kwargs): 191 | base = self.down(x) 192 | 193 | b0 = self.block0(base) 194 | b1 = self.block1(b0) 195 | b2 = self.block2(b1) 196 | b3 = self.block3(b2) 197 | 198 | out0 = self.out0(b0) 199 | out1 = self.out1(b1) 200 | out2 = self.out2(b2) 201 | out3 = self.out3(b3) 202 | 203 | # convert to tokens 204 | ot0 = rearrange(out0, "B C H W -> B (H W) C") 205 | ot1 = rearrange(out1, "B C H W -> B (H W) C") 206 | ot2 = rearrange(out2, "B C H W -> B (H W) C") 207 | ot3 = rearrange(out3, "B C H W -> B (H W) C") 208 | 209 | return ot0, ot1, ot2, ot3 210 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union, Literal 3 | import torch 4 | from torch import nn 5 | from src.utils import DataProvider 6 | import src.lora as loras 7 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 8 | from diffusers import AutoencoderTiny 9 | from src.utils import getattr_recursive 10 | 11 | import torch.nn.functional as F 12 | from pydoc import locate 13 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 14 | 15 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 16 | retrieve_timesteps, 17 | ) 18 | 19 | from tqdm.auto import tqdm 20 | import random 21 | 22 | from diffusers import ControlNetModel 23 | from torchvision.transforms import Compose 24 | from typing import Callable 25 | 26 | ATTENTION_MODULES = ["to_k", "to_v"] 27 | 28 | # only for SD15 29 | CONV_MODULES = ["conv1", "conv2"] 30 | 31 | ADAPTION_MODE = Literal[ 32 | "full_attention", 33 | "only_self", 34 | "only_cross", 35 | "only_conv", 36 | "only_first_conv", 37 | "only_res_conv", 38 | "full", 39 | "no_cross", 40 | "only_value", 41 | # below only works for sdxl 42 | "b-lora_style", 43 | "b-lora_content", 44 | "b-lora", 45 | "sdxl_cross", 46 | "sdxl_self", 47 | "sdxl_inner", 48 | ] 49 | 50 | CONDITION_MODE = Literal["style", "structure"] 51 | 52 | 53 | class ModelBase(ABC, nn.Module): 54 | 55 | def __init__( 56 | self, 57 | pipeline_type: str, 58 | model_name: str, 59 | local_files_only: bool = True, 60 | c_dropout: float = 0.05, 61 | guidance_scale: float = 7.5, 62 | use_controlnet: bool = False, 63 | annotator: None | nn.Module = None, 64 | tiny_vae: bool = False, 65 | ) -> None: 66 | super().__init__() 67 | self.params_to_optimize: list[nn.Parameter] = [] 68 | self.lora_state_dict_keys: dict[str, list[str]] = {} 69 | self.lora_layers: dict[str, list[nn.Module]] = {} 70 | self.lora_transforms: list[Compose | None] = [] 71 | 72 | self.encoders: list[nn.Module] = list() 73 | self.mappers: list[nn.Module] = list() 74 | self.dps: list[DataProvider] = [] 75 | 76 | self.tiny_vae = tiny_vae 77 | self.c_dropout = c_dropout 78 | self.guidance_scale = guidance_scale 79 | self.use_controlnet = use_controlnet 80 | 81 | addition_config = {} 82 | 83 | # Note that this requires the controlnet pipe which also has to be set in the config 84 | 85 | if tiny_vae: 86 | vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", local_files_only=local_files_only) 87 | addition_config["vae"] = vae 88 | 89 | if self.use_controlnet: 90 | assert annotator is not None, "Need annotator for controlnet" 91 | 92 | controlnet = ControlNetModel.from_pretrained( 93 | "lllyasviel/sd-controlnet-depth", 94 | use_safetensors=True, 95 | local_files_only=local_files_only, 96 | **addition_config, 97 | ) 98 | # controlnet.requires_grad_(False) 99 | # controlnet.eval() 100 | addition_config["controlnet"] = controlnet 101 | 102 | # fix this cheap work around! 103 | self.encoders.append(annotator) 104 | self.mappers.append(controlnet) 105 | 106 | self.pipe: DiffusionPipeline = locate(pipeline_type).from_pretrained( 107 | model_name, 108 | local_files_only=local_files_only, 109 | safety_checker=None, # too anoying 110 | safe_tensors=True, 111 | **addition_config, 112 | ) 113 | assert isinstance(self.pipe, DiffusionPipeline) 114 | 115 | self.noise_scheduler = DDPMScheduler.from_config( 116 | {**self.pipe.scheduler.config, "rescale_betas_zero_snr": False}, 117 | subfolder="scheduler", 118 | ) 119 | 120 | self.max_depth = len(self.pipe.unet.config["block_out_channels"]) - 1 121 | 122 | # we register all the individual pipeline modules here 123 | # such that all the typical calls like .to and .prepare effect them. 124 | self.unet = self.pipe.unet 125 | self.unet.requires_grad_(False) 126 | 127 | self.vae = self.pipe.vae 128 | self.text_encoder = self.pipe.text_encoder 129 | self.tokenizer = self.pipe.tokenizer 130 | 131 | self.vae = self.pipe.vae 132 | self.text_encoder.requires_grad_(False) 133 | 134 | # handle sdxl case 135 | if hasattr(self.pipe, "text_encoder_2"): 136 | self.text_encoder_2 = self.pipe.text_encoder_2 137 | self.text_encoder_2.requires_grad_(False) 138 | 139 | def add_lora_to_unet( 140 | self, 141 | config: dict, 142 | name: str, 143 | data_provider: DataProvider, 144 | encoder: nn.Module, 145 | mapper: nn.Module, 146 | optimize: bool = True, 147 | transforms: list[Callable] = [], 148 | ): 149 | self.rank = config.rank 150 | self.c_dim = config.c_dim 151 | unet = self.unet 152 | sd = unet.state_dict() 153 | 154 | self.mappers.append(mapper) 155 | self.encoders.append(encoder) 156 | self.dps.append(data_provider) 157 | 158 | self.lora_transforms.append(Compose(transforms) if len(transforms) > 0 else None) 159 | 160 | print(f"adding {len(transforms)} transforms to LoRA {name}") 161 | 162 | lora_cls = config.lora_cls 163 | adaption_mode = config.adaption_mode 164 | 165 | if not optimize: 166 | mapper.eval() 167 | mapper.requires_grad_(False) 168 | 169 | local_lora_sd_keys: list[str] = [] 170 | 171 | for path, w in sd.items(): 172 | class_config = {**config} 173 | del class_config["lora_cls"] 174 | del class_config["adaption_mode"] 175 | 176 | _continue = True 177 | if adaption_mode == "full_attention" and "attn" in path: 178 | _continue = False 179 | 180 | if adaption_mode == "only_self" and "attn1" in path: 181 | _continue = False 182 | 183 | if adaption_mode == "only_cross" and "attn2" in path: 184 | _continue = False 185 | 186 | if adaption_mode == "only_conv" and ("conv1" in path or "conv2" in path): 187 | _continue = False 188 | 189 | # only the first conv layer in each resnet block 190 | if adaption_mode == "only_first_conv" and "0.conv1" in path: 191 | _continue = False 192 | 193 | if adaption_mode == "only_res_conv" and ("0.conv1" in path or "1.conv1" in path): 194 | _continue = False 195 | 196 | if adaption_mode == "full" and ("attn" in path or "conv" in path): 197 | _continue = False 198 | 199 | if adaption_mode == "no_cross" and "attn2" not in path: 200 | _continue = False 201 | 202 | if adaption_mode == "b-lora_content" and ("up_blocks.0.attentions.0" in path and "attn" in path): 203 | _continue = False 204 | 205 | if adaption_mode == "b-lora_style" and ("up_blocks.0.attentions.1" in path and "attn" in path): 206 | _continue = False 207 | 208 | if ( 209 | adaption_mode == "b-lora" 210 | and ("up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path) 211 | and "attn" in path 212 | ): 213 | _continue = False 214 | 215 | # supposed setting content to have no effect 216 | # if "up_blocks.0.attentions.0" in path: 217 | # class_config["lora_scale"] = 0.0 218 | 219 | # "down_blocks.2.attentions.1" in path or 220 | if ( 221 | adaption_mode == "sdxl_inner" 222 | and ("mid_block" in path or "up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path) 223 | and "attn2" in path 224 | ): 225 | _continue = False 226 | 227 | if ( 228 | adaption_mode == "sdxl_cross" 229 | and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path) 230 | and "attn2" in path 231 | ): 232 | _continue = False 233 | 234 | if ( 235 | adaption_mode == "sdxl_self" 236 | and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path) 237 | and "attn1" in path 238 | ): 239 | _continue = False 240 | 241 | if _continue: 242 | continue 243 | 244 | if "bias" in path: 245 | # we handle the bias together with the weight 246 | # this is only relevant for the conv layers 247 | continue 248 | 249 | parent_path = ".".join(path.split(".")[:-2]) 250 | target_path = ".".join(path.split(".")[:-1]) 251 | target_name = path.split(".")[-2] 252 | parent_module = getattr_recursive(unet, parent_path) 253 | target_module = getattr_recursive(unet, target_path) 254 | 255 | if "mid_block" in path: 256 | depth = self.max_depth 257 | elif "down_blocks" in path: 258 | depth = int(path.split("down_blocks.")[1][0]) 259 | elif "up_blocks" in path: 260 | depth = self.max_depth - int(path.split("up_blocks.")[1][0]) 261 | else: 262 | raise ValueError(f"Unknown module {path}") 263 | 264 | lora = None 265 | if "attn" in path: 266 | if not any([m in path for m in ATTENTION_MODULES]): 267 | continue 268 | 269 | lora = getattr(loras, lora_cls)( 270 | out_features=target_module.out_features, 271 | in_features=target_module.in_features, 272 | data_provider=data_provider, 273 | depth=depth, 274 | **class_config, 275 | ) 276 | 277 | # W is the original weight matrix 278 | # those layers have no bias 279 | lora.W.load_state_dict({path.split(".")[-1]: w}) 280 | 281 | if lora_cls == "IPLinear": 282 | # for faster convergence 283 | lora.W_IP.load_state_dict({path.split(".")[-1]: w}) 284 | 285 | if "conv" in path: 286 | lora = getattr(loras, lora_cls)( 287 | in_channels=target_module.in_channels, 288 | out_channels=target_module.out_channels, 289 | kernel_size=target_module.kernel_size, 290 | stride=target_module.stride, 291 | padding=target_module.padding, 292 | data_provider=data_provider, 293 | depth=depth, 294 | **class_config, 295 | ) 296 | 297 | # find bias term 298 | bias_path = ".".join(path.split(".")[:-1] + ["bias"]) 299 | b = sd[bias_path] 300 | lora.W.load_state_dict({path.split(".")[-1]: w, "bias": b}) 301 | 302 | if lora is None: 303 | raise ValueError(f"Unknown module {path}") 304 | 305 | for k in lora.state_dict().keys(): 306 | # W is by design the original weight matrix which we don't need to save 307 | if k.split(".")[0] == "W": 308 | continue 309 | 310 | local_lora_sd_keys.append(f"{target_path}.{k}") 311 | 312 | self.lora_state_dict_keys[name] = local_lora_sd_keys 313 | 314 | setattr( 315 | parent_module, 316 | target_name, 317 | lora, 318 | ) 319 | 320 | if optimize: 321 | for p in lora.parameters(): 322 | if p.requires_grad: 323 | self.params_to_optimize.append(p) 324 | else: 325 | lora.eval() 326 | for p in lora.parameters(): 327 | p.requires_grad_(False) 328 | 329 | self.lora_layers[name] = [lora] + self.lora_layers.get(name, []) 330 | 331 | def get_lora_state_dict(self, unet: Union[nn.Module, None] = None): 332 | lora_sd = {} 333 | 334 | if unet is None: 335 | unet = self.unet 336 | 337 | for k, v in unet.state_dict().items(): 338 | for n, keys in self.lora_state_dict_keys.items(): 339 | if n not in lora_sd: 340 | lora_sd[n] = {} 341 | 342 | if k in keys: 343 | lora_sd[n][k] = v.cpu() 344 | 345 | return lora_sd 346 | 347 | @abstractmethod 348 | def get_input(self, imgs: torch.Tensor, prompts: list[str]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: 349 | raise NotImplementedError() 350 | 351 | # -> epsilon, loss, x0 352 | def forward_easy(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 353 | return self(args, kwargs) 354 | 355 | def sample(self, *args, **kwargs): 356 | return self.pipe(*args, **kwargs).images 357 | 358 | 359 | class SD15(ModelBase): 360 | def __init__(self, pipeline_type, model_name, *args, **kwargs) -> None: 361 | super().__init__(pipeline_type, model_name, *args, **kwargs) 362 | 363 | @torch.no_grad() 364 | def get_input(self, imgs: torch.Tensor, prompts: list[str]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: 365 | assert len(imgs.shape) == 4 366 | assert imgs.min() >= -1.0 367 | assert imgs.max() <= 1.0 368 | 369 | imgs = imgs.clip(-1.0, 1.0) 370 | 371 | # Convert images to latent space 372 | if self.tiny_vae: 373 | latents = self.vae.encode(imgs).latents 374 | else: 375 | latents = self.vae.encode(imgs).latent_dist.sample() 376 | 377 | latents = latents * self.vae.config.scaling_factor 378 | 379 | # prompt dropout 380 | prompts = ["" if random.random() < self.c_dropout else p for p in prompts] 381 | 382 | # prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( 383 | # prompt=prompts, 384 | # device=self.unet.device, 385 | # num_images_per_prompt=1, 386 | # do_classifier_free_guidance=False, 387 | # ) 388 | 389 | # do it manually to avoid stupid warnings 390 | input_ids = self.tokenizer( 391 | prompts, 392 | truncation=True, 393 | padding="max_length", 394 | max_length=self.tokenizer.model_max_length, 395 | return_tensors="pt", 396 | ).input_ids 397 | prompt_embeds = self.text_encoder(input_ids.to(imgs.device))["last_hidden_state"] 398 | 399 | # assert (prompt_embeds - prompt_embeds).mean() < 1e-6 400 | # assert (prompt_embeds == prompt_embeds).all() 401 | 402 | c = { 403 | "prompt_embeds": prompt_embeds, 404 | } 405 | 406 | return latents, c 407 | 408 | def forward( 409 | self, 410 | latents: torch.Tensor, 411 | c: dict[str, torch.Tensor], 412 | cs: list[torch.Tensor], 413 | timesteps: torch.Tensor, 414 | noise: torch.Tensor, 415 | cfg_mask: list[bool] | None = None, 416 | skip_encode: bool = False, 417 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 418 | 419 | noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) 420 | prompt_embeds = c["prompt_embeds"] 421 | bsz = latents.shape[0] 422 | encoders = self.encoders 423 | mappers = self.mappers 424 | 425 | additional_inputs = {} 426 | if self.use_controlnet: 427 | # controlnet related stuff is always at index 0 428 | cn_input = cs[0] 429 | cs = cs[1:] 430 | 431 | controlnet = mappers[0] 432 | mappers = mappers[1:] 433 | 434 | annotator = encoders[0] 435 | encoders = encoders[1:] 436 | 437 | with torch.no_grad(): 438 | cn_cond = annotator(cn_input) 439 | 440 | down_block_res_samples, mid_block_res_sample = controlnet( 441 | noisy_latents, 442 | timesteps, 443 | encoder_hidden_states=prompt_embeds, 444 | controlnet_cond=cn_cond, 445 | conditioning_scale=1.0, 446 | return_dict=False, 447 | ) 448 | 449 | additional_inputs["down_block_additional_residuals"] = down_block_res_samples 450 | additional_inputs["mid_block_additional_residual"] = mid_block_res_sample 451 | 452 | # add our lora conditioning 453 | # cs in [-1, 1] 454 | for i, (encoder, dp, mapper, lora_c) in enumerate(zip(encoders, self.dps, mappers, cs)): 455 | if cfg_mask is None or cfg_mask[i]: 456 | dropout_mask = torch.rand(bsz, device=lora_c.device) < self.c_dropout 457 | 458 | # apply dropout for cfg 459 | lora_c[dropout_mask] = torch.zeros_like(lora_c[dropout_mask]) 460 | 461 | if skip_encode: 462 | cond = lora_c 463 | else: 464 | # some encoders we want to finetune 465 | # so no torch.no_grad() here 466 | # instead we set requires_grad in the corresponding classes/configs 467 | t = self.lora_transforms[i] 468 | if t is not None: 469 | lora_c = t(lora_c) 470 | cond = encoder(lora_c) 471 | mapped_cond = mapper(cond) 472 | dp.set_batch(mapped_cond) 473 | 474 | # Predict the noise residual 475 | model_pred = self.unet( 476 | noisy_latents, timesteps, encoder_hidden_states=prompt_embeds, **additional_inputs 477 | ).sample 478 | 479 | # get x0 prediction 480 | alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device=model_pred.device, dtype=model_pred.dtype) 481 | 482 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 483 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 484 | while len(sqrt_alpha_prod.shape) < len(model_pred.shape): 485 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 486 | 487 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 488 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 489 | while len(sqrt_one_minus_alpha_prod.shape) < len(model_pred.shape): 490 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 491 | 492 | x0 = (noisy_latents - sqrt_one_minus_alpha_prod * model_pred) / sqrt_alpha_prod 493 | 494 | # Get the target for loss depending on the prediction type 495 | if self.noise_scheduler.config.prediction_type == "epsilon": 496 | target = noise 497 | elif self.noise_scheduler.config.prediction_type == "v_prediction": 498 | target = self.noise_scheduler.get_velocity(latents, noise, timesteps) 499 | else: 500 | raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") 501 | 502 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 503 | 504 | return model_pred, loss, x0, cond 505 | 506 | def forward_easy( 507 | self, 508 | imgs: torch.Tensor, 509 | prompts: list[str], 510 | cs: list[torch.Tensor], 511 | cfg_mask: list[bool] | None = None, 512 | skip_encode: bool = False, 513 | batch: dict | None = None, 514 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 515 | 516 | latents, c = self.get_input(imgs, prompts) 517 | 518 | # Sample noise that we'll add to the latents 519 | noise = torch.randn_like(latents) 520 | bsz = latents.shape[0] 521 | # Sample a random timestep for each image 522 | timesteps = torch.randint( 523 | 0, 524 | self.noise_scheduler.config.num_train_timesteps, 525 | (bsz,), 526 | device=latents.device, 527 | ) 528 | # timesteps = timesteps.long() 529 | 530 | return self( 531 | latents=latents, 532 | c=c, 533 | cs=cs, 534 | timesteps=timesteps, 535 | noise=noise, 536 | cfg_mask=cfg_mask, 537 | skip_encode=skip_encode, 538 | ) 539 | 540 | @torch.no_grad() 541 | def sample_custom( 542 | self, 543 | prompt, 544 | num_images_per_prompt, 545 | cs: list[torch.Tensor], 546 | generator, 547 | cfg_mask: list[bool] | None = None, 548 | prompt_offset_step: int = 0, 549 | skip_encode: bool = False, 550 | **kwargs, 551 | ): 552 | height = self.unet.config.sample_size * self.pipe.vae_scale_factor 553 | width = self.unet.config.sample_size * self.pipe.vae_scale_factor 554 | 555 | if prompt is not None and isinstance(prompt, str): 556 | batch_size = 1 557 | elif prompt is not None and isinstance(prompt, list): 558 | batch_size = len(prompt) 559 | 560 | batch_size = batch_size * num_images_per_prompt 561 | 562 | device = self.unet.device 563 | 564 | prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( 565 | prompt, device, num_images_per_prompt, True 566 | ) # do cfg 567 | dtype = prompt_embeds.dtype 568 | 569 | # for cfg 570 | c_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]).to(dtype) 571 | uc_prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds]).to(dtype) 572 | 573 | # we have to do two separate forward passes for the cfg with the loras 574 | # add our lora conditioning 575 | for i, (encoder, dp, mapper, c) in enumerate(zip(self.encoders, self.dps, self.mappers, cs)): 576 | 577 | if c.shape[0] != batch_size: 578 | assert c.shape[0] == 1 579 | c = torch.cat(batch_size * [c]) # repeat along batch dim 580 | 581 | neg_c = torch.zeros_like(c) 582 | if cfg_mask is not None and not cfg_mask[i]: 583 | print("no cfg for lora nr", i) 584 | c = torch.cat([c, c]) 585 | else: 586 | c = torch.cat([neg_c, c]) 587 | 588 | if skip_encode: 589 | cond = c 590 | else: 591 | cond = encoder(c) 592 | mapped_cond = mapper(cond) 593 | if isinstance(mapped_cond, tuple) or isinstance(mapped_cond, list): 594 | mapped_cond = [mc.to(dtype) for mc in mapped_cond] 595 | else: 596 | mapped_cond = mapped_cond.to(dtype) 597 | 598 | dp.set_batch(mapped_cond) 599 | 600 | # 4. Prepare timesteps 601 | timesteps, num_inference_steps = retrieve_timesteps(self.pipe.scheduler, 50, device) 602 | 603 | # 5. Prepare latent variables 604 | num_channels_latents = 4 # self.unet.config.in_channels 605 | latents = self.pipe.prepare_latents( 606 | batch_size, 607 | num_channels_latents, 608 | height, 609 | width, 610 | c_prompt_embeds.dtype, 611 | device, 612 | generator, 613 | ) 614 | 615 | for i, t in tqdm(enumerate(timesteps)): 616 | # cfg 617 | latent_model_input = torch.cat([latents] * 2) 618 | 619 | noise_pred = self.unet( 620 | latent_model_input, 621 | t, 622 | encoder_hidden_states=(c_prompt_embeds if i >= prompt_offset_step else uc_prompt_embeds), 623 | return_dict=False, 624 | )[0] 625 | 626 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 627 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 628 | 629 | latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 630 | 631 | latents = latents.to(torch.float32) 632 | image = self.vae.decode( 633 | latents / self.vae.config.scaling_factor, 634 | return_dict=False, 635 | generator=generator, 636 | )[0] 637 | do_denormalize = [True] * image.shape[0] 638 | 639 | image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize) 640 | 641 | return image 642 | 643 | # with self.progress_bar(total=num_inference_steps) as progress_bar: 644 | 645 | @torch.no_grad() 646 | def sample(self, *args, **kwargs): 647 | return self.sample_easy(*args, **kwargs) 648 | 649 | @torch.no_grad() 650 | def sample_easy( 651 | self, 652 | prompt, 653 | num_images_per_prompt, 654 | cs: list[torch.Tensor], 655 | generator, 656 | cfg_mask: list[bool] | None = None, 657 | prompt_offset_step: int = 0, 658 | # dtype=torch.float32, 659 | **kwargs, 660 | ): 661 | if prompt is not None and isinstance(prompt, str): 662 | batch_size = 1 663 | elif prompt is not None and isinstance(prompt, list): 664 | batch_size = len(prompt) 665 | 666 | batch_size = batch_size * num_images_per_prompt 667 | 668 | mappers = self.mappers 669 | encoders = self.encoders 670 | if self.use_controlnet: 671 | # controlnet related stuff is always at index 0 672 | cn_input = cs[0] 673 | cs = cs[1:] 674 | 675 | mappers = mappers[1:] 676 | 677 | annotator = encoders[0] 678 | encoders = encoders[1:] 679 | 680 | with torch.no_grad(): 681 | cn_cond = annotator(cn_input) 682 | 683 | kwargs["image"] = cn_cond 684 | 685 | # we have to do two separate forward passes for the cfg with the loras 686 | # add our lora conditioning 687 | for i, (encoder, dp, mapper, c) in enumerate(zip(encoders, self.dps, mappers, cs)): 688 | 689 | if c.shape[0] != batch_size: 690 | assert c.shape[0] == 1 691 | c = torch.cat(batch_size * [c]) # repeat along batch dim 692 | 693 | neg_c = torch.zeros_like(c) 694 | if cfg_mask is not None and not cfg_mask[i]: 695 | print("no cfg for lora nr", i) 696 | c = torch.cat([c, c]) 697 | else: 698 | c = torch.cat([neg_c, c]) 699 | cond = encoder(c) 700 | mapped_cond = mapper(cond) 701 | # if isinstance(mapped_cond, tuple) or isinstance(mapped_cond, list): 702 | # mapped_cond = [mc.to(dtype) for mc in mapped_cond] 703 | # else: 704 | # mapped_cond = mapped_cond.to(dtype) 705 | 706 | dp.set_batch(mapped_cond) 707 | 708 | return self.pipe( 709 | prompt=prompt, 710 | num_images_per_prompt=num_images_per_prompt, 711 | generator=generator, 712 | **kwargs, 713 | ).images 714 | 715 | 716 | class SDXL(ModelBase): 717 | def __init__(self, pipeline_type, model_name, *args, **kwargs) -> None: 718 | super().__init__(pipeline_type, model_name, *args, **kwargs) 719 | 720 | def get_input(self, imgs: torch.Tensor, prompts: list[str]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: 721 | raise NotImplementedError() 722 | 723 | def compute_time_ids(self, device, weight_dtype): 724 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 725 | 726 | # we could adjust this if we knew that we had cropped / shifted images 727 | original_size = (1024, 1024) 728 | target_size = (1024, 1024) 729 | crops_coords_top_left = (0, 0) 730 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 731 | add_time_ids = torch.tensor([add_time_ids]) 732 | add_time_ids = add_time_ids.to(device, dtype=weight_dtype) 733 | return add_time_ids 734 | 735 | def get_conditioning( 736 | self, 737 | prompts: list[str], 738 | bsz: int, 739 | device: torch.device, 740 | dtype: torch.dtype, 741 | do_cfg=False, 742 | ): 743 | add_time_ids = self.compute_time_ids(device, dtype) 744 | negative_add_time_ids = add_time_ids # no conditioning for now 745 | 746 | ( 747 | prompt_embeds, 748 | negative_prompt_embeds, 749 | pooled_prompt_embeds, 750 | negative_pooled_prompt_embeds, 751 | ) = self.pipe.encode_prompt( 752 | prompt=prompts, 753 | device=device, 754 | num_images_per_prompt=1, 755 | do_classifier_free_guidance=do_cfg, 756 | ) 757 | 758 | if do_cfg: 759 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 760 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 761 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 762 | else: 763 | # prompt_embeds = prompt_embeds 764 | add_text_embeds = pooled_prompt_embeds 765 | # add_time_ids = add_time_ids.repeat 766 | 767 | prompt_embeds = prompt_embeds.to(device) 768 | add_text_embeds = add_text_embeds.to(device) 769 | add_time_ids = add_time_ids.to(device).repeat(bsz, 1) 770 | 771 | return { 772 | "prompt_embeds": prompt_embeds, 773 | "add_text_embeds": add_text_embeds, 774 | "add_time_ids": add_time_ids, 775 | } 776 | 777 | def forward_easy(self, *args, **kwargs): 778 | return self.forward(*args, **kwargs) 779 | 780 | def forward( 781 | self, 782 | imgs: torch.Tensor, 783 | prompts: list[str], 784 | cs: list[torch.Tensor], 785 | cfg_mask: list[bool] | None = None, 786 | skip_encode: bool = False, 787 | batch: dict | None = None, 788 | ) -> Union[torch.Tensor, torch.Tensor, torch.Tensor]: 789 | assert len(imgs.shape) == 4 790 | assert imgs.min() >= -1.0 791 | assert imgs.max() <= 1.0 792 | 793 | B = imgs.shape[0] 794 | 795 | with torch.no_grad(): 796 | # Convert images to latent space 797 | imgs = imgs.to(self.unet.device) 798 | latents = self.pipe.vae.encode(imgs).latent_dist.sample() 799 | latents = latents * self.pipe.vae.config.scaling_factor 800 | 801 | # prompt dropout 802 | prompts = ["" if random.random() < self.c_dropout else p for p in prompts] 803 | 804 | c = self.get_conditioning(prompts, B, latents.device, latents.dtype) 805 | 806 | unet_added_conditions = { 807 | "time_ids": c["add_time_ids"], 808 | "text_embeds": c["add_text_embeds"], 809 | } 810 | prompt_embeds_input = c["prompt_embeds"] 811 | 812 | # Sample noise that we'll add to the latents 813 | noise = torch.randn_like(latents) 814 | bsz = latents.shape[0] 815 | # Sample a random timestep for each image 816 | timesteps = torch.randint( 817 | 0, 818 | self.noise_scheduler.config.num_train_timesteps, 819 | (B,), 820 | device=latents.device, 821 | ) 822 | timesteps = timesteps.long() 823 | 824 | # Add noise to the latents according to the noise magnitude at each timestep 825 | # (this is the forward diffusion process) 826 | noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) 827 | 828 | # add our lora conditioning 829 | for i, (encoder, dp, mapper, c) in enumerate(zip(self.encoders, self.dps, self.mappers, cs)): 830 | if cfg_mask is None or cfg_mask[i]: 831 | dropout_mask = torch.rand(bsz, device=c.device) < self.c_dropout 832 | 833 | # apply dropout for cfg 834 | c[dropout_mask] = torch.zeros_like(c[dropout_mask]) 835 | 836 | with torch.no_grad(): 837 | cond = encoder(c) 838 | mapped_cond = mapper(cond) 839 | dp.set_batch(mapped_cond) 840 | 841 | # Predict the noise residual 842 | model_pred = self.unet( 843 | noisy_latents, 844 | timesteps, 845 | prompt_embeds_input, 846 | added_cond_kwargs=unet_added_conditions, 847 | ).sample 848 | 849 | # get the x0 prediction in ddpm sampling 850 | alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device=model_pred.device, dtype=model_pred.dtype) 851 | 852 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 853 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 854 | while len(sqrt_alpha_prod.shape) < len(model_pred.shape): 855 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 856 | 857 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 858 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 859 | while len(sqrt_one_minus_alpha_prod.shape) < len(model_pred.shape): 860 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 861 | 862 | x0 = (noisy_latents - sqrt_one_minus_alpha_prod * model_pred) / sqrt_alpha_prod 863 | 864 | # Get the target for loss depending on the prediction type 865 | if self.noise_scheduler.config.prediction_type == "epsilon": 866 | target = noise 867 | 868 | elif self.noise_scheduler.config.prediction_type == "v_prediction": 869 | target = self.noise_scheduler.get_velocity(latents, noise, timesteps) 870 | else: 871 | raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") 872 | 873 | loss = F.mse_loss(model_pred, target, reduction="mean") 874 | 875 | return model_pred, loss, x0 876 | 877 | @torch.no_grad() 878 | def sample( 879 | self, 880 | prompt, 881 | num_images_per_prompt, 882 | cs: list[torch.Tensor], 883 | generator, 884 | cfg_mask: list[bool] | None = None, 885 | prompt_offset_step: int = 0, 886 | skip_encode: bool = False, 887 | dtype=torch.float32, 888 | batch: dict | None = None, 889 | **kwargs, 890 | ): 891 | if prompt is not None and isinstance(prompt, str): 892 | batch_size = 1 893 | elif prompt is not None and isinstance(prompt, list): 894 | batch_size = len(prompt) 895 | 896 | batch_size = batch_size * num_images_per_prompt 897 | 898 | device = self.unet.device 899 | 900 | prompt_embeds = None 901 | pooled_prompt_embeds = None 902 | 903 | # we have to do two separate forward passes for the cfg with the loras 904 | # add our lora conditioning 905 | for i, (encoder, dp, mapper, c) in enumerate(zip(self.encoders, self.dps, self.mappers, cs)): 906 | 907 | if c.shape[0] != batch_size: 908 | assert c.shape[0] == 1 909 | c = torch.cat(batch_size * [c]) # repeat along batch dim 910 | 911 | neg_c = torch.zeros_like(c) 912 | if self.guidance_scale > 1: 913 | if cfg_mask is not None and not cfg_mask[i]: 914 | print("no cfg for lora nr", i) 915 | c = torch.cat([c, c]) 916 | else: 917 | c = torch.cat([neg_c, c]) 918 | cond = encoder(c) 919 | mapped_cond = mapper(cond) 920 | # if isinstance(mapped_cond, tuple) or isinstance(mapped_cond, list): 921 | # mapped_cond = [mc.to(dtype) for mc in mapped_cond] 922 | # else: 923 | # mapped_cond = mapped_cond.to(dtype) 924 | 925 | dp.set_batch(mapped_cond) 926 | 927 | return self.pipe( 928 | prompt=prompt, 929 | num_images_per_prompt=num_images_per_prompt, 930 | generator=generator, 931 | guidance_scale=self.guidance_scale, 932 | prompt_embeds=prompt_embeds, 933 | pooled_prompt_embeds=pooled_prompt_embeds, 934 | **kwargs, 935 | ).images 936 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | import torch 3 | from accelerate import Accelerator 4 | from pathlib import Path 5 | from torch.nn.utils import clip_grad_norm_ 6 | from functools import reduce 7 | import os 8 | 9 | # from src.model import ModelBase 10 | 11 | MODE = Literal[ 12 | "train", 13 | "val", 14 | "always", 15 | ] 16 | 17 | 18 | class DataProvider: 19 | def __init__(self): 20 | self.batch = None 21 | 22 | def set_batch(self, batch): 23 | if self.batch is not None: 24 | if isinstance(self.batch, torch.Tensor): 25 | assert self.batch.shape[1:] == batch.shape[1:], "Check: shapes probably should not change during training" 26 | 27 | self.batch = batch 28 | 29 | def get_batch(self): 30 | assert self.batch is not None, "Error: need to set a batch first" 31 | 32 | return self.batch 33 | 34 | def reset(self): 35 | self.batch = None 36 | 37 | 38 | def getattr_recursive(obj: Any, path: str) -> Any: 39 | parts = path.split(".") 40 | for part in parts: 41 | if part.isnumeric(): 42 | obj = obj[int(part)] 43 | else: 44 | obj = getattr(obj, part) 45 | return obj 46 | 47 | 48 | def add_lora_from_config(model, cfg: Any, device: torch.device, dtype: torch.dtype = torch.float32) -> list[bool]: 49 | total_dict_keys: list[str] = [] 50 | cfg_mask: list[bool] = [] 51 | 52 | global_ckpt_path = cfg.get("ckpt_path", None) 53 | project_root = Path(os.path.abspath(__file__)).parent.parent 54 | 55 | for name, l in cfg.lora.items(): 56 | if l.get("enable", "always") == "never": 57 | continue 58 | 59 | optimize = l.get("optimize", False) 60 | lora_cfg = l.config 61 | print(f"Adding {name} lora! Optimize: {optimize}") 62 | 63 | dp = DataProvider() 64 | mapper_network = l.mapper_network.to(device, dtype) 65 | encoder = l.encoder.to(device, dtype) 66 | local_ckpt_path = l.get("ckpt_path", None) 67 | 68 | model.add_lora_to_unet( 69 | lora_cfg, 70 | name=name, 71 | data_provider=dp, 72 | mapper=mapper_network, 73 | encoder=encoder, 74 | optimize=optimize, 75 | transforms=l.get("transforms", []), 76 | ) 77 | 78 | cfg_mask.append(l.get("cfg", True)) 79 | 80 | p = None 81 | if global_ckpt_path is not None: 82 | p = Path(project_root, global_ckpt_path) / name 83 | 84 | # local checkpoints path always override global ones 85 | if local_ckpt_path is not None: 86 | p = Path(project_root, local_ckpt_path) / name 87 | 88 | if p is not None: 89 | print("loaded checkpoint for lora", name) 90 | mapper_sd = torch.load(p / "mapper-checkpoint.pt", map_location=device) 91 | lora_sd = torch.load(p / "lora-checkpoint.pt", map_location=device) 92 | 93 | if os.path.isfile(p / "encoder-checkpoint.pt"): 94 | encoder_sd = torch.load(p / "encoder-checkpoint.pt", map_location=device) 95 | encoder.load_state_dict(encoder_sd) 96 | 97 | mapper_network.load_state_dict(mapper_sd) 98 | 99 | if not optimize: 100 | mapper_network.requires_grad_(False) 101 | mapper_network.eval() 102 | 103 | model.unet.load_state_dict(lora_sd, strict=False) 104 | model.unet.to(device, dtype) 105 | total_dict_keys += list(lora_sd.keys()) 106 | 107 | if len(total_dict_keys) > 0 and not cfg.get("ignore_check", False): 108 | assert set([v for vs in model.lora_state_dict_keys.values() for v in vs]) == set( 109 | total_dict_keys 110 | ), "Probably missing or incorrect checkpoint file path. Otherwise set ignore_check=true in config." 111 | 112 | return cfg_mask 113 | 114 | 115 | def toggle_loras(model, cfg: Any, mode: MODE): 116 | for name, l in cfg.lora.items(): 117 | if l.get("enable", "always") in [mode, "always"]: 118 | for layer in model.lora_layers[name]: 119 | layer.lora_scale = l.config.get("lora_scale", 1.0) 120 | else: 121 | try: 122 | for layer in model.lora_layers[name]: 123 | layer.lora_scale = 0.0 124 | except: 125 | print(f"LoRA {name} is disabled. Ignoring...") 126 | 127 | 128 | def global_gradient_norm(model): 129 | mappers_params = list(filter(lambda p: p.requires_grad, reduce(lambda x, y: x + list(y.parameters()), model.mappers, []))) 130 | encoder_params = list(filter(lambda p: p.requires_grad, reduce(lambda x, y: x + list(y.parameters()), model.encoders, []))) 131 | 132 | total_norm = clip_grad_norm_(model.params_to_optimize + mappers_params + encoder_params, 1e9) 133 | return total_norm.item() 134 | 135 | 136 | def save_checkpoint(unet_sds: dict[str, dict[str, torch.Tensor]], mapper_network_sd: list[dict[str, torch.Tensor]], encoder_sd: list[dict[str, torch.Tensor]] | None, path: Path): 137 | for i, (name, sd) in enumerate(unet_sds.items()): 138 | p = path / name 139 | p.mkdir(parents=True, exist_ok=True) 140 | 141 | torch.save(sd, p / "lora-checkpoint.pt") 142 | torch.save(mapper_network_sd[i], p / f"mapper-checkpoint.pt") 143 | if encoder_sd is not None and len(encoder_sd[i]) > 0: 144 | torch.save(encoder_sd[i], p / f"encoder-checkpoint.pt") 145 | 146 | 147 | def roll_list(l, n): 148 | # consistent with torch.roll 149 | return l[-n:] + l[:-n] 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import math 3 | from src.model import ModelBase 4 | from diffusers.optimization import get_scheduler 5 | import torch 6 | from accelerate import Accelerator 7 | from tqdm.auto import tqdm 8 | from pathlib import Path 9 | import numpy as np 10 | import torchvision.transforms.functional as TF 11 | from accelerate.logging import get_logger 12 | import signal 13 | import einops 14 | import os 15 | import traceback 16 | from functools import reduce 17 | 18 | from src.utils import add_lora_from_config, save_checkpoint 19 | 20 | 21 | torch.set_float32_matmul_precision("high") 22 | 23 | 24 | stop_training = False 25 | 26 | 27 | def signal_handler(sig, frame): 28 | global stop_training 29 | stop_training = True 30 | print("got stop signal") 31 | 32 | 33 | @hydra.main(config_path="configs", config_name="train") 34 | def main(cfg): 35 | signal.signal(signal.SIGUSR1, signal_handler) 36 | # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 37 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 38 | 39 | output_path = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) 40 | 41 | accelerator = Accelerator( 42 | project_dir=output_path / "logs", 43 | log_with="tensorboard", 44 | gradient_accumulation_steps=cfg.gradient_accumulation_steps, 45 | mixed_precision="bf16", 46 | ) 47 | 48 | logger = get_logger(__name__) 49 | 50 | logger.info("==================================") 51 | logger.info(cfg) 52 | logger.info(output_path) 53 | 54 | cfg = hydra.utils.instantiate(cfg) 55 | model: ModelBase = cfg.model 56 | 57 | model = model.to(accelerator.device) 58 | model.pipe.to(accelerator.device) 59 | n_loras = len(cfg.lora.keys()) 60 | 61 | cfg_mask = add_lora_from_config(model, cfg, accelerator.device) 62 | 63 | if cfg.get("gradient_checkpointing", False): 64 | model.unet.enable_gradient_checkpointing() 65 | 66 | dm = cfg.data 67 | 68 | train_dataloader = dm.train_dataloader() 69 | val_dataloader = dm.val_dataloader() 70 | 71 | mappers_params = list( 72 | filter(lambda p: p.requires_grad, reduce(lambda x, y: x + list(y.parameters()), model.mappers, [])) 73 | ) 74 | encoder_params = list( 75 | filter(lambda p: p.requires_grad, reduce(lambda x, y: x + list(y.parameters()), model.encoders, [])) 76 | ) 77 | 78 | optimizer = torch.optim.AdamW( 79 | model.params_to_optimize + mappers_params + encoder_params, 80 | lr=cfg.learning_rate, 81 | ) 82 | 83 | lr_scheduler = get_scheduler( 84 | cfg.lr_scheduler, 85 | optimizer=optimizer, 86 | ) 87 | 88 | logger.info(f"Number params Mapper Network(s) {sum(p.numel() for p in mappers_params):,}") 89 | logger.info(f"Number params Encoder Network(s) {sum(p.numel() for p in encoder_params):,}") 90 | logger.info(f"Number params all LoRAs(s) {sum(p.numel() for p in model.params_to_optimize):,}") 91 | 92 | logger.info("init trackers") 93 | if accelerator.is_main_process: 94 | accelerator.init_trackers("tensorboard") 95 | 96 | logger.info("prepare network") 97 | 98 | prepared = accelerator.prepare( 99 | *model.mappers, 100 | *model.encoders, 101 | model.unet, 102 | optimizer, 103 | train_dataloader, 104 | val_dataloader, 105 | lr_scheduler, 106 | ) 107 | 108 | mappers = prepared[: len(model.mappers)] 109 | encoders = prepared[len(model.mappers) : len(model.mappers) + len(model.encoders)] 110 | (unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) = prepared[ 111 | len(model.mappers) + len(model.encoders) : 112 | ] 113 | model.unet = unet 114 | model.mappers = mappers 115 | model.encoders = encoders 116 | 117 | try: 118 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) 119 | 120 | if cfg.get("max_train_steps", None) is None: 121 | max_train_steps = cfg.epochs * num_update_steps_per_epoch 122 | else: 123 | max_train_steps = cfg.max_train_steps 124 | except: 125 | max_train_steps = 10000000 126 | 127 | global_step = 0 128 | progress_bar = tqdm( 129 | range(global_step, max_train_steps), 130 | disable=not accelerator.is_main_process, 131 | ) 132 | progress_bar.set_description("Steps") 133 | 134 | logger.info("start training") 135 | for epoch in range(cfg.epochs): 136 | logger.info("new epoch") 137 | unet.train() 138 | map(lambda m: m.train(), mappers) 139 | map(lambda m: m.train(), encoders) 140 | 141 | for step, batch in enumerate(train_dataloader): 142 | with accelerator.accumulate(unet, *mappers, *encoders): 143 | imgs = batch["jpg"] 144 | imgs = imgs.to(accelerator.device) 145 | imgs = imgs.clip(-1.0, 1.0) 146 | B = imgs.shape[0] 147 | 148 | cs = [imgs] * n_loras 149 | 150 | if cfg.get("prompt", None) is not None: 151 | prompts = [cfg.prompt] * B 152 | else: 153 | prompts = batch["caption"] 154 | 155 | # cfg mask to always true such that the model always learns dropout 156 | model_pred, loss, x0, _ = model.forward_easy( 157 | imgs, 158 | prompts, 159 | cs, 160 | cfg_mask=[True for _ in cfg_mask], 161 | batch=batch, 162 | ) 163 | 164 | accelerator.backward(loss) 165 | 166 | optimizer.step() 167 | lr_scheduler.step() 168 | optimizer.zero_grad() 169 | 170 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 171 | progress_bar.set_postfix(**logs, refresh=False) 172 | accelerator.log(logs, step=global_step) 173 | 174 | # after every gradient update step 175 | if accelerator.sync_gradients: 176 | progress_bar.update(1) 177 | global_step += 1 178 | 179 | if global_step % cfg.val_steps != 0 and not stop_training: 180 | continue 181 | 182 | # VALIDATION 183 | with torch.no_grad(): 184 | try: 185 | unet.eval() 186 | map(lambda m: m.eval(), mappers) 187 | map(lambda m: m.eval(), encoders) 188 | 189 | generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed) 190 | 191 | val_prompts = [] 192 | for i, val_batch in enumerate(val_dataloader): 193 | 194 | B = val_batch["jpg"].shape[0] 195 | 196 | if i >= cfg.get("val_batches", 4): 197 | break 198 | 199 | if cfg.get("prompt", None) is not None: 200 | prompts = [cfg.prompt] * B 201 | else: 202 | prompts = val_batch["caption"] 203 | 204 | val_prompts = prompts 205 | 206 | imgs = val_batch["jpg"] 207 | imgs = imgs.to(accelerator.device) 208 | imgs = imgs.clip(-1.0, 1.0) 209 | 210 | cs = [imgs] * n_loras 211 | 212 | pipeline_args = { 213 | "prompt": prompts, 214 | "num_images_per_prompt": 1, 215 | "cs": cs, 216 | "generator": generator, 217 | "cfg_mask": cfg_mask, 218 | "batch": val_batch, 219 | } 220 | 221 | preds = model.sample(**pipeline_args) 222 | 223 | if accelerator.is_main_process: 224 | # IMAGE saving 225 | if cfg.get("log_c", False): 226 | # ALWAYS in [0, 1] 227 | lp = model.encoders[0](cs[-1]).cpu() 228 | else: 229 | lp = (imgs.cpu() + 1) / 2 230 | 231 | lp = torch.nn.functional.interpolate( 232 | lp, 233 | size=(cfg.size, cfg.size), 234 | mode="bicubic", 235 | align_corners=False, 236 | ) 237 | 238 | log_cond = TF.to_pil_image(einops.rearrange(lp, "b c h w -> c h (b w) ")) 239 | log_cond = log_cond.convert("RGB") 240 | log_cond = np.asarray(log_cond) 241 | 242 | log_pred = np.concatenate( 243 | [np.asarray(img.resize((cfg.size, cfg.size))) for img in preds], 244 | axis=1, 245 | ) 246 | 247 | for tracker in accelerator.trackers: 248 | if tracker.name == "tensorboard": 249 | np_images = np.concatenate([log_cond, log_pred], axis=0) 250 | tracker.writer.add_images( 251 | "validation", 252 | np_images, 253 | global_step, 254 | dataformats="HWC", 255 | ) 256 | tracker.writer.add_scalar("lr", lr_scheduler.get_last_lr()[0], global_step) 257 | tracker.writer.add_scalar("loss", loss.detach().item(), global_step) 258 | tracker.writer.add_text( 259 | "prompts", 260 | "------------".join(val_prompts), 261 | global_step, 262 | ) 263 | 264 | except Exception as e: 265 | print("!!!!!!!!!!!!!!!!!!!") 266 | print("ERROR IN VALIDATION") 267 | print(e) 268 | print(traceback.format_exc()) 269 | print("!!!!!!!!!!!!!!!!!!!") 270 | 271 | finally: 272 | if accelerator.is_main_process: 273 | save_checkpoint( 274 | model.get_lora_state_dict(accelerator.unwrap_model(unet)), 275 | [accelerator.unwrap_model(m).state_dict() for m in mappers], 276 | None, 277 | output_path / f"checkpoint-{global_step}", 278 | ) 279 | 280 | unet.train() 281 | map(lambda m: m.train(), mappers) 282 | map(lambda m: m.train(), encoders) 283 | 284 | if stop_training: 285 | break 286 | 287 | accelerator.wait_for_everyone() 288 | save_checkpoint( 289 | model.get_lora_state_dict(accelerator.unwrap_model(unet)), 290 | [accelerator.unwrap_model(m).state_dict() for m in mappers], 291 | None, 292 | output_path / f"checkpoint-{global_step}", 293 | ) 294 | 295 | 296 | if __name__ == "__main__": 297 | main() 298 | --------------------------------------------------------------------------------