├── __init__.py ├── src ├── core │ ├── scripts │ │ ├── __init__.py │ │ └── cli.py │ ├── templates │ │ ├── __init__.py │ │ └── diffusion.py │ ├── utils │ │ ├── __init__.py │ │ ├── save_and_load.py │ │ └── base_dto.py │ └── data │ │ ├── __init__.py │ │ └── bucketeer.py ├── modules │ ├── cnet_modules │ │ ├── pidinet │ │ │ ├── ckpts │ │ │ │ └── table5_pidinet.pth │ │ │ ├── __init__.py │ │ │ └── util.py │ │ ├── inpainting │ │ │ └── saliency_model.py │ │ └── face_id │ │ │ └── arcface.py │ ├── __init__.py │ ├── effnet.py │ ├── previewer.py │ ├── lora.py │ ├── common.py │ ├── stage_a.py │ ├── stage_b.py │ └── stage_c.py ├── train │ ├── __init__.py │ ├── example_train.sh │ ├── readme.md │ ├── train_c.py │ └── train_b.py ├── gdf │ ├── targets.py │ ├── scalers.py │ ├── samplers.py │ ├── readme.md │ ├── loss_weights.py │ ├── noise_conditions.py │ ├── __init__.py │ └── schedulers.py ├── inference │ ├── utils.py │ ├── generate.py │ └── text2img.py ├── webui.py └── model_downloader.py ├── figures ├── collage_1.jpg ├── collage_2.jpg ├── collage_3.jpg ├── collage_4.jpg ├── comparison.png ├── fernando.jpg ├── original.jpg ├── controlnet-face.jpg ├── controlnet-sr.jpg ├── model-overview.jpg ├── reconstructed.jpg ├── controlnet-canny.jpg ├── controlnet-paint.jpg ├── fernando_original.jpg ├── comparison-inference-speed.jpg ├── image-to-image-example-rodent.jpg ├── text-to-image-example-penguin.jpg └── image-variations-example-headset.jpg ├── configs ├── inference │ ├── stage_c_full_fp32.yaml │ ├── stage_c_full_bf16.yaml │ ├── stage_c_lite_fp32.yaml │ ├── stage_c_lite_bf16.yaml │ ├── stage_b_full_fp32.yaml │ ├── stage_b_full_bf16.yaml │ ├── stage_b_lite_fp32.yaml │ ├── stage_b_lite_bf16.yaml │ ├── controlnet_c_3b_canny.yaml │ ├── controlnet_c_3b_inpainting.yaml │ ├── controlnet_c_3b_sr.yaml │ ├── lora_c_3b.yaml │ └── controlnet_c_3b_identity.yaml └── training │ ├── finetune_c_3b.yaml │ ├── finetune_c_1b.yaml │ ├── finetune_b_3b.yaml │ ├── finetune_b_700m.yaml │ ├── finetune_c_3b_v.yaml │ ├── finetune_c_3b_lowres.yaml │ ├── controlnet_c_3b_canny.yaml │ ├── finetune_c_3b_lora.yaml │ ├── controlnet_c_3b_inpainting.yaml │ ├── controlnet_c_3b_sr.yaml │ └── controlnet_c_3b_identity.yaml ├── .gitignore ├── webui.sh ├── requirements.txt ├── LICENSE ├── readme.md ├── models ├── readme.md └── download_models.sh └── webui.bat /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import DiffusionCore -------------------------------------------------------------------------------- /figures/collage_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/collage_1.jpg -------------------------------------------------------------------------------- /figures/collage_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/collage_2.jpg -------------------------------------------------------------------------------- /figures/collage_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/collage_3.jpg -------------------------------------------------------------------------------- /figures/collage_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/collage_4.jpg -------------------------------------------------------------------------------- /figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/comparison.png -------------------------------------------------------------------------------- /figures/fernando.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/fernando.jpg -------------------------------------------------------------------------------- /figures/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/original.jpg -------------------------------------------------------------------------------- /figures/controlnet-face.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/controlnet-face.jpg -------------------------------------------------------------------------------- /figures/controlnet-sr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/controlnet-sr.jpg -------------------------------------------------------------------------------- /figures/model-overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/model-overview.jpg -------------------------------------------------------------------------------- /figures/reconstructed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/reconstructed.jpg -------------------------------------------------------------------------------- /figures/controlnet-canny.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/controlnet-canny.jpg -------------------------------------------------------------------------------- /figures/controlnet-paint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/controlnet-paint.jpg -------------------------------------------------------------------------------- /figures/fernando_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/fernando_original.jpg -------------------------------------------------------------------------------- /figures/comparison-inference-speed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/comparison-inference-speed.jpg -------------------------------------------------------------------------------- /figures/image-to-image-example-rodent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/image-to-image-example-rodent.jpg -------------------------------------------------------------------------------- /figures/text-to-image-example-penguin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/text-to-image-example-penguin.jpg -------------------------------------------------------------------------------- /figures/image-variations-example-headset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/figures/image-variations-example-headset.jpg -------------------------------------------------------------------------------- /src/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umisetokikaze/StableCascade-webui/HEAD/src/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_b import WurstCore as WurstCoreB 2 | from .train_c import WurstCore as WurstCoreC 3 | from .train_c_controlnet import WurstCore as ControlNetCore 4 | from .train_c_lora import WurstCore as LoraCore -------------------------------------------------------------------------------- /configs/inference/stage_c_full_fp32.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: float32 4 | 5 | effnet_checkpoint_path: models/effnet_encoder.safetensors 6 | previewer_checkpoint_path: models/previewer.safetensors 7 | generator_checkpoint_path: models/stage_c.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_c_full_bf16.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | effnet_checkpoint_path: models/effnet_encoder.safetensors 6 | previewer_checkpoint_path: models/previewer.safetensors 7 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_c_lite_fp32.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 1B 3 | dtype: float32 4 | 5 | effnet_checkpoint_path: models/effnet_encoder.safetensors 6 | previewer_checkpoint_path: models/previewer.safetensors 7 | generator_checkpoint_path: models/stage_c_lite.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_c_lite_bf16.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 1B 3 | dtype: bfloat16 4 | 5 | effnet_checkpoint_path: models/effnet_encoder.safetensors 6 | previewer_checkpoint_path: models/previewer.safetensors 7 | generator_checkpoint_path: models/stage_c_lite_bf16.safetensors -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.yml 2 | *.out 3 | dist_file_* 4 | __pycache__/* 5 | */__pycache__/* 6 | */**/__pycache__/* 7 | *_latest_output.jpg 8 | *_sample.jpg 9 | jobs/*.sh 10 | .ipynb_checkpoints 11 | *.safetensors 12 | *_test.yaml 13 | *.pt 14 | venv/* 15 | /temp_models 16 | test 17 | *.png 18 | /temp -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .effnet import EfficientNetEncoder 2 | from .stage_c import StageC 3 | from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock 4 | from .previewer import Previewer 5 | from .controlnet import ControlNet, ControlNetDeliverer 6 | from . import controlnet as controlnet_filters -------------------------------------------------------------------------------- /webui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 仮想環境ディレクトリの名前 4 | VENV_DIR="venv" 5 | 6 | # 仮想環境が存在しない場合は作成 7 | if [ ! -d "$VENV_DIR" ]; then 8 | python -m venv $VENV_DIR 9 | echo "仮想環境を作成しました。" 10 | fi 11 | 12 | # 仮想環境をアクティベート 13 | source ./$VENV_DIR/bin/activate 14 | 15 | # requirements.txt から必要なモジュールをインストール 16 | pip install -r requirements.txt 17 | 18 | # srcディレクトリ内のwebui.pyを実行 19 | python src/webui.py "$@" -------------------------------------------------------------------------------- /configs/inference/stage_b_full_fp32.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3B 3 | dtype: float32 4 | 5 | # For demonstration purposes in reconstruct_images.ipynb 6 | webdataset_path: file:inference/imagenet_1024.tar 7 | batch_size: 4 8 | image_size: 1024 9 | grad_accum_steps: 1 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | stage_a_checkpoint_path: models/stage_a.safetensors 13 | generator_checkpoint_path: models/stage_b.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_b_full_bf16.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3B 3 | dtype: bfloat16 4 | 5 | # For demonstration purposes in reconstruct_images.ipynb 6 | webdataset_path: file:inference/imagenet_1024.tar 7 | batch_size: 4 8 | image_size: 1024 9 | grad_accum_steps: 1 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | stage_a_checkpoint_path: models/stage_a.safetensors 13 | generator_checkpoint_path: models/stage_b_bf16.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_b_lite_fp32.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 700M 3 | dtype: float32 4 | 5 | # For demonstration purposes in reconstruct_images.ipynb 6 | webdataset_path: file:inference/imagenet_1024.tar 7 | batch_size: 4 8 | image_size: 1024 9 | grad_accum_steps: 1 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | stage_a_checkpoint_path: models/stage_a.safetensors 13 | generator_checkpoint_path: models/stage_b_lite.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_b_lite_bf16.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 700M 3 | dtype: bfloat16 4 | 5 | # For demonstration purposes in reconstruct_images.ipynb 6 | webdataset_path: file:inference/imagenet_1024.tar 7 | batch_size: 4 8 | image_size: 1024 9 | grad_accum_steps: 1 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | stage_a_checkpoint_path: models/stage_a.safetensors 13 | generator_checkpoint_path: models/stage_b_lite_bf16.safetensors -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_canny.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 7 | controlnet_filter: CannyFilter 8 | controlnet_filter_params: 9 | resize: 224 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | previewer_checkpoint_path: models/previewer.safetensors 13 | generator_checkpoint_path: models/stage_c_bf16.safetensors 14 | controlnet_checkpoint_path: models/canny.safetensors 15 | -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_inpainting.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 7 | controlnet_filter: InpaintFilter 8 | controlnet_filter_params: 9 | thresold: [0.04, 0.4] 10 | p_outpaint: 0.4 11 | 12 | effnet_checkpoint_path: models/effnet_encoder.safetensors 13 | previewer_checkpoint_path: models/previewer.safetensors 14 | generator_checkpoint_path: models/stage_c_bf16.safetensors 15 | controlnet_checkpoint_path: models/inpainting.safetensors 16 | -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_sr.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_bottleneck_mode: 'large' 7 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 8 | controlnet_filter: SREffnetFilter 9 | controlnet_filter_params: 10 | scale_factor: 0.5 11 | 12 | effnet_checkpoint_path: models/effnet_encoder.safetensors 13 | previewer_checkpoint_path: models/previewer.safetensors 14 | generator_checkpoint_path: models/stage_c_bf16.safetensors 15 | controlnet_checkpoint_path: models/super_resolution.safetensors 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | accelerate>=0.25.0 3 | torch==2.1.2+cu118 4 | torchvision==0.16.2+cu118 5 | transformers>=4.30.0 6 | numpy>=1.23.5 7 | kornia>=0.7.0 8 | insightface>=0.7.3 9 | opencv-python>=4.8.1.78 10 | tqdm>=4.66.1 11 | matplotlib>=3.7.4 12 | webdataset>=0.2.79 13 | wandb>=0.16.2 14 | munch>=4.0.0 15 | onnxruntime>=1.16.3 16 | einops>=0.7.0 17 | gradio 18 | onnx2torch>=1.5.13 19 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 20 | torchtools @ git+https://github.com/pabloppp/pytorch-tools 21 | -------------------------------------------------------------------------------- /src/modules/effnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | 4 | 5 | # EfficientNet 6 | class EfficientNetEncoder(nn.Module): 7 | def __init__(self, c_latent=16): 8 | super().__init__() 9 | self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() 10 | self.mapper = nn.Sequential( 11 | nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), 12 | nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 13 | ) 14 | 15 | def forward(self, x): 16 | return self.mapper(self.backbone(x)) 17 | 18 | -------------------------------------------------------------------------------- /configs/inference/lora_c_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # LoRA specific 6 | module_filters: ['.attn'] 7 | rank: 4 8 | train_tokens: 9 | # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized 10 | - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails 11 | 12 | effnet_checkpoint_path: models/effnet_encoder.safetensors 13 | previewer_checkpoint_path: models/previewer.safetensors 14 | generator_checkpoint_path: models/stage_c_bf16.safetensors 15 | lora_checkpoint_path: models/lora_fernando_10k.safetensors 16 | -------------------------------------------------------------------------------- /src/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN 2 | from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail 3 | 4 | # MOVE IT SOMERWHERE ELSE 5 | def update_weights_ema(tgt_model, src_model, beta=0.999): 6 | for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): 7 | self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) 8 | for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): 9 | self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_identity.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_bottleneck_mode: 'simple' 7 | controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] 8 | controlnet_filter: IdentityFilter 9 | controlnet_filter_params: 10 | max_faces: 4 11 | p_drop: 0.00 12 | p_full: 0.0 13 | 14 | effnet_checkpoint_path: models/effnet_encoder.safetensors 15 | previewer_checkpoint_path: models/previewer.safetensors 16 | generator_checkpoint_path: models/stage_c_bf16.safetensors 17 | controlnet_checkpoint_path: 18 | -------------------------------------------------------------------------------- /configs/training/finetune_c_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 512 14 | image_size: 768 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 100000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | adaptive_loss_weight: True 25 | 26 | # ema_start_iters: 5000 27 | # ema_iters: 100 28 | # ema_beta: 0.9 29 | 30 | webdataset_path: 31 | - s3://path/to/your/first/dataset/on/s3 32 | - s3://path/to/your/second/dataset/on/s3 33 | effnet_checkpoint_path: models/effnet_encoder.safetensors 34 | previewer_checkpoint_path: models/previewer.safetensors 35 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /configs/training/finetune_c_1b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_1b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 1B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 1024 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 10000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | # adaptive_loss_weight: True 25 | 26 | # ema_start_iters: 5000 27 | # ema_iters: 100 28 | # ema_beta: 0.9 29 | 30 | webdataset_path: 31 | - s3://path/to/your/first/dataset/on/s3 32 | - s3://path/to/your/second/dataset/on/s3 33 | effnet_checkpoint_path: models/effnet_encoder.safetensors 34 | previewer_checkpoint_path: models/previewer.safetensors 35 | generator_checkpoint_path: models/stage_c_lite_bf16.safetensors -------------------------------------------------------------------------------- /configs/training/finetune_b_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_b_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 1024 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | shift: 4 17 | grad_accum_steps: 1 18 | updates: 100000 19 | backup_every: 20000 20 | save_every: 1000 21 | warmup_updates: 1 22 | use_fsdp: True 23 | 24 | # GDF 25 | adaptive_loss_weight: True 26 | 27 | # ema_start_iters: 5000 28 | # ema_iters: 100 29 | # ema_beta: 0.9 30 | 31 | webdataset_path: 32 | - s3://path/to/your/first/dataset/on/s3 33 | - s3://path/to/your/second/dataset/on/s3 34 | effnet_checkpoint_path: models/effnet_encoder.safetensors 35 | stage_a_checkpoint_path: models/stage_a.safetensors 36 | generator_checkpoint_path: models/stage_b_bf16.safetensors 37 | -------------------------------------------------------------------------------- /configs/training/finetune_b_700m.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_b_700m_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 700M 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 512 14 | image_size: 1024 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | shift: 4 17 | grad_accum_steps: 1 18 | updates: 10000 19 | backup_every: 20000 20 | save_every: 2000 21 | warmup_updates: 1 22 | use_fsdp: True 23 | 24 | # GDF 25 | adaptive_loss_weight: True 26 | 27 | # ema_start_iters: 5000 28 | # ema_iters: 100 29 | # ema_beta: 0.9 30 | 31 | webdataset_path: 32 | - s3://path/to/your/first/dataset/on/s3 33 | - s3://path/to/your/second/dataset/on/s3 34 | effnet_checkpoint_path: models/effnet_encoder.safetensors 35 | stage_a_checkpoint_path: models/stage_a.safetensors 36 | generator_checkpoint_path: models/stage_b_lite_bf16.safetensors 37 | -------------------------------------------------------------------------------- /configs/training/finetune_c_3b_v.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 512 14 | image_size: 768 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 100000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | adaptive_loss_weight: True 25 | edm_objective: True 26 | 27 | # ema_start_iters: 5000 28 | # ema_iters: 100 29 | # ema_beta: 0.9 30 | 31 | webdataset_path: 32 | - s3://path/to/your/first/dataset/on/s3 33 | - s3://path/to/your/second/dataset/on/s3 34 | effnet_checkpoint_path: models/effnet_encoder.safetensors 35 | previewer_checkpoint_path: models/previewer.safetensors 36 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /src/core/scripts/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from .. import WarpCore 4 | from .. import templates 5 | 6 | 7 | def template_init(args): 8 | return '''' 9 | 10 | 11 | '''.strip() 12 | 13 | 14 | def init_template(args): 15 | parser = argparse.ArgumentParser(description='WarpCore template init tool') 16 | parser.add_argument('-t', '--template', type=str, default='WarpCore') 17 | args = parser.parse_args(args) 18 | 19 | if args.template == 'WarpCore': 20 | template_cls = WarpCore 21 | else: 22 | try: 23 | template_cls = __import__(args.template) 24 | except ModuleNotFoundError: 25 | template_cls = getattr(templates, args.template) 26 | print(template_cls) 27 | 28 | 29 | def main(): 30 | if len(sys.argv) < 2: 31 | print('Usage: core ') 32 | sys.exit(1) 33 | if sys.argv[1] == 'init': 34 | init_template(sys.argv[2:]) 35 | else: 36 | print('Unknown command') 37 | sys.exit(1) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /configs/training/finetune_c_3b_lowres.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 1024 14 | image_size: 384 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 100000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | adaptive_loss_weight: True 25 | 26 | # CUSTOM CAPTIONS GETTER & FILTERS 27 | # captions_getter: ['json', captions_getter] 28 | # dataset_filters: 29 | # - ['normalized_score', 'lambda s: s > 9.0'] 30 | # - ['pgen_normalized_score', 'lambda s: s > 3.0'] 31 | 32 | # ema_start_iters: 5000 33 | # ema_iters: 100 34 | # ema_beta: 0.9 35 | 36 | webdataset_path: 37 | - s3://path/to/your/first/dataset/on/s3 38 | - s3://path/to/your/second/dataset/on/s3 39 | effnet_checkpoint_path: models/effnet_encoder.safetensors 40 | previewer_checkpoint_path: models/previewer.safetensors 41 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /src/train/example_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=A100 3 | #SBATCH --nodes=1 4 | #SBATCH --gpus-per-node=8 5 | #SBATCH --ntasks-per-node=8 6 | #SBATCH --exclusive 7 | #SBATCH --job-name=your_job_name 8 | #SBATCH --account your_account_name 9 | 10 | module load openmpi 11 | module load cuda/11.8 12 | export NCCL_PROTO=simple 13 | 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | 18 | export NCCL_DEBUG=info 19 | export PYTHONFAULTHANDLER=1 20 | 21 | export CUDA_LAUNCH_BLOCKING=0 22 | export OMPI_MCA_mtl_base_verbose=1 23 | export FI_EFA_ENABLE_SHM_TRANSFER=0 24 | export FI_PROVIDER=efa 25 | export FI_EFA_TX_MIN_CREDITS=64 26 | export NCCL_TREE_THRESHOLD=0 27 | 28 | export PYTHONWARNINGS="ignore" 29 | export CXX=g++ 30 | 31 | source /path/to/your/python/environment/bin/activate 32 | 33 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 34 | export MASTER_ADDR=$master_addr 35 | export MASTER_PORT=33751 36 | export PYTHONPATH=./StableWurst 37 | echo "r$SLURM_NODEID master: $MASTER_ADDR" 38 | echo "r$SLURM_NODEID Launching python script" 39 | 40 | cd /path/to/your/directory 41 | rm dist_file 42 | srun python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml -------------------------------------------------------------------------------- /configs/training/controlnet_c_3b_canny.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_canny 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 10000 18 | backup_every: 2000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 25 | controlnet_filter: CannyFilter 26 | controlnet_filter_params: 27 | resize: 224 28 | # offset_noise: 0.1 29 | 30 | # CUSTOM CAPTIONS GETTER & FILTERS 31 | captions_getter: ['txt', identity] 32 | dataset_filters: 33 | - ['width', 'lambda w: w >= 768'] 34 | - ['height', 'lambda h: h >= 768'] 35 | 36 | # ema_start_iters: 5000 37 | # ema_iters: 100 38 | # ema_beta: 0.9 39 | 40 | webdataset_path: 41 | - s3://path/to/your/first/dataset/on/s3 42 | - s3://path/to/your/second/dataset/on/s3 43 | effnet_checkpoint_path: models/effnet_encoder.safetensors 44 | previewer_checkpoint_path: models/previewer.safetensors 45 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /configs/training/finetune_c_3b_lora.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_lora 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 32 14 | image_size: 768 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 4 17 | updates: 10000 18 | backup_every: 1000 19 | save_every: 100 20 | warmup_updates: 1 21 | # use_fsdp: True -> FSDP doesn't work at the moment for LoRA 22 | use_fsdp: False 23 | 24 | # GDF 25 | # adaptive_loss_weight: True 26 | 27 | # LoRA specific 28 | module_filters: ['.attn'] 29 | rank: 4 30 | train_tokens: 31 | # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized 32 | - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails 33 | 34 | 35 | # ema_start_iters: 5000 36 | # ema_iters: 100 37 | # ema_beta: 0.9 38 | 39 | webdataset_path: 40 | - s3://path/to/your/first/dataset/on/s3 41 | - s3://path/to/your/second/dataset/on/s3 42 | effnet_checkpoint_path: models/effnet_encoder.safetensors 43 | previewer_checkpoint_path: models/previewer.safetensors 44 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /src/gdf/targets.py: -------------------------------------------------------------------------------- 1 | class EpsilonTarget(): 2 | def __call__(self, x0, epsilon, logSNR, a, b): 3 | return epsilon 4 | 5 | def x0(self, noised, pred, logSNR, a, b): 6 | return (noised - pred * b) / a 7 | 8 | def epsilon(self, noised, pred, logSNR, a, b): 9 | return pred 10 | 11 | class X0Target(): 12 | def __call__(self, x0, epsilon, logSNR, a, b): 13 | return x0 14 | 15 | def x0(self, noised, pred, logSNR, a, b): 16 | return pred 17 | 18 | def epsilon(self, noised, pred, logSNR, a, b): 19 | return (noised - pred * a) / b 20 | 21 | class VTarget(): 22 | def __call__(self, x0, epsilon, logSNR, a, b): 23 | return a * epsilon - b * x0 24 | 25 | def x0(self, noised, pred, logSNR, a, b): 26 | squared_sum = a**2 + b**2 27 | return a/squared_sum * noised - b/squared_sum * pred 28 | 29 | def epsilon(self, noised, pred, logSNR, a, b): 30 | squared_sum = a**2 + b**2 31 | return b/squared_sum * noised + a/squared_sum * pred 32 | 33 | class RectifiedFlowsTarget(): 34 | def __call__(self, x0, epsilon, logSNR, a, b): 35 | return epsilon - x0 36 | 37 | def x0(self, noised, pred, logSNR, a, b): 38 | return noised - pred * b 39 | 40 | def epsilon(self, noised, pred, logSNR, a, b): 41 | return noised + pred * a 42 | -------------------------------------------------------------------------------- /configs/training/controlnet_c_3b_inpainting.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_inpainting 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 10000 18 | backup_every: 2000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 25 | controlnet_filter: InpaintFilter 26 | controlnet_filter_params: 27 | thresold: [0.04, 0.4] 28 | p_outpaint: 0.4 29 | offset_noise: 0.1 30 | 31 | # CUSTOM CAPTIONS GETTER & FILTERS 32 | captions_getter: ['txt', identity] 33 | dataset_filters: 34 | - ['width', 'lambda w: w >= 768'] 35 | - ['height', 'lambda h: h >= 768'] 36 | 37 | # ema_start_iters: 5000 38 | # ema_iters: 100 39 | # ema_beta: 0.9 40 | 41 | webdataset_path: 42 | - s3://path/to/your/first/dataset/on/s3 43 | - s3://path/to/your/second/dataset/on/s3 44 | effnet_checkpoint_path: models/effnet_encoder.safetensors 45 | previewer_checkpoint_path: models/previewer.safetensors 46 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /configs/training/controlnet_c_3b_sr.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_sr 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 30000 18 | backup_every: 5000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_bottleneck_mode: 'large' 25 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 26 | controlnet_filter: SREffnetFilter 27 | controlnet_filter_params: 28 | scale_factor: 0.5 29 | offset_noise: 0.1 30 | 31 | # CUSTOM CAPTIONS GETTER & FILTERS 32 | captions_getter: ['txt', identity] 33 | dataset_filters: 34 | - ['width', 'lambda w: w >= 768'] 35 | - ['height', 'lambda h: h >= 768'] 36 | 37 | # ema_start_iters: 5000 38 | # ema_iters: 100 39 | # ema_beta: 0.9 40 | 41 | webdataset_path: 42 | - s3://path/to/your/first/dataset/on/s3 43 | - s3://path/to/your/second/dataset/on/s3 44 | effnet_checkpoint_path: models/effnet_encoder.safetensors 45 | previewer_checkpoint_path: models/previewer.safetensors 46 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /src/inference/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | import requests 4 | import torchvision 5 | from math import ceil 6 | from io import BytesIO 7 | import matplotlib.pyplot as plt 8 | import torchvision.transforms.functional as F 9 | 10 | 11 | def download_image(url): 12 | return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") 13 | 14 | 15 | def resize_image(image, size=768): 16 | tensor_image = F.to_tensor(image) 17 | resized_image = F.resize(tensor_image, size, antialias=True) 18 | return resized_image 19 | 20 | 21 | def downscale_images(images, factor=3/4): 22 | scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) 23 | scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) 24 | return scaled_image 25 | 26 | 27 | def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): 28 | resolution_multiple = 42.67 29 | latent_height = ceil(height / compression_factor_b) 30 | latent_width = ceil(width / compression_factor_b) 31 | stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) 32 | 33 | latent_height = ceil(height / compression_factor_a) 34 | latent_width = ceil(width / compression_factor_a) 35 | stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) 36 | 37 | return stage_c_latent_shape, stage_b_latent_shape 38 | -------------------------------------------------------------------------------- /configs/training/controlnet_c_3b_identity.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_identity 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 200000 18 | backup_every: 2000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_bottleneck_mode: 'simple' 25 | controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] 26 | controlnet_filter: IdentityFilter 27 | controlnet_filter_params: 28 | max_faces: 4 29 | p_drop: 0.05 30 | p_full: 0.3 31 | # offset_noise: 0.1 32 | 33 | # CUSTOM CAPTIONS GETTER & FILTERS 34 | captions_getter: ['txt', identity] 35 | dataset_filters: 36 | - ['width', 'lambda w: w >= 768'] 37 | - ['height', 'lambda h: h >= 768'] 38 | 39 | # ema_start_iters: 5000 40 | # ema_iters: 100 41 | # ema_beta: 0.9 42 | 43 | webdataset_path: 44 | - s3://path/to/your/first/dataset/on/s3 45 | - s3://path/to/your/second/dataset/on/s3 46 | effnet_checkpoint_path: models/effnet_encoder.safetensors 47 | previewer_checkpoint_path: models/previewer.safetensors 48 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /src/gdf/scalers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class BaseScaler(): 4 | def __init__(self): 5 | self.stretched_limits = None 6 | 7 | def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): 8 | min_logSNR = schedule(torch.ones(1), shift=shift) 9 | max_logSNR = schedule(torch.zeros(1), shift=shift) 10 | 11 | min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] 12 | max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] 13 | self.stretched_limits = [min_a, max_a, min_b, max_b] 14 | return self.stretched_limits 15 | 16 | def stretch_limits(self, a, b): 17 | min_a, max_a, min_b, max_b = self.stretched_limits 18 | return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) 19 | 20 | def scalers(self, logSNR): 21 | raise NotImplementedError("this method needs to be overridden") 22 | 23 | def __call__(self, logSNR): 24 | a, b = self.scalers(logSNR) 25 | if self.stretched_limits is not None: 26 | a, b = self.stretch_limits(a, b) 27 | return a, b 28 | 29 | class VPScaler(BaseScaler): 30 | def scalers(self, logSNR): 31 | a_squared = logSNR.sigmoid() 32 | a = a_squared.sqrt() 33 | b = (1-a_squared).sqrt() 34 | return a, b 35 | 36 | class LERPScaler(BaseScaler): 37 | def scalers(self, logSNR): 38 | _a = logSNR.exp() - 1 39 | _a[_a == 0] = 1e-3 # Avoid division by zero 40 | a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) 41 | b = 1-a 42 | return a, b 43 | -------------------------------------------------------------------------------- /src/modules/cnet_modules/pidinet/__init__.py: -------------------------------------------------------------------------------- 1 | # Pidinet 2 | # https://github.com/hellozhuo/pidinet 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | from einops import rearrange 8 | from .model import pidinet 9 | from .util import annotator_ckpts_path, safe_step 10 | 11 | 12 | class PidiNetDetector: 13 | def __init__(self, device): 14 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" 15 | modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") 16 | if not os.path.exists(modelpath): 17 | from basicsr.utils.download_util import load_file_from_url 18 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 19 | self.netNetwork = pidinet() 20 | self.netNetwork.load_state_dict( 21 | {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) 22 | self.netNetwork.to(device).eval().requires_grad_(False) 23 | 24 | def __call__(self, input_image): # , safe=False): 25 | return self.netNetwork(input_image)[-1] 26 | # assert input_image.ndim == 3 27 | # input_image = input_image[:, :, ::-1].copy() 28 | # with torch.no_grad(): 29 | # image_pidi = torch.from_numpy(input_image).float().cuda() 30 | # image_pidi = image_pidi / 255.0 31 | # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') 32 | # edge = self.netNetwork(image_pidi)[-1] 33 | 34 | # if safe: 35 | # edge = safe_step(edge) 36 | # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 37 | # return edge[0][0] 38 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Welcome to the Stable Cascade Web UI 4 | ===================================== 5 | 6 | This repository implements Stability AI's Stable Cascade within a Web UI. It is still in the very early stages and is expected to gradually improve over time. 7 | 8 | ### Prerequisites: 9 | 10 | Due to library dependencies, Windows users will need to download the C++ build tools. 11 | 12 | => https://visualstudio.microsoft.com/visual-cpp-build-tools/ 13 | 14 | ### How to Run: 15 | 16 | **For Windows users**: 17 | 18 | > git clone https://github.com/umisetokikaze/StableCascade-webui 19 | 20 | > Execute webui.bat 21 | 22 | **For Linux users**: 23 | 24 | > git clone https://github.com/umisetokikaze/StableCascade-webui 25 | 26 | > Execute webui.sh 27 | 28 | Current Features: 29 | - A simple UI based on Gradio 30 | - Modifiable generation parameters 31 | - Basic image generation from text 32 | - Model download support 33 | 34 | 35 | ____________________________ 36 | 37 | 日本語: 38 | 39 | ## Welcome Stable Cascade Webui 40 | =============================== 41 | 42 | 43 | #このリポジトリはStability AI社のStable CascadeをWebuiに実装したものです。 44 | 未だに、非常に初期の段階にあり、徐々に改良されていく予定です。 45 | 46 | 47 | 48 | ### 前提: 49 | 50 | 使用ライブラリの関係で、Windowsユーザーの場合はCpp build toolsをダウンロードする必要があります。 51 | 52 | => https://visualstudio.microsoft.com/visual-cpp-build-tools/ 53 | 54 | ### 実行方法: 55 | 56 | ** windowsユーザー **: 57 | 58 | > git clone https://github.com/umisetokikaze/StableCascade-webui 59 | 60 | > webui.batを実行 61 | 62 | ** linuxユーザー **: 63 | 64 | > git clone https://github.com/umisetokikaze/StableCascade-webui 65 | > webui.shを実行 66 | 67 | 68 | 69 | 現在の機能 70 | - GradioベースのシンプルなUI 71 | - 変更可能な生成パラメーター 72 | - ベーシックなテキストからの画像生成 73 | - モデルダウンロード支援 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /src/modules/previewer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 5 | class Previewer(nn.Module): 6 | def __init__(self, c_in=16, c_hidden=512, c_out=3): 7 | super().__init__() 8 | self.blocks = nn.Sequential( 9 | nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels 10 | nn.GELU(), 11 | nn.BatchNorm2d(c_hidden), 12 | 13 | nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), 14 | nn.GELU(), 15 | nn.BatchNorm2d(c_hidden), 16 | 17 | nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 18 | nn.GELU(), 19 | nn.BatchNorm2d(c_hidden // 2), 20 | 21 | nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), 22 | nn.GELU(), 23 | nn.BatchNorm2d(c_hidden // 2), 24 | 25 | nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 26 | nn.GELU(), 27 | nn.BatchNorm2d(c_hidden // 4), 28 | 29 | nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), 30 | nn.GELU(), 31 | nn.BatchNorm2d(c_hidden // 4), 32 | 33 | nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 34 | nn.GELU(), 35 | nn.BatchNorm2d(c_hidden // 4), 36 | 37 | nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), 38 | nn.GELU(), 39 | nn.BatchNorm2d(c_hidden // 4), 40 | 41 | nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), 42 | ) 43 | 44 | def forward(self, x): 45 | return self.blocks(x) 46 | -------------------------------------------------------------------------------- /src/gdf/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SimpleSampler(): 4 | def __init__(self, gdf): 5 | self.gdf = gdf 6 | self.current_step = -1 7 | 8 | def __call__(self, *args, **kwargs): 9 | self.current_step += 1 10 | return self.step(*args, **kwargs) 11 | 12 | def init_x(self, shape): 13 | return torch.randn(*shape) 14 | 15 | def step(self, x, x0, epsilon, logSNR, logSNR_prev): 16 | raise NotImplementedError("You should override the 'apply' function.") 17 | 18 | class DDIMSampler(SimpleSampler): 19 | def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): 20 | a, b = self.gdf.input_scaler(logSNR) 21 | if len(a.shape) == 1: 22 | a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) 23 | 24 | a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) 25 | if len(a_prev.shape) == 1: 26 | a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) 27 | 28 | sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 29 | # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) 30 | x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) 31 | return x 32 | 33 | class DDPMSampler(DDIMSampler): 34 | def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): 35 | return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) 36 | 37 | class LCMSampler(SimpleSampler): 38 | def step(self, x, x0, epsilon, logSNR, logSNR_prev): 39 | a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) 40 | if len(a_prev.shape) == 1: 41 | a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) 42 | return x0 * a_prev + torch.randn_like(epsilon) * b_prev 43 | -------------------------------------------------------------------------------- /models/readme.md: -------------------------------------------------------------------------------- 1 | # Download Models 2 | 3 | As there are many models provided, let's make sure you only download the ones you need. 4 | The ``download_models.sh`` will make that very easy. The basic usage looks like this:
5 | ```bash 6 | bash download_models.sh essential variant bfloat16 7 | ``` 8 | 9 | **essential**
10 | This is optional and determines if you want to download the EfficientNet, Stage A & Previewer. 11 | If this is the first time you run this command, you should definitely do it, because we need it. 12 | 13 | **variant**
14 | This determines which varient you want to use for **Stage B** and **Stage C**. 15 | There are four options: 16 | 17 | | | Stage C (Large) | Stage C (Lite) | 18 | |---------------------|-----------------|----------------| 19 | | **Stage B (Large)** | big-big | big-small | 20 | | **Stage B (Lite)** | small-big | small-small | 21 | 22 | 23 | So if you want to download the large Stage B & large Stage C you can execute:
24 | ```bash 25 | bash download_models.sh essential big-big bfloat16 26 | ``` 27 | 28 | **bfloat16**
29 | The last argument is optional as well, and simply determines in which precision you download Stage B & Stage C. 30 | If you want a faster download, choose _bfloat16_ (if your machine supports it), otherwise use _float32_. 31 | 32 | ### Recommendation 33 | If your GPU allows for it, you should definitely go for the **large** Stage C, which has 3.6 billion parameters. 34 | It is a lot better and was finetuned a lot more. Also, the ControlNet and Lora examples are only for the large Stage C at the moment. 35 | For Stage B the difference is not so big. The **large** Stage B is better at reconstructing small details, 36 | but if your GPU is not so powerful, just go for the smaller one. 37 | 38 | ### Remark 39 | Unfortunately, you can not run the models in float16 at the moment. Only bfloat16 or float32 work for now. However, 40 | with some investigation, it should be possible to fix the overflowing and allow for inference in float16 as well. -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not defined PYTHON (set PYTHON=python) 4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") 5 | 6 | set ERROR_REPORTING=FALSE 7 | mkdir tmp 2>NUL 8 | 9 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 10 | if %ERRORLEVEL% == 0 goto :check_pip 11 | echo Couldn't launch python 12 | goto :show_stdout_stderr 13 | 14 | :check_pip 15 | %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt 16 | if %ERRORLEVEL% == 0 goto :start_venv 17 | if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr 18 | %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt 19 | if %ERRORLEVEL% == 0 goto :start_venv 20 | echo Couldn't install pip 21 | goto :show_stdout_stderr 22 | 23 | :start_venv 24 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt 25 | if %ERRORLEVEL% == 0 goto :activate_venv 26 | 27 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 28 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 29 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt 30 | if %ERRORLEVEL% == 0 goto :activate_venv 31 | 32 | echo Unable to create venv in directory "%VENV_DIR%" 33 | goto :show_stdout_stderr 34 | 35 | :activate_venv 36 | set PYTHON="%VENV_DIR%\Scripts\Python.exe" 37 | CALL %VENV_DIR%\Scripts\activate 38 | echo venv %PYTHON% 39 | 40 | if exist "%VENV_DIR%\share\" ( 41 | goto :launch 42 | ) else ( 43 | pip install -r requirements.txt 44 | ) 45 | 46 | :launch 47 | %PYTHON% ./src/webui.py %* 48 | pause 49 | exit /b 50 | 51 | 52 | 53 | :show_stdout_stderr 54 | 55 | echo. 56 | echo exit code: %errorlevel% 57 | 58 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 59 | if %size% equ 0 goto :show_stderr 60 | echo. 61 | echo stdout: 62 | type tmp\stdout.txt 63 | 64 | :show_stderr 65 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 66 | if %size% equ 0 goto :show_stderr 67 | echo. 68 | echo stderr: 69 | type tmp\stderr.txt 70 | 71 | :endofscript 72 | 73 | echo. 74 | echo Launch unsuccessful. Exiting. 75 | pause -------------------------------------------------------------------------------- /src/core/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from pathlib import Path 5 | import safetensors 6 | import wandb 7 | 8 | 9 | def create_folder_if_necessary(path): 10 | path = "/".join(path.split("/")[:-1]) 11 | Path(path).mkdir(parents=True, exist_ok=True) 12 | 13 | 14 | def safe_save(ckpt, path): 15 | try: 16 | os.remove(f"{path}.bak") 17 | except OSError: 18 | pass 19 | try: 20 | os.rename(path, f"{path}.bak") 21 | except OSError: 22 | pass 23 | if path.endswith(".pt") or path.endswith(".ckpt"): 24 | torch.save(ckpt, path) 25 | elif path.endswith(".json"): 26 | with open(path, "w", encoding="utf-8") as f: 27 | json.dump(ckpt, f, indent=4) 28 | elif path.endswith(".safetensors"): 29 | safetensors.torch.save_file(ckpt, path) 30 | else: 31 | raise ValueError(f"File extension not supported: {path}") 32 | 33 | 34 | def load_or_fail(path, wandb_run_id=None): 35 | accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] 36 | try: 37 | assert any( 38 | [path.endswith(ext) for ext in accepted_extensions] 39 | ), f"Automatic loading not supported for this extension: {path}" 40 | if not os.path.exists(path): 41 | checkpoint = None 42 | elif path.endswith(".pt") or path.endswith(".ckpt"): 43 | checkpoint = torch.load(path, map_location="cpu") 44 | elif path.endswith(".json"): 45 | with open(path, "r", encoding="utf-8") as f: 46 | checkpoint = json.load(f) 47 | elif path.endswith(".safetensors"): 48 | checkpoint = {} 49 | with safetensors.safe_open(path, framework="pt", device="cpu") as f: 50 | for key in f.keys(): 51 | checkpoint[key] = f.get_tensor(key) 52 | return checkpoint 53 | except Exception as e: 54 | if wandb_run_id is not None: 55 | wandb.alert( 56 | title=f"Corrupt checkpoint for run {wandb_run_id}", 57 | text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", 58 | ) 59 | raise e 60 | -------------------------------------------------------------------------------- /src/webui.py: -------------------------------------------------------------------------------- 1 | import random 2 | import gradio as gr 3 | from inference.generate import t2i 4 | import argparse # argparseをインポート 5 | 6 | # コマンドライン引数を解析するための関数を定義 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="Launch Gradio UI for text-to-image generation.") 9 | parser.add_argument("--share", action="store_true", help="If set, share the Gradio app publicly.") 10 | parser.add_argument("--listen", action="store_true", help="If set, listen on all network interfaces (0.0.0.0).") 11 | return parser.parse_args() 12 | 13 | def webui(share=False, listen=False): 14 | with gr.Blocks() as ui: 15 | with gr.Row(): 16 | precision = gr.Dropdown(value="bf16", choices=["bf16", "fp32"], label="Precision") 17 | model_size = gr.Dropdown(value="big-small", choices=["big-big", "big-small", "small-big", "small-small"], label="Model size") 18 | essential = gr.Checkbox(label="Download essential models", value=True) 19 | with gr.Row(): 20 | with gr.Column(): 21 | caption = gr.TextArea(label="Caption") 22 | batch_size = gr.Slider(1, 10, 4, step=1, label="Batch Size") 23 | height = gr.Slider(64, 2048, 1024, step=2, label="Height") 24 | width = gr.Slider(64, 2048, 1024, step=2, label="Width") 25 | seed = gr.Number(-1, 9999, -1, step=1, label="Seed") 26 | cfg_c = gr.Slider(1, 20, 4, step=0.1, label="cfg_c") 27 | cfg_b = gr.Slider(1, 20, 1.1, step=0.1, label="cfg_b") 28 | shift_c = gr.Slider(1, 7, 1, step=1, label="shift_c") 29 | shift_b = gr.Slider(1, 7, 2, step=1, label="shift_b") 30 | step_c = gr.Slider(1, 200, 20, step=2, label="step_c") 31 | step_b = gr.Slider(1, 200, 10, step=2, label="step_b") 32 | 33 | outdir = gr.Textbox(label="Output Directory", value="output") 34 | with gr.Column(): 35 | output = gr.Gallery(label="Output Image") 36 | run = gr.Button(value="Run") 37 | 38 | run.click(fn=t2i, inputs=[batch_size, caption, height, width, precision, model_size, essential, outdir, seed, cfg_c, cfg_b, shift_c, shift_b, step_c, step_b], outputs=[output]) 39 | 40 | 41 | ui.launch(share=share, server_name="0.0.0.0" if listen else None) 42 | 43 | if __name__ == "__main__": 44 | args = parse_args() # コマンドライン引数を解析 45 | webui(share=args.share, listen=args.listen) # shareとlisten引数をwebui関数に渡す -------------------------------------------------------------------------------- /src/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import yaml 4 | import os 5 | from .bucketeer import Bucketeer 6 | 7 | class MultiFilter(): 8 | def __init__(self, rules, default=False): 9 | self.rules = rules 10 | self.default = default 11 | 12 | def __call__(self, x): 13 | try: 14 | x_json = x['json'] 15 | if isinstance(x_json, bytes): 16 | x_json = json.loads(x_json) 17 | validations = [] 18 | for k, r in self.rules.items(): 19 | if isinstance(k, tuple): 20 | v = r(*[x_json[kv] for kv in k]) 21 | else: 22 | v = r(x_json[k]) 23 | validations.append(v) 24 | return all(validations) 25 | except Exception: 26 | return False 27 | 28 | class MultiGetter(): 29 | def __init__(self, rules): 30 | self.rules = rules 31 | 32 | def __call__(self, x_json): 33 | if isinstance(x_json, bytes): 34 | x_json = json.loads(x_json) 35 | outputs = [] 36 | for k, r in self.rules.items(): 37 | if isinstance(k, tuple): 38 | v = r(*[x_json[kv] for kv in k]) 39 | else: 40 | v = r(x_json[k]) 41 | outputs.append(v) 42 | if len(outputs) == 1: 43 | outputs = outputs[0] 44 | return outputs 45 | 46 | def setup_webdataset_path(paths, cache_path=None): 47 | if cache_path is None or not os.path.exists(cache_path): 48 | tar_paths = [] 49 | if isinstance(paths, str): 50 | paths = [paths] 51 | for path in paths: 52 | if path.strip().endswith(".tar"): 53 | # Avoid looking up s3 if we already have a tar file 54 | tar_paths.append(path) 55 | continue 56 | bucket = "/".join(path.split("/")[:3]) 57 | result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) 58 | files = result.stdout.decode('utf-8').split() 59 | files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] 60 | tar_paths += files 61 | 62 | with open(cache_path, 'w', encoding='utf-8') as outfile: 63 | yaml.dump(tar_paths, outfile, default_flow_style=False) 64 | else: 65 | with open(cache_path, 'r', encoding='utf-8') as file: 66 | tar_paths = yaml.safe_load(file) 67 | 68 | tar_paths_str = ",".join([f"{p}" for p in tar_paths]) 69 | return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" 70 | -------------------------------------------------------------------------------- /src/core/utils/base_dto.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass, _MISSING_TYPE 3 | from munch import Munch 4 | 5 | EXPECTED = "___REQUIRED___" 6 | EXPECTED_TRAIN = "___REQUIRED_TRAIN___" 7 | 8 | # pylint: disable=invalid-field-call 9 | def nested_dto(x, raw=False): 10 | return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) 11 | 12 | @dataclass(frozen=True) 13 | class Base: 14 | training: bool = None 15 | def __new__(cls, **kwargs): 16 | training = kwargs.get('training', True) 17 | setteable_fields = cls.setteable_fields(**kwargs) 18 | mandatory_fields = cls.mandatory_fields(**kwargs) 19 | invalid_kwargs = [ 20 | {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) 21 | ] 22 | print(mandatory_fields) 23 | assert ( 24 | len(invalid_kwargs) == 0 25 | ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." 26 | missing_kwargs = [f for f in mandatory_fields if f not in kwargs] 27 | assert ( 28 | len(missing_kwargs) == 0 29 | ), f"Required fields missing initializing this DTO: {missing_kwargs}." 30 | return object.__new__(cls) 31 | 32 | 33 | @classmethod 34 | def setteable_fields(cls, **kwargs): 35 | return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] 36 | 37 | @classmethod 38 | def mandatory_fields(cls, **kwargs): 39 | training = kwargs.get('training', True) 40 | return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] 41 | 42 | @classmethod 43 | def from_dict(cls, kwargs): 44 | for k in kwargs: 45 | if isinstance(kwargs[k], (dict, list, tuple)): 46 | kwargs[k] = Munch.fromDict(kwargs[k]) 47 | return cls(**kwargs) 48 | 49 | def to_dict(self): 50 | # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes 51 | selfdict = {} 52 | for k in dataclasses.fields(self): 53 | selfdict[k.name] = getattr(self, k.name) 54 | if isinstance(selfdict[k.name], Munch): 55 | selfdict[k.name] = selfdict[k.name].toDict() 56 | return selfdict 57 | -------------------------------------------------------------------------------- /src/modules/cnet_modules/inpainting/saliency_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | 9 | # MICRO RESNET 10 | class ResBlock(nn.Module): 11 | def __init__(self, channels): 12 | super(ResBlock, self).__init__() 13 | 14 | self.resblock = nn.Sequential( 15 | nn.ReflectionPad2d(1), 16 | nn.Conv2d(channels, channels, kernel_size=3), 17 | nn.InstanceNorm2d(channels, affine=True), 18 | nn.ReLU(), 19 | nn.ReflectionPad2d(1), 20 | nn.Conv2d(channels, channels, kernel_size=3), 21 | nn.InstanceNorm2d(channels, affine=True), 22 | ) 23 | 24 | def forward(self, x): 25 | out = self.resblock(x) 26 | return out + x 27 | 28 | 29 | class Upsample2d(nn.Module): 30 | def __init__(self, scale_factor): 31 | super(Upsample2d, self).__init__() 32 | 33 | self.interp = nn.functional.interpolate 34 | self.scale_factor = scale_factor 35 | 36 | def forward(self, x): 37 | x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') 38 | return x 39 | 40 | 41 | class MicroResNet(nn.Module): 42 | def __init__(self): 43 | super(MicroResNet, self).__init__() 44 | 45 | self.downsampler = nn.Sequential( 46 | nn.ReflectionPad2d(4), 47 | nn.Conv2d(3, 8, kernel_size=9, stride=4), 48 | nn.InstanceNorm2d(8, affine=True), 49 | nn.ReLU(), 50 | nn.ReflectionPad2d(1), 51 | nn.Conv2d(8, 16, kernel_size=3, stride=2), 52 | nn.InstanceNorm2d(16, affine=True), 53 | nn.ReLU(), 54 | nn.ReflectionPad2d(1), 55 | nn.Conv2d(16, 32, kernel_size=3, stride=2), 56 | nn.InstanceNorm2d(32, affine=True), 57 | nn.ReLU(), 58 | ) 59 | 60 | self.residual = nn.Sequential( 61 | ResBlock(32), 62 | nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), 63 | ResBlock(64), 64 | ) 65 | 66 | self.segmentator = nn.Sequential( 67 | nn.ReflectionPad2d(1), 68 | nn.Conv2d(64, 16, kernel_size=3), 69 | nn.InstanceNorm2d(16, affine=True), 70 | nn.ReLU(), 71 | Upsample2d(scale_factor=2), 72 | nn.ReflectionPad2d(4), 73 | nn.Conv2d(16, 1, kernel_size=9), 74 | nn.Sigmoid() 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.downsampler(x) 79 | out = self.residual(out) 80 | out = self.segmentator(out) 81 | return out 82 | -------------------------------------------------------------------------------- /src/model_downloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import requests 3 | import os 4 | import requests 5 | from tqdm import tqdm 6 | 7 | 8 | 9 | def download_file(url, path="./models"): 10 | local_filename = url.split('/')[-1] 11 | full_path = f"{path}/{local_filename}" 12 | 13 | # ファイルが既に存在するかチェック 14 | if os.path.exists(full_path): 15 | return 16 | 17 | # レスポンスヘッダからファイルサイズを取得 18 | response = requests.head(url) 19 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 20 | 21 | with requests.get(url, stream=True) as r: 22 | r.raise_for_status() 23 | with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, 24 | desc=f"{local_filename}", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}') as progress: 25 | with open(full_path, 'wb') as f: 26 | for chunk in r.iter_content(chunk_size=8192): 27 | progress.update(len(chunk)) 28 | f.write(chunk) 29 | print(f"Downloaded {local_filename} to {path}") 30 | 31 | def download_model(essential, model_size, presion): 32 | # Check for the optional "essential" argument and download the essential models if present 33 | if essential == True: 34 | print("Downloading Essential Models (EfficientNet, Stage A, Previewer)") 35 | download_file("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors") 36 | download_file("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors") 37 | download_file("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors") 38 | 39 | base_url = "https://huggingface.co/stabilityai/StableWurst/resolve/main/" 40 | model_sizes = { 41 | "big-big": [("stage_b_bf16.safetensors", "stage_c_bf16.safetensors") if presion == "bf16" else ("stage_b.safetensors", "stage_c.safetensors")], 42 | "big-small": [("stage_b_bf16.safetensors", "stage_c_lite_bf16.safetensors") if presion == "bf16" else ("stage_b.safetensors", "stage_c_lite.safetensors")], 43 | "small-big": [("stage_b_lite_bf16.safetensors", "stage_c_bf16.safetensors") if presion == "bf16" else ("stage_b_lite.safetensors", "stage_c.safetensors")], 44 | "small-small": [("stage_b_lite_bf16.safetensors", "stage_c_lite_bf16.safetensors") if presion == "bf16" else ("stage_b_lite.safetensors", "stage_c_lite.safetensors")] 45 | } 46 | 47 | if model_size in model_sizes: 48 | for filename in model_sizes[model_size][0]: 49 | download_file(base_url + filename) 50 | else: 51 | print("Invalid second argument. Please provide a valid argument: big-big, big-small, small-big, or small-small.") 52 | sys.exit(2) 53 | 54 | -------------------------------------------------------------------------------- /src/modules/cnet_modules/pidinet/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | 7 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 8 | 9 | 10 | def HWC3(x): 11 | assert x.dtype == np.uint8 12 | if x.ndim == 2: 13 | x = x[:, :, None] 14 | assert x.ndim == 3 15 | H, W, C = x.shape 16 | assert C == 1 or C == 3 or C == 4 17 | if C == 3: 18 | return x 19 | if C == 1: 20 | return np.concatenate([x, x, x], axis=2) 21 | if C == 4: 22 | color = x[:, :, 0:3].astype(np.float32) 23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 24 | y = color * alpha + 255.0 * (1.0 - alpha) 25 | y = y.clip(0, 255).astype(np.uint8) 26 | return y 27 | 28 | 29 | def resize_image(input_image, resolution): 30 | H, W, C = input_image.shape 31 | H = float(H) 32 | W = float(W) 33 | k = float(resolution) / min(H, W) 34 | H *= k 35 | W *= k 36 | H = int(np.round(H / 64.0)) * 64 37 | W = int(np.round(W / 64.0)) * 64 38 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 39 | return img 40 | 41 | 42 | def nms(x, t, s): 43 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 44 | 45 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 46 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 47 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 48 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 49 | 50 | y = np.zeros_like(x) 51 | 52 | for f in [f1, f2, f3, f4]: 53 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 54 | 55 | z = np.zeros_like(y, dtype=np.uint8) 56 | z[y > t] = 255 57 | return z 58 | 59 | 60 | def make_noise_disk(H, W, C, F): 61 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 62 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 63 | noise = noise[F: F + H, F: F + W] 64 | noise -= np.min(noise) 65 | noise /= np.max(noise) 66 | if C == 1: 67 | noise = noise[:, :, None] 68 | return noise 69 | 70 | 71 | def min_max_norm(x): 72 | x -= np.min(x) 73 | x /= np.maximum(np.max(x), 1e-5) 74 | return x 75 | 76 | 77 | def safe_step(x, step=2): 78 | y = x.astype(np.float32) * float(step + 1) 79 | y = y.astype(np.int32).astype(np.float32) / float(step) 80 | return y 81 | 82 | 83 | def img2mask(img, H, W, low=10, high=90): 84 | assert img.ndim == 3 or img.ndim == 2 85 | assert img.dtype == np.uint8 86 | 87 | if img.ndim == 3: 88 | y = img[:, :, random.randrange(0, img.shape[2])] 89 | else: 90 | y = img 91 | 92 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 93 | 94 | if random.uniform(0, 1) < 0.5: 95 | y = 255 - y 96 | 97 | return y < np.percentile(y, random.randrange(low, high)) 98 | -------------------------------------------------------------------------------- /src/modules/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LoRA(nn.Module): 6 | def __init__(self, layer, name='weight', rank=16, alpha=1): 7 | super().__init__() 8 | weight = getattr(layer, name) 9 | self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) 10 | self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) 11 | nn.init.normal_(self.lora_up, mean=0, std=1) 12 | 13 | self.scale = alpha / rank 14 | self.enabled = True 15 | 16 | def forward(self, original_weights): 17 | if self.enabled: 18 | lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) 19 | lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale 20 | return original_weights + lora_weights 21 | else: 22 | return original_weights 23 | 24 | 25 | def apply_lora(model, filters=None, rank=16): 26 | def check_parameter(module, name): 27 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( 28 | getattr(module, name), nn.Parameter) 29 | 30 | for name, module in model.named_modules(): 31 | if filters is None or any([f in name for f in filters]): 32 | if check_parameter(module, "weight"): 33 | device, dtype = module.weight.device, module.weight.dtype 34 | torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) 35 | elif check_parameter(module, "in_proj_weight"): 36 | device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype 37 | torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) 38 | 39 | 40 | class ReToken(nn.Module): 41 | def __init__(self, indices=None): 42 | super().__init__() 43 | assert indices is not None 44 | self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) 45 | self.register_buffer('indices', torch.tensor(indices)) 46 | self.enabled = True 47 | 48 | def forward(self, embeddings): 49 | if self.enabled: 50 | embeddings = embeddings.clone() 51 | for i, idx in enumerate(self.indices): 52 | embeddings[idx] += self.embeddings[i] 53 | return embeddings 54 | 55 | 56 | def apply_retoken(module, indices=None): 57 | def check_parameter(module, name): 58 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( 59 | getattr(module, name), nn.Parameter) 60 | 61 | if check_parameter(module, "weight"): 62 | device, dtype = module.weight.device, module.weight.dtype 63 | torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) 64 | 65 | 66 | def remove_lora(model, leave_parametrized=True): 67 | for module in model.modules(): 68 | if torch.nn.utils.parametrize.is_parametrized(module, "weight"): 69 | nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) 70 | elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): 71 | nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) 72 | -------------------------------------------------------------------------------- /src/gdf/readme.md: -------------------------------------------------------------------------------- 1 | # Generic Diffusion Framework (GDF) 2 | 3 | # Basic usage 4 | GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM 5 | , EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different 6 | frameworks 7 | 8 | Using GDF is very straighforward, first of all just define an instance of the GDF class: 9 | 10 | ```python 11 | from gdf import GDF 12 | from gdf import CosineSchedule 13 | from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight 14 | 15 | gdf = GDF( 16 | schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), 17 | input_scaler=VPScaler(), target=EpsilonTarget(), 18 | noise_cond=CosineTNoiseCond(), 19 | loss_weight=P2LossWeight(), 20 | ) 21 | ``` 22 | 23 | You need to define the following components: 24 | * **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. 25 | * **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. 26 | * **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) 27 | * **Target**: What the target is during training, usually: epsilon, x0 or v 28 | * **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` 29 | * **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use 30 | 31 | All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: 32 | ```python 33 | class VPScaler(): 34 | def __call__(self, logSNR): 35 | a_squared = logSNR.sigmoid() 36 | a = a_squared.sqrt() 37 | b = (1-a_squared).sqrt() 38 | return a, b 39 | 40 | ``` 41 | 42 | So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... 43 | 44 | ### Training 45 | 46 | When you define your training loop you can get all you need by just doing: 47 | ```python 48 | shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution 49 | for inputs, extra_conditions in dataloader_iterator: 50 | noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) 51 | pred = diffusion_model(noised, noise_cond, extra_conditions) 52 | 53 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 54 | loss_adjusted = (loss * loss_weight).mean() 55 | 56 | loss_adjusted.backward() 57 | optimizer.step() 58 | optimizer.zero_grad(set_to_none=True) 59 | ``` 60 | 61 | And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the 62 | training from the GDF class. 63 | 64 | ### Sampling 65 | 66 | The other important part is sampling, when you want to use this framework to sample you can just do the following: 67 | 68 | ```python 69 | from gdf import DDPMSampler 70 | 71 | shift = 1 72 | sampling_configs = { 73 | "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, 74 | "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) 75 | } 76 | 77 | *_, (sampled, _, _) = gdf.sample( 78 | diffusion_model, {"cond": extra_conditions}, latents.shape, 79 | unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, 80 | device=device, **sampling_configs 81 | ) 82 | ``` 83 | 84 | # Available modules 85 | 86 | TODO 87 | -------------------------------------------------------------------------------- /src/gdf/loss_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # --- Loss Weighting 5 | class BaseLossWeight(): 6 | def weight(self, logSNR): 7 | raise NotImplementedError("this method needs to be overridden") 8 | 9 | def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): 10 | clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range 11 | if shift != 1: 12 | logSNR = logSNR.clone() + 2 * np.log(shift) 13 | return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) 14 | 15 | class ComposedLossWeight(BaseLossWeight): 16 | def __init__(self, div, mul): 17 | self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul 18 | self.div = [div] if isinstance(div, BaseLossWeight) else div 19 | 20 | def weight(self, logSNR): 21 | prod, div = 1, 1 22 | for m in self.mul: 23 | prod *= m.weight(logSNR) 24 | for d in self.div: 25 | div *= d.weight(logSNR) 26 | return prod/div 27 | 28 | class ConstantLossWeight(BaseLossWeight): 29 | def __init__(self, v=1): 30 | self.v = v 31 | 32 | def weight(self, logSNR): 33 | return torch.ones_like(logSNR) * self.v 34 | 35 | class SNRLossWeight(BaseLossWeight): 36 | def weight(self, logSNR): 37 | return logSNR.exp() 38 | 39 | class P2LossWeight(BaseLossWeight): 40 | def __init__(self, k=1.0, gamma=1.0, s=1.0): 41 | self.k, self.gamma, self.s = k, gamma, s 42 | 43 | def weight(self, logSNR): 44 | return (self.k + (logSNR * self.s).exp()) ** -self.gamma 45 | 46 | class SNRPlusOneLossWeight(BaseLossWeight): 47 | def weight(self, logSNR): 48 | return logSNR.exp() + 1 49 | 50 | class MinSNRLossWeight(BaseLossWeight): 51 | def __init__(self, max_snr=5): 52 | self.max_snr = max_snr 53 | 54 | def weight(self, logSNR): 55 | return logSNR.exp().clamp(max=self.max_snr) 56 | 57 | class MinSNRPlusOneLossWeight(BaseLossWeight): 58 | def __init__(self, max_snr=5): 59 | self.max_snr = max_snr 60 | 61 | def weight(self, logSNR): 62 | return (logSNR.exp() + 1).clamp(max=self.max_snr) 63 | 64 | class TruncatedSNRLossWeight(BaseLossWeight): 65 | def __init__(self, min_snr=1): 66 | self.min_snr = min_snr 67 | 68 | def weight(self, logSNR): 69 | return logSNR.exp().clamp(min=self.min_snr) 70 | 71 | class SechLossWeight(BaseLossWeight): 72 | def __init__(self, div=2): 73 | self.div = div 74 | 75 | def weight(self, logSNR): 76 | return 1/(logSNR/self.div).cosh() 77 | 78 | class DebiasedLossWeight(BaseLossWeight): 79 | def weight(self, logSNR): 80 | return 1/logSNR.exp().sqrt() 81 | 82 | class SigmoidLossWeight(BaseLossWeight): 83 | def __init__(self, s=1): 84 | self.s = s 85 | 86 | def weight(self, logSNR): 87 | return (logSNR * self.s).sigmoid() 88 | 89 | class AdaptiveLossWeight(BaseLossWeight): 90 | def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): 91 | self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) 92 | self.bucket_losses = torch.ones(buckets) 93 | self.weight_range = weight_range 94 | 95 | def weight(self, logSNR): 96 | indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) 97 | return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) 98 | 99 | def update_buckets(self, logSNR, loss, beta=0.99): 100 | indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() 101 | self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) 102 | -------------------------------------------------------------------------------- /src/core/data/bucketeer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from torchtools.transforms import SmartCrop 5 | import math 6 | 7 | class Bucketeer(): 8 | def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): 9 | assert crop_mode in ['center', 'random', 'smart'] 10 | self.crop_mode = crop_mode 11 | self.ratios = ratios 12 | if reverse_list: 13 | for r in list(ratios): 14 | if 1/r not in self.ratios: 15 | self.ratios.append(1/r) 16 | self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios] 17 | self.batch_size = dataloader.batch_size 18 | self.iterator = iter(dataloader) 19 | self.buckets = {s: [] for s in self.sizes} 20 | self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None 21 | self.p_random_ratio = p_random_ratio 22 | self.interpolate_nearest = interpolate_nearest 23 | 24 | def get_available_batch(self): 25 | for b in self.buckets: 26 | if len(self.buckets[b]) >= self.batch_size: 27 | batch = self.buckets[b][:self.batch_size] 28 | self.buckets[b] = self.buckets[b][self.batch_size:] 29 | return batch 30 | return None 31 | 32 | def get_closest_size(self, x): 33 | if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: 34 | best_size_idx = np.random.randint(len(self.ratios)) 35 | else: 36 | w, h = x.size(-1), x.size(-2) 37 | best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) 38 | return self.sizes[best_size_idx] 39 | 40 | def get_resize_size(self, orig_size, tgt_size): 41 | if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: 42 | alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) 43 | resize_size = max(alt_min, min(tgt_size)) 44 | else: 45 | alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) 46 | resize_size = max(alt_max, max(tgt_size)) 47 | return resize_size 48 | 49 | def __next__(self): 50 | batch = self.get_available_batch() 51 | while batch is None: 52 | elements = next(self.iterator) 53 | for dct in elements: 54 | img = dct['images'] 55 | size = self.get_closest_size(img) 56 | resize_size = self.get_resize_size(img.shape[-2:], size) 57 | if self.interpolate_nearest: 58 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) 59 | else: 60 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) 61 | if self.crop_mode == 'center': 62 | img = torchvision.transforms.functional.center_crop(img, size) 63 | elif self.crop_mode == 'random': 64 | img = torchvision.transforms.RandomCrop(size)(img) 65 | elif self.crop_mode == 'smart': 66 | self.smartcrop.output_size = size 67 | img = self.smartcrop(img) 68 | self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) 69 | batch = self.get_available_batch() 70 | 71 | out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} 72 | return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} 73 | -------------------------------------------------------------------------------- /src/gdf/noise_conditions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class BaseNoiseCond(): 5 | def __init__(self, *args, shift=1, clamp_range=None, **kwargs): 6 | clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range 7 | self.shift = shift 8 | self.clamp_range = clamp_range 9 | self.setup(*args, **kwargs) 10 | 11 | def setup(self, *args, **kwargs): 12 | pass # this method is optional, override it if required 13 | 14 | def cond(self, logSNR): 15 | raise NotImplementedError("this method needs to be overriden") 16 | 17 | def __call__(self, logSNR): 18 | if self.shift != 1: 19 | logSNR = logSNR.clone() + 2 * np.log(self.shift) 20 | return self.cond(logSNR).clamp(*self.clamp_range) 21 | 22 | class CosineTNoiseCond(BaseNoiseCond): 23 | def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] 24 | self.s = torch.tensor([s]) 25 | self.clamp_range = clamp_range 26 | self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 27 | 28 | def cond(self, logSNR): 29 | var = logSNR.sigmoid() 30 | var = var.clamp(*self.clamp_range) 31 | s, min_var = self.s.to(var.device), self.min_var.to(var.device) 32 | t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s 33 | return t 34 | 35 | class EDMNoiseCond(BaseNoiseCond): 36 | def cond(self, logSNR): 37 | return -logSNR/8 38 | 39 | class SigmoidNoiseCond(BaseNoiseCond): 40 | def cond(self, logSNR): 41 | return (-logSNR).sigmoid() 42 | 43 | class LogSNRNoiseCond(BaseNoiseCond): 44 | def cond(self, logSNR): 45 | return logSNR 46 | 47 | class EDMSigmaNoiseCond(BaseNoiseCond): 48 | def setup(self, sigma_data=1): 49 | self.sigma_data = sigma_data 50 | 51 | def cond(self, logSNR): 52 | return torch.exp(-logSNR / 2) * self.sigma_data 53 | 54 | class RectifiedFlowsNoiseCond(BaseNoiseCond): 55 | def cond(self, logSNR): 56 | _a = logSNR.exp() - 1 57 | _a[_a == 0] = 1e-3 # Avoid division by zero 58 | a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) 59 | return a 60 | 61 | # Any NoiseCond that cannot be described easily as a continuous function of t 62 | # It needs to define self.x and self.y in the setup() method 63 | class PiecewiseLinearNoiseCond(BaseNoiseCond): 64 | def setup(self): 65 | self.x = None 66 | self.y = None 67 | 68 | def piecewise_linear(self, y, xs, ys): 69 | indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) 70 | x_min, x_max = xs[indices], xs[indices+1] 71 | y_min, y_max = ys[indices], ys[indices+1] 72 | x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) 73 | return x 74 | 75 | def cond(self, logSNR): 76 | var = logSNR.sigmoid() 77 | t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) 78 | return t 79 | 80 | class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): 81 | def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): 82 | self.total_steps = total_steps 83 | linear_range_sqrt = [r**0.5 for r in linear_range] 84 | self.x = torch.linspace(0, 1, total_steps+1) 85 | 86 | alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 87 | self.y = alphas.cumprod(dim=-1) 88 | 89 | def cond(self, logSNR): 90 | return super().cond(logSNR).clamp(0, 1) 91 | 92 | class DiscreteNoiseCond(BaseNoiseCond): 93 | def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): 94 | self.noise_cond = noise_cond 95 | self.steps = steps 96 | self.continuous_range = continuous_range 97 | 98 | def cond(self, logSNR): 99 | cond = self.noise_cond(logSNR) 100 | cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) 101 | return cond.mul(self.steps).long() 102 | -------------------------------------------------------------------------------- /models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if at least two arguments were provided (excluding the optional first one) 4 | if [ $# -lt 2 ]; then 5 | echo "Insufficient arguments provided. At least two arguments are required." 6 | exit 1 7 | fi 8 | 9 | # Check for the optional "essential" argument and download the essential models if present 10 | if [ "$1" == "essential" ]; then 11 | echo "Downloading Essential Models (EfficientNet, Stage A, Previewer)" 12 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors -P . -q --show-progress 13 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors -P . -q --show-progress 14 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors -P . -q --show-progress 15 | shift # Move the arguments, $2 becomes $1, $3 becomes $2, etc. 16 | fi 17 | 18 | # Now, $1 is the second argument due to the potential shift above 19 | second_argument="$1" 20 | binary_decision="${2:-bfloat16}" # Use default or specific binary value if provided 21 | 22 | case $second_argument in 23 | big-big) 24 | if [ "$binary_decision" == "bfloat16" ]; then 25 | echo "Downloading Large Stage B & Large Stage C" 26 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_bf16.safetensors -P . -q --show-progress 27 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -P . -q --show-progress 28 | else 29 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b.safetensors -P . -q --show-progress 30 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c.safetensors -P . -q --show-progress 31 | fi 32 | ;; 33 | big-small) 34 | if [ "$binary_decision" == "bfloat16" ]; then 35 | echo "Downloading Large Stage B & Small Stage C (BFloat16)" 36 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_bf16.safetensors -P . -q --show-progress 37 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite_bf16.safetensors -P . -q --show-progress 38 | else 39 | echo "Downloading Large Stage B & Small Stage C" 40 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b.safetensors -P . -q --show-progress 41 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite.safetensors -P . -q --show-progress 42 | fi 43 | ;; 44 | small-big) 45 | if [ "$binary_decision" == "bfloat16" ]; then 46 | echo "Downloading Small Stage B & Large Stage C (BFloat16)" 47 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -P . -q --show-progress 48 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -P . -q --show-progress 49 | else 50 | echo "Downloading Small Stage B & Large Stage C" 51 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite.safetensors -P . -q --show-progress 52 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c.safetensors -P . -q --show-progress 53 | fi 54 | ;; 55 | small-small) 56 | if [ "$binary_decision" == "bfloat16" ]; then 57 | echo "Downloading Small Stage B & Small Stage C (BFloat16)" 58 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -P . -q --show-progress 59 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite_bf16.safetensors -P . -q --show-progress 60 | else 61 | echo "Downloading Small Stage B & Small Stage C" 62 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite.safetensors -P . -q --show-progress 63 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite.safetensors -P . -q --show-progress 64 | fi 65 | ;; 66 | *) 67 | echo "Invalid second argument. Please provide a valid argument: big-big, big-small, small-big, or small-small." 68 | exit 2 69 | ;; 70 | esac 71 | -------------------------------------------------------------------------------- /src/modules/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Linear(torch.nn.Linear): 5 | def reset_parameters(self): 6 | return None 7 | 8 | class Conv2d(torch.nn.Conv2d): 9 | def reset_parameters(self): 10 | return None 11 | 12 | 13 | class Attention2D(nn.Module): 14 | def __init__(self, c, nhead, dropout=0.0): 15 | super().__init__() 16 | self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) 17 | 18 | def forward(self, x, kv, self_attn=False): 19 | orig_shape = x.shape 20 | x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 21 | if self_attn: 22 | kv = torch.cat([x, kv], dim=1) 23 | x = self.attn(x, kv, kv, need_weights=False)[0] 24 | x = x.permute(0, 2, 1).view(*orig_shape) 25 | return x 26 | 27 | 28 | class LayerNorm2d(nn.LayerNorm): 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | 32 | def forward(self, x): 33 | return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 34 | 35 | 36 | class GlobalResponseNorm(nn.Module): 37 | "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" 38 | def __init__(self, dim): 39 | super().__init__() 40 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 41 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 42 | 43 | def forward(self, x): 44 | Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) 45 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 46 | return self.gamma * (x * Nx) + self.beta + x 47 | 48 | 49 | class ResBlock(nn.Module): 50 | def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): 51 | super().__init__() 52 | self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) 53 | # self.depthwise = SAMBlock(c, num_heads, expansion) 54 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 55 | self.channelwise = nn.Sequential( 56 | Linear(c + c_skip, c * 4), 57 | nn.GELU(), 58 | GlobalResponseNorm(c * 4), 59 | nn.Dropout(dropout), 60 | Linear(c * 4, c) 61 | ) 62 | 63 | def forward(self, x, x_skip=None): 64 | x_res = x 65 | x = self.norm(self.depthwise(x)) 66 | if x_skip is not None: 67 | x = torch.cat([x, x_skip], dim=1) 68 | x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 69 | return x + x_res 70 | 71 | 72 | class AttnBlock(nn.Module): 73 | def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): 74 | super().__init__() 75 | self.self_attn = self_attn 76 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 77 | self.attention = Attention2D(c, nhead, dropout) 78 | self.kv_mapper = nn.Sequential( 79 | nn.SiLU(), 80 | Linear(c_cond, c) 81 | ) 82 | 83 | def forward(self, x, kv): 84 | kv = self.kv_mapper(kv) 85 | x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) 86 | return x 87 | 88 | 89 | class FeedForwardBlock(nn.Module): 90 | def __init__(self, c, dropout=0.0): 91 | super().__init__() 92 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 93 | self.channelwise = nn.Sequential( 94 | Linear(c, c * 4), 95 | nn.GELU(), 96 | GlobalResponseNorm(c * 4), 97 | nn.Dropout(dropout), 98 | Linear(c * 4, c) 99 | ) 100 | 101 | def forward(self, x): 102 | x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 103 | return x 104 | 105 | 106 | class TimestepBlock(nn.Module): 107 | def __init__(self, c, c_timestep, conds=['sca']): 108 | super().__init__() 109 | self.mapper = Linear(c_timestep, c * 2) 110 | self.conds = conds 111 | for cname in conds: 112 | setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) 113 | 114 | def forward(self, x, t): 115 | t = t.chunk(len(self.conds) + 1, dim=1) 116 | a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) 117 | for i, c in enumerate(self.conds): 118 | ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) 119 | a, b = a + ac, b + bc 120 | return x * (1 + a) + b 121 | -------------------------------------------------------------------------------- /src/gdf/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .scalers import * 3 | from .targets import * 4 | from .schedulers import * 5 | from .noise_conditions import * 6 | from .loss_weights import * 7 | from .samplers import * 8 | 9 | class GDF(): 10 | def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): 11 | self.schedule = schedule 12 | self.input_scaler = input_scaler 13 | self.target = target 14 | self.noise_cond = noise_cond 15 | self.loss_weight = loss_weight 16 | self.offset_noise = offset_noise 17 | 18 | def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): 19 | stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) 20 | return stretched_limits 21 | 22 | def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): 23 | if epsilon is None: 24 | epsilon = torch.randn_like(x0) 25 | if self.offset_noise > 0: 26 | if offset is None: 27 | offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device) 28 | epsilon = epsilon + offset * self.offset_noise 29 | logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) 30 | a, b = self.input_scaler(logSNR) # B 31 | if len(a.shape) == 1: 32 | a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW 33 | target = self.target(x0, epsilon, logSNR, a, b) 34 | 35 | # noised, noise, logSNR, t_cond 36 | return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) 37 | 38 | def undiffuse(self, x, logSNR, pred): 39 | a, b = self.input_scaler(logSNR) 40 | if len(a.shape) == 1: 41 | a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1)) 42 | return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) 43 | 44 | def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): 45 | sampler_params = {} if sampler_params is None else sampler_params 46 | if sampler is None: 47 | sampler = DDPMSampler(self) 48 | r_range = torch.linspace(t_start, t_end, timesteps+1) 49 | schedule = self.schedule if schedule is None else schedule 50 | logSNR_range = schedule(r_range, shift=shift)[:, None].expand( 51 | -1, shape[0] if x_init is None else x_init.size(0) 52 | ).to(device) 53 | 54 | x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() 55 | if cfg is not None: 56 | if unconditional_inputs is None: 57 | unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} 58 | model_inputs = { 59 | k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) 60 | else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) 61 | else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) 62 | else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) 63 | } 64 | for i in range(0, timesteps): 65 | noise_cond = self.noise_cond(logSNR_range[i]) 66 | if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): 67 | cfg_val = cfg 68 | if isinstance(cfg_val, (list, tuple)): 69 | assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" 70 | cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) 71 | pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) 72 | pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) 73 | if cfg_rho > 0: 74 | std_pos, std_cfg = pred.std(), pred_cfg.std() 75 | pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) 76 | else: 77 | pred = pred_cfg 78 | else: 79 | pred = model(x, noise_cond, **model_inputs) 80 | x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) 81 | x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) 82 | altered_vars = yield (x0, x, pred) 83 | 84 | # Update some running variables if the user wants 85 | if altered_vars is not None: 86 | cfg = altered_vars.get('cfg', cfg) 87 | cfg_rho = altered_vars.get('cfg_rho', cfg_rho) 88 | sampler = altered_vars.get('sampler', sampler) 89 | model_inputs = altered_vars.get('model_inputs', model_inputs) 90 | x = altered_vars.get('x', x) 91 | x_init = altered_vars.get('x_init', x_init) 92 | 93 | -------------------------------------------------------------------------------- /src/inference/generate.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import sys 5 | import time 6 | import torch 7 | import yaml 8 | from tqdm import tqdm 9 | import torchvision 10 | 11 | # Assuming model_downloader, inference.utils, core, and train modules are correctly set up 12 | from model_downloader import download_model 13 | from inference.utils import * 14 | from core import load_or_fail 15 | from train import WurstCoreC, WurstCoreB 16 | 17 | def load_config(config_name, config_path='./configs/inference'): 18 | config_file = os.path.join(config_path, f'{config_name}.yaml') 19 | with open(config_file, "r", encoding="utf-8") as file: 20 | return yaml.safe_load(file) 21 | 22 | def initialize_model(core_class, config, device, training=False,Bmode=False): 23 | core = core_class(config_dict=config, device=device, training=training) 24 | extras = core.setup_extras_pre() 25 | models = core.setup_models(extras) 26 | models.generator.eval().requires_grad_(False) 27 | 28 | if Bmode: 29 | models = WurstCoreB.Models( 30 | **{**models.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} 31 | ) 32 | models.generator.bfloat16().eval().requires_grad_(False) 33 | print("STAGE B READY") 34 | return core,models,extras 35 | print("STAGE C READY") 36 | return core,models,extras 37 | 38 | def determine_model_sizes(model_type): 39 | sizes = { 40 | "big-big": ("full", "full"), 41 | "big-small": ("lite", "full"), 42 | "small-big": ("full", "lite"), 43 | "small-small": ("lite", "lite") 44 | } 45 | return sizes.get(model_type, ("Error: Invalid model type specified.",)) 46 | 47 | def setup_sampling_configs(extras, cfg=4, shift=2, timesteps=20, t_start=1.0): 48 | extras.sampling_configs.update({ 49 | 'cfg': cfg, 50 | 'shift': shift, 51 | 'timesteps': timesteps, 52 | 't_start': t_start 53 | }) 54 | 55 | def generate(core, core_b, models, models_b, extras, extras_b, caption,batch_size, stage_c_latent_shape, stage_b_latent_shape, device,seed=42, outdir='output'): 56 | os.makedirs(outdir, exist_ok=True) 57 | seed = random.randint(0, 999999) if seed == -1 else seed 58 | # PREPARE CONDITIONS 59 | batch = {'captions': [caption] * batch_size} 60 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 61 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 62 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 63 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 64 | 65 | with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): 66 | torch.manual_seed(seed) 67 | sampling_c = extras.gdf.sample( 68 | models.generator, conditions, stage_c_latent_shape, 69 | unconditions, device=device, **extras.sampling_configs, 70 | ) 71 | for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']): 72 | sampled_c = sampled_c 73 | 74 | # preview_c = models.previewer(sampled_c).float() 75 | # show_images(preview_c) 76 | 77 | conditions_b['effnet'] = sampled_c 78 | unconditions_b['effnet'] = torch.zeros_like(sampled_c) 79 | 80 | sampling_b = extras_b.gdf.sample( 81 | models_b.generator, conditions_b, stage_b_latent_shape, 82 | unconditions_b, device=device, **extras_b.sampling_configs 83 | ) 84 | for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']): 85 | sampled_b = sampled_b 86 | sampled = models_b.stage_a.decode(sampled_b).float() 87 | # save_image 88 | times = time.strftime(r"%Y%m%d%H%M%S") 89 | imgs = [] 90 | for i, img in enumerate(sampled): 91 | img = torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)) 92 | fileName = f"img-{times}-{i+1}.png" 93 | img.save(f"output/{fileName}") 94 | imgs.append(img) 95 | return imgs 96 | 97 | def t2i(batch_size, caption, height, width, presion,model_size,essential,outdir,seed,cfg_c,cfg_b,shift_c,shift_b,step_c,step_b): 98 | device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 99 | download_model(essential, model_size, presion) 100 | c_model_size, b_model_size = determine_model_sizes(model_size) 101 | config_c,config_b = f'stage_c_{c_model_size}_{presion}',f'stage_b_{b_model_size}_{presion}' 102 | config_c, config_b = load_config(config_c), load_config(config_b) 103 | 104 | core, models, extras = initialize_model(WurstCoreC, config_c, device) 105 | core_b, models_b, extras_b = initialize_model(WurstCoreB, config_b, device,False,True) 106 | 107 | setup_sampling_configs(extras,cfg=cfg_c, shift=shift_c, timesteps=step_c) 108 | setup_sampling_configs(extras_b, cfg=cfg_b, shift=shift_b, timesteps=step_b) 109 | stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) 110 | images = generate(core, core_b, models, models_b, extras, extras_b, caption,batch_size, stage_c_latent_shape, stage_b_latent_shape, device,seed,outdir) 111 | return images -------------------------------------------------------------------------------- /src/inference/text2img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | from tqdm import tqdm 5 | import sys 6 | import time 7 | 8 | from model_downloader import download_model 9 | 10 | current_dir = os.path.dirname(os.path.abspath(__file__)) 11 | parent_dir = os.path.dirname(current_dir) 12 | 13 | if parent_dir not in sys.path: 14 | sys.path.append(parent_dir) 15 | 16 | from inference.utils import * 17 | from core import load_or_fail 18 | from train import WurstCoreC, WurstCoreB 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | print(device) 22 | 23 | 24 | def load_config(c_model_config, b_model_config): 25 | # SETUP STAGE C 26 | config_file = f'./configs/inference/{c_model_config}.yaml' 27 | with open(config_file, "r", encoding="utf-8") as file: 28 | loaded_config = yaml.safe_load(file) 29 | 30 | core = WurstCoreC(config_dict=loaded_config, device=device, training=False) 31 | 32 | # SETUP STAGE B 33 | config_file_b = f'./configs/inference/{b_model_config}.yaml' 34 | with open(config_file_b, "r", encoding="utf-8") as file: 35 | config_file_b = yaml.safe_load(file) 36 | 37 | core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) 38 | return core, core_b 39 | 40 | def load_models(core, core_b): 41 | # SETUP MODELS & DATA 42 | extras = core.setup_extras_pre() 43 | models = core.setup_models(extras) 44 | models.generator.eval().requires_grad_(False) 45 | print("STAGE C READY") 46 | 47 | extras_b = core_b.setup_extras_pre() 48 | models_b = core_b.setup_models(extras_b, skip_clip=True) 49 | models_b = WurstCoreB.Models( 50 | **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} 51 | ) 52 | models_b.generator.bfloat16().eval().requires_grad_(False) 53 | print("STAGE B READY") 54 | return models, models_b, extras, extras_b 55 | 56 | def determine_model_sizes(model_type): 57 | if model_type == "big-big": 58 | c_model_size = "full" 59 | b_model_size = "full" 60 | elif model_type == "big-small": 61 | c_model_size = "lite" 62 | b_model_size = "full" 63 | elif model_type == "small-big": 64 | c_model_size = "full" 65 | b_model_size = "lite" 66 | elif model_type == "small-small": 67 | c_model_size = "lite" 68 | b_model_size = "lite" 69 | else: 70 | # 不正なmodel_typeが指定された場合はエラーメッセージを返す 71 | return "Error: Invalid model type specified." 72 | return c_model_size,b_model_size 73 | 74 | def generate(batch_size, caption, height, width, presion,model_size,essential): 75 | download_model(essential, model_size, presion) 76 | c_model_size, b_model_size = determine_model_sizes(model_size) 77 | os.makedirs('output', exist_ok=True) 78 | core, core_b = load_config(f'stage_c_{c_model_size}_{presion}',f'stage_b_{b_model_size}_{presion}') 79 | models, models_b, extras, extras_b = load_models(core, core_b) 80 | stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) 81 | 82 | # Stage C Parameters 83 | extras.sampling_configs['cfg'] = 4 84 | extras.sampling_configs['shift'] = 2 85 | extras.sampling_configs['timesteps'] = 20 86 | extras.sampling_configs['t_start'] = 1.0 87 | 88 | # Stage B Parameters 89 | extras_b.sampling_configs['cfg'] = 1.1 90 | extras_b.sampling_configs['shift'] = 1 91 | extras_b.sampling_configs['timesteps'] = 10 92 | extras_b.sampling_configs['t_start'] = 1.0 93 | 94 | # PREPARE CONDITIONS 95 | batch = {'captions': [caption] * batch_size} 96 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 97 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 98 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 99 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 100 | 101 | with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): 102 | # torch.manual_seed(42) 103 | sampling_c = extras.gdf.sample( 104 | models.generator, conditions, stage_c_latent_shape, 105 | unconditions, device=device, **extras.sampling_configs, 106 | ) 107 | for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']): 108 | sampled_c = sampled_c 109 | 110 | # preview_c = models.previewer(sampled_c).float() 111 | # show_images(preview_c) 112 | 113 | conditions_b['effnet'] = sampled_c 114 | unconditions_b['effnet'] = torch.zeros_like(sampled_c) 115 | 116 | sampling_b = extras_b.gdf.sample( 117 | models_b.generator, conditions_b, stage_b_latent_shape, 118 | unconditions_b, device=device, **extras_b.sampling_configs 119 | ) 120 | for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']): 121 | sampled_b = sampled_b 122 | sampled = models_b.stage_a.decode(sampled_b).float() 123 | # save_image 124 | times = time.strftime(r"%Y%m%d%H%M%S") 125 | imgs = [] 126 | for i, img in enumerate(sampled): 127 | img = torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)) 128 | fileName = f"img-{times}-{i+1}.png" 129 | img.save(f"output/{fileName}") 130 | imgs.append(img) 131 | return imgs 132 | 133 | -------------------------------------------------------------------------------- /src/modules/stage_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchtools.nn import VectorQuantize 4 | 5 | 6 | class ResBlock(nn.Module): 7 | def __init__(self, c, c_hidden): 8 | super().__init__() 9 | # depthwise/attention 10 | self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) 11 | self.depthwise = nn.Sequential( 12 | nn.ReplicationPad2d(1), 13 | nn.Conv2d(c, c, kernel_size=3, groups=c) 14 | ) 15 | 16 | # channelwise 17 | self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) 18 | self.channelwise = nn.Sequential( 19 | nn.Linear(c, c_hidden), 20 | nn.GELU(), 21 | nn.Linear(c_hidden, c), 22 | ) 23 | 24 | self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) 25 | 26 | # Init weights 27 | def _basic_init(module): 28 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 29 | torch.nn.init.xavier_uniform_(module.weight) 30 | if module.bias is not None: 31 | nn.init.constant_(module.bias, 0) 32 | 33 | self.apply(_basic_init) 34 | 35 | def _norm(self, x, norm): 36 | return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 37 | 38 | def forward(self, x): 39 | mods = self.gammas 40 | 41 | x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] 42 | x = x + self.depthwise(x_temp) * mods[2] 43 | 44 | x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] 45 | x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] 46 | 47 | return x 48 | 49 | 50 | class StageA(nn.Module): 51 | def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, 52 | scale_factor=0.43): # 0.3764 53 | super().__init__() 54 | self.c_latent = c_latent 55 | self.scale_factor = scale_factor 56 | c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] 57 | 58 | # Encoder blocks 59 | self.in_block = nn.Sequential( 60 | nn.PixelUnshuffle(2), 61 | nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) 62 | ) 63 | down_blocks = [] 64 | for i in range(levels): 65 | if i > 0: 66 | down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) 67 | block = ResBlock(c_levels[i], c_levels[i] * 4) 68 | down_blocks.append(block) 69 | down_blocks.append(nn.Sequential( 70 | nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), 71 | nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 72 | )) 73 | self.down_blocks = nn.Sequential(*down_blocks) 74 | self.down_blocks[0] 75 | 76 | self.codebook_size = codebook_size 77 | self.vquantizer = VectorQuantize(c_latent, k=codebook_size) 78 | 79 | # Decoder blocks 80 | up_blocks = [nn.Sequential( 81 | nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) 82 | )] 83 | for i in range(levels): 84 | for j in range(bottleneck_blocks if i == 0 else 1): 85 | block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) 86 | up_blocks.append(block) 87 | if i < levels - 1: 88 | up_blocks.append( 89 | nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, 90 | padding=1)) 91 | self.up_blocks = nn.Sequential(*up_blocks) 92 | self.out_block = nn.Sequential( 93 | nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), 94 | nn.PixelShuffle(2), 95 | ) 96 | 97 | def encode(self, x, quantize=False): 98 | x = self.in_block(x) 99 | x = self.down_blocks(x) 100 | if quantize: 101 | qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) 102 | return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 103 | else: 104 | return x / self.scale_factor, None, None, None 105 | 106 | def decode(self, x): 107 | x = x * self.scale_factor 108 | x = self.up_blocks(x) 109 | x = self.out_block(x) 110 | return x 111 | 112 | def forward(self, x, quantize=False): 113 | qe, x, _, vq_loss = self.encode(x, quantize) 114 | x = self.decode(qe) 115 | return x, vq_loss 116 | 117 | 118 | class Discriminator(nn.Module): 119 | def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): 120 | super().__init__() 121 | d = max(depth - 3, 3) 122 | layers = [ 123 | nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), 124 | nn.LeakyReLU(0.2), 125 | ] 126 | for i in range(depth - 1): 127 | c_in = c_hidden // (2 ** max((d - i), 0)) 128 | c_out = c_hidden // (2 ** max((d - 1 - i), 0)) 129 | layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) 130 | layers.append(nn.InstanceNorm2d(c_out)) 131 | layers.append(nn.LeakyReLU(0.2)) 132 | self.encoder = nn.Sequential(*layers) 133 | self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) 134 | self.logits = nn.Sigmoid() 135 | 136 | def forward(self, x, cond=None): 137 | x = self.encoder(x) 138 | if cond is not None: 139 | cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) 140 | x = torch.cat([x, cond], dim=1) 141 | x = self.shuffle(x) 142 | x = self.logits(x) 143 | return x 144 | -------------------------------------------------------------------------------- /src/gdf/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class BaseSchedule(): 5 | def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): 6 | self.setup(*args, **kwargs) 7 | self.limits = None 8 | self.discrete_steps = discrete_steps 9 | self.shift = shift 10 | if force_limits: 11 | self.reset_limits() 12 | 13 | def reset_limits(self, shift=1, disable=False): 14 | try: 15 | self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max 16 | return self.limits 17 | except Exception: 18 | print("WARNING: this schedule doesn't support t and will be unbounded") 19 | return None 20 | 21 | def setup(self, *args, **kwargs): 22 | raise NotImplementedError("this method needs to be overriden") 23 | 24 | def schedule(self, *args, **kwargs): 25 | raise NotImplementedError("this method needs to be overriden") 26 | 27 | def __call__(self, t, *args, shift=1, **kwargs): 28 | if isinstance(t, torch.Tensor): 29 | batch_size = None 30 | if self.discrete_steps is not None: 31 | if t.dtype != torch.long: 32 | t = (t * (self.discrete_steps-1)).round().long() 33 | t = t / (self.discrete_steps-1) 34 | t = t.clamp(0, 1) 35 | else: 36 | batch_size = t 37 | t = None 38 | logSNR = self.schedule(t, batch_size, *args, **kwargs) 39 | if shift*self.shift != 1: 40 | logSNR += 2 * np.log(1/(shift*self.shift)) 41 | if self.limits is not None: 42 | logSNR = logSNR.clamp(*self.limits) 43 | return logSNR 44 | 45 | class CosineSchedule(BaseSchedule): 46 | def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): 47 | self.s = torch.tensor([s]) 48 | self.clamp_range = clamp_range 49 | self.norm_instead = norm_instead 50 | self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 51 | 52 | def schedule(self, t, batch_size): 53 | if t is None: 54 | t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) 55 | s, min_var = self.s.to(t.device), self.min_var.to(t.device) 56 | var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var 57 | if self.norm_instead: 58 | var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] 59 | else: 60 | var = var.clamp(*self.clamp_range) 61 | logSNR = (var/(1-var)).log() 62 | return logSNR 63 | 64 | class CosineSchedule2(BaseSchedule): 65 | def setup(self, logsnr_range=[-15, 15]): 66 | self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1])) 67 | self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0])) 68 | 69 | def schedule(self, t, batch_size): 70 | if t is None: 71 | t = 1-torch.rand(batch_size) 72 | return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log() 73 | 74 | class SqrtSchedule(BaseSchedule): 75 | def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False): 76 | self.s = s 77 | self.clamp_range = clamp_range 78 | self.norm_instead = norm_instead 79 | 80 | def schedule(self, t, batch_size): 81 | if t is None: 82 | t = 1-torch.rand(batch_size) 83 | var = 1 - (t + self.s)**0.5 84 | if self.norm_instead: 85 | var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] 86 | else: 87 | var = var.clamp(*self.clamp_range) 88 | logSNR = (var/(1-var)).log() 89 | return logSNR 90 | 91 | class RectifiedFlowsSchedule(BaseSchedule): 92 | def setup(self, logsnr_range=[-15, 15]): 93 | self.logsnr_range = logsnr_range 94 | 95 | def schedule(self, t, batch_size): 96 | if t is None: 97 | t = 1-torch.rand(batch_size) 98 | logSNR = (((1-t)**2)/(t**2)).log() 99 | logSNR = logSNR.clamp(*self.logsnr_range) 100 | return logSNR 101 | 102 | class EDMSampleSchedule(BaseSchedule): 103 | def setup(self, sigma_range=[0.002, 80], p=7): 104 | self.sigma_range = sigma_range 105 | self.p = p 106 | 107 | def schedule(self, t, batch_size): 108 | if t is None: 109 | t = 1-torch.rand(batch_size) 110 | smin, smax, p = *self.sigma_range, self.p 111 | sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p 112 | logSNR = (1/sigma**2).log() 113 | return logSNR 114 | 115 | class EDMTrainSchedule(BaseSchedule): 116 | def setup(self, mu=-1.2, std=1.2): 117 | self.mu = mu 118 | self.std = std 119 | 120 | def schedule(self, t, batch_size): 121 | if t is not None: 122 | raise Exception("EDMTrainSchedule doesn't support passing timesteps: t") 123 | logSNR = -2*(torch.randn(batch_size) * self.std - self.mu) 124 | return logSNR 125 | 126 | class LinearSchedule(BaseSchedule): 127 | def setup(self, logsnr_range=[-10, 10]): 128 | self.logsnr_range = logsnr_range 129 | 130 | def schedule(self, t, batch_size): 131 | if t is None: 132 | t = 1-torch.rand(batch_size) 133 | logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1] 134 | return logSNR 135 | 136 | # Any schedule that cannot be described easily as a continuous function of t 137 | # It needs to define self.x and self.y in the setup() method 138 | class PiecewiseLinearSchedule(BaseSchedule): 139 | def setup(self): 140 | self.x = None 141 | self.y = None 142 | 143 | def piecewise_linear(self, x, xs, ys): 144 | indices = torch.searchsorted(xs[:-1], x) - 1 145 | x_min, x_max = xs[indices], xs[indices+1] 146 | y_min, y_max = ys[indices], ys[indices+1] 147 | var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min) 148 | return var 149 | 150 | def schedule(self, t, batch_size): 151 | if t is None: 152 | t = 1-torch.rand(batch_size) 153 | var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device)) 154 | logSNR = (var/(1-var)).log() 155 | return logSNR 156 | 157 | class StableDiffusionSchedule(PiecewiseLinearSchedule): 158 | def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): 159 | linear_range_sqrt = [r**0.5 for r in linear_range] 160 | self.x = torch.linspace(0, 1, total_steps+1) 161 | 162 | alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 163 | self.y = alphas.cumprod(dim=-1) 164 | 165 | class AdaptiveTrainSchedule(BaseSchedule): 166 | def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0): 167 | th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1) 168 | self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)]) 169 | self.bucket_probs = torch.ones(buckets) 170 | self.min_probs = min_probs 171 | 172 | def schedule(self, t, batch_size): 173 | if t is not None: 174 | raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t") 175 | norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum()) 176 | buckets = torch.multinomial(norm_probs, batch_size, replacement=True) 177 | ranges = self.bucket_ranges[buckets] 178 | logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0] 179 | return logSNR 180 | 181 | def update_buckets(self, logSNR, loss, beta=0.99): 182 | range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device) 183 | range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float() 184 | range_idx = range_mask.argmax(-1).cpu() 185 | self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta) 186 | 187 | class InterpolatedSchedule(BaseSchedule): 188 | def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]): 189 | self.scheduler1 = scheduler1 190 | self.scheduler2 = scheduler2 191 | self.shifts = shifts 192 | 193 | def schedule(self, t, batch_size): 194 | if t is None: 195 | t = 1-torch.rand(batch_size) 196 | t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan 197 | low_logSNR = self.scheduler1(t, shift=self.shifts[0]) 198 | high_logSNR = self.scheduler2(t, shift=self.shifts[1]) 199 | return low_logSNR * t + high_logSNR * (1-t) 200 | 201 | -------------------------------------------------------------------------------- /src/core/templates/diffusion.py: -------------------------------------------------------------------------------- 1 | from .. import WarpCore 2 | from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary 3 | from abc import abstractmethod 4 | from dataclasses import dataclass 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | from gdf import GDF 9 | import numpy as np 10 | from tqdm import tqdm 11 | import wandb 12 | 13 | import webdataset as wds 14 | from webdataset.handlers import warn_and_continue 15 | from torch.distributed import barrier 16 | from enum import Enum 17 | 18 | class TargetReparametrization(Enum): 19 | EPSILON = 'epsilon' 20 | X0 = 'x0' 21 | 22 | class DiffusionCore(WarpCore): 23 | @dataclass(frozen=True) 24 | class Config(WarpCore.Config): 25 | # TRAINING PARAMS 26 | lr: float = EXPECTED_TRAIN 27 | grad_accum_steps: int = EXPECTED_TRAIN 28 | batch_size: int = EXPECTED_TRAIN 29 | updates: int = EXPECTED_TRAIN 30 | warmup_updates: int = EXPECTED_TRAIN 31 | save_every: int = 500 32 | backup_every: int = 20000 33 | use_fsdp: bool = True 34 | 35 | # EMA UPDATE 36 | ema_start_iters: int = None 37 | ema_iters: int = None 38 | ema_beta: float = None 39 | 40 | # GDF setting 41 | gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 42 | 43 | @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED 44 | class Info(WarpCore.Info): 45 | ema_loss: float = None 46 | 47 | @dataclass(frozen=True) 48 | class Models(WarpCore.Models): 49 | generator : nn.Module = EXPECTED 50 | generator_ema : nn.Module = None # optional 51 | 52 | @dataclass(frozen=True) 53 | class Optimizers(WarpCore.Optimizers): 54 | generator : any = EXPECTED 55 | 56 | @dataclass(frozen=True) 57 | class Schedulers(WarpCore.Schedulers): 58 | generator: any = None 59 | 60 | @dataclass(frozen=True) 61 | class Extras(WarpCore.Extras): 62 | gdf: GDF = EXPECTED 63 | sampling_configs: dict = EXPECTED 64 | 65 | # -------------------------------------------- 66 | info: Info 67 | config: Config 68 | 69 | @abstractmethod 70 | def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 71 | raise NotImplementedError("This method needs to be overriden") 72 | 73 | @abstractmethod 74 | def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 75 | raise NotImplementedError("This method needs to be overriden") 76 | 77 | @abstractmethod 78 | def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): 79 | raise NotImplementedError("This method needs to be overriden") 80 | 81 | @abstractmethod 82 | def webdataset_path(self, extras: Extras): 83 | raise NotImplementedError("This method needs to be overriden") 84 | 85 | @abstractmethod 86 | def webdataset_filters(self, extras: Extras): 87 | raise NotImplementedError("This method needs to be overriden") 88 | 89 | @abstractmethod 90 | def webdataset_preprocessors(self, extras: Extras): 91 | raise NotImplementedError("This method needs to be overriden") 92 | 93 | @abstractmethod 94 | def sample(self, models: Models, data: WarpCore.Data, extras: Extras): 95 | raise NotImplementedError("This method needs to be overriden") 96 | # ------------- 97 | 98 | def setup_data(self, extras: Extras) -> WarpCore.Data: 99 | # SETUP DATASET 100 | dataset_path = self.webdataset_path(extras) 101 | preprocessors = self.webdataset_preprocessors(extras) 102 | filters = self.webdataset_filters(extras) 103 | 104 | handler = warn_and_continue # None 105 | # handler = None 106 | dataset = wds.WebDataset( 107 | dataset_path, resampled=True, handler=handler 108 | ).select(filters).shuffle(690, handler=handler).decode( 109 | "pilrgb", handler=handler 110 | ).to_tuple( 111 | *[p[0] for p in preprocessors], handler=handler 112 | ).map_tuple( 113 | *[p[1] for p in preprocessors], handler=handler 114 | ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) 115 | 116 | # SETUP DATALOADER 117 | real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) 118 | dataloader = DataLoader( 119 | dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True 120 | ) 121 | 122 | return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) 123 | 124 | def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): 125 | batch = next(data.iterator) 126 | 127 | with torch.no_grad(): 128 | conditions = self.get_conditions(batch, models, extras) 129 | latents = self.encode_latents(batch, models, extras) 130 | noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) 131 | 132 | # FORWARD PASS 133 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 134 | pred = models.generator(noised, noise_cond, **conditions) 135 | if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: 136 | pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss 137 | target = noise 138 | elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: 139 | pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss 140 | target = latents 141 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 142 | loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps 143 | 144 | return loss, loss_adjusted 145 | 146 | def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): 147 | start_iter = self.info.iter+1 148 | max_iters = self.config.updates * self.config.grad_accum_steps 149 | if self.is_main_node: 150 | print(f"STARTING AT STEP: {start_iter}/{max_iters}") 151 | 152 | pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP 153 | models.generator.train() 154 | for i in pbar: 155 | # FORWARD PASS 156 | loss, loss_adjusted = self.forward_pass(data, extras, models) 157 | 158 | # BACKWARD PASS 159 | if i % self.config.grad_accum_steps == 0 or i == max_iters: 160 | loss_adjusted.backward() 161 | grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) 162 | optimizers_dict = optimizers.to_dict() 163 | for k in optimizers_dict: 164 | optimizers_dict[k].step() 165 | schedulers_dict = schedulers.to_dict() 166 | for k in schedulers_dict: 167 | schedulers_dict[k].step() 168 | models.generator.zero_grad(set_to_none=True) 169 | self.info.total_steps += 1 170 | else: 171 | with models.generator.no_sync(): 172 | loss_adjusted.backward() 173 | self.info.iter = i 174 | 175 | # UPDATE EMA 176 | if models.generator_ema is not None and i % self.config.ema_iters == 0: 177 | update_weights_ema( 178 | models.generator_ema, models.generator, 179 | beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) 180 | ) 181 | 182 | # UPDATE LOSS METRICS 183 | self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 184 | 185 | if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): 186 | wandb.alert( 187 | title=f"NaN value encountered in training run {self.info.wandb_run_id}", 188 | text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", 189 | wait_duration=60*30 190 | ) 191 | 192 | if self.is_main_node: 193 | logs = { 194 | 'loss': self.info.ema_loss, 195 | 'raw_loss': loss.mean().item(), 196 | 'grad_norm': grad_norm.item(), 197 | 'lr': optimizers.generator.param_groups[0]['lr'], 198 | 'total_steps': self.info.total_steps, 199 | } 200 | 201 | pbar.set_postfix(logs) 202 | if self.config.wandb_project is not None: 203 | wandb.log(logs) 204 | 205 | if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: 206 | # SAVE AND CHECKPOINT STUFF 207 | if np.isnan(loss.mean().item()): 208 | if self.is_main_node and self.config.wandb_project is not None: 209 | tqdm.write("Skipping sampling & checkpoint because the loss is NaN") 210 | wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") 211 | else: 212 | self.save_checkpoints(models, optimizers) 213 | if self.is_main_node: 214 | create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') 215 | self.sample(models, data, extras) 216 | 217 | def models_to_save(self): 218 | return ['generator', 'generator_ema'] 219 | 220 | def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): 221 | barrier() 222 | suffix = '' if suffix is None else suffix 223 | self.save_info(self.info, suffix=suffix) 224 | models_dict = models.to_dict() 225 | optimizers_dict = optimizers.to_dict() 226 | for key in self.models_to_save(): 227 | model = models_dict[key] 228 | if model is not None: 229 | self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) 230 | for key in optimizers_dict: 231 | optimizer = optimizers_dict[key] 232 | if optimizer is not None: 233 | self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) 234 | if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: 235 | self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") 236 | torch.cuda.empty_cache() 237 | -------------------------------------------------------------------------------- /src/modules/cnet_modules/face_id/arcface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx, onnx2torch, cv2 3 | import torch 4 | from insightface.utils import face_align 5 | 6 | 7 | class ArcFaceRecognizer: 8 | def __init__(self, model_file=None, device='cpu', dtype=torch.float32): 9 | assert model_file is not None 10 | self.model_file = model_file 11 | 12 | self.device = device 13 | self.dtype = dtype 14 | self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) 15 | for param in self.model.parameters(): 16 | param.requires_grad = False 17 | self.model.eval() 18 | 19 | self.input_mean = 127.5 20 | self.input_std = 127.5 21 | self.input_size = (112, 112) 22 | self.input_shape = ['None', 3, 112, 112] 23 | 24 | def get(self, img, face): 25 | aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) 26 | face.embedding = self.get_feat(aimg).flatten() 27 | return face.embedding 28 | 29 | def compute_sim(self, feat1, feat2): 30 | from numpy.linalg import norm 31 | feat1 = feat1.ravel() 32 | feat2 = feat2.ravel() 33 | sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) 34 | return sim 35 | 36 | def get_feat(self, imgs): 37 | if not isinstance(imgs, list): 38 | imgs = [imgs] 39 | input_size = self.input_size 40 | 41 | blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, 42 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 43 | 44 | blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) 45 | net_out = self.model(blob_torch) 46 | return net_out[0].float().cpu() 47 | 48 | 49 | def distance2bbox(points, distance, max_shape=None): 50 | """Decode distance prediction to bounding box. 51 | 52 | Args: 53 | points (Tensor): Shape (n, 2), [x, y]. 54 | distance (Tensor): Distance from the given point to 4 55 | boundaries (left, top, right, bottom). 56 | max_shape (tuple): Shape of the image. 57 | 58 | Returns: 59 | Tensor: Decoded bboxes. 60 | """ 61 | x1 = points[:, 0] - distance[:, 0] 62 | y1 = points[:, 1] - distance[:, 1] 63 | x2 = points[:, 0] + distance[:, 2] 64 | y2 = points[:, 1] + distance[:, 3] 65 | if max_shape is not None: 66 | x1 = x1.clamp(min=0, max=max_shape[1]) 67 | y1 = y1.clamp(min=0, max=max_shape[0]) 68 | x2 = x2.clamp(min=0, max=max_shape[1]) 69 | y2 = y2.clamp(min=0, max=max_shape[0]) 70 | return np.stack([x1, y1, x2, y2], axis=-1) 71 | 72 | 73 | def distance2kps(points, distance, max_shape=None): 74 | """Decode distance prediction to bounding box. 75 | 76 | Args: 77 | points (Tensor): Shape (n, 2), [x, y]. 78 | distance (Tensor): Distance from the given point to 4 79 | boundaries (left, top, right, bottom). 80 | max_shape (tuple): Shape of the image. 81 | 82 | Returns: 83 | Tensor: Decoded bboxes. 84 | """ 85 | preds = [] 86 | for i in range(0, distance.shape[1], 2): 87 | px = points[:, i % 2] + distance[:, i] 88 | py = points[:, i % 2 + 1] + distance[:, i + 1] 89 | if max_shape is not None: 90 | px = px.clamp(min=0, max=max_shape[1]) 91 | py = py.clamp(min=0, max=max_shape[0]) 92 | preds.append(px) 93 | preds.append(py) 94 | return np.stack(preds, axis=-1) 95 | 96 | 97 | class FaceDetector: 98 | def __init__(self, model_file=None, dtype=torch.float32, device='cuda'): 99 | self.model_file = model_file 100 | self.taskname = 'detection' 101 | self.center_cache = {} 102 | self.nms_thresh = 0.4 103 | self.det_thresh = 0.5 104 | 105 | self.device = device 106 | self.dtype = dtype 107 | self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) 108 | for param in self.model.parameters(): 109 | param.requires_grad = False 110 | self.model.eval() 111 | 112 | input_shape = (320, 320) 113 | self.input_size = input_shape 114 | self.input_shape = input_shape 115 | 116 | self.input_mean = 127.5 117 | self.input_std = 128.0 118 | self._anchor_ratio = 1.0 119 | self._num_anchors = 1 120 | self.fmc = 3 121 | self._feat_stride_fpn = [8, 16, 32] 122 | self._num_anchors = 2 123 | self.use_kps = True 124 | 125 | self.det_thresh = 0.5 126 | self.nms_thresh = 0.4 127 | 128 | def forward(self, img, threshold): 129 | scores_list = [] 130 | bboxes_list = [] 131 | kpss_list = [] 132 | input_size = tuple(img.shape[0:2][::-1]) 133 | blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size, 134 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 135 | blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) 136 | net_outs_torch = self.model(blob_torch) 137 | # print(list(map(lambda x: x.shape, net_outs_torch))) 138 | net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch)) 139 | 140 | input_height = blob.shape[2] 141 | input_width = blob.shape[3] 142 | fmc = self.fmc 143 | for idx, stride in enumerate(self._feat_stride_fpn): 144 | scores = net_outs[idx] 145 | bbox_preds = net_outs[idx + fmc] 146 | bbox_preds = bbox_preds * stride 147 | if self.use_kps: 148 | kps_preds = net_outs[idx + fmc * 2] * stride 149 | height = input_height // stride 150 | width = input_width // stride 151 | K = height * width 152 | key = (height, width, stride) 153 | if key in self.center_cache: 154 | anchor_centers = self.center_cache[key] 155 | else: 156 | # solution-1, c style: 157 | # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) 158 | # for i in range(height): 159 | # anchor_centers[i, :, 1] = i 160 | # for i in range(width): 161 | # anchor_centers[:, i, 0] = i 162 | 163 | # solution-2: 164 | # ax = np.arange(width, dtype=np.float32) 165 | # ay = np.arange(height, dtype=np.float32) 166 | # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) 167 | # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) 168 | 169 | # solution-3: 170 | anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) 171 | # print(anchor_centers.shape) 172 | 173 | anchor_centers = (anchor_centers * stride).reshape((-1, 2)) 174 | if self._num_anchors > 1: 175 | anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) 176 | if len(self.center_cache) < 100: 177 | self.center_cache[key] = anchor_centers 178 | 179 | pos_inds = np.where(scores >= threshold)[0] 180 | bboxes = distance2bbox(anchor_centers, bbox_preds) 181 | pos_scores = scores[pos_inds] 182 | pos_bboxes = bboxes[pos_inds] 183 | scores_list.append(pos_scores) 184 | bboxes_list.append(pos_bboxes) 185 | if self.use_kps: 186 | kpss = distance2kps(anchor_centers, kps_preds) 187 | # kpss = kps_preds 188 | kpss = kpss.reshape((kpss.shape[0], -1, 2)) 189 | pos_kpss = kpss[pos_inds] 190 | kpss_list.append(pos_kpss) 191 | return scores_list, bboxes_list, kpss_list 192 | 193 | def detect(self, img, input_size=None, max_num=0, metric='default'): 194 | assert input_size is not None or self.input_size is not None 195 | input_size = self.input_size if input_size is None else input_size 196 | 197 | im_ratio = float(img.shape[0]) / img.shape[1] 198 | model_ratio = float(input_size[1]) / input_size[0] 199 | if im_ratio > model_ratio: 200 | new_height = input_size[1] 201 | new_width = int(new_height / im_ratio) 202 | else: 203 | new_width = input_size[0] 204 | new_height = int(new_width * im_ratio) 205 | det_scale = float(new_height) / img.shape[0] 206 | resized_img = cv2.resize(img, (new_width, new_height)) 207 | det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) 208 | det_img[:new_height, :new_width, :] = resized_img 209 | 210 | scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) 211 | 212 | scores = np.vstack(scores_list) 213 | scores_ravel = scores.ravel() 214 | order = scores_ravel.argsort()[::-1] 215 | bboxes = np.vstack(bboxes_list) / det_scale 216 | if self.use_kps: 217 | kpss = np.vstack(kpss_list) / det_scale 218 | pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) 219 | pre_det = pre_det[order, :] 220 | keep = self.nms(pre_det) 221 | det = pre_det[keep, :] 222 | if self.use_kps: 223 | kpss = kpss[order, :, :] 224 | kpss = kpss[keep, :, :] 225 | else: 226 | kpss = None 227 | if max_num > 0 and det.shape[0] > max_num: 228 | area = (det[:, 2] - det[:, 0]) * (det[:, 3] - 229 | det[:, 1]) 230 | img_center = img.shape[0] // 2, img.shape[1] // 2 231 | offsets = np.vstack([ 232 | (det[:, 0] + det[:, 2]) / 2 - img_center[1], 233 | (det[:, 1] + det[:, 3]) / 2 - img_center[0] 234 | ]) 235 | offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) 236 | if metric == 'max': 237 | values = area 238 | else: 239 | values = area - offset_dist_squared * 2.0 # some extra weight on the centering 240 | bindex = np.argsort( 241 | values)[::-1] # some extra weight on the centering 242 | bindex = bindex[0:max_num] 243 | det = det[bindex, :] 244 | if kpss is not None: 245 | kpss = kpss[bindex, :] 246 | return det, kpss 247 | 248 | def nms(self, dets): 249 | thresh = self.nms_thresh 250 | x1 = dets[:, 0] 251 | y1 = dets[:, 1] 252 | x2 = dets[:, 2] 253 | y2 = dets[:, 3] 254 | scores = dets[:, 4] 255 | 256 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 257 | order = scores.argsort()[::-1] 258 | 259 | keep = [] 260 | while order.size > 0: 261 | i = order[0] 262 | keep.append(i) 263 | xx1 = np.maximum(x1[i], x1[order[1:]]) 264 | yy1 = np.maximum(y1[i], y1[order[1:]]) 265 | xx2 = np.minimum(x2[i], x2[order[1:]]) 266 | yy2 = np.minimum(y2[i], y2[order[1:]]) 267 | 268 | w = np.maximum(0.0, xx2 - xx1 + 1) 269 | h = np.maximum(0.0, yy2 - yy1 + 1) 270 | inter = w * h 271 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 272 | 273 | inds = np.where(ovr <= thresh)[0] 274 | order = order[inds + 1] 275 | 276 | return keep 277 | -------------------------------------------------------------------------------- /src/train/readme.md: -------------------------------------------------------------------------------- 1 | # Training 2 |

3 | 4 |

5 | 6 | This directory provides a training code for Stable Cascade, as well as guides to download the models you need. 7 | Specifically, you can find training scripts for the following use-cases: 8 | - Text-to-Image 9 | - ControlNet 10 | - LoRA 11 | - Image Reconstruction 12 | 13 | #### Note: 14 | A quick clarification, Stable Cascade uses Stage A & B to compress images and Stage C is used for the text-conditional 15 | learning. Therefore, it makes sense to train a LoRA or ControlNet **only** for Stage C. You also don't train a LoRA or 16 | ControlNet for the Stable Diffusion VAE right? 17 | 18 | ## Basics 19 | In the [training configs](../configs/training) folder we provide config files for all trainings. All config files 20 | follow a similar structure and only contain the most essential parameters you need to set. Let's take a look at the 21 | structure each config follows: 22 | 23 | At first, you will set the run name, checkpoint-, & output-folder and which version you want to train. 24 | ```yaml 25 | experiment_id: stage_c_3b_controlnet_base 26 | checkpoint_path: /path/to/checkpoint 27 | output_path: /path/to/output 28 | model_version: 3.6B 29 | ``` 30 | 31 | Next, you can set your [Weights & Biases]() information if you want to use it for logging. 32 | ```yaml 33 | wandb_project: StableCascade 34 | wandb_entity: wandb_username 35 | ``` 36 | 37 | Afterwards, you define the training parameters. 38 | ```yaml 39 | lr: 1.0e-4 40 | batch_size: 256 41 | image_size: 768 42 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 43 | grad_accum_steps: 1 44 | updates: 500000 45 | backup_every: 50000 46 | save_every: 2000 47 | warmup_updates: 1 48 | use_fsdp: False 49 | ``` 50 | 51 | Most, of them will be quite familiar to you probably already. A few clarification tho: `updates` refers to the number of 52 | training steps, `backup_every` creates additional checkpoints, so you can revert to earlier ones if you want, 53 | `save_every` concerns how often models will be saved and sampling will be done. Furthermore, since distributed training 54 | is essential when training large models from scratch or doing large finetunes, we have an option to use PyTorch's 55 | [**Fully Shared Data Parallel (FSDP)**](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/). You 56 | can use it by setting `use_fsdp: True`. Note, that you will need multiple GPUs for FSDP. However, this as mentioned 57 | above, this is only needed for large runs. You can still train and finetune our largest models on a powerful single 58 | machine.

59 | Another thing we provide is training with **Multi-Aspect-Ratio**. You can set the aspect ratios you want in the list 60 | for `multi_aspect_ratio`.

61 | 62 | For diffusion models, having an EMA (Exponential Moving Average) model, can drastically improve the performance of 63 | your model. To include an EMA model in your training you can set the following parameters, otherwise you can just 64 | leave them away. 65 | ```yaml 66 | ema_start_iters: 5000 67 | ema_iters: 100 68 | ema_beta: 0.9 69 | ``` 70 | 71 | Next, you can define the dataset that you want to use. Note, that the code uses 72 | [webdataset](https://github.com/webdataset/webdataset) for this. 73 | ```yaml 74 | webdataset_path: 75 | - s3://path/to/your/first/dataset/on/s3 76 | - file:/path/to/your/local/dataset.tar 77 | ``` 78 | You can set as many dataset paths as you want, and they can either be on 79 | [Amazon S3 storage](https://aws.amazon.com/s3/) or just local. 80 |

81 | There are a few more specifics to each kind of training and to datasets in general. These will be discussed below. 82 | 83 | ## Starting a Training 84 | You can start an actual training very easily by first moving to the root directory of this repository (so [here](..)). 85 | Next, the python command looks like the following: 86 | ```python 87 | python3 training_file training_config 88 | ``` 89 | For example, if you want to train a LoRA model, the command would look like this: 90 | ```python 91 | python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml 92 | ``` 93 | 94 | Moreover, we also provide a [bash script](example_train.sh) for working with slurm. Note, this assumes you have access to a cluster 95 | that runs slurm as the cluster manager. 96 | 97 | ## Dataset 98 | As mentioned above, the code uses [webdataset](https://github.com/webdataset/webdataset) for working with datasets, 99 | because this library supports working with large amounts of data very easily. In case you want to **finetune** a model, 100 | train a **LoRA** or train a **ControlNet**, you might not have them in a webdataset format. Therefore, here follows 101 | a simple example how you can convert your dataset into the appropriate format. 102 | 1. Put all your images and captions into a folder 103 | 2. Rename them to have the same number / id as the name. For example: 104 | `0000.jpg, 0000.txt, 0001.jpg, 0001.txt, 0002.jpg, 0002.txt, 0003.jpg, 0003.txt` 105 | 3. Run the following command: ``tar --sort=name -cf dataset.tar dataset/`` or manually create a tar file from the folder 106 | 4. Set the `webdataset_path: file:/path/to/your/local/dataset.tar` in the config file 107 | 108 | Next, there are a few more settings that might be helpful to you, especially when working with large datasets that 109 | might contain more information about images, like some kind of variables that you want to filter for. You can apply 110 | dataset filters like the following in the config file: 111 | ```yaml 112 | dataset_filters: 113 | - ['aesthetic_score', 'lambda s: s > 4.5'] 114 | - ['nsfw_probability', 'lambda s: s < 0.01'] 115 | ``` 116 | In this case, you would have `0000.json, 0001.json, 0002.json, 0003.json` in your dataset as well, with keys for 117 | `aesthetic_score` and `nsfw_probability`. 118 | 119 | ## Starting from a Pretrained Model 120 | If you want to finetune any model you need the pretrained models. You can find details on how to download them in the 121 | [models](../models) section. After downloading them, you need to modify the checkpoint paths in the config file too. 122 | See below for example config files. 123 | 124 | ## Text-to-Image Training 125 | You can use the following configs for finetuning Stage C on your own datasets. All necessary parameters were already 126 | explained above. So there is nothing new here. Take a look at the config for finetuning the 127 | [3.6B Stage C](../configs/training/finetune_c_3b.yaml) and the [1B Stage C](../configs/training/finetune_c_1b.yaml). 128 | 129 | ## ControlNet Training 130 | Training a ControlNet requires setting some extra parameters as well as adding the specific ControlNet Filter you want. 131 | With filter, we simply mean a class that for example performs Canny Edge Detection, Human Pose Detection, etc. 132 | ```yaml 133 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 134 | controlnet_filter: CannyFilter 135 | controlnet_filter_params: 136 | resize: 224 137 | ``` 138 | Here we need to give a little more detail on how Stage C's architecture looks like. It basically is just a stack of 139 | residual blocks (convolutional and attention) that all work at the same latent resolution. We **do not** use a UNet. 140 | And this is where `controlnet_blocks` comes in. It determines at which blocks you want to inject the controlling 141 | information. This way, the ControlNet architecture differs from the common one used in Stable Diffusion where you 142 | create an entire copy of the encoder of the UNet. With Stable Cascade it is a bit simpler and comes with the great 143 | benefit of using much fewer parameters.
144 | Next you define the class that filters the images and extracts the information you want to condition Stage C on 145 | (Canny Edge Detection, Human Pose Detection, etc.) with the `controlnet_filter` parameter. In the example, we use the 146 | CannyFilter defined in the [controlnet.py](../modules/controlnet.py) file. This is the place where you can add your own 147 | ControlNet Filters. Lastly, `controlnet_filter_params` simply sets additional parameters to your `controlnet_filter` 148 | class. That's it. You can view the example ControlNet configs for 149 | [Inpainting / Outpainting](../configs/training/controlnet_c_3b_inpainting.yaml), 150 | [Face Identity](../configs/training/controlnet_c_3b_identity.yaml), 151 | [Canny](../configs/training/controlnet_c_3b_canny.yaml) and 152 | [Super Resolution](../configs/training/controlnet_c_3b_sr.yaml). 153 | 154 | ## LoRA Training 155 | To train a LoRA on Stage C, you have a few more parameters available to set for the training. 156 | ```yaml 157 | module_filters: ['.attn'] 158 | rank: 4 159 | train_tokens: 160 | # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized 161 | - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails 162 | ``` 163 | These include the `module_filters`, which simply determines on what modules you want to train LoRA-layers. In the 164 | example above, it is using the attention layers (`.attn`). Currently, only linear layers can be lora'd. 165 | However, adding different layers (like convolutions) is possible as well.
166 | You can also set the `rank` and if you want to learn a specific token for your training. The latter can be done by 167 | setting `train_tokens` which expects a list of two things for each element: the token you want to train and a regex for 168 | the token / tokens that you want to use for initializing the token. In the example above, a token `[fernando]` is 169 | created and is initialized with the average of all tokens that include the word `dog`. Note, in order to **add** a new 170 | token, **it has to start with `[` and end with `]`**. There is also the option of using existing tokens which will be 171 | trained. For this, you just enter the token, **without** placing `[ ]` around it, like in the commented example above 172 | for the token `sanil`. The second element is `null`, because we don't initialize this token and just finetune the 173 | `snail` token.
174 | You can find an example config for training a LoRA [here](../configs/training/finetune_c_3b_lora.yaml). 175 | Additionally, you can also download an 176 | [example dataset](https://huggingface.co/dome272/stable-cascade/blob/main/fernando.tar) for a cute little good boy dog. 177 | Simply download it and set the path in the config file to your destination path. 178 | 179 | ## Image Reconstruction Training 180 | Here we mainly focus on training **Stage B**, because it is doing most of the heavy lifting for the compression, while 181 | Stage A only applies a very small compression and thus the results are near perfect. Why do we use Stage A even? The 182 | reason is just to make the training and inference of Stage B cheaper and faster. With Stage A in place, Stage B works 183 | at a 4x smaller space (for example `1 x 4 x 256 x 256` instead of `1 x 3 x 1024 x 1024`). Furthermore, we observed that 184 | Stage B learns faster when using Stage A compared to learning Stage B directly at pixel space. Anyway, why would you 185 | even want to train Stage B? Either you want to try to create an even higher compression or finetune on something 186 | very specific. But this probably is a rare occasion. If you do want to, you can take a look at the training config 187 | for the large Stage B [here](../configs/training/finetune_b_3b.yaml) or for the small Stage B 188 | [here](../configs/training/finetune_b_700m.yaml). 189 | 190 | ## Remarks 191 | The codebase is in early development. You might encounter unexpected errors or not perfectly optimized training and 192 | inference code. We apologize for that in advance. If there is interest, we will continue releasing updates to it, 193 | aiming to bring in the latest improvements and optimizations. Moreover, we would be more than happy to receive 194 | ideas, feedback or even updates from people that would like to contribute. Cheers. -------------------------------------------------------------------------------- /src/modules/stage_b.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock 6 | 7 | 8 | class StageB(nn.Module): 9 | def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], 10 | nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], 11 | block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, 12 | c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True, 13 | t_conds=['sca']): 14 | super().__init__() 15 | self.c_r = c_r 16 | self.t_conds = t_conds 17 | self.c_clip_seq = c_clip_seq 18 | if not isinstance(dropout, list): 19 | dropout = [dropout] * len(c_hidden) 20 | if not isinstance(self_attn, list): 21 | self_attn = [self_attn] * len(c_hidden) 22 | 23 | # CONDITIONING 24 | self.effnet_mapper = nn.Sequential( 25 | nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), 26 | nn.GELU(), 27 | nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), 28 | LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) 29 | ) 30 | self.pixels_mapper = nn.Sequential( 31 | nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), 32 | nn.GELU(), 33 | nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), 34 | LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) 35 | ) 36 | self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) 37 | self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) 38 | 39 | self.embedding = nn.Sequential( 40 | nn.PixelUnshuffle(patch_size), 41 | nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), 42 | LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) 43 | ) 44 | 45 | def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): 46 | if block_type == 'C': 47 | return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) 48 | elif block_type == 'A': 49 | return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) 50 | elif block_type == 'F': 51 | return FeedForwardBlock(c_hidden, dropout=dropout) 52 | elif block_type == 'T': 53 | return TimestepBlock(c_hidden, c_r, conds=t_conds) 54 | else: 55 | raise Exception(f'Block type {block_type} not supported') 56 | 57 | # BLOCKS 58 | # -- down blocks 59 | self.down_blocks = nn.ModuleList() 60 | self.down_downscalers = nn.ModuleList() 61 | self.down_repeat_mappers = nn.ModuleList() 62 | for i in range(len(c_hidden)): 63 | if i > 0: 64 | self.down_downscalers.append(nn.Sequential( 65 | LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), 66 | nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), 67 | )) 68 | else: 69 | self.down_downscalers.append(nn.Identity()) 70 | down_block = nn.ModuleList() 71 | for _ in range(blocks[0][i]): 72 | for block_type in level_config[i]: 73 | block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) 74 | down_block.append(block) 75 | self.down_blocks.append(down_block) 76 | if block_repeat is not None: 77 | block_repeat_mappers = nn.ModuleList() 78 | for _ in range(block_repeat[0][i] - 1): 79 | block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) 80 | self.down_repeat_mappers.append(block_repeat_mappers) 81 | 82 | # -- up blocks 83 | self.up_blocks = nn.ModuleList() 84 | self.up_upscalers = nn.ModuleList() 85 | self.up_repeat_mappers = nn.ModuleList() 86 | for i in reversed(range(len(c_hidden))): 87 | if i > 0: 88 | self.up_upscalers.append(nn.Sequential( 89 | LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), 90 | nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), 91 | )) 92 | else: 93 | self.up_upscalers.append(nn.Identity()) 94 | up_block = nn.ModuleList() 95 | for j in range(blocks[1][::-1][i]): 96 | for k, block_type in enumerate(level_config[i]): 97 | c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 98 | block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], 99 | self_attn=self_attn[i]) 100 | up_block.append(block) 101 | self.up_blocks.append(up_block) 102 | if block_repeat is not None: 103 | block_repeat_mappers = nn.ModuleList() 104 | for _ in range(block_repeat[1][::-1][i] - 1): 105 | block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) 106 | self.up_repeat_mappers.append(block_repeat_mappers) 107 | 108 | # OUTPUT 109 | self.clf = nn.Sequential( 110 | LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), 111 | nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), 112 | nn.PixelShuffle(patch_size), 113 | ) 114 | 115 | # --- WEIGHT INIT --- 116 | self.apply(self._init_weights) # General init 117 | nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings 118 | nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings 119 | nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings 120 | nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings 121 | nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings 122 | torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs 123 | nn.init.constant_(self.clf[1].weight, 0) # outputs 124 | 125 | # blocks 126 | for level_block in self.down_blocks + self.up_blocks: 127 | for block in level_block: 128 | if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): 129 | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) 130 | elif isinstance(block, TimestepBlock): 131 | for layer in block.modules(): 132 | if isinstance(layer, nn.Linear): 133 | nn.init.constant_(layer.weight, 0) 134 | 135 | def _init_weights(self, m): 136 | if isinstance(m, (nn.Conv2d, nn.Linear)): 137 | torch.nn.init.xavier_uniform_(m.weight) 138 | if m.bias is not None: 139 | nn.init.constant_(m.bias, 0) 140 | 141 | def gen_r_embedding(self, r, max_positions=10000): 142 | r = r * max_positions 143 | half_dim = self.c_r // 2 144 | emb = math.log(max_positions) / (half_dim - 1) 145 | emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() 146 | emb = r[:, None] * emb[None, :] 147 | emb = torch.cat([emb.sin(), emb.cos()], dim=1) 148 | if self.c_r % 2 == 1: # zero pad 149 | emb = nn.functional.pad(emb, (0, 1), mode='constant') 150 | return emb 151 | 152 | def gen_c_embeddings(self, clip): 153 | if len(clip.shape) == 2: 154 | clip = clip.unsqueeze(1) 155 | clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) 156 | clip = self.clip_norm(clip) 157 | return clip 158 | 159 | def _down_encode(self, x, r_embed, clip): 160 | level_outputs = [] 161 | block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) 162 | for down_block, downscaler, repmap in block_group: 163 | x = downscaler(x) 164 | for i in range(len(repmap) + 1): 165 | for block in down_block: 166 | if isinstance(block, ResBlock) or ( 167 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 168 | ResBlock)): 169 | x = block(x) 170 | elif isinstance(block, AttnBlock) or ( 171 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 172 | AttnBlock)): 173 | x = block(x, clip) 174 | elif isinstance(block, TimestepBlock) or ( 175 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 176 | TimestepBlock)): 177 | x = block(x, r_embed) 178 | else: 179 | x = block(x) 180 | if i < len(repmap): 181 | x = repmap[i](x) 182 | level_outputs.insert(0, x) 183 | return level_outputs 184 | 185 | def _up_decode(self, level_outputs, r_embed, clip): 186 | x = level_outputs[0] 187 | block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) 188 | for i, (up_block, upscaler, repmap) in enumerate(block_group): 189 | for j in range(len(repmap) + 1): 190 | for k, block in enumerate(up_block): 191 | if isinstance(block, ResBlock) or ( 192 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 193 | ResBlock)): 194 | skip = level_outputs[i] if k == 0 and i > 0 else None 195 | if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): 196 | x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', 197 | align_corners=True) 198 | x = block(x, skip) 199 | elif isinstance(block, AttnBlock) or ( 200 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 201 | AttnBlock)): 202 | x = block(x, clip) 203 | elif isinstance(block, TimestepBlock) or ( 204 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 205 | TimestepBlock)): 206 | x = block(x, r_embed) 207 | else: 208 | x = block(x) 209 | if j < len(repmap): 210 | x = repmap[j](x) 211 | x = upscaler(x) 212 | return x 213 | 214 | def forward(self, x, r, effnet, clip, pixels=None, **kwargs): 215 | if pixels is None: 216 | pixels = x.new_zeros(x.size(0), 3, 8, 8) 217 | 218 | # Process the conditioning embeddings 219 | r_embed = self.gen_r_embedding(r) 220 | for c in self.t_conds: 221 | t_cond = kwargs.get(c, torch.zeros_like(r)) 222 | r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) 223 | clip = self.gen_c_embeddings(clip) 224 | 225 | # Model Blocks 226 | x = self.embedding(x) 227 | x = x + self.effnet_mapper( 228 | nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True)) 229 | x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear', 230 | align_corners=True) 231 | level_outputs = self._down_encode(x, r_embed, clip) 232 | x = self._up_decode(level_outputs, r_embed, clip) 233 | return self.clf(x) 234 | 235 | def update_weights_ema(self, src_model, beta=0.999): 236 | for self_params, src_params in zip(self.parameters(), src_model.parameters()): 237 | self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) 238 | for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): 239 | self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) 240 | -------------------------------------------------------------------------------- /src/train/train_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn, optim 4 | from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | import sys 8 | import os 9 | from dataclasses import dataclass 10 | 11 | from gdf import GDF, EpsilonTarget, CosineSchedule 12 | from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight 13 | from torchtools.transforms import SmartCrop 14 | 15 | from modules.effnet import EfficientNetEncoder 16 | from modules.stage_c import StageC 17 | from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock 18 | from modules.previewer import Previewer 19 | 20 | from train.base import DataCore, TrainingCore 21 | 22 | from core import WarpCore 23 | from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail 24 | 25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 26 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy 27 | from accelerate import init_empty_weights 28 | from accelerate.utils import set_module_tensor_to_device 29 | from contextlib import contextmanager 30 | 31 | class WurstCore(TrainingCore, DataCore, WarpCore): 32 | @dataclass(frozen=True) 33 | class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): 34 | # TRAINING PARAMS 35 | lr: float = EXPECTED_TRAIN 36 | warmup_updates: int = EXPECTED_TRAIN 37 | dtype: str = None 38 | 39 | # MODEL VERSION 40 | model_version: str = EXPECTED # 3.6B or 1B 41 | clip_image_model_name: str = 'openai/clip-vit-large-patch14' 42 | clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' 43 | 44 | # CHECKPOINT PATHS 45 | effnet_checkpoint_path: str = EXPECTED 46 | previewer_checkpoint_path: str = EXPECTED 47 | generator_checkpoint_path: str = None 48 | 49 | # gdf customization 50 | adaptive_loss_weight: str = None 51 | 52 | @dataclass(frozen=True) 53 | class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): 54 | effnet: nn.Module = EXPECTED 55 | previewer: nn.Module = EXPECTED 56 | 57 | @dataclass(frozen=True) 58 | class Schedulers(WarpCore.Schedulers): 59 | generator: any = None 60 | 61 | @dataclass(frozen=True) 62 | class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): 63 | gdf: GDF = EXPECTED 64 | sampling_configs: dict = EXPECTED 65 | effnet_preprocess: torchvision.transforms.Compose = EXPECTED 66 | 67 | info: TrainingCore.Info 68 | config: Config 69 | 70 | def setup_extras_pre(self) -> Extras: 71 | gdf = GDF( 72 | schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), 73 | input_scaler=VPScaler(), target=EpsilonTarget(), 74 | noise_cond=CosineTNoiseCond(), 75 | loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), 76 | ) 77 | sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} 78 | 79 | if self.info.adaptive_loss is not None: 80 | gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) 81 | gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) 82 | 83 | effnet_preprocess = torchvision.transforms.Compose([ 84 | torchvision.transforms.Normalize( 85 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 86 | ) 87 | ]) 88 | 89 | clip_preprocess = torchvision.transforms.Compose([ 90 | torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), 91 | torchvision.transforms.CenterCrop(224), 92 | torchvision.transforms.Normalize( 93 | mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) 94 | ) 95 | ]) 96 | 97 | if self.config.training: 98 | transforms = torchvision.transforms.Compose([ 99 | torchvision.transforms.ToTensor(), 100 | torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), 101 | SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) 102 | ]) 103 | else: 104 | transforms = None 105 | 106 | return self.Extras( 107 | gdf=gdf, 108 | sampling_configs=sampling_configs, 109 | transforms=transforms, 110 | effnet_preprocess=effnet_preprocess, 111 | clip_preprocess=clip_preprocess 112 | ) 113 | 114 | def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, 115 | eval_image_embeds=False, return_fields=None): 116 | conditions = super().get_conditions( 117 | batch, models, extras, is_eval, is_unconditional, 118 | eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] 119 | ) 120 | return conditions 121 | 122 | def setup_models(self, extras: Extras) -> Models: 123 | dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 124 | 125 | # EfficientNet encoder 126 | effnet = EfficientNetEncoder() 127 | effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) 128 | effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) 129 | effnet.eval().requires_grad_(False).to(self.device) 130 | del effnet_checkpoint 131 | 132 | # Previewer 133 | previewer = Previewer() 134 | previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) 135 | previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) 136 | previewer.eval().requires_grad_(False).to(self.device) 137 | del previewer_checkpoint 138 | 139 | @contextmanager 140 | def dummy_context(): 141 | yield None 142 | 143 | loading_context = dummy_context if self.config.training else init_empty_weights 144 | 145 | # Diffusion models 146 | with loading_context(): 147 | generator_ema = None 148 | if self.config.model_version == '3.6B': 149 | generator = StageC() 150 | if self.config.ema_start_iters is not None: 151 | generator_ema = StageC() 152 | elif self.config.model_version == '1B': 153 | generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) 154 | if self.config.ema_start_iters is not None: 155 | generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) 156 | else: 157 | raise ValueError(f"Unknown model version {self.config.model_version}") 158 | 159 | if self.config.generator_checkpoint_path is not None: 160 | if loading_context is dummy_context: 161 | generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) 162 | else: 163 | for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): 164 | set_module_tensor_to_device(generator, param_name, "cpu", value=param) 165 | generator = generator.to(dtype).to(self.device) 166 | generator = self.load_model(generator, 'generator') 167 | 168 | if generator_ema is not None: 169 | if loading_context is dummy_context: 170 | generator_ema.load_state_dict(generator.state_dict()) 171 | else: 172 | for param_name, param in generator.state_dict().items(): 173 | set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) 174 | generator_ema = self.load_model(generator_ema, 'generator_ema') 175 | generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) 176 | 177 | if self.config.use_fsdp: 178 | fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) 179 | generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) 180 | if generator_ema is not None: 181 | generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) 182 | 183 | # CLIP encoders 184 | tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) 185 | text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) 186 | image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) 187 | 188 | return self.Models( 189 | effnet=effnet, previewer=previewer, 190 | generator=generator, generator_ema=generator_ema, 191 | tokenizer=tokenizer, text_model=text_model, image_model=image_model 192 | ) 193 | 194 | def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: 195 | optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) 196 | optimizer = self.load_optimizer(optimizer, 'generator_optim', 197 | fsdp_model=models.generator if self.config.use_fsdp else None) 198 | return self.Optimizers(generator=optimizer) 199 | 200 | def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: 201 | scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) 202 | scheduler.last_epoch = self.info.total_steps 203 | return self.Schedulers(generator=scheduler) 204 | 205 | # Training loop -------------------------------- 206 | def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): 207 | batch = next(data.iterator) 208 | 209 | with torch.no_grad(): 210 | conditions = self.get_conditions(batch, models, extras) 211 | latents = self.encode_latents(batch, models, extras) 212 | noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) 213 | 214 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 215 | pred = models.generator(noised, noise_cond, **conditions) 216 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 217 | loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps 218 | 219 | if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): 220 | extras.gdf.loss_weight.update_buckets(logSNR, loss) 221 | 222 | return loss, loss_adjusted 223 | 224 | def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): 225 | if update: 226 | loss_adjusted.backward() 227 | grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) 228 | optimizers_dict = optimizers.to_dict() 229 | for k in optimizers_dict: 230 | if k != 'training': 231 | optimizers_dict[k].step() 232 | schedulers_dict = schedulers.to_dict() 233 | for k in schedulers_dict: 234 | if k != 'training': 235 | schedulers_dict[k].step() 236 | for k in optimizers_dict: 237 | if k != 'training': 238 | optimizers_dict[k].zero_grad(set_to_none=True) 239 | self.info.total_steps += 1 240 | else: 241 | loss_adjusted.backward() 242 | grad_norm = torch.tensor(0.0).to(self.device) 243 | 244 | return grad_norm 245 | 246 | def models_to_save(self): 247 | return ['generator', 'generator_ema'] 248 | 249 | def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 250 | images = batch['images'].to(self.device) 251 | return models.effnet(extras.effnet_preprocess(images)) 252 | 253 | def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 254 | return models.previewer(latents) 255 | 256 | 257 | if __name__ == '__main__': 258 | print("Launching Script") 259 | warpcore = WurstCore( 260 | config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, 261 | device=torch.device(int(os.environ.get("SLURM_LOCALID"))) 262 | ) 263 | # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD 264 | 265 | # RUN TRAINING 266 | warpcore() 267 | -------------------------------------------------------------------------------- /src/modules/stage_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import math 5 | from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock 6 | from .controlnet import ControlNetDeliverer 7 | 8 | 9 | class UpDownBlock2d(nn.Module): 10 | def __init__(self, c_in, c_out, mode, enabled=True): 11 | super().__init__() 12 | assert mode in ['up', 'down'] 13 | interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', 14 | align_corners=True) if enabled else nn.Identity() 15 | mapping = nn.Conv2d(c_in, c_out, kernel_size=1) 16 | self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) 17 | 18 | def forward(self, x): 19 | for block in self.blocks: 20 | x = block(x.float()) 21 | return x 22 | 23 | 24 | class StageC(nn.Module): 25 | def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], 26 | blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], 27 | c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, 28 | dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]): 29 | super().__init__() 30 | self.c_r = c_r 31 | self.t_conds = t_conds 32 | self.c_clip_seq = c_clip_seq 33 | if not isinstance(dropout, list): 34 | dropout = [dropout] * len(c_hidden) 35 | if not isinstance(self_attn, list): 36 | self_attn = [self_attn] * len(c_hidden) 37 | 38 | # CONDITIONING 39 | self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) 40 | self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) 41 | self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) 42 | self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) 43 | 44 | self.embedding = nn.Sequential( 45 | nn.PixelUnshuffle(patch_size), 46 | nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), 47 | LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) 48 | ) 49 | 50 | def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): 51 | if block_type == 'C': 52 | return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) 53 | elif block_type == 'A': 54 | return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) 55 | elif block_type == 'F': 56 | return FeedForwardBlock(c_hidden, dropout=dropout) 57 | elif block_type == 'T': 58 | return TimestepBlock(c_hidden, c_r, conds=t_conds) 59 | else: 60 | raise Exception(f'Block type {block_type} not supported') 61 | 62 | # BLOCKS 63 | # -- down blocks 64 | self.down_blocks = nn.ModuleList() 65 | self.down_downscalers = nn.ModuleList() 66 | self.down_repeat_mappers = nn.ModuleList() 67 | for i in range(len(c_hidden)): 68 | if i > 0: 69 | self.down_downscalers.append(nn.Sequential( 70 | LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), 71 | UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) 72 | )) 73 | else: 74 | self.down_downscalers.append(nn.Identity()) 75 | down_block = nn.ModuleList() 76 | for _ in range(blocks[0][i]): 77 | for block_type in level_config[i]: 78 | block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) 79 | down_block.append(block) 80 | self.down_blocks.append(down_block) 81 | if block_repeat is not None: 82 | block_repeat_mappers = nn.ModuleList() 83 | for _ in range(block_repeat[0][i] - 1): 84 | block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) 85 | self.down_repeat_mappers.append(block_repeat_mappers) 86 | 87 | # -- up blocks 88 | self.up_blocks = nn.ModuleList() 89 | self.up_upscalers = nn.ModuleList() 90 | self.up_repeat_mappers = nn.ModuleList() 91 | for i in reversed(range(len(c_hidden))): 92 | if i > 0: 93 | self.up_upscalers.append(nn.Sequential( 94 | LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), 95 | UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) 96 | )) 97 | else: 98 | self.up_upscalers.append(nn.Identity()) 99 | up_block = nn.ModuleList() 100 | for j in range(blocks[1][::-1][i]): 101 | for k, block_type in enumerate(level_config[i]): 102 | c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 103 | block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], 104 | self_attn=self_attn[i]) 105 | up_block.append(block) 106 | self.up_blocks.append(up_block) 107 | if block_repeat is not None: 108 | block_repeat_mappers = nn.ModuleList() 109 | for _ in range(block_repeat[1][::-1][i] - 1): 110 | block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) 111 | self.up_repeat_mappers.append(block_repeat_mappers) 112 | 113 | # OUTPUT 114 | self.clf = nn.Sequential( 115 | LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), 116 | nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), 117 | nn.PixelShuffle(patch_size), 118 | ) 119 | 120 | # --- WEIGHT INIT --- 121 | self.apply(self._init_weights) # General init 122 | nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings 123 | nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings 124 | nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings 125 | torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs 126 | nn.init.constant_(self.clf[1].weight, 0) # outputs 127 | 128 | # blocks 129 | for level_block in self.down_blocks + self.up_blocks: 130 | for block in level_block: 131 | if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): 132 | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) 133 | elif isinstance(block, TimestepBlock): 134 | for layer in block.modules(): 135 | if isinstance(layer, nn.Linear): 136 | nn.init.constant_(layer.weight, 0) 137 | 138 | def _init_weights(self, m): 139 | if isinstance(m, (nn.Conv2d, nn.Linear)): 140 | torch.nn.init.xavier_uniform_(m.weight) 141 | if m.bias is not None: 142 | nn.init.constant_(m.bias, 0) 143 | 144 | def gen_r_embedding(self, r, max_positions=10000): 145 | r = r * max_positions 146 | half_dim = self.c_r // 2 147 | emb = math.log(max_positions) / (half_dim - 1) 148 | emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() 149 | emb = r[:, None] * emb[None, :] 150 | emb = torch.cat([emb.sin(), emb.cos()], dim=1) 151 | if self.c_r % 2 == 1: # zero pad 152 | emb = nn.functional.pad(emb, (0, 1), mode='constant') 153 | return emb 154 | 155 | def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): 156 | clip_txt = self.clip_txt_mapper(clip_txt) 157 | if len(clip_txt_pooled.shape) == 2: 158 | clip_txt_pool = clip_txt_pooled.unsqueeze(1) 159 | if len(clip_img.shape) == 2: 160 | clip_img = clip_img.unsqueeze(1) 161 | clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) 162 | clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) 163 | clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) 164 | clip = self.clip_norm(clip) 165 | return clip 166 | 167 | def _down_encode(self, x, r_embed, clip, cnet=None): 168 | level_outputs = [] 169 | block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) 170 | for down_block, downscaler, repmap in block_group: 171 | x = downscaler(x) 172 | for i in range(len(repmap) + 1): 173 | for block in down_block: 174 | if isinstance(block, ResBlock) or ( 175 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 176 | ResBlock)): 177 | if cnet is not None: 178 | next_cnet = cnet() 179 | if next_cnet is not None: 180 | x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', 181 | align_corners=True) 182 | x = block(x) 183 | elif isinstance(block, AttnBlock) or ( 184 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 185 | AttnBlock)): 186 | x = block(x, clip) 187 | elif isinstance(block, TimestepBlock) or ( 188 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 189 | TimestepBlock)): 190 | x = block(x, r_embed) 191 | else: 192 | x = block(x) 193 | if i < len(repmap): 194 | x = repmap[i](x) 195 | level_outputs.insert(0, x) 196 | return level_outputs 197 | 198 | def _up_decode(self, level_outputs, r_embed, clip, cnet=None): 199 | x = level_outputs[0] 200 | block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) 201 | for i, (up_block, upscaler, repmap) in enumerate(block_group): 202 | for j in range(len(repmap) + 1): 203 | for k, block in enumerate(up_block): 204 | if isinstance(block, ResBlock) or ( 205 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 206 | ResBlock)): 207 | skip = level_outputs[i] if k == 0 and i > 0 else None 208 | if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): 209 | x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', 210 | align_corners=True) 211 | if cnet is not None: 212 | next_cnet = cnet() 213 | if next_cnet is not None: 214 | x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', 215 | align_corners=True) 216 | x = block(x, skip) 217 | elif isinstance(block, AttnBlock) or ( 218 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 219 | AttnBlock)): 220 | x = block(x, clip) 221 | elif isinstance(block, TimestepBlock) or ( 222 | hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, 223 | TimestepBlock)): 224 | x = block(x, r_embed) 225 | else: 226 | x = block(x) 227 | if j < len(repmap): 228 | x = repmap[j](x) 229 | x = upscaler(x) 230 | return x 231 | 232 | def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): 233 | # Process the conditioning embeddings 234 | r_embed = self.gen_r_embedding(r) 235 | for c in self.t_conds: 236 | t_cond = kwargs.get(c, torch.zeros_like(r)) 237 | r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) 238 | clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) 239 | 240 | # Model Blocks 241 | x = self.embedding(x) 242 | if cnet is not None: 243 | cnet = ControlNetDeliverer(cnet) 244 | level_outputs = self._down_encode(x, r_embed, clip, cnet) 245 | x = self._up_decode(level_outputs, r_embed, clip, cnet) 246 | return self.clf(x) 247 | 248 | def update_weights_ema(self, src_model, beta=0.999): 249 | for self_params, src_params in zip(self.parameters(), src_model.parameters()): 250 | self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) 251 | for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): 252 | self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) 253 | -------------------------------------------------------------------------------- /src/train/train_b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn, optim 4 | from transformers import AutoTokenizer, CLIPTextModelWithProjection 5 | from warmup_scheduler import GradualWarmupScheduler 6 | import numpy as np 7 | 8 | import sys 9 | import os 10 | from dataclasses import dataclass 11 | 12 | from gdf import GDF, EpsilonTarget, CosineSchedule 13 | from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight 14 | from torchtools.transforms import SmartCrop 15 | 16 | from modules.effnet import EfficientNetEncoder 17 | from modules.stage_a import StageA 18 | 19 | from modules.stage_b import StageB 20 | from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock 21 | 22 | from train.base import DataCore, TrainingCore 23 | 24 | from core import WarpCore 25 | from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail 26 | 27 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 28 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy 29 | from accelerate import init_empty_weights 30 | from accelerate.utils import set_module_tensor_to_device 31 | from contextlib import contextmanager 32 | 33 | class WurstCore(TrainingCore, DataCore, WarpCore): 34 | @dataclass(frozen=True) 35 | class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): 36 | # TRAINING PARAMS 37 | lr: float = EXPECTED_TRAIN 38 | warmup_updates: int = EXPECTED_TRAIN 39 | shift: float = EXPECTED_TRAIN 40 | dtype: str = None 41 | 42 | # MODEL VERSION 43 | model_version: str = EXPECTED # 3BB or 700M 44 | clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' 45 | 46 | # CHECKPOINT PATHS 47 | stage_a_checkpoint_path: str = EXPECTED 48 | effnet_checkpoint_path: str = EXPECTED 49 | generator_checkpoint_path: str = None 50 | 51 | # gdf customization 52 | adaptive_loss_weight: str = None 53 | 54 | @dataclass(frozen=True) 55 | class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): 56 | effnet: nn.Module = EXPECTED 57 | stage_a: nn.Module = EXPECTED 58 | 59 | @dataclass(frozen=True) 60 | class Schedulers(WarpCore.Schedulers): 61 | generator: any = None 62 | 63 | @dataclass(frozen=True) 64 | class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): 65 | gdf: GDF = EXPECTED 66 | sampling_configs: dict = EXPECTED 67 | effnet_preprocess: torchvision.transforms.Compose = EXPECTED 68 | 69 | info: TrainingCore.Info 70 | config: Config 71 | 72 | def setup_extras_pre(self) -> Extras: 73 | gdf = GDF( 74 | schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), 75 | input_scaler=VPScaler(), target=EpsilonTarget(), 76 | noise_cond=CosineTNoiseCond(), 77 | loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), 78 | ) 79 | sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10} 80 | 81 | if self.info.adaptive_loss is not None: 82 | gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) 83 | gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) 84 | 85 | effnet_preprocess = torchvision.transforms.Compose([ 86 | torchvision.transforms.Normalize( 87 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 88 | ) 89 | ]) 90 | 91 | transforms = torchvision.transforms.Compose([ 92 | torchvision.transforms.ToTensor(), 93 | torchvision.transforms.Resize(self.config.image_size, 94 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR, 95 | antialias=True), 96 | SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size) 97 | ]) 98 | 99 | return self.Extras( 100 | gdf=gdf, 101 | sampling_configs=sampling_configs, 102 | transforms=transforms, 103 | effnet_preprocess=effnet_preprocess, 104 | clip_preprocess=None 105 | ) 106 | 107 | def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None): 108 | images = batch.get('images', None) 109 | 110 | if images is not None: 111 | images = images.to(self.device) 112 | if is_eval and not is_unconditional: 113 | effnet_embeddings = models.effnet(extras.effnet_preprocess(images)) 114 | else: 115 | if is_eval: 116 | effnet_factor = 1 117 | else: 118 | effnet_factor = np.random.uniform(0.5, 1) # f64 to f32 119 | effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32) 120 | 121 | effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device) 122 | if not is_eval: 123 | effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) 124 | rand_idx = np.random.rand(len(images)) <= 0.9 125 | if any(rand_idx): 126 | effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx])) 127 | else: 128 | effnet_embeddings = None 129 | 130 | conditions = super().get_conditions( 131 | batch, models, extras, is_eval, is_unconditional, 132 | eval_image_embeds, return_fields=return_fields or ['clip_text_pooled'] 133 | ) 134 | 135 | return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']} 136 | 137 | def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models: 138 | dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 139 | 140 | # EfficientNet encoder 141 | effnet = EfficientNetEncoder().to(self.device) 142 | effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) 143 | effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) 144 | effnet.eval().requires_grad_(False) 145 | del effnet_checkpoint 146 | 147 | # vqGAN 148 | stage_a = StageA().to(self.device) 149 | stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path) 150 | stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict']) 151 | stage_a.eval().requires_grad_(False) 152 | del stage_a_checkpoint 153 | 154 | @contextmanager 155 | def dummy_context(): 156 | yield None 157 | 158 | loading_context = dummy_context if self.config.training else init_empty_weights 159 | 160 | # Diffusion models 161 | with loading_context(): 162 | generator_ema = None 163 | if self.config.model_version == '3B': 164 | generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) 165 | if self.config.ema_start_iters is not None: 166 | generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) 167 | elif self.config.model_version == '700M': 168 | generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) 169 | if self.config.ema_start_iters is not None: 170 | generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) 171 | else: 172 | raise ValueError(f"Unknown model version {self.config.model_version}") 173 | 174 | if self.config.generator_checkpoint_path is not None: 175 | if loading_context is dummy_context: 176 | generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) 177 | else: 178 | for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): 179 | set_module_tensor_to_device(generator, param_name, "cpu", value=param) 180 | generator = generator.to(dtype).to(self.device) 181 | generator = self.load_model(generator, 'generator') 182 | 183 | if generator_ema is not None: 184 | if loading_context is dummy_context: 185 | generator_ema.load_state_dict(generator.state_dict()) 186 | else: 187 | for param_name, param in generator.state_dict().items(): 188 | set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) 189 | generator_ema = self.load_model(generator_ema, 'generator_ema') 190 | generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) 191 | 192 | if self.config.use_fsdp: 193 | fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) 194 | generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) 195 | if generator_ema is not None: 196 | generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) 197 | 198 | if skip_clip: 199 | tokenizer = None 200 | text_model = None 201 | else: 202 | tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) 203 | text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) 204 | 205 | return self.Models( 206 | effnet=effnet, stage_a=stage_a, 207 | generator=generator, generator_ema=generator_ema, 208 | tokenizer=tokenizer, text_model=text_model 209 | ) 210 | 211 | def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: 212 | optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) 213 | optimizer = self.load_optimizer(optimizer, 'generator_optim', 214 | fsdp_model=models.generator if self.config.use_fsdp else None) 215 | return self.Optimizers(generator=optimizer) 216 | 217 | def setup_schedulers(self, extras: Extras, models: Models, 218 | optimizers: TrainingCore.Optimizers) -> Schedulers: 219 | scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) 220 | scheduler.last_epoch = self.info.total_steps 221 | return self.Schedulers(generator=scheduler) 222 | 223 | def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'): 224 | epsilon = epsilon.clone() 225 | multipliers = [1] 226 | for i in range(1, levels): 227 | m = 0.75 ** i 228 | h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i) 229 | if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]): 230 | offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device) 231 | epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:], 232 | mode=scale_mode) * m 233 | multipliers.append(m) 234 | if h <= 1 or w <= 1: 235 | break 236 | epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5 237 | # epsilon = epsilon / epsilon.std() 238 | return epsilon 239 | 240 | def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): 241 | batch = next(data.iterator) 242 | 243 | with torch.no_grad(): 244 | conditions = self.get_conditions(batch, models, extras) 245 | latents = self.encode_latents(batch, models, extras) 246 | epsilon = torch.randn_like(latents) 247 | epsilon = self._pyramid_noise(epsilon, size_range=[1, 16]) 248 | noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1, 249 | epsilon=epsilon) 250 | 251 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 252 | pred = models.generator(noised, noise_cond, **conditions) 253 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 254 | loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps 255 | 256 | if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): 257 | extras.gdf.loss_weight.update_buckets(logSNR, loss) 258 | 259 | return loss, loss_adjusted 260 | 261 | def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, 262 | schedulers: Schedulers): 263 | if update: 264 | loss_adjusted.backward() 265 | grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) 266 | optimizers_dict = optimizers.to_dict() 267 | for k in optimizers_dict: 268 | if k != 'training': 269 | optimizers_dict[k].step() 270 | schedulers_dict = schedulers.to_dict() 271 | for k in schedulers_dict: 272 | if k != 'training': 273 | schedulers_dict[k].step() 274 | for k in optimizers_dict: 275 | if k != 'training': 276 | optimizers_dict[k].zero_grad(set_to_none=True) 277 | self.info.total_steps += 1 278 | else: 279 | loss_adjusted.backward() 280 | grad_norm = torch.tensor(0.0).to(self.device) 281 | 282 | return grad_norm 283 | 284 | def models_to_save(self): 285 | return ['generator', 'generator_ema'] 286 | 287 | def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 288 | images = batch['images'].to(self.device) 289 | return models.stage_a.encode(images)[0] 290 | 291 | def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 292 | return models.stage_a.decode(latents.float()).clamp(0, 1) 293 | 294 | 295 | if __name__ == '__main__': 296 | print("Launching Script") 297 | warpcore = WurstCore( 298 | config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, 299 | device=torch.device(int(os.environ.get("SLURM_LOCALID"))) 300 | ) 301 | # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD 302 | 303 | # RUN TRAINING 304 | warpcore() 305 | --------------------------------------------------------------------------------