├── .github ├── FUNDING.yml └── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature.yml │ └── bug.yml ├── .gitattributes ├── requirements.txt ├── resources ├── icons │ ├── icon.ico │ ├── icon.png │ ├── icon_discord.png │ └── icon_small.png ├── images │ └── OneTrainerGUI.gif ├── sd_model_spec │ ├── sana.json │ ├── sana-lora.json │ ├── chroma.json │ ├── qwen.json │ ├── z_image.json │ ├── flux_dev_1.0.json │ ├── sana-embedding.json │ ├── chroma-lora.json │ ├── hi_dream_full.json │ ├── hunyuan_video.json │ ├── qwen-lora.json │ ├── z_image-lora.json │ ├── flux_dev_1.0-lora.json │ ├── flux_dev_fill_1.0.json │ ├── pixart_alpha_1.0.json │ ├── pixart_sigma_1.0.json │ ├── sd_1.5.json │ ├── wuerstchen_2.0.json │ ├── chroma-embedding.json │ ├── hi_dream_full-lora.json │ ├── hunyuan_video-lora.json │ ├── sd_2.0.json │ ├── sd_2.1.json │ ├── sd_3.5_1.0.json │ ├── stable_cascade_1.0.json │ ├── flux_dev_1.0-embedding.json │ ├── flux_dev_fill_1.0-lora.json │ ├── pixart_alpha_1.0-lora.json │ ├── pixart_sigma_1.0-lora.json │ ├── sd_1.5-lora.json │ ├── sd_2.0-lora.json │ ├── sd_2.1-lora.json │ ├── sd_3_2b_1.0.json │ ├── wuerstchen_2.0-lora.json │ ├── hi_dream_full-embedding.json │ ├── hunyuan_video-embedding.json │ ├── sd_3.5_1.0-lora.json │ ├── sd_2.0_depth.json │ ├── stable_cascade_1.0-lora.json │ ├── flux_dev_fill_1.0-embedding.json │ ├── pixart_alpha_1.0-embedding.json │ ├── pixart_sigma_1.0-embedding.json │ ├── sd_1.5-embedding.json │ ├── sd_2.0-embedding.json │ ├── sd_2.0_depth-lora.json │ ├── sd_2.1-embedding.json │ ├── sd_3_2b_1.0-lora.json │ ├── wuerstchen_2.0-embedding.json │ ├── sd_1.5_inpainting.json │ ├── sd_3.5_1.0-embedding.json │ ├── sd_2.0_inpainting.json │ ├── stable_cascade_1.0-embedding.json │ ├── sd_1.5_inpainting-lora.json │ ├── sd_2.0_depth-embedding.json │ ├── sd_2.0_inpainting-lora.json │ ├── sd_3_2b_1.0-embedding.json │ ├── sd_xl_base_1.0-embedding.json │ ├── sd_1.5_inpainting-embedding.json │ ├── sd_2.0_inpainting-embedding.json │ ├── sd_xl_base_1.0_inpainting.json │ ├── sd_xl_base_1.0_inpainting-lora.json │ └── sd_xl_base_1.0_Inpainting-embedding.json ├── docker │ ├── RunPod-NVIDIA-CLI-start.sh.patch │ ├── Vast-NVIDIA-CLI.Dockerfile │ └── RunPod-NVIDIA-CLI.Dockerfile └── model_config │ └── stable_cascade │ ├── stable_cascade_prior_1.0b.json │ └── stable_cascade_prior_3.6b.json ├── embedding_templates ├── .gitignore └── subject.txt ├── requirements-default.txt ├── install.sh ├── modules ├── util │ ├── time_util.py │ ├── enum │ │ ├── CloudType.py │ │ ├── EMAMode.py │ │ ├── CloudFileSync.py │ │ ├── BalancingStrategy.py │ │ ├── CloudAction.py │ │ ├── FileType.py │ │ ├── ConfigPart.py │ │ ├── GenerateCaptionsModel.py │ │ ├── ConceptType.py │ │ ├── GenerateMasksModel.py │ │ ├── TrainingMethod.py │ │ ├── AudioFormat.py │ │ ├── TimestepDistribution.py │ │ ├── LearningRateScheduler.py │ │ ├── LossWeight.py │ │ ├── NoiseScheduler.py │ │ ├── TimeUnit.py │ │ ├── GradientCheckpointingMethod.py │ │ ├── ImageFormat.py │ │ ├── VideoFormat.py │ │ ├── ModelFormat.py │ │ ├── LossScaler.py │ │ ├── LearningRateScaler.py │ │ └── GradientReducePrecision.py │ ├── args │ │ ├── arg_type_util.py │ │ ├── CalculateLossArgs.py │ │ └── CaptionUIArgs.py │ ├── type_util.py │ ├── image_util.py │ ├── conv_util.py │ ├── git_util.py │ ├── TrainProgress.py │ ├── config │ │ └── SecretsConfig.py │ ├── loss │ │ └── masked_loss.py │ └── convert │ │ └── rescale_noise_scheduler_to_zero_terminal_snr.py ├── module │ ├── quantized │ │ └── mixin │ │ │ ├── QuantizedModuleMixin.py │ │ │ └── QuantizedLinearMixin.py │ ├── RembgModel.py │ ├── RembgHumanModel.py │ ├── Blip2Model.py │ └── BlipModel.py ├── modelLoader │ ├── QwenFineTuneModelLoader.py │ ├── flux │ │ ├── FluxEmbeddingLoader.py │ │ └── FluxLoRALoader.py │ ├── sana │ │ ├── SanaEmbeddingLoader.py │ │ └── SanaLoRALoader.py │ ├── chroma │ │ ├── ChromaEmbeddingLoader.py │ │ └── ChromaLoRALoader.py │ ├── hiDream │ │ ├── HiDreamEmbeddingLoader.py │ │ └── HiDreamLoRALoader.py │ ├── wuerstchen │ │ ├── WuerstchenEmbeddingLoader.py │ │ └── WuerstchenLoRALoader.py │ ├── pixartAlpha │ │ ├── PixArtAlphaEmbeddingLoader.py │ │ └── PixArtAlphaLoRALoader.py │ ├── hunyuanVideo │ │ ├── HunyuanVideoEmbeddingLoader.py │ │ └── HunyuanVideoLoRALoader.py │ ├── stableDiffusion │ │ ├── StableDiffusionEmbeddingLoader.py │ │ └── StableDiffusionLoRALoader.py │ ├── stableDiffusion3 │ │ ├── StableDiffusion3EmbeddingLoader.py │ │ └── StableDiffusion3LoRALoader.py │ ├── SanaFineTuneModelLoader.py │ ├── stableDiffusionXL │ │ ├── StableDiffusionXLEmbeddingLoader.py │ │ └── StableDiffusionXLLoRALoader.py │ ├── QwenLoRAModelLoader.py │ ├── SanaEmbeddingModelLoader.py │ ├── ChromaFineTuneModelLoader.py │ ├── ChromaEmbeddingModelLoader.py │ ├── HiDreamFineTuneModelLoader.py │ ├── HiDreamEmbeddingModelLoader.py │ ├── SanaLoRAModelLoader.py │ ├── HunyuanVideoFineTuneModelLoader.py │ ├── FluxFineTuneModelLoader.py │ ├── HunyuanVideoEmbeddingModelLoader.py │ ├── qwen │ │ └── QwenLoRALoader.py │ ├── ChromaLoRAModelLoader.py │ ├── FluxEmbeddingModelLoader.py │ ├── HiDreamLoRAModelLoader.py │ ├── WuerstchenFineTuneModelLoader.py │ ├── PixArtAlphaFineTuneModelLoader.py │ ├── FluxLoRAModelLoader.py │ ├── WuerstchenEmbeddingModelLoader.py │ ├── PixArtAlphaEmbeddingModelLoader.py │ ├── HunyuanVideoLoRAModelLoader.py │ ├── StableDiffusion3FineTuneModelLoader.py │ ├── StableDiffusion3EmbeddingModelLoader.py │ ├── WuerstchenLoRAModelLoader.py │ ├── StableDiffusionXLFineTuneModelLoader.py │ ├── PixArtAlphaLoRAModelLoader.py │ ├── StableDiffusionXLEmbeddingModelLoader.py │ ├── StableDiffusion3LoRAModelLoader.py │ ├── StableDiffusionXLLoRAModelLoader.py │ ├── mixin │ │ └── ModelSpecModelLoaderMixin.py │ ├── StableDiffusionFineTuneModelLoader.py │ └── StableDiffusionEmbeddingModelLoader.py ├── modelSaver │ ├── BaseModelSaver.py │ ├── QwenLoRAModelSaver.py │ ├── QwenFineTuneModelSaver.py │ ├── ZImageLoRAModelSaver.py │ ├── ZImageFineTuneModelSaver.py │ ├── FluxEmbeddingModelSaver.py │ ├── SanaEmbeddingModelSaver.py │ ├── ChromaEmbeddingModelSaver.py │ ├── HiDreamEmbeddingModelSaver.py │ ├── zImage │ │ └── ZImageLoRASaver.py │ ├── WuerstchenEmbeddingModelSaver.py │ ├── PixArtAlphaEmbeddingModelSaver.py │ ├── HunyuanVideoEmbeddingModelSaver.py │ ├── SanaFineTuneModelSaver.py │ ├── StableDiffusionEmbeddingModelSaver.py │ ├── FluxFineTuneModelSaver.py │ ├── StableDiffusion3EmbeddingModelSaver.py │ ├── StableDiffusionXLEmbeddingModelSaver.py │ ├── ChromaFineTuneModelSaver.py │ ├── qwen │ │ └── QwenLoRASaver.py │ ├── WuerstchenFineTuneModelSaver.py │ ├── PixArtAlphaFineTuneModelSaver.py │ ├── HunyuanVideoFineTuneModelSaver.py │ ├── SanaLoRAModelSaver.py │ ├── FluxLoRAModelSaver.py │ ├── ChromaLoRAModelSaver.py │ ├── HiDreamLoRAModelSaver.py │ ├── StableDiffusion3FineTuneModelSaver.py │ └── StableDiffusionFineTuneModelSaver.py ├── dataLoader │ └── BaseDataLoader.py ├── model │ └── util │ │ ├── t5_util.py │ │ ├── gemma_util.py │ │ └── llama_util.py ├── ui │ └── SampleParamsWindow.py └── modelSetup │ └── mixin │ └── ModelSetupFlowMatchingMixin.py ├── scripts ├── README.md ├── train_ui.py ├── video_tool_ui.py ├── convert_model_ui.py ├── caption_ui.py ├── install_zluda.py ├── calculate_loss.py ├── util │ └── import_util.py ├── generate_captions.py └── train.py ├── requirements-dev.txt ├── docs ├── Overview.md ├── DockerImage.md ├── CliTraining.md └── Contributing.md ├── training_presets ├── .gitignore ├── #pixart alpha 1.0.json ├── #sd 2.1.json ├── #sd 2.0 inpaint.json ├── #sd 1.5 inpaint.json ├── #sana 1.6b.json ├── #sd 1.5.json ├── #sd 1.5 inpaint masked.json ├── #sd 1.5 masked.json ├── #sd 2.1 LoRA.json ├── #sd 1.5 LoRA.json ├── #pixart sigma 1.0.json ├── #sd 2.0 inpaint LoRA.json ├── #sd 2.1 embedding.json ├── #sd 1.5 embedding.json ├── #pixart sigma 1.0 LoRA.json ├── #sdxl 1.0.json ├── #sdxl 1.0 embedding.json ├── #sdxl 1.0 inpaint LoRA.json ├── #sd 3.json ├── #sdxl 1.0 LoRA.json ├── #flux LoRA.json ├── #chroma LoRA 16GB.json ├── #chroma LoRA 24GB.json ├── #sd 3 LoRA.json ├── #hunyuan video LoRA.json ├── #wuerstchen 2.0 LoRA.json ├── #qwen LoRA 16GB.json ├── #qwen LoRA 24GB.json ├── #chroma LoRA 8GB.json ├── #wuerstchen 2.0 embedding.json ├── #z-image LoRA 16GB.json ├── #hidream LoRA.json ├── #z-image LoRA 8GB.json ├── #chroma Finetune 24GB.json ├── #qwen Finetune 16GB.json └── #qwen Finetune 24GB.json ├── requirements-cuda.txt ├── start-ui.sh ├── requirements-rocm.txt ├── .editorconfig ├── update.sh ├── .gitignore ├── .pre-commit-config.yaml ├── run-cmd.sh ├── export_debug.bat └── pyproject.toml /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: Nerogar 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | *.bat text eol=crlf 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements-cuda.txt 2 | -r requirements-global.txt 3 | -------------------------------------------------------------------------------- /resources/icons/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/HEAD/resources/icons/icon.ico -------------------------------------------------------------------------------- /resources/icons/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/HEAD/resources/icons/icon.png -------------------------------------------------------------------------------- /resources/icons/icon_discord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/HEAD/resources/icons/icon_discord.png -------------------------------------------------------------------------------- /resources/icons/icon_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/HEAD/resources/icons/icon_small.png -------------------------------------------------------------------------------- /resources/images/OneTrainerGUI.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/HEAD/resources/images/OneTrainerGUI.gif -------------------------------------------------------------------------------- /embedding_templates/.gitignore: -------------------------------------------------------------------------------- 1 | # ignores everything except the builtin template files 2 | * 3 | !.gitignore 4 | !subject.txt 5 | -------------------------------------------------------------------------------- /requirements-default.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | torch==2.8.0 3 | torchvision==0.23.0 4 | onnxruntime==1.22.1 5 | 6 | # optimizers 7 | # TODO 8 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | source "${BASH_SOURCE[0]%/*}/lib.include.sh" 6 | 7 | prepare_runtime_environment 8 | -------------------------------------------------------------------------------- /modules/util/time_util.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | def get_string_timestamp(): 5 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 6 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # OneTrainer Scripts 2 | 3 | For an overview of these scripts, see the 4 | [CLI Mode subsection](../README.md#cli-mode) in the OneTrainer README. 5 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Git commit hook handler. 2 | # NOTE: It's a user-wide tool installed outside venv. Don't pin exact versions. 3 | # SEE: https://pre-commit.com/ 4 | pre-commit>=4.0.1 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Discord 4 | url: https://discord.gg/KwgcQd5scF 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /modules/util/enum/CloudType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class CloudType(Enum): 5 | RUNPOD = 'RUNPOD' 6 | LINUX = 'LINUX' 7 | def __str__(self): 8 | return self.value 9 | -------------------------------------------------------------------------------- /modules/util/enum/EMAMode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class EMAMode(Enum): 5 | OFF = 'OFF' 6 | GPU = 'GPU' 7 | CPU = 'CPU' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /modules/util/enum/CloudFileSync.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class CloudFileSync(Enum): 5 | FABRIC_SFTP = 'FABRIC_SFTP' 6 | NATIVE_SCP = 'NATIVE_SCP' 7 | def __str__(self): 8 | return self.value 9 | -------------------------------------------------------------------------------- /modules/util/enum/BalancingStrategy.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BalancingStrategy(Enum): 5 | REPEATS = 'REPEATS' 6 | SAMPLES = 'SAMPLES' 7 | 8 | def __str__(self): 9 | return self.value 10 | -------------------------------------------------------------------------------- /modules/util/enum/CloudAction.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class CloudAction(Enum): 5 | NONE = 'NONE' 6 | STOP = 'STOP' 7 | DELETE = 'DELETE' 8 | def __str__(self): 9 | return self.value 10 | -------------------------------------------------------------------------------- /modules/util/enum/FileType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class FileType(Enum): 5 | IMAGE = 'IMAGE' 6 | VIDEO = 'VIDEO' 7 | AUDIO = 'AUDIO' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sana.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "sana", 4 | "modelspec.implementation": "https://github.com/NVlabs/Sana", 5 | "modelspec.title": "Sana" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/enum/ConfigPart.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ConfigPart(Enum): 5 | NONE = 'NONE' 6 | SETTINGS = 'SETTINGS' 7 | ALL = 'ALL' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /docs/Overview.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | [Quick Start Guide](QuickStartGuide.md) 4 | 5 | [Training from CLI](CliTraining.md) 6 | 7 | [Embedding Training](EmbeddingTraining.md) 8 | 9 | [Captioning and Masking](CaptioningAndMasking.md) 10 | -------------------------------------------------------------------------------- /resources/docker/RunPod-NVIDIA-CLI-start.sh.patch: -------------------------------------------------------------------------------- 1 | 83a84,88 2 | > link_onetrainer() { 3 | > if [[ ! -f /workspace/OneTrainer ]]; then 4 | > ln -snf /OneTrainer /workspace/OneTrainer 5 | > fi 6 | > } 7 | 93a100 8 | > link_onetrainer 9 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sana-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "sana/lora", 4 | "modelspec.implementation": "https://github.com/NVlabs/Sana", 5 | "modelspec.title": "Sana LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /training_presets/.gitignore: -------------------------------------------------------------------------------- 1 | # ignores everything except the builtin config files starting with # 2 | * 3 | !.gitignore 4 | !\#* 5 | 6 | # #.json is an exception (from the exception). This file stores the ui state from a previous run. 7 | \#.json 8 | -------------------------------------------------------------------------------- /modules/util/args/arg_type_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def torch_device(device_name: str) -> torch.device: 5 | return torch.device(device_name) 6 | 7 | def nullable_bool(bool_value: str) -> bool: 8 | return bool_value.lower() == "true" 9 | -------------------------------------------------------------------------------- /resources/sd_model_spec/chroma.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Chroma1", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Chroma1" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Qwen-Image", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Qwen Image" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/z_image.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Z-Image", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Z-Image" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-dev", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sana-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "sana/embedding", 4 | "modelspec.implementation": "https://github.com/NVlabs/Sana", 5 | "modelspec.title": "Sana Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/enum/GenerateCaptionsModel.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GenerateCaptionsModel(Enum): 5 | BLIP = 'BLIP' 6 | BLIP2 = 'BLIP2' 7 | WD14_VIT_2 = 'WD14_VIT_2' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /resources/sd_model_spec/chroma-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Chroma1/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Chroma1 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hi_dream_full.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hidream-i1", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HiDream I1 Full" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hunyuan_video.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hunyuan-video", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HunyuanVideo" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/qwen-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Qwen-Image/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Qwen Image LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/z_image-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Z-Image/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Z-Image LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/enum/ConceptType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ConceptType(Enum): 5 | STANDARD = 'STANDARD' 6 | VALIDATION = 'VALIDATION' 7 | PRIOR_PREDICTION = 'PRIOR_PREDICTION' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-dev/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_fill_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-fill-dev", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev Fill 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_alpha_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-alpha", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-alpha", 5 | "modelspec.title": "PixArt Alpha 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_sigma_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-sigma", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-sigma", 5 | "modelspec.title": "PixArt Sigma 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1", 4 | "modelspec.implementation": "https://github.com/CompVis/stable-diffusion", 5 | "modelspec.title": "Stable Diffusion 1.5" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/wuerstchen_2.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "wuerstchen-v2-prior", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Wuerstchen 2.0" 6 | } 7 | -------------------------------------------------------------------------------- /scripts/train_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.TrainUI import TrainUI 6 | 7 | 8 | def main(): 9 | ui = TrainUI() 10 | ui.mainloop() 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /resources/sd_model_spec/chroma-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Chroma1/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Chroma1 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hi_dream_full-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hidream-i1/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HiDream I1 Full LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hunyuan_video-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hunyuan-video/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HunyuanVideo LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.1.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.1" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3.5_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3.5", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3.5" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/stable_cascade_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-cascade-v1-prior", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Cascade 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /modules/module/quantized/mixin/QuantizedModuleMixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class QuantizedModuleMixin(metaclass=ABCMeta): 7 | @abstractmethod 8 | def quantize(self, device: torch.device | None = None): 9 | pass 10 | -------------------------------------------------------------------------------- /modules/util/enum/GenerateMasksModel.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GenerateMasksModel(Enum): 5 | CLIPSEG = 'CLIPSEG' 6 | REMBG = 'REMBG' 7 | REMBG_HUMAN = 'REMBG_HUMAN' 8 | COLOR = 'COLOR' 9 | 10 | def __str__(self): 11 | return self.value 12 | -------------------------------------------------------------------------------- /modules/util/type_util.py: -------------------------------------------------------------------------------- 1 | from typing import get_origin 2 | 3 | 4 | def issubclass_safe(x, t): 5 | # if x is defined as a generic list or dict (e.g. `list[int]`), issubclass will throw an error 6 | return get_origin(x) is not list and get_origin(x) is not dict and issubclass(x, t) 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-dev/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_fill_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-fill-dev/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev Fill 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_alpha_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-alpha/lora", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-alpha", 5 | "modelspec.title": "PixArt Alpha 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_sigma_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-sigma/lora", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-sigma", 5 | "modelspec.title": "PixArt Sigma 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.1-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.1 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3_2b_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3-medium", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3 Medium" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/wuerstchen_2.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "wuerstchen-v2-prior/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Wuerstchen 2.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/enum/TrainingMethod.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TrainingMethod(Enum): 5 | FINE_TUNE = 'FINE_TUNE' 6 | LORA = 'LORA' 7 | EMBEDDING = 'EMBEDDING' 8 | FINE_TUNE_VAE = 'FINE_TUNE_VAE' 9 | 10 | def __str__(self): 11 | return self.value 12 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hi_dream_full-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hidream-i1/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HiDream I1 Full Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hunyuan_video-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hunyuan-video/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HunyuanVideo Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3.5_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3.5/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3.5 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_depth.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-depth", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.0 Depth" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/stable_cascade_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-cascade-v1-prior/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Cascade 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /scripts/video_tool_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.VideoToolUI import VideoToolUI 6 | 7 | 8 | def main(): 9 | ui = VideoToolUI(None) 10 | ui.mainloop() 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_fill_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-fill-dev/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev Fill 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_alpha_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-alpha/embedding", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-alpha", 5 | "modelspec.title": "PixArt Alpha 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_sigma_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-sigma/embedding", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-sigma", 5 | "modelspec.title": "PixArt Sigma 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_depth-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-depth/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Depth LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.1-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.1 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3_2b_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3-medium/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3 Medium LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/wuerstchen_2.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "wuerstchen-v2-prior/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Wuerstchen 2.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5_inpainting.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1-inpainting", 4 | "modelspec.implementation": "https://github.com/CompVis/stable-diffusion", 5 | "modelspec.title": "Stable Diffusion 1.5 Inpainting" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3.5_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3.5/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3.5 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /scripts/convert_model_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.ConvertModelUI import ConvertModelUI 6 | 7 | 8 | def main(): 9 | ui = ConvertModelUI(None) 10 | ui.mainloop() 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_inpainting.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-inpainting", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.0 Inpainting" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/stable_cascade_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-cascade-v1-prior/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Cascade 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5_inpainting-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1-inpainting/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 Inpainting LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_depth-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-depth/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Depth Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_inpainting-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-inpainting/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Inpainting LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3_2b_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3-medium/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3 Medium Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/image_util.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | 3 | 4 | def load_image(path: str, convert_mode: str = 'RGB') -> Image.Image: 5 | image = Image.open(path) 6 | image = ImageOps.exif_transpose(image) 7 | if convert_mode: 8 | image = image.convert(convert_mode) 9 | return image 10 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/conv_util.py: -------------------------------------------------------------------------------- 1 | from modules.module.LoRAModule import LoRAModuleWrapper 2 | 3 | from torch import nn 4 | 5 | 6 | def apply_circular_padding_to_conv2d(module: nn.Module | LoRAModuleWrapper): 7 | for m in module.modules(): 8 | if isinstance(m, nn.Conv2d): 9 | m.padding_mode = 'circular' 10 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5_inpainting-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1-inpainting/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 Inpainting Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_inpainting-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-inpainting/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Inpainting Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0_inpainting.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base-inpainting", 4 | "modelspec.implementation": "https://github.com/Stability-AI/generative-models", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Inpainting" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0_inpainting-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base-inpainting/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Inpainting LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0_Inpainting-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base-inpainting/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Inpainting Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /modules/util/enum/AudioFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AudioFormat(Enum): 5 | MP3 = 'MP3' 6 | 7 | def __str__(self): 8 | return self.value 9 | 10 | def extension(self) -> str: 11 | match self: 12 | case AudioFormat.MP3: 13 | return ".mp3" 14 | case _: 15 | return "" 16 | -------------------------------------------------------------------------------- /modules/util/enum/TimestepDistribution.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TimestepDistribution(Enum): 5 | UNIFORM = 'UNIFORM' 6 | SIGMOID = 'SIGMOID' 7 | LOGIT_NORMAL = 'LOGIT_NORMAL' 8 | HEAVY_TAIL = 'HEAVY_TAIL' 9 | COS_MAP = 'COS_MAP' 10 | INVERTED_PARABOLA = 'INVERTED_PARABOLA' 11 | 12 | def __str__(self): 13 | return self.value 14 | -------------------------------------------------------------------------------- /requirements-cuda.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | --extra-index-url https://download.pytorch.org/whl/cu128 3 | torch==2.8.0+cu128 4 | torchvision==0.23.0+cu128 5 | onnxruntime-gpu==1.22.0 6 | nvidia-nccl-cu12==2.27.3; sys_platform == "linux" 7 | triton-windows==3.4.0.post20; sys_platform == "win32" 8 | 9 | # optimizers 10 | bitsandbytes==0.46.0 # bitsandbytes for 8-bit optimizers and weight quantization 11 | -------------------------------------------------------------------------------- /start-ui.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | source "${BASH_SOURCE[0]%/*}/lib.include.sh" 6 | 7 | # Xet is buggy. Disabled by default unless already defined - https://github.com/Nerogar/OneTrainer/issues/949 8 | if [[ -z "${HF_HUB_DISABLE_XET+x}" ]]; then 9 | export HF_HUB_DISABLE_XET=1 10 | fi 11 | 12 | prepare_runtime_environment 13 | 14 | run_python_in_active_env "scripts/train_ui.py" "$@" 15 | -------------------------------------------------------------------------------- /modules/module/quantized/mixin/QuantizedLinearMixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class QuantizedLinearMixin(metaclass=ABCMeta): 7 | @abstractmethod 8 | def original_weight_shape(self) -> tuple[int, ...]: 9 | pass 10 | 11 | @abstractmethod 12 | def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: 13 | pass 14 | -------------------------------------------------------------------------------- /training_presets/#pixart alpha 1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "PixArt-alpha/PixArt-XL-2-1024-MS", 3 | "model_type": "PIXART_ALPHA", 4 | "output_model_destination": "models/model", 5 | "output_model_format": "DIFFUSERS", 6 | "resolution": "1024", 7 | "text_encoder": { 8 | "train": false, 9 | "weight_dtype": "BFLOAT_16" 10 | }, 11 | "training_method": "FINE_TUNE" 12 | } 13 | -------------------------------------------------------------------------------- /training_presets/#sd 2.1.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "sd2-community/stable-diffusion-2-1", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_21", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "vae": { 10 | "weight_dtype": "FLOAT_32" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /scripts/caption_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.CaptionUI import CaptionUI 6 | from modules.util.args.CaptionUIArgs import CaptionUIArgs 7 | 8 | 9 | def main(): 10 | args = CaptionUIArgs.parse_args() 11 | 12 | ui = CaptionUI(None, args.dir, args.include_subdirectories) 13 | ui.mainloop() 14 | 15 | 16 | if __name__ == '__main__': 17 | main() 18 | -------------------------------------------------------------------------------- /requirements-rocm.txt: -------------------------------------------------------------------------------- 1 | # Note: AMD requirements might be outdated. If you can provide information about running OneTrainer on AMD, 2 | # please open an issue or pull request on github 3 | 4 | # pytorch 5 | --extra-index-url https://download.pytorch.org/whl/rocm6.3 6 | torch==2.7.1+rocm6.3 #intentionally not upgraded because of reported problems 7 | torchvision==0.22.1+rocm6.3 8 | onnxruntime==1.22.1 9 | 10 | # optimizers 11 | # TODO 12 | -------------------------------------------------------------------------------- /training_presets/#sd 2.0 inpaint.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "sd2-community/stable-diffusion-2-inpainting", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_20_INPAINTING", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "vae": { 10 | "weight_dtype": "FLOAT_32" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 inpaint.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-inpainting", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_15_INPAINTING", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "vae": { 10 | "weight_dtype": "FLOAT_32" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /modules/util/enum/LearningRateScheduler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LearningRateScheduler(Enum): 5 | CONSTANT = 'CONSTANT' 6 | LINEAR = 'LINEAR' 7 | COSINE = 'COSINE' 8 | COSINE_WITH_RESTARTS = 'COSINE_WITH_RESTARTS' 9 | COSINE_WITH_HARD_RESTARTS = 'COSINE_WITH_HARD_RESTARTS' 10 | REX = 'REX' 11 | ADAFACTOR = 'ADAFACTOR' 12 | CUSTOM = 'CUSTOM' 13 | 14 | def __str__(self): 15 | return self.value 16 | -------------------------------------------------------------------------------- /training_presets/#sana 1.6b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", 3 | "model_type": "SANA", 4 | "output_model_destination": "models/model", 5 | "output_model_format": "DIFFUSERS", 6 | "resolution": "1024", 7 | "text_encoder": { 8 | "train": false, 9 | "weight_dtype": "BFLOAT_16" 10 | }, 11 | "training_method": "FINE_TUNE", 12 | "train_dtype": "BFLOAT_16" 13 | } 14 | -------------------------------------------------------------------------------- /modules/util/enum/LossWeight.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LossWeight(Enum): 5 | CONSTANT = 'CONSTANT' 6 | P2 = 'P2' 7 | MIN_SNR_GAMMA = 'MIN_SNR_GAMMA' 8 | DEBIASED_ESTIMATION = 'DEBIASED_ESTIMATION' 9 | SIGMA = 'SIGMA' 10 | 11 | def supports_flow_matching(self) -> bool: 12 | return self == LossWeight.CONSTANT \ 13 | or self == LossWeight.SIGMA 14 | 15 | def __str__(self): 16 | return self.value 17 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_15", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "unet": { 10 | "train": true 11 | }, 12 | "vae": { 13 | "weight_dtype": "FLOAT_32" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /modules/util/enum/NoiseScheduler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class NoiseScheduler(Enum): 5 | DDIM = 'DDIM' 6 | 7 | EULER = 'EULER' 8 | EULER_A = 'EULER_A' 9 | DPMPP = 'DPMPP' 10 | DPMPP_SDE = 'DPMPP_SDE' 11 | UNIPC = 'UNIPC' 12 | 13 | EULER_KARRAS = 'EULER_KARRAS' 14 | DPMPP_KARRAS = 'DPMPP_KARRAS' 15 | DPMPP_SDE_KARRAS = 'DPMPP_SDE_KARRAS' 16 | UNIPC_KARRAS = 'UNIPC_KARRAS' 17 | 18 | def __str__(self): 19 | return self.value 20 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | [*] 2 | charset = utf-8 3 | end_of_line = lf 4 | indent_size = 4 5 | indent_style = space 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | max_line_length = 120 9 | tab_width = 4 10 | ij_continuation_indent_size = 8 11 | ij_formatter_off_tag = @formatter:off 12 | ij_formatter_on_tag = @formatter:on 13 | ij_formatter_tags_enabled = true 14 | ij_smart_tabs = false 15 | ij_visual_guides = none 16 | ij_wrap_on_typing = false 17 | 18 | [*.bat] 19 | end_of_line = crlf 20 | -------------------------------------------------------------------------------- /modules/util/enum/TimeUnit.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TimeUnit(Enum): 5 | EPOCH = 'EPOCH' 6 | STEP = 'STEP' 7 | SECOND = 'SECOND' 8 | MINUTE = 'MINUTE' 9 | HOUR = 'HOUR' 10 | 11 | NEVER = 'NEVER' 12 | ALWAYS = 'ALWAYS' 13 | 14 | def __str__(self): 15 | return self.value 16 | 17 | def is_time_unit(self) -> bool: 18 | return self == TimeUnit.SECOND \ 19 | or self == TimeUnit.MINUTE \ 20 | or self == TimeUnit.HOUR 21 | -------------------------------------------------------------------------------- /update.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Change our working dir to the root of the project. 6 | cd -- "$(dirname -- "${BASH_SOURCE[0]}")" 7 | 8 | # Pull the latest changes via Git. 9 | echo "[OneTrainer] Updating OneTrainer to latest version from Git repository..." 10 | git pull 11 | 12 | # Load the newest version of the function library. 13 | source "lib.include.sh" 14 | 15 | # Prepare runtime and upgrade all dependencies to latest compatible version. 16 | prepare_runtime_environment upgrade 17 | -------------------------------------------------------------------------------- /modules/util/enum/GradientCheckpointingMethod.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GradientCheckpointingMethod(Enum): 5 | OFF = 'OFF' 6 | ON = 'ON' 7 | CPU_OFFLOADED = 'CPU_OFFLOADED' 8 | 9 | def __str__(self): 10 | return self.value 11 | 12 | def enabled(self): 13 | return self == GradientCheckpointingMethod.ON \ 14 | or self == GradientCheckpointingMethod.CPU_OFFLOADED 15 | 16 | def offload(self): 17 | return self == GradientCheckpointingMethod.CPU_OFFLOADED 18 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 inpaint masked.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-inpainting", 3 | "batch_size": 4, 4 | "masked_training": true, 5 | "model_type": "STABLE_DIFFUSION_15_INPAINTING", 6 | "normalize_masked_area_loss": true, 7 | "output_model_destination": "models/model.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "512", 10 | "training_method": "FINE_TUNE", 11 | "vae": { 12 | "weight_dtype": "FLOAT_32" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /modules/module/RembgModel.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseRembgModel import BaseRembgModel 2 | 3 | import torch 4 | 5 | 6 | class RembgModel(BaseRembgModel): 7 | def __init__(self, device: torch.device, dtype: torch.dtype): 8 | super().__init__( 9 | model_filename="u2net.onnx", 10 | model_path="https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx", 11 | model_md5="md5:60024c5c889badc19c04ad937298a77b", 12 | device=device, 13 | dtype=dtype, 14 | ) 15 | -------------------------------------------------------------------------------- /scripts/install_zluda.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports(allow_zluda=False) 4 | 5 | import sys 6 | 7 | from modules.zluda import ZLUDAInstaller 8 | 9 | if __name__ == '__main__': 10 | try: 11 | zluda_path = ZLUDAInstaller.get_path() 12 | ZLUDAInstaller.install(zluda_path) 13 | ZLUDAInstaller.make_copy(zluda_path) 14 | except Exception as e: 15 | print(f'Failed to install ZLUDA: {e}') 16 | sys.exit(1) 17 | 18 | print(f'ZLUDA installed: {zluda_path}') 19 | -------------------------------------------------------------------------------- /modules/util/git_util.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def get_git_branch() -> str: 5 | try: 6 | return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 7 | except subprocess.CalledProcessError: 8 | return "git not installed" 9 | 10 | 11 | def get_git_revision() -> str: 12 | try: 13 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 14 | except subprocess.CalledProcessError: 15 | return "git not installed" 16 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 masked.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 3 | "batch_size": 4, 4 | "masked_training": true, 5 | "max_noising_strength": 0.8, 6 | "model_type": "STABLE_DIFFUSION_15", 7 | "normalize_masked_area_loss": true, 8 | "output_model_destination": "models/model.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "training_method": "FINE_TUNE", 12 | "vae": { 13 | "weight_dtype": "FLOAT_32" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /modules/modelLoader/QwenFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.QwenModel import QwenModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.qwen.QwenModelLoader import QwenModelLoader 4 | from modules.util.enum.ModelType import ModelType 5 | 6 | QwenFineTuneModelLoader = make_fine_tune_model_loader( 7 | model_spec_map={ModelType.QWEN: "resources/sd_model_spec/qwen.json"}, 8 | model_class=QwenModel, 9 | model_loader_class=QwenModelLoader, 10 | embedding_loader_class=None, 11 | ) 12 | -------------------------------------------------------------------------------- /modules/module/RembgHumanModel.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseRembgModel import BaseRembgModel 2 | 3 | import torch 4 | 5 | 6 | class RembgHumanModel(BaseRembgModel): 7 | def __init__(self, device: torch.device, dtype: torch.dtype): 8 | super().__init__( 9 | model_filename="u2net_human_seg.onnx", 10 | model_path="https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx", 11 | model_md5="md5:c09ddc2e0104f800e3e1bb4652583d1f", 12 | device=device, 13 | dtype=dtype, 14 | ) 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # development 2 | .idea 3 | *.bak 4 | *.pyc 5 | *.swp 6 | *~ 7 | debug*.ipynb 8 | debug*.py 9 | /debug* 10 | 11 | # user data 12 | /workspace* 13 | /models* 14 | /training_concepts 15 | /training_samples 16 | /training_user_settings 17 | /external 18 | /update.var 19 | config.json 20 | secrets.json 21 | *.zip 22 | 23 | # environments 24 | /.venv* 25 | /venv* 26 | /conda_env* 27 | .python-version 28 | *.egg-info 29 | 30 | # pixi environments 31 | .pixi 32 | pixi.lock 33 | pixi.toml 34 | 35 | # misc files 36 | /src 37 | train.bat 38 | debug_report.log 39 | config_diff.txt 40 | -------------------------------------------------------------------------------- /modules/modelLoader/flux/FluxEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class FluxEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: FluxModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/sana/SanaEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class SanaEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: SanaModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/chroma/ChromaEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class ChromaEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: ChromaModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/hiDream/HiDreamEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class HiDreamEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: HiDreamModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /docs/DockerImage.md: -------------------------------------------------------------------------------- 1 | # Docker Image 2 | 3 | A Dockerfile based on the `nvidia/cuda:11.8.0-devel-ubuntu22.04` is provided. 4 | 5 | This image requires `nvidia-driver-525` and `nvidia-docker2` installed on the host. 6 | 7 | ## Building Image 8 | 9 | Build using: 10 | 11 | ``` 12 | docker build -t myuser/onetrainer:latest -f Dockerfile . 13 | ``` 14 | 15 | ## Running Image 16 | 17 | This is an example 18 | 19 | ``` 20 | docker run \ 21 | --gpus all \ 22 | -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix \ 23 | -i \ 24 | --tty \ 25 | --shm-size=512m \ 26 | myuser/onetrainer:latest 27 | ``` 28 | -------------------------------------------------------------------------------- /modules/modelLoader/wuerstchen/WuerstchenEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class WuerstchenEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: WuerstchenModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/pixartAlpha/PixArtAlphaEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class PixArtAlphaEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: PixArtAlphaModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/hunyuanVideo/HunyuanVideoEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class HunyuanVideoEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: HunyuanVideoModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /docs/CliTraining.md: -------------------------------------------------------------------------------- 1 | # Training from CLI 2 | 3 | All training functionality is available through the CLI command `python scripts/train.py`. The training configuration is 4 | stored in a `.json` file that is passed to this script. 5 | 6 | Some options require specifying paths to files with a specific 7 | layout. These files can be created using the create_train_files.py script. You can call the script like 8 | this `python scripts/create_train_files.py -h`. 9 | 10 | To simplify the creation of the training config, you can export your settings from the UI by using the export button. 11 | This will create a single file that contains every setting. 12 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion/StableDiffusionEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class StableDiffusionEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: StableDiffusionModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion3/StableDiffusion3EmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class StableDiffusion3EmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: StableDiffusion3Model, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/SanaFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.sana.SanaEmbeddingLoader import SanaEmbeddingLoader 4 | from modules.modelLoader.sana.SanaModelLoader import SanaModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | SanaFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ModelType.SANA: "resources/sd_model_spec/sana.json"}, 9 | model_class=SanaModel, 10 | model_loader_class=SanaModelLoader, 11 | embedding_loader_class=SanaEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusionXL/StableDiffusionXLEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class StableDiffusionXLEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: StableDiffusionXLModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/QwenLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.QwenModel import QwenModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.qwen.QwenLoRALoader import QwenLoRALoader 4 | from modules.modelLoader.qwen.QwenModelLoader import QwenModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | QwenLoRAModelLoader = make_lora_model_loader( 8 | model_spec_map={ModelType.QWEN: "resources/sd_model_spec/qwen-lora.json"}, 9 | model_class=QwenModel, 10 | model_loader_class=QwenModelLoader, 11 | embedding_loader_class=None, 12 | lora_loader_class=QwenLoRALoader, 13 | ) 14 | -------------------------------------------------------------------------------- /training_presets/#sd 2.1 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "sd2-community/stable-diffusion-2-1", 4 | "batch_size": 4, 5 | "layer_filter_preset": "attn-mlp", 6 | "learning_rate": 0.0003, 7 | "model_type": "STABLE_DIFFUSION_21", 8 | "output_model_destination": "models/lora.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "training_method": "LORA", 12 | "vae": { 13 | "weight_dtype": "FLOAT_16" 14 | }, 15 | "unet": { 16 | "weight_dtype": "FLOAT_16" 17 | }, 18 | "text_encoder": { 19 | "weight_dtype": "FLOAT_16" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /modules/modelLoader/SanaEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.sana.SanaEmbeddingLoader import SanaEmbeddingLoader 4 | from modules.modelLoader.sana.SanaModelLoader import SanaModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | SanaEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ModelType.SANA: "resources/sd_model_spec/sana-embedding.json"}, 9 | model_class=SanaModel, 10 | model_loader_class=SanaModelLoader, 11 | embedding_loader_class=SanaEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 4 | "batch_size": 4, 5 | "layer_filter_preset": "attn-mlp", 6 | "learning_rate": 0.0003, 7 | "model_type": "STABLE_DIFFUSION_15", 8 | "output_model_destination": "models/lora.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "training_method": "LORA", 12 | "vae": { 13 | "weight_dtype": "FLOAT_32" 14 | }, 15 | "unet": { 16 | "weight_dtype": "FLOAT_16" 17 | }, 18 | "text_encoder": { 19 | "weight_dtype": "FLOAT_16" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /training_presets/#pixart sigma 1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "PixArt-sigma/PixArt-XL-2-1024-MS", 3 | "batch_size": 8, 4 | "model_type": "PIXART_SIGMA", 5 | "epochs": 100, 6 | "learning_rate": 0.00002, 7 | "output_model_destination": "models/model", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "training_method": "FINE_TUNE", 14 | "transformer": { 15 | "weight_dtype": "BFLOAT_16" 16 | }, 17 | "text_encoder": { 18 | "weight_dtype": "BFLOAT_16" 19 | }, 20 | "vae": { 21 | "weight_dtype": "FLOAT_32" 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /modules/modelLoader/ChromaFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelLoader.chroma.ChromaEmbeddingLoader import ChromaEmbeddingLoader 3 | from modules.modelLoader.chroma.ChromaModelLoader import ChromaModelLoader 4 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | ChromaFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ModelType.CHROMA_1: "resources/sd_model_spec/chroma.json"}, 9 | model_class=ChromaModel, 10 | model_loader_class=ChromaModelLoader, 11 | embedding_loader_class=ChromaEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /modules/modelSaver/BaseModelSaver.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from modules.model.BaseModel import BaseModel 4 | from modules.util.enum.ModelFormat import ModelFormat 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | import torch 8 | 9 | 10 | class BaseModelSaver(metaclass=ABCMeta): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | @abstractmethod 15 | def save( 16 | self, 17 | model: BaseModel, 18 | model_type: ModelType, 19 | output_model_format: ModelFormat, 20 | output_model_destination: str, 21 | dtype: torch.dtype | None, 22 | ): 23 | pass 24 | -------------------------------------------------------------------------------- /training_presets/#sd 2.0 inpaint LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "sd2-community/stable-diffusion-2-inpainting", 4 | "batch_size": 4, 5 | "layer_filter_preset": "attn-mlp", 6 | "learning_rate": 0.0003, 7 | "model_type": "STABLE_DIFFUSION_20_INPAINTING", 8 | "output_model_destination": "models/lora.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "training_method": "LORA", 12 | "vae": { 13 | "weight_dtype": "FLOAT_16" 14 | }, 15 | "unet": { 16 | "weight_dtype": "FLOAT_16" 17 | }, 18 | "text_encoder": { 19 | "weight_dtype": "FLOAT_16" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /modules/modelLoader/ChromaEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelLoader.chroma.ChromaEmbeddingLoader import ChromaEmbeddingLoader 3 | from modules.modelLoader.chroma.ChromaModelLoader import ChromaModelLoader 4 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | ChromaEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ModelType.CHROMA_1: "resources/sd_model_spec/chroma-embedding.json"}, 9 | model_class=ChromaModel, 10 | model_loader_class=ChromaModelLoader, 11 | embedding_loader_class=ChromaEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /scripts/calculate_loss.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | import json 6 | 7 | from modules.module.GenerateLossesModel import GenerateLossesModel 8 | from modules.util.args.CalculateLossArgs import CalculateLossArgs 9 | from modules.util.config.TrainConfig import TrainConfig 10 | 11 | 12 | def main(): 13 | args = CalculateLossArgs.parse_args() 14 | 15 | train_config = TrainConfig.default_values() 16 | with open(args.config_path, "r") as f: 17 | train_config.from_dict(json.load(f)) 18 | 19 | trainer = GenerateLossesModel(train_config, args.output_path) 20 | trainer.start() 21 | 22 | 23 | if __name__ == '__main__': 24 | main() 25 | -------------------------------------------------------------------------------- /modules/modelLoader/HiDreamFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.hiDream.HiDreamEmbeddingLoader import HiDreamEmbeddingLoader 4 | from modules.modelLoader.hiDream.HiDreamModelLoader import HiDreamModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | HiDreamFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ModelType.HI_DREAM_FULL: "resources/sd_model_spec/hi_dream_full.json"}, 9 | model_class=HiDreamModel, 10 | model_loader_class=HiDreamModelLoader, 11 | embedding_loader_class=HiDreamEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /training_presets/#sd 2.1 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "sd2-community/stable-diffusion-2-1", 4 | "latent_caching": false, 5 | "learning_rate": 0.0003, 6 | "learning_rate_warmup_steps": 20, 7 | "model_type": "STABLE_DIFFUSION_21", 8 | "output_model_destination": "models/embedding.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "sample_after": 1, 12 | "training_method": "EMBEDDING", 13 | "vae": { 14 | "weight_dtype": "FLOAT_32" 15 | }, 16 | "unet": { 17 | "weight_dtype": "FLOAT_16" 18 | }, 19 | "text_encoder": { 20 | "weight_dtype": "FLOAT_16" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /modules/modelLoader/HiDreamEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.hiDream.HiDreamEmbeddingLoader import HiDreamEmbeddingLoader 4 | from modules.modelLoader.hiDream.HiDreamModelLoader import HiDreamModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | HiDreamEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ModelType.HI_DREAM_FULL: "resources/sd_model_spec/hi_dream_full-embedding.json"}, 9 | model_class=HiDreamModel, 10 | model_loader_class=HiDreamModelLoader, 11 | embedding_loader_class=HiDreamEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /modules/util/enum/ImageFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ImageFormat(Enum): 5 | PNG = 'PNG' 6 | JPG = 'JPG' 7 | 8 | def __str__(self): 9 | return self.value 10 | 11 | def extension(self) -> str: 12 | match self: 13 | case ImageFormat.PNG: 14 | return ".png" 15 | case ImageFormat.JPG: 16 | return ".jpg" 17 | case _: 18 | return "" 19 | 20 | def pil_format(self) -> str: 21 | match self: 22 | case ImageFormat.PNG: 23 | return "PNG" 24 | case ImageFormat.JPG: 25 | return "JPEG" 26 | case _: 27 | return "" 28 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 4 | "latent_caching": false, 5 | "learning_rate": 0.0003, 6 | "learning_rate_warmup_steps": 20, 7 | "model_type": "STABLE_DIFFUSION_15", 8 | "output_model_destination": "models/embedding.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "sample_after": 1, 12 | "training_method": "EMBEDDING", 13 | "vae": { 14 | "weight_dtype": "FLOAT_32" 15 | }, 16 | "unet": { 17 | "weight_dtype": "FLOAT_16" 18 | }, 19 | "text_encoder": { 20 | "weight_dtype": "FLOAT_16" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /training_presets/#pixart sigma 1.0 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "PixArt-sigma/PixArt-XL-2-1024-MS", 3 | "batch_size": 4, 4 | "model_type": "PIXART_SIGMA", 5 | "epochs": 50, 6 | "learning_rate": 0.00001, 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "training_method": "LORA", 14 | "lora_rank": 16, 15 | "lora_alpha": 1.0, 16 | "transformer": { 17 | "weight_dtype": "BFLOAT_16" 18 | }, 19 | "text_encoder": { 20 | "weight_dtype": "BFLOAT_16" 21 | }, 22 | "vae": { 23 | "weight_dtype": "FLOAT_32" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-case-conflict 7 | - id: check-illegal-windows-names 8 | - id: destroyed-symlinks 9 | - id: fix-byte-order-marker 10 | - id: mixed-line-ending 11 | - id: trailing-whitespace 12 | - id: end-of-file-fixer 13 | - id: check-executables-have-shebangs 14 | - id: check-yaml 15 | 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: v0.14.9 18 | hooks: 19 | # Run the Ruff linter, but not the formatter. 20 | - id: ruff 21 | args: ["--fix"] 22 | types_or: [ python, pyi, jupyter ] 23 | 24 | ci: 25 | autofix_prs: false 26 | -------------------------------------------------------------------------------- /modules/modelLoader/SanaLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.sana.SanaEmbeddingLoader import SanaEmbeddingLoader 4 | from modules.modelLoader.sana.SanaLoRALoader import SanaLoRALoader 5 | from modules.modelLoader.sana.SanaModelLoader import SanaModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | SanaLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ModelType.SANA: "resources/sd_model_spec/sana-lora.json"}, 10 | model_class=SanaModel, 11 | model_loader_class=SanaModelLoader, 12 | embedding_loader_class=SanaEmbeddingLoader, 13 | lora_loader_class=SanaLoRALoader, 14 | ) 15 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0", 3 | "model_type": "STABLE_DIFFUSION_XL_10_BASE", 4 | "output_model_destination": "models/model.safetensors", 5 | "output_model_format": "SAFETENSORS", 6 | "resolution": "1024", 7 | "text_encoder": { 8 | "train": false 9 | }, 10 | "text_encoder_2": { 11 | "train": false 12 | }, 13 | "training_method": "FINE_TUNE", 14 | "vae": { 15 | "weight_dtype": "FLOAT_32" 16 | }, 17 | "unet": { 18 | "weight_dtype": "BFLOAT_16" 19 | }, 20 | "text_encoder": { 21 | "weight_dtype": "BFLOAT_16" 22 | }, 23 | "text_encoder_2": { 24 | "weight_dtype": "BFLOAT_16" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /modules/modelLoader/HunyuanVideoFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.hunyuanVideo.HunyuanVideoEmbeddingLoader import HunyuanVideoEmbeddingLoader 4 | from modules.modelLoader.hunyuanVideo.HunyuanVideoModelLoader import HunyuanVideoModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | HunyuanVideoFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ModelType.HUNYUAN_VIDEO: "resources/sd_model_spec/hunyuan_video.json"}, 9 | model_class=HunyuanVideoModel, 10 | model_loader_class=HunyuanVideoModelLoader, 11 | embedding_loader_class=HunyuanVideoEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /modules/modelLoader/FluxFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelLoader.flux.FluxEmbeddingLoader import FluxEmbeddingLoader 3 | from modules.modelLoader.flux.FluxModelLoader import FluxModelLoader 4 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | FluxFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ 9 | ModelType.FLUX_DEV_1: "resources/sd_model_spec/flux_dev_1.0.json", 10 | ModelType.FLUX_FILL_DEV_1: "resources/sd_model_spec/flux_dev_fill_1.0.json", 11 | }, 12 | model_class=FluxModel, 13 | model_loader_class=FluxModelLoader, 14 | embedding_loader_class=FluxEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /modules/modelLoader/HunyuanVideoEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.hunyuanVideo.HunyuanVideoEmbeddingLoader import HunyuanVideoEmbeddingLoader 4 | from modules.modelLoader.hunyuanVideo.HunyuanVideoModelLoader import HunyuanVideoModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | HunyuanVideoEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ModelType.HUNYUAN_VIDEO: "resources/sd_model_spec/hunyuan_video-embedding.json"}, 9 | model_class=HunyuanVideoModel, 10 | model_loader_class=HunyuanVideoModelLoader, 11 | embedding_loader_class=HunyuanVideoEmbeddingLoader, 12 | ) 13 | -------------------------------------------------------------------------------- /modules/modelLoader/qwen/QwenLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.QwenModel import QwenModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | 8 | 9 | class QwenLoRALoader( 10 | LoRALoaderMixin 11 | ): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 16 | return None 17 | 18 | def load( 19 | self, 20 | model: QwenModel, 21 | model_names: ModelNames, 22 | ): 23 | return self._load(model, model_names) 24 | -------------------------------------------------------------------------------- /modules/modelLoader/ChromaLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelLoader.chroma.ChromaEmbeddingLoader import ChromaEmbeddingLoader 3 | from modules.modelLoader.chroma.ChromaLoRALoader import ChromaLoRALoader 4 | from modules.modelLoader.chroma.ChromaModelLoader import ChromaModelLoader 5 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | ChromaLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ModelType.CHROMA_1: "resources/sd_model_spec/chroma-lora.json"}, 10 | model_class=ChromaModel, 11 | model_loader_class=ChromaModelLoader, 12 | embedding_loader_class=ChromaEmbeddingLoader, 13 | lora_loader_class=ChromaLoRALoader, 14 | ) 15 | -------------------------------------------------------------------------------- /modules/dataLoader/BaseDataLoader.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from modules.dataLoader.mixin.DataLoaderMgdsMixin import DataLoaderMgdsMixin 4 | 5 | from mgds.MGDS import MGDS, TrainDataLoader 6 | 7 | import torch 8 | 9 | 10 | class BaseDataLoader( 11 | DataLoaderMgdsMixin, 12 | metaclass=ABCMeta, 13 | ): 14 | 15 | def __init__( 16 | self, 17 | train_device: torch.device, 18 | temp_device: torch.device, 19 | ): 20 | super().__init__() 21 | 22 | self.train_device = train_device 23 | self.temp_device = temp_device 24 | 25 | @abstractmethod 26 | def get_data_set(self) -> MGDS: 27 | pass 28 | 29 | @abstractmethod 30 | def get_data_loader(self) -> TrainDataLoader: 31 | pass 32 | -------------------------------------------------------------------------------- /modules/modelLoader/FluxEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelLoader.flux.FluxEmbeddingLoader import FluxEmbeddingLoader 3 | from modules.modelLoader.flux.FluxModelLoader import FluxModelLoader 4 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | FluxEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ 9 | ModelType.FLUX_DEV_1: "resources/sd_model_spec/flux_dev_1.0-embedding.json", 10 | ModelType.FLUX_FILL_DEV_1: "resources/sd_model_spec/flux_dev_fill_1.0-embedding.json", 11 | }, 12 | model_class=FluxModel, 13 | model_loader_class=FluxModelLoader, 14 | embedding_loader_class=FluxEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0", 4 | "latent_caching": false, 5 | "learning_rate": 0.0003, 6 | "learning_rate_warmup_steps": 20, 7 | "model_type": "STABLE_DIFFUSION_XL_10_BASE", 8 | "output_model_destination": "models/embedding.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "1024", 11 | "sample_after": 1, 12 | "training_method": "EMBEDDING", 13 | "vae": { 14 | "weight_dtype": "FLOAT_32" 15 | }, 16 | "unet": { 17 | "weight_dtype": "FLOAT_16" 18 | }, 19 | "text_encoder": { 20 | "weight_dtype": "FLOAT_16" 21 | }, 22 | "text_encoder_2": { 23 | "weight_dtype": "FLOAT_16" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /modules/modelLoader/sana/SanaLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.SanaModel import SanaModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | 8 | 9 | class SanaLoRALoader( 10 | LoRALoaderMixin 11 | ): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 16 | return None # TODO: not yet implemented 17 | 18 | def load( 19 | self, 20 | model: SanaModel, 21 | model_names: ModelNames, 22 | ): 23 | return self._load(model, model_names) 24 | -------------------------------------------------------------------------------- /modules/util/TrainProgress.py: -------------------------------------------------------------------------------- 1 | class TrainProgress: 2 | def __init__( 3 | self, 4 | epoch: int = 0, 5 | epoch_step: int = 0, 6 | epoch_sample: int = 0, 7 | global_step: int = 0, 8 | ): 9 | self.epoch = epoch 10 | self.epoch_step = epoch_step 11 | self.epoch_sample = epoch_sample 12 | self.global_step = global_step 13 | 14 | def next_step(self, batch_size: int): 15 | self.epoch_step += 1 16 | self.epoch_sample += batch_size 17 | self.global_step += 1 18 | 19 | def next_epoch(self): 20 | self.epoch_step = 0 21 | self.epoch_sample = 0 22 | self.epoch += 1 23 | 24 | def filename_string(self): 25 | return f"{self.global_step}-{self.epoch}-{self.epoch_step}" 26 | -------------------------------------------------------------------------------- /modules/modelLoader/HiDreamLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.hiDream.HiDreamEmbeddingLoader import HiDreamEmbeddingLoader 4 | from modules.modelLoader.hiDream.HiDreamLoRALoader import HiDreamLoRALoader 5 | from modules.modelLoader.hiDream.HiDreamModelLoader import HiDreamModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | HiDreamLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ModelType.HI_DREAM_FULL: "resources/sd_model_spec/hi_dream_full-lora.json"}, 10 | model_class=HiDreamModel, 11 | model_loader_class=HiDreamModelLoader, 12 | embedding_loader_class=HiDreamEmbeddingLoader, 13 | lora_loader_class=HiDreamLoRALoader, 14 | ) 15 | -------------------------------------------------------------------------------- /modules/util/config/SecretsConfig.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from modules.util.config.BaseConfig import BaseConfig 4 | from modules.util.config.CloudConfig import CloudSecretsConfig 5 | 6 | 7 | class SecretsConfig(BaseConfig): 8 | huggingface_token: str 9 | cloud: CloudSecretsConfig 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def default_values() -> 'SecretsConfig': 16 | data = [] 17 | 18 | # name, default value, data type, nullable 19 | data.append(("huggingface_token", "", str, False)) 20 | 21 | # cloud 22 | cloud = CloudSecretsConfig.default_values() 23 | data.append(("cloud", cloud, CloudSecretsConfig, False)) 24 | 25 | return SecretsConfig(data) 26 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0 inpaint LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 4 | "batch_size": 4, 5 | "layer_filter_preset": "attn-mlp", 6 | "learning_rate": 0.0003, 7 | "model_type": "STABLE_DIFFUSION_XL_10_BASE_INPAINTING", 8 | "output_model_destination": "models/lora.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "1024", 11 | "text_encoder": { 12 | "train": false, 13 | "weight_dtype": "FLOAT_16" 14 | }, 15 | "text_encoder_2": { 16 | "train": false, 17 | "weight_dtype": "FLOAT_16" 18 | }, 19 | "unet": { 20 | "weight_dtype": "FLOAT_16" 21 | }, 22 | "training_method": "LORA", 23 | "vae": { 24 | "weight_dtype": "FLOAT_32" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /modules/modelLoader/WuerstchenFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.wuerstchen.WuerstchenEmbeddingLoader import WuerstchenEmbeddingLoader 4 | from modules.modelLoader.wuerstchen.WuerstchenModelLoader import WuerstchenModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | WuerstchenFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ 9 | ModelType.WUERSTCHEN_2: "resources/sd_model_spec/wuerstchen_2.0.json", 10 | ModelType.STABLE_CASCADE_1: "resources/sd_model_spec/stable_cascade_1.0.json", 11 | }, 12 | model_class=WuerstchenModel, 13 | model_loader_class=WuerstchenModelLoader, 14 | embedding_loader_class=WuerstchenEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /modules/modelLoader/PixArtAlphaFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.pixartAlpha.PixArtAlphaEmbeddingLoader import PixArtAlphaEmbeddingLoader 4 | from modules.modelLoader.pixartAlpha.PixArtAlphaModelLoader import PixArtAlphaModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | PixArtAlphaFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ 9 | ModelType.PIXART_ALPHA: "resources/sd_model_spec/pixart_alpha_1.0.json", 10 | ModelType.PIXART_SIGMA: "resources/sd_model_spec/pixart_sigma_1.0.json", 11 | }, 12 | model_class=PixArtAlphaModel, 13 | model_loader_class=PixArtAlphaModelLoader, 14 | embedding_loader_class=PixArtAlphaEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /training_presets/#sd 3.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-diffusion-3-medium-diffusers", 3 | "model_type": "STABLE_DIFFUSION_3", 4 | "output_dtype": "FLOAT_16", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "transformer": { 8 | "weight_dtype": "FLOAT_32" 9 | }, 10 | "resolution": "1024", 11 | "text_encoder": { 12 | "train": false, 13 | "weight_dtype": "FLOAT_16" 14 | }, 15 | "text_encoder_2": { 16 | "train": false, 17 | "weight_dtype": "FLOAT_16" 18 | }, 19 | "text_encoder_3": { 20 | "train": false, 21 | "weight_dtype": "FLOAT_16" 22 | }, 23 | "training_method": "FINE_TUNE", 24 | "vae": { 25 | "weight_dtype": "FLOAT_32" 26 | }, 27 | "timestep_distribution": "LOGIT_NORMAL" 28 | } 29 | -------------------------------------------------------------------------------- /modules/modelLoader/FluxLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelLoader.flux.FluxEmbeddingLoader import FluxEmbeddingLoader 3 | from modules.modelLoader.flux.FluxLoRALoader import FluxLoRALoader 4 | from modules.modelLoader.flux.FluxModelLoader import FluxModelLoader 5 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | FluxLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ 10 | ModelType.FLUX_DEV_1: "resources/sd_model_spec/flux_dev_1.0-lora.json", 11 | ModelType.FLUX_FILL_DEV_1: "resources/sd_model_spec/flux_dev_fill_1.0-lora.json", 12 | }, 13 | model_class=FluxModel, 14 | model_loader_class=FluxModelLoader, 15 | embedding_loader_class=FluxEmbeddingLoader, 16 | lora_loader_class=FluxLoRALoader, 17 | ) 18 | -------------------------------------------------------------------------------- /modules/modelLoader/WuerstchenEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.wuerstchen.WuerstchenEmbeddingLoader import WuerstchenEmbeddingLoader 4 | from modules.modelLoader.wuerstchen.WuerstchenModelLoader import WuerstchenModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | WuerstchenEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ 9 | ModelType.WUERSTCHEN_2: "resources/sd_model_spec/wuerstchen_2.0-embedding.json", 10 | ModelType.STABLE_CASCADE_1: "resources/sd_model_spec/stable_cascade_1.0-embedding.json", 11 | }, 12 | model_class=WuerstchenModel, 13 | model_loader_class=WuerstchenModelLoader, 14 | embedding_loader_class=WuerstchenEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /modules/modelLoader/PixArtAlphaEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.pixartAlpha.PixArtAlphaEmbeddingLoader import PixArtAlphaEmbeddingLoader 4 | from modules.modelLoader.pixartAlpha.PixArtAlphaModelLoader import PixArtAlphaModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | PixArtAlphaEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ 9 | ModelType.PIXART_ALPHA: "resources/sd_model_spec/pixart_alpha_1.0-embedding.json", 10 | ModelType.PIXART_SIGMA: "resources/sd_model_spec/pixart_sigma_1.0-embedding.json", 11 | }, 12 | model_class=PixArtAlphaModel, 13 | model_loader_class=PixArtAlphaModelLoader, 14 | embedding_loader_class=PixArtAlphaEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0", 4 | "batch_size": 4, 5 | "layer_filter_preset": "attn-mlp", 6 | "learning_rate": 0.0003, 7 | "model_type": "STABLE_DIFFUSION_XL_10_BASE", 8 | "output_model_destination": "models/lora.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "1024", 11 | "text_encoder": { 12 | "train": false 13 | }, 14 | "text_encoder_2": { 15 | "train": false 16 | }, 17 | "training_method": "LORA", 18 | "unet": { 19 | "weight_dtype": "FLOAT_16" 20 | }, 21 | "text_encoder": { 22 | "weight_dtype": "FLOAT_16" 23 | }, 24 | "text_encoder_2": { 25 | "weight_dtype": "FLOAT_16" 26 | }, 27 | "vae": { 28 | "weight_dtype": "FLOAT_32" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /modules/modelLoader/HunyuanVideoLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.hunyuanVideo.HunyuanVideoEmbeddingLoader import HunyuanVideoEmbeddingLoader 4 | from modules.modelLoader.hunyuanVideo.HunyuanVideoLoRALoader import HunyuanVideoLoRALoader 5 | from modules.modelLoader.hunyuanVideo.HunyuanVideoModelLoader import HunyuanVideoModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | HunyuanVideoLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ModelType.HUNYUAN_VIDEO: "resources/sd_model_spec/hunyuan_video-lora.json"}, 10 | model_class=HunyuanVideoModel, 11 | model_loader_class=HunyuanVideoModelLoader, 12 | embedding_loader_class=HunyuanVideoEmbeddingLoader, 13 | lora_loader_class=HunyuanVideoLoRALoader, 14 | ) 15 | -------------------------------------------------------------------------------- /modules/modelLoader/flux/FluxLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.FluxModel import FluxModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class FluxLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_flux_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: FluxModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /training_presets/#flux LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "black-forest-labs/FLUX.1-dev", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "FLUX_DEV_1", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "768", 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "NFLOAT_4" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "BFLOAT_16" 17 | }, 18 | "text_encoder_2": { 19 | "train": false, 20 | "weight_dtype": "NFLOAT_4" 21 | }, 22 | "training_method": "LORA", 23 | "vae": { 24 | "weight_dtype": "FLOAT_32" 25 | }, 26 | "train_dtype": "BFLOAT_16", 27 | "timestep_distribution": "LOGIT_NORMAL", 28 | "dynamic_timestep_shifting": false 29 | } 30 | -------------------------------------------------------------------------------- /modules/modelLoader/chroma/ChromaLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.ChromaModel import ChromaModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_chroma_lora import convert_chroma_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class ChromaLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_chroma_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: ChromaModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /run-cmd.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Xet is buggy. Disabled by default unless already defined - https://github.com/Nerogar/OneTrainer/issues/949 6 | if [[ -z "${HF_HUB_DISABLE_XET+x}" ]]; then 7 | export HF_HUB_DISABLE_XET=1 8 | fi 9 | 10 | source "${BASH_SOURCE[0]%/*}/lib.include.sh" 11 | 12 | # Fetch and validate the name of the target script. 13 | if [[ -z "${1}" ]]; then 14 | print_error "You must provide the name of the script to execute, such as \"train\"." 15 | exit 1 16 | fi 17 | 18 | OT_CUSTOM_SCRIPT_FILE="scripts/${1}.py" 19 | if [[ ! -f "${OT_CUSTOM_SCRIPT_FILE}" ]]; then 20 | print_error "Custom script file \"${OT_CUSTOM_SCRIPT_FILE}\" does not exist." 21 | exit 1 22 | fi 23 | 24 | prepare_runtime_environment 25 | 26 | # Remove $1 (name of the script) and pass all remaining arguments to the script. 27 | shift 28 | run_python_in_active_env "${OT_CUSTOM_SCRIPT_FILE}" "$@" 29 | -------------------------------------------------------------------------------- /modules/modelLoader/hiDream/HiDreamLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.HiDreamModel import HiDreamModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_hidream_lora import convert_hidream_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class HiDreamLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_hidream_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: HiDreamModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/util/enum/VideoFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class VideoFormat(Enum): 5 | PNG_IMAGE_SEQUENCE = 'PNG_IMAGE_SEQUENCE' 6 | JPG_IMAGE_SEQUENCE = 'JPG_IMAGE_SEQUENCE' 7 | MP4 = 'MP4' 8 | 9 | def __str__(self): 10 | return self.value 11 | 12 | def extension(self) -> str: 13 | match self: 14 | case VideoFormat.PNG_IMAGE_SEQUENCE: 15 | return ".png" 16 | case VideoFormat.JPG_IMAGE_SEQUENCE: 17 | return ".jpg" 18 | case VideoFormat.MP4: 19 | return ".mp4" 20 | case _: 21 | return "" 22 | 23 | def pil_format(self) -> str: 24 | match self: 25 | case VideoFormat.PNG_IMAGE_SEQUENCE: 26 | return "PNG" 27 | case VideoFormat.JPG_IMAGE_SEQUENCE: 28 | return "JPEG" 29 | case _: 30 | return "" 31 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusion3FineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.stableDiffusion3.StableDiffusion3EmbeddingLoader import StableDiffusion3EmbeddingLoader 4 | from modules.modelLoader.stableDiffusion3.StableDiffusion3ModelLoader import StableDiffusion3ModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | StableDiffusion3FineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ 9 | ModelType.STABLE_DIFFUSION_3: "resources/sd_model_spec/sd_3_2b_1.0.json", 10 | ModelType.STABLE_DIFFUSION_35: "resources/sd_model_spec/sd_3.5_1.0.json", 11 | }, 12 | model_class=StableDiffusion3Model, 13 | model_loader_class=StableDiffusion3ModelLoader, 14 | embedding_loader_class=StableDiffusion3EmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /training_presets/#chroma LoRA 16GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "lodestones/Chroma1-HD", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "CHROMA_1", 6 | "resolution": "512", 7 | "transformer": { 8 | "train": true, 9 | "weight_dtype": "FLOAT_8" 10 | }, 11 | "text_encoder": { 12 | "train": false, 13 | "weight_dtype": "BFLOAT_16" 14 | }, 15 | "training_method": "LORA", 16 | "vae": { 17 | "weight_dtype": "FLOAT_32" 18 | }, 19 | "train_dtype": "BFLOAT_16", 20 | "weight_dtype": "BFLOAT_16", 21 | "output_dtype": "BFLOAT_16", 22 | "timestep_distribution": "INVERTED_PARABOLA", 23 | "noising_weight": 7.7, 24 | "layer_filter": "attn,ff.net", 25 | "layer_filter_preset": "attn-mlp", 26 | "quantization": { 27 | "layer_filter": "transformer_block", 28 | "layer_filter_preset": "blocks" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /training_presets/#chroma LoRA 24GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "lodestones/Chroma1-HD", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "CHROMA_1", 6 | "resolution": "512", 7 | "transformer": { 8 | "train": true, 9 | "weight_dtype": "BFLOAT_16" 10 | }, 11 | "text_encoder": { 12 | "train": false, 13 | "weight_dtype": "BFLOAT_16" 14 | }, 15 | "training_method": "LORA", 16 | "vae": { 17 | "weight_dtype": "FLOAT_32" 18 | }, 19 | "train_dtype": "BFLOAT_16", 20 | "weight_dtype": "BFLOAT_16", 21 | "output_dtype": "BFLOAT_16", 22 | "timestep_distribution": "INVERTED_PARABOLA", 23 | "noising_weight": 7.7, 24 | "layer_filter": "attn,ff.net", 25 | "layer_filter_preset": "attn-mlp", 26 | "quantization": { 27 | "layer_filter": "transformer_block", 28 | "layer_filter_preset": "blocks" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /modules/modelLoader/pixartAlpha/PixArtAlphaLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_pixart_lora import convert_pixart_lora_key_sets 8 | 9 | 10 | class PixArtAlphaLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_pixart_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: PixArtAlphaModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/util/enum/ModelFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ModelFormat(Enum): 5 | DIFFUSERS = 'DIFFUSERS' 6 | CKPT = 'CKPT' 7 | SAFETENSORS = 'SAFETENSORS' 8 | LEGACY_SAFETENSORS = 'LEGACY_SAFETENSORS' 9 | 10 | INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training 11 | 12 | def __str__(self): 13 | return self.value 14 | 15 | 16 | def file_extension(self) -> str: 17 | match self: 18 | case ModelFormat.DIFFUSERS: 19 | return '' 20 | case ModelFormat.CKPT: 21 | return '.ckpt' 22 | case ModelFormat.SAFETENSORS: 23 | return '.safetensors' 24 | case ModelFormat.LEGACY_SAFETENSORS: 25 | return '.safetensors' 26 | case _: 27 | return '' 28 | 29 | def is_single_file(self) -> bool: 30 | return self.file_extension() != '' 31 | -------------------------------------------------------------------------------- /training_presets/#sd 3 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-3-medium-diffusers", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_3", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "transformer": { 11 | "train": false, 12 | "weight_dtype": "FLOAT_16" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "FLOAT_16" 17 | }, 18 | "text_encoder_2": { 19 | "train": false, 20 | "weight_dtype": "FLOAT_16" 21 | }, 22 | "text_encoder_3": { 23 | "train": false, 24 | "weight_dtype": "FLOAT_16" 25 | }, 26 | "training_method": "LORA", 27 | "vae": { 28 | "weight_dtype": "FLOAT_32" 29 | }, 30 | "timestep_distribution": "LOGIT_NORMAL" 31 | } 32 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusion3EmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.stableDiffusion3.StableDiffusion3EmbeddingLoader import StableDiffusion3EmbeddingLoader 4 | from modules.modelLoader.stableDiffusion3.StableDiffusion3ModelLoader import StableDiffusion3ModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | StableDiffusion3EmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ 9 | ModelType.STABLE_DIFFUSION_3: "resources/sd_model_spec/sd_3_2b_1.0-embedding.json", 10 | ModelType.STABLE_DIFFUSION_35: "resources/sd_model_spec/sd_3.5_1.0-embedding.json", 11 | }, 12 | model_class=StableDiffusion3Model, 13 | model_loader_class=StableDiffusion3ModelLoader, 14 | embedding_loader_class=StableDiffusion3EmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion/StableDiffusionLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.StableDiffusionModel import StableDiffusionModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_sd_lora import convert_sd_lora_key_sets 8 | 9 | 10 | class StableDiffusionLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_sd_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: StableDiffusionModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /embedding_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a 2 | a rendering of a 3 | a cropped photo of the 4 | the photo of a 5 | a photo of a clean 6 | a photo of a dirty 7 | a dark photo of the 8 | a photo of my 9 | a photo of the cool 10 | a close-up photo of a 11 | a bright photo of the 12 | a cropped photo of a 13 | a photo of the 14 | a good photo of the 15 | a photo of one 16 | a close-up photo of the 17 | a rendition of the 18 | a photo of the clean 19 | a rendition of a 20 | a photo of a nice 21 | a good photo of a 22 | a photo of the nice 23 | a photo of the small 24 | a photo of the weird 25 | a photo of the large 26 | a photo of a cool 27 | a photo of a small 28 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion3/StableDiffusion3LoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_sd3_lora import convert_sd3_lora_key_sets 8 | 9 | 10 | class StableDiffusion3LoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_sd3_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: StableDiffusion3Model, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /training_presets/#hunyuan video LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "hunyuanvideo-community/HunyuanVideo", 4 | "batch_size": 4, 5 | "gradient_checkpointing": "CPU_OFFLOADED", 6 | "layer_offload_fraction": 0.5, 7 | "dataloader_threads": 1, 8 | "learning_rate": 0.0003, 9 | "model_type": "HUNYUAN_VIDEO", 10 | "output_model_destination": "models/lora.safetensors", 11 | "output_model_format": "SAFETENSORS", 12 | "resolution": "512", 13 | "timestep_distribution": "LOGIT_NORMAL", 14 | "dynamic_timestep_shifting": false, 15 | "train_dtype": "BFLOAT_16", 16 | "training_method": "LORA", 17 | "transformer": { 18 | "train": true, 19 | "weight_dtype": "FLOAT_8" 20 | }, 21 | "text_encoder": { 22 | "train": false, 23 | "weight_dtype": "FLOAT_8" 24 | }, 25 | "text_encoder_2": { 26 | "train": false, 27 | "weight_dtype": "FLOAT_8" 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /modules/modelLoader/hunyuanVideo/HunyuanVideoLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_hunyuan_video_lora import convert_hunyuan_video_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class HunyuanVideoLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_hunyuan_video_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: HunyuanVideoModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusionXL/StableDiffusionXLLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_sdxl_lora import convert_sdxl_lora_key_sets 8 | 9 | 10 | class StableDiffusionXLLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_sdxl_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: StableDiffusionXLModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/WuerstchenLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.wuerstchen.WuerstchenEmbeddingLoader import WuerstchenEmbeddingLoader 4 | from modules.modelLoader.wuerstchen.WuerstchenLoRALoader import WuerstchenLoRALoader 5 | from modules.modelLoader.wuerstchen.WuerstchenModelLoader import WuerstchenModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | WuerstchenLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ 10 | ModelType.WUERSTCHEN_2: "resources/sd_model_spec/wuerstchen_2.0-lora.json", 11 | ModelType.STABLE_CASCADE_1: "resources/sd_model_spec/stable_cascade_1.0-lora.json", 12 | }, 13 | model_class=WuerstchenModel, 14 | model_loader_class=WuerstchenModelLoader, 15 | embedding_loader_class=WuerstchenEmbeddingLoader, 16 | lora_loader_class=WuerstchenLoRALoader, 17 | ) 18 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusionXLFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLEmbeddingLoader import StableDiffusionXLEmbeddingLoader 4 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | StableDiffusionXLFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ 9 | ModelType.STABLE_DIFFUSION_XL_10_BASE: "resources/sd_model_spec/sd_xl_base_1.0.json", 10 | ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING: "resources/sd_model_spec/sd_xl_base_1.0_inpainting.json", 11 | }, 12 | model_class=StableDiffusionXLModel, 13 | model_loader_class=StableDiffusionXLModelLoader, 14 | embedding_loader_class=StableDiffusionXLEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /modules/modelLoader/PixArtAlphaLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.pixartAlpha.PixArtAlphaEmbeddingLoader import PixArtAlphaEmbeddingLoader 4 | from modules.modelLoader.pixartAlpha.PixArtAlphaLoRALoader import PixArtAlphaLoRALoader 5 | from modules.modelLoader.pixartAlpha.PixArtAlphaModelLoader import PixArtAlphaModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | PixArtAlphaLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ 10 | ModelType.PIXART_ALPHA: "resources/sd_model_spec/pixart_alpha_1.0-lora.json", 11 | ModelType.PIXART_SIGMA: "resources/sd_model_spec/pixart_sigma_1.0-lora.json", 12 | }, 13 | model_class=PixArtAlphaModel, 14 | model_loader_class=PixArtAlphaModelLoader, 15 | embedding_loader_class=PixArtAlphaEmbeddingLoader, 16 | lora_loader_class=PixArtAlphaLoRALoader, 17 | ) 18 | -------------------------------------------------------------------------------- /training_presets/#wuerstchen 2.0 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "warp-ai/wuerstchen-prior", 4 | "prior": { 5 | "weight_dtype": "FLOAT_16" 6 | }, 7 | "text_encoder": { 8 | "weight_dtype": "FLOAT_16" 9 | }, 10 | "decoder_text_encoder": { 11 | "weight_dtype": "FLOAT_16" 12 | }, 13 | "decoder_vqgan": { 14 | "weight_dtype": "FLOAT_16" 15 | }, 16 | "decoder": { 17 | "model_name": "warp-ai/wuerstchen", 18 | "weight_dtype": "FLOAT_16" 19 | }, 20 | "effnet_encoder": { 21 | "model_name": "warp-ai/EfficientNetEncoder", 22 | "weight_dtype": "FLOAT_16" 23 | }, 24 | "learning_rate": 0.0003, 25 | "model_type": "WUERSTCHEN_2", 26 | "output_model_destination": "models/lora.safetensors", 27 | "output_model_format": "SAFETENSORS", 28 | "resolution": "1024", 29 | "training_method": "LORA", 30 | "loss_weight_fn": "P2", 31 | "loss_weight_strength": 1.0 32 | } 33 | -------------------------------------------------------------------------------- /training_presets/#qwen LoRA 16GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Qwen/Qwen-Image", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "QWEN", 6 | "resolution": "512", 7 | "gradient_checkpointing": "CPU_OFFLOADED", 8 | "layer_offload_fraction": 0.5, 9 | "dataloader_threads": 1, 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "FLOAT_8" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "FLOAT_8" 17 | }, 18 | "training_method": "LORA", 19 | "vae": { 20 | "weight_dtype": "FLOAT_32" 21 | }, 22 | "train_dtype": "BFLOAT_16", 23 | "weight_dtype": "BFLOAT_16", 24 | "output_dtype": "BFLOAT_16", 25 | "timestep_distribution": "LOGIT_NORMAL", 26 | "layer_filter": "attn,img_mlp,txt_mlp", 27 | "layer_filter_preset": "attn-mlp", 28 | "quantization": { 29 | "layer_filter": "transformer_block", 30 | "layer_filter_preset": "blocks" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /training_presets/#qwen LoRA 24GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Qwen/Qwen-Image", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "QWEN", 6 | "resolution": "512", 7 | "gradient_checkpointing": "CPU_OFFLOADED", 8 | "layer_offload_fraction": 0.1, 9 | "dataloader_threads": 1, 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "FLOAT_8" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "FLOAT_8" 17 | }, 18 | "training_method": "LORA", 19 | "vae": { 20 | "weight_dtype": "FLOAT_32" 21 | }, 22 | "train_dtype": "BFLOAT_16", 23 | "weight_dtype": "BFLOAT_16", 24 | "output_dtype": "BFLOAT_16", 25 | "timestep_distribution": "LOGIT_NORMAL", 26 | "layer_filter": "attn,img_mlp,txt_mlp", 27 | "layer_filter_preset": "attn-mlp", 28 | "quantization": { 29 | "layer_filter": "transformer_block", 30 | "layer_filter_preset": "blocks" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusionXLEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLEmbeddingLoader import StableDiffusionXLEmbeddingLoader 4 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | StableDiffusionXLEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ 9 | ModelType.STABLE_DIFFUSION_XL_10_BASE: "resources/sd_model_spec/sd_xl_base_1.0-embedding.json", 10 | ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING: "resources/sd_model_spec/sd_xl_base_1.0_inpainting-embedding.json", 11 | }, 12 | model_class=StableDiffusionXLModel, 13 | model_loader_class=StableDiffusionXLModelLoader, 14 | embedding_loader_class=StableDiffusionXLEmbeddingLoader, 15 | ) 16 | -------------------------------------------------------------------------------- /resources/docker/Vast-NVIDIA-CLI.Dockerfile: -------------------------------------------------------------------------------- 1 | #To build, run 2 | # docker build -t . -f Vast-NVIDIA-CLI.Dockerfile 3 | # docker tag /: 4 | # docker push /: 5 | 6 | FROM vastai/pytorch:cuda-12.8.1-auto 7 | 8 | WORKDIR / 9 | RUN git clone https://github.com/Nerogar/OneTrainer 10 | RUN cd OneTrainer \ 11 | && export OT_PLATFORM_REQUIREMENTS=requirements-cuda.txt \ 12 | && export OT_LAZY_UPDATES=true \ 13 | && export OT_PYTHON_CMD=/venv/main/bin/python \ 14 | && ./install.sh \ 15 | && pip cache purge \ 16 | && rm -r ~/.cache/pip 17 | RUN apt-get update --yes \ 18 | && apt-get install --yes --no-install-recommends \ 19 | joe \ 20 | less \ 21 | gh \ 22 | iputils-ping \ 23 | nano \ 24 | && apt-get autoremove -y \ 25 | && apt-get clean \ 26 | && rm -rf /var/lib/apt/lists/* 27 | RUN pip install nvitop \ 28 | && pip cache purge \ 29 | && rm -rf ~/.cache/pip 30 | RUN ln -snf /OneTrainer /workspace/OneTrainer 31 | -------------------------------------------------------------------------------- /modules/modelLoader/wuerstchen/WuerstchenLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.WuerstchenModel import WuerstchenModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_stable_cascade_lora import convert_stable_cascade_lora_key_sets 8 | 9 | 10 | class WuerstchenLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | if model.model_type.is_stable_cascade(): 18 | return convert_stable_cascade_lora_key_sets() 19 | return None 20 | 21 | def load( 22 | self, 23 | model: WuerstchenModel, 24 | model_names: ModelNames, 25 | ): 26 | return self._load(model, model_names) 27 | -------------------------------------------------------------------------------- /training_presets/#chroma LoRA 8GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "lodestones/Chroma1-HD", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "CHROMA_1", 6 | "resolution": "512", 7 | "gradient_checkpointing": "CPU_OFFLOADED", 8 | "layer_offload_fraction": 0.6, 9 | "dataloader_threads": 1, 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "FLOAT_8" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "BFLOAT_16" 17 | }, 18 | "training_method": "LORA", 19 | "vae": { 20 | "weight_dtype": "FLOAT_32" 21 | }, 22 | "train_dtype": "BFLOAT_16", 23 | "weight_dtype": "BFLOAT_16", 24 | "output_dtype": "BFLOAT_16", 25 | "timestep_distribution": "INVERTED_PARABOLA", 26 | "noising_weight": 7.7, 27 | "layer_filter": "attn,ff.net", 28 | "layer_filter_preset": "attn-mlp", 29 | "quantization": { 30 | "layer_filter": "transformer_block", 31 | "layer_filter_preset": "blocks" 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusion3LoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.stableDiffusion3.StableDiffusion3EmbeddingLoader import StableDiffusion3EmbeddingLoader 4 | from modules.modelLoader.stableDiffusion3.StableDiffusion3LoRALoader import StableDiffusion3LoRALoader 5 | from modules.modelLoader.stableDiffusion3.StableDiffusion3ModelLoader import StableDiffusion3ModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | StableDiffusion3LoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ 10 | ModelType.STABLE_DIFFUSION_3: "resources/sd_model_spec/sd_3_2b_1.0-lora.json", 11 | ModelType.STABLE_DIFFUSION_35: "resources/sd_model_spec/sd_3.5_1.0-lora.json", 12 | }, 13 | model_class=StableDiffusion3Model, 14 | model_loader_class=StableDiffusion3ModelLoader, 15 | embedding_loader_class=StableDiffusion3EmbeddingLoader, 16 | lora_loader_class=StableDiffusion3LoRALoader, 17 | ) 18 | -------------------------------------------------------------------------------- /docs/Contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome in any form. Here are some guidelines to make this easier for everyone. 4 | 5 | ## Discussion 6 | 7 | Discussions can be started in two ways. Either through the "Discussions" tab on GitHub, or on the 8 | [Discord](https://discord.gg/KwgcQd5scF) server. If you have any questions or ideas, these are a great way to help 9 | improve OneTrainer. 10 | 11 | ## Issues 12 | 13 | If you find any problems, or you want to suggest new features, please create an issue. Before creating an issue, please 14 | check if one already exists for the same topic to avoid duplications. 15 | 16 | ## Pull Requests 17 | 18 | Before creating a bigger pull request for a new feature, please consider joining the discord or creating a discussion. This can 19 | avoid situations where multiple people work on the same change. It also helps in keeping the changes aligned with the 20 | general project structure and vision. Please also take a look at the [project structure documentation](ProjectStructure.md). 21 | For smaller changes or fixes, this is not needed. 22 | -------------------------------------------------------------------------------- /training_presets/#wuerstchen 2.0 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "warp-ai/wuerstchen-prior", 4 | "prior": { 5 | "weight_dtype": "FLOAT_16" 6 | }, 7 | "text_encoder": { 8 | "weight_dtype": "FLOAT_16" 9 | }, 10 | "decoder_text_encoder": { 11 | "weight_dtype": "FLOAT_16" 12 | }, 13 | "decoder_vqgan": { 14 | "weight_dtype": "FLOAT_16" 15 | }, 16 | "decoder": { 17 | "model_name": "warp-ai/wuerstchen", 18 | "weight_dtype": "FLOAT_16" 19 | }, 20 | "effnet_encoder": { 21 | "model_name": "warp-ai/EfficientNetEncoder", 22 | "weight_dtype": "FLOAT_16" 23 | }, 24 | "latent_caching": false, 25 | "learning_rate": 0.0003, 26 | "learning_rate_warmup_steps": 20, 27 | "model_type": "WUERSTCHEN_2", 28 | "output_model_destination": "models/embedding.safetensors", 29 | "output_model_format": "SAFETENSORS", 30 | "resolution": "1024", 31 | "sample_after": 1, 32 | "training_method": "EMBEDDING", 33 | "loss_weight_fn": "P2", 34 | "loss_weight_strength": 1.0 35 | } 36 | -------------------------------------------------------------------------------- /modules/util/enum/LossScaler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import modules.util.multi_gpu_util as multi 4 | 5 | 6 | class LossScaler(Enum): 7 | NONE = 'NONE' 8 | BATCH = 'BATCH' 9 | GLOBAL_BATCH = 'GLOBAL_BATCH' 10 | GRADIENT_ACCUMULATION = 'GRADIENT_ACCUMULATION' 11 | BOTH = 'BOTH' 12 | GLOBAL_BOTH = 'GLOBAL_BOTH' 13 | 14 | def __str__(self): 15 | return self.value 16 | 17 | def get_scale(self, batch_size: int, accumulation_steps: int) -> int: 18 | match self: 19 | case LossScaler.NONE: 20 | return 1 21 | case LossScaler.BATCH: 22 | return batch_size 23 | case LossScaler.GLOBAL_BATCH: 24 | return batch_size * multi.world_size() 25 | case LossScaler.GRADIENT_ACCUMULATION: 26 | return accumulation_steps 27 | case LossScaler.BOTH: 28 | return accumulation_steps * batch_size 29 | case LossScaler.GLOBAL_BOTH: 30 | return accumulation_steps * batch_size * multi.world_size() 31 | case _: 32 | raise ValueError 33 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusionXLLoRAModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader 3 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLEmbeddingLoader import StableDiffusionXLEmbeddingLoader 4 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLLoRALoader import StableDiffusionXLLoRALoader 5 | from modules.modelLoader.stableDiffusionXL.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | StableDiffusionXLLoRAModelLoader = make_lora_model_loader( 9 | model_spec_map={ 10 | ModelType.STABLE_DIFFUSION_XL_10_BASE: "resources/sd_model_spec/sd_xl_base_1.0-lora.json", 11 | ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING: "resources/sd_model_spec/sd_xl_base_1.0_inpainting-lora.json", 12 | }, 13 | model_class=StableDiffusionXLModel, 14 | model_loader_class=StableDiffusionXLModelLoader, 15 | embedding_loader_class=StableDiffusionXLEmbeddingLoader, 16 | lora_loader_class=StableDiffusionXLLoRALoader, 17 | ) 18 | -------------------------------------------------------------------------------- /modules/modelSaver/QwenLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.QwenModel import QwenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.qwen.QwenLoRASaver import QwenLoRASaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class QwenLoRAModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: QwenModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | lora_model_saver = QwenLoRASaver() 27 | 28 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 29 | if output_model_format == ModelFormat.INTERNAL: 30 | self._save_internal_data(model, output_model_destination) 31 | -------------------------------------------------------------------------------- /export_debug.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Avoid footgun by explictly navigating to the directory containing the batch file 4 | cd /d "%~dp0" 5 | 6 | REM Verify that OneTrainer is our current working directory 7 | if not exist "scripts\train_ui.py" ( 8 | echo Error: train_ui.py does not exist, you have done something very wrong. Reclone the repository. 9 | goto :end 10 | ) 11 | 12 | if not defined PYTHON (set PYTHON=python) 13 | if not defined VENV_DIR (set "VENV_DIR=%~dp0venv") 14 | 15 | :check_venv 16 | dir "%VENV_DIR%" >NUL 2>NUL 17 | if not errorlevel 1 goto :activate_venv 18 | echo venv not found, please run install.bat first 19 | goto :end 20 | 21 | :activate_venv 22 | echo activating venv %VENV_DIR% 23 | set PYTHON="%VENV_DIR%\Scripts\python.exe" -X utf8 24 | echo Using Python %PYTHON% 25 | 26 | :launch 27 | echo Generating debug report... 28 | %PYTHON% scripts\generate_debug_report.py 29 | if errorlevel 1 ( 30 | echo Error: Debug report generation failed with code %ERRORLEVEL% 31 | ) else ( 32 | echo Now upload the debug report to your Github issue or post in Discord. 33 | ) 34 | 35 | :end 36 | pause 37 | -------------------------------------------------------------------------------- /training_presets/#z-image LoRA 16GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Tongyi-MAI/Z-Image-Turbo", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "Z_IMAGE", 6 | "resolution": "512", 7 | "compile": true, 8 | "transformer": { 9 | "train": true, 10 | "weight_dtype": "INT_W8A8", 11 | "model_name": "https://huggingface.co/ostris/Z-Image-De-Turbo/blob/main/z_image_de_turbo_v1_bf16.safetensors" 12 | }, 13 | "text_encoder": { 14 | "train": false, 15 | "weight_dtype": "FLOAT_8" 16 | }, 17 | "training_method": "LORA", 18 | "vae": { 19 | "weight_dtype": "FLOAT_32" 20 | }, 21 | "train_dtype": "BFLOAT_16", 22 | "weight_dtype": "BFLOAT_16", 23 | "output_dtype": "BFLOAT_16", 24 | "layer_filter": "^(?=.*attention)(?!.*refiner).*,^(?=.*feed_forward)(?!.*refiner).*", 25 | "layer_filter_preset": "attn-mlp", 26 | "layer_filter_regex": true, 27 | "quantization": { 28 | "layer_filter": "layers", 29 | "layer_filter_preset": "blocks" 30 | }, 31 | "dataloader_threads": 1, 32 | "timestep_distribution": "LOGIT_NORMAL" 33 | } 34 | -------------------------------------------------------------------------------- /modules/model/util/t5_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import T5EncoderModel 4 | 5 | 6 | def encode_t5( 7 | text_encoder: T5EncoderModel, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | text_encoder_output: Tensor | None = None, 12 | use_attention_mask: bool = True, 13 | attention_mask: Tensor | None = None, 14 | add_layer_norm: bool = True, 15 | ) -> Tensor: 16 | if text_encoder_output is None and text_encoder is not None: 17 | text_encoder_output = text_encoder( 18 | tokens, 19 | attention_mask=attention_mask if use_attention_mask else None, 20 | output_hidden_states=True, 21 | return_dict=True, 22 | ) 23 | hidden_state_output_index = default_layer - layer_skip 24 | text_encoder_output = text_encoder_output.hidden_states[hidden_state_output_index] 25 | if hidden_state_output_index != -1 and add_layer_norm: 26 | text_encoder_output = text_encoder.encoder.final_layer_norm(text_encoder_output) 27 | 28 | return text_encoder_output 29 | -------------------------------------------------------------------------------- /modules/modelSaver/QwenFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.QwenModel import QwenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.qwen.QwenModelSaver import QwenModelSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class QwenFineTuneModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: QwenModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | base_model_saver = QwenModelSaver() 27 | 28 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 29 | 30 | if output_model_format == ModelFormat.INTERNAL: 31 | self._save_internal_data(model, output_model_destination) 32 | -------------------------------------------------------------------------------- /modules/modelSaver/ZImageLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.ZImageModel import ZImageModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.zImage.ZImageLoRASaver import ZImageLoRASaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class ZImageLoRAModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: ZImageModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | lora_model_saver = ZImageLoRASaver() 27 | 28 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 29 | if output_model_format == ModelFormat.INTERNAL: 30 | self._save_internal_data(model, output_model_destination) 31 | -------------------------------------------------------------------------------- /modules/model/util/gemma_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import Gemma2Model 4 | 5 | 6 | def encode_gemma( 7 | text_encoder: Gemma2Model, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | text_encoder_output: Tensor | None = None, 12 | use_attention_mask: bool = True, 13 | attention_mask: Tensor | None = None, 14 | add_layer_norm: bool = True, 15 | ) -> Tensor: 16 | if text_encoder_output is None and text_encoder is not None: 17 | text_encoder_output = text_encoder( 18 | tokens, 19 | attention_mask=attention_mask if use_attention_mask else None, 20 | output_hidden_states=True, 21 | return_dict=True, 22 | use_cache=False, 23 | ) 24 | hidden_state_output_index = default_layer - layer_skip 25 | text_encoder_output = text_encoder_output.hidden_states[hidden_state_output_index] 26 | if hidden_state_output_index != -1 and add_layer_norm: 27 | text_encoder_output = text_encoder.norm(text_encoder_output) 28 | 29 | return text_encoder_output 30 | -------------------------------------------------------------------------------- /modules/modelSaver/ZImageFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.ZImageModel import ZImageModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.zImage.ZImageModelSaver import ZImageModelSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class ZImageFineTuneModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: ZImageModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | base_model_saver = ZImageModelSaver() 27 | 28 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 29 | 30 | if output_model_format == ModelFormat.INTERNAL: 31 | self._save_internal_data(model, output_model_destination) 32 | -------------------------------------------------------------------------------- /modules/util/enum/LearningRateScaler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import modules.util.multi_gpu_util as multi 4 | 5 | 6 | class LearningRateScaler(Enum): 7 | NONE = 'NONE' 8 | BATCH = 'BATCH' 9 | GLOBAL_BATCH = 'GLOBAL_BATCH' 10 | GRADIENT_ACCUMULATION = 'GRADIENT_ACCUMULATION' 11 | BOTH = 'BOTH' 12 | GLOBAL_BOTH = 'GLOBAL_BOTH' 13 | 14 | def __str__(self): 15 | return self.value 16 | 17 | def get_scale(self, batch_size: int, accumulation_steps: int) -> int: 18 | match self: 19 | case LearningRateScaler.NONE: 20 | return 1 21 | case LearningRateScaler.BATCH: 22 | return batch_size 23 | case LearningRateScaler.GLOBAL_BATCH: 24 | return batch_size * multi.world_size() 25 | case LearningRateScaler.GRADIENT_ACCUMULATION: 26 | return accumulation_steps 27 | case LearningRateScaler.BOTH: 28 | return accumulation_steps * batch_size 29 | case LearningRateScaler.GLOBAL_BOTH: 30 | return accumulation_steps * batch_size * multi.world_size() 31 | case _: 32 | raise ValueError 33 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | requires-python = ">=3.10" 3 | 4 | [tool.ruff] 5 | extend-exclude = [ 6 | # Exclude all third-party dependencies and environments. 7 | # NOTE: Conda installs mgds+diffusers into "src/" in the project directory. 8 | ".venv*", 9 | "venv*", 10 | "conda_env*", 11 | "src", 12 | # Do not lint the universal Python 2/3 version checker. 13 | "scripts/util/version_check.py", 14 | ] 15 | line-length = 120 16 | 17 | [tool.ruff.lint] 18 | select = ["F", "E", "W", "I", "B", "UP", "YTT", "BLE", "C4", "T10", "ISC", "ICN", "PIE", "PYI", "RSE", "RET", "SIM", "PGH", "FLY", "NPY", "PERF"] 19 | ignore = ["BLE001", "E402", "E501", "B024", "PGH003", "RET504", "RET505", "SIM102", "UP015", "PYI041"] 20 | 21 | [tool.ruff.lint.isort.sections] 22 | hf = ["diffusers*", "transformers*"] 23 | mgds = ["mgds"] 24 | torch = ["torch*"] 25 | 26 | [tool.ruff.format] 27 | quote-style = "double" 28 | docstring-code-format = true 29 | 30 | [tool.ruff.lint.isort] 31 | known-first-party = ["modules"] 32 | section-order = [ 33 | "future", 34 | "standard-library", 35 | "first-party", 36 | "mgds", 37 | "torch", 38 | "hf", 39 | "third-party", 40 | "local-folder", 41 | ] 42 | -------------------------------------------------------------------------------- /training_presets/#hidream LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "HiDream-ai/HiDream-I1-Full", 4 | "batch_size": 4, 5 | "gradient_checkpointing": "CPU_OFFLOADED", 6 | "layer_offload_fraction": 0.5, 7 | "dataloader_threads": 1, 8 | "learning_rate": 0.0003, 9 | "model_type": "HI_DREAM_FULL", 10 | "output_model_destination": "models/lora.safetensors", 11 | "output_model_format": "SAFETENSORS", 12 | "resolution": "512", 13 | "timestep_distribution": "LOGIT_NORMAL", 14 | "dynamic_timestep_shifting": false, 15 | "train_dtype": "BFLOAT_16", 16 | "training_method": "LORA", 17 | "transformer": { 18 | "train": true, 19 | "weight_dtype": "FLOAT_8" 20 | }, 21 | "text_encoder": { 22 | "train": false, 23 | "weight_dtype": "FLOAT_8" 24 | }, 25 | "text_encoder_2": { 26 | "train": false, 27 | "weight_dtype": "FLOAT_8" 28 | }, 29 | "text_encoder_3": { 30 | "train": false, 31 | "weight_dtype": "FLOAT_8" 32 | }, 33 | "text_encoder_4": { 34 | "model_name": "meta-llama/Llama-3.1-8B-Instruct", 35 | "train": false, 36 | "weight_dtype": "FLOAT_8" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Request a feature 3 | title: "[Feat]: " 4 | labels: ["enhancement"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this feature request! 10 | - type: textarea 11 | id: motivation 12 | attributes: 13 | label: Describe your use-case. 14 | description: Give as much detail as possible. Can the functionality be achieved using existing features? 15 | placeholder: Tell us what you're missing! 16 | validations: 17 | required: true 18 | - type: textarea 19 | id: proposed-solution 20 | attributes: 21 | label: What would you like to see as a solution? 22 | description: Describe the proposed solution with a workflow. Be precise. 23 | placeholder: Tell us what you'd like to see! 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: alternative-solutions 28 | attributes: 29 | label: Have you considered alternatives? List them here. 30 | description: Give a sketch of alternative solutions you've considered. 31 | placeholder: How else can it be achieved? 32 | validations: 33 | required: false 34 | -------------------------------------------------------------------------------- /training_presets/#z-image LoRA 8GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Tongyi-MAI/Z-Image-Turbo", 3 | "batch_size": 2, 4 | "learning_rate": 0.0003, 5 | "model_type": "Z_IMAGE", 6 | "resolution": "512", 7 | "gradient_checkpointing": "CPU_OFFLOADED", 8 | "layer_offload_fraction": 0.6, 9 | "compile": true, 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "INT_W8A8", 13 | "model_name": "https://huggingface.co/ostris/Z-Image-De-Turbo/blob/main/z_image_de_turbo_v1_bf16.safetensors" 14 | }, 15 | "text_encoder": { 16 | "train": false, 17 | "weight_dtype": "FLOAT_8" 18 | }, 19 | "training_method": "LORA", 20 | "vae": { 21 | "weight_dtype": "FLOAT_32" 22 | }, 23 | "train_dtype": "BFLOAT_16", 24 | "weight_dtype": "BFLOAT_16", 25 | "output_dtype": "BFLOAT_16", 26 | "layer_filter": "^(?=.*attention)(?!.*refiner).*,^(?=.*feed_forward)(?!.*refiner).*", 27 | "layer_filter_preset": "attn-mlp", 28 | "layer_filter_regex": true, 29 | "quantization": { 30 | "layer_filter": "layers", 31 | "layer_filter_preset": "blocks" 32 | }, 33 | "dataloader_threads": 1, 34 | "timestep_distribution": "LOGIT_NORMAL" 35 | } 36 | -------------------------------------------------------------------------------- /modules/ui/SampleParamsWindow.py: -------------------------------------------------------------------------------- 1 | from modules.ui.SampleFrame import SampleFrame 2 | from modules.util.config.SampleConfig import SampleConfig 3 | from modules.util.ui import components 4 | from modules.util.ui.ui_utils import set_window_icon 5 | from modules.util.ui.UIState import UIState 6 | 7 | import customtkinter as ctk 8 | 9 | 10 | class SampleParamsWindow(ctk.CTkToplevel): 11 | def __init__(self, parent, sample: SampleConfig, ui_state: UIState, *args, **kwargs): 12 | super().__init__(parent, *args, **kwargs) 13 | 14 | self.sample = sample 15 | self.ui_state = ui_state 16 | 17 | self.title("Sample") 18 | self.geometry("800x500") 19 | self.resizable(True, True) 20 | 21 | self.grid_rowconfigure(0, weight=1) 22 | self.grid_rowconfigure(1, weight=0) 23 | self.grid_columnconfigure(0, weight=1) 24 | 25 | frame = SampleFrame(self, self.sample, self.ui_state) 26 | frame.grid(row=0, column=0, padx=0, pady=0, sticky="nsew") 27 | 28 | components.button(self, 1, 0, "ok", self.__ok) 29 | 30 | self.wait_visibility() 31 | self.grab_set() 32 | self.focus_set() 33 | self.after(200, lambda: set_window_icon(self)) 34 | 35 | 36 | def __ok(self): 37 | self.destroy() 38 | -------------------------------------------------------------------------------- /resources/docker/RunPod-NVIDIA-CLI.Dockerfile: -------------------------------------------------------------------------------- 1 | #To build, run 2 | # docker build -t . -f RunPod-NVIDIA-CLI.Dockerfile 3 | # docker tag /: 4 | # docker push /: 5 | 6 | FROM runpod/pytorch:2.8.0-py3.11-cuda12.8.1-cudnn-devel-ubuntu22.04 7 | #the base image is barely used, pytorch is the wrong version. However, by using 8 | #a base image that is popular on RunPod, the base image likely is already available 9 | #in the image cache of a pod, and no download is necessary 10 | 11 | WORKDIR / 12 | RUN git clone https://github.com/Nerogar/OneTrainer 13 | RUN cd OneTrainer \ 14 | && export OT_PLATFORM_REQUIREMENTS=requirements-cuda.txt \ 15 | && export OT_LAZY_UPDATES=true \ 16 | && ./install.sh \ 17 | && pip cache purge \ 18 | && rm -r ~/.cache/pip 19 | RUN apt-get update --yes \ 20 | && apt-get install --yes --no-install-recommends \ 21 | joe \ 22 | less \ 23 | gh \ 24 | iputils-ping \ 25 | nano \ 26 | && apt-get autoremove -y \ 27 | && apt-get clean \ 28 | && rm -rf /var/lib/apt/lists/* 29 | RUN pip install nvitop \ 30 | && pip cache purge \ 31 | && rm -rf ~/.cache/pip 32 | COPY RunPod-NVIDIA-CLI-start.sh.patch /start.sh.patch 33 | RUN patch /start.sh < /start.sh.patch 34 | -------------------------------------------------------------------------------- /modules/modelSaver/FluxEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class FluxEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: FluxModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = FluxEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/SanaEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class SanaEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: SanaModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = SanaEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/model/util/llama_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import LlamaModel 4 | 5 | 6 | def encode_llama( 7 | text_encoder: LlamaModel, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | text_encoder_output: Tensor | None = None, 12 | use_attention_mask: bool = True, 13 | attention_mask: Tensor | None = None, 14 | crop_start: int | None = None, 15 | ) -> tuple[Tensor, Tensor, Tensor]: 16 | if text_encoder_output is None and text_encoder is not None: 17 | text_encoder_output = text_encoder( 18 | tokens, 19 | attention_mask=attention_mask if use_attention_mask else None, 20 | output_hidden_states=True, 21 | return_dict=True, 22 | use_cache=False, 23 | ) 24 | hidden_state_output_index = default_layer - layer_skip 25 | text_encoder_output = text_encoder_output.hidden_states[hidden_state_output_index] 26 | 27 | if crop_start is not None: 28 | tokens = tokens[:, crop_start:] 29 | text_encoder_output = text_encoder_output[:, crop_start:] 30 | attention_mask = attention_mask[:, crop_start:] 31 | 32 | return text_encoder_output, attention_mask, tokens 33 | -------------------------------------------------------------------------------- /modules/util/args/CalculateLossArgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from modules.util.args.BaseArgs import BaseArgs 5 | 6 | 7 | class CalculateLossArgs(BaseArgs): 8 | config_path: str 9 | output_path: str 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def parse_args() -> 'CalculateLossArgs': 16 | parser = argparse.ArgumentParser(description="One Trainer Loss Calculation Script.") 17 | 18 | # @formatter:off 19 | 20 | parser.add_argument("--config-path", type=str, required=True, dest="config_path", help="The path to the config file") 21 | parser.add_argument("--output-path", type=str, required=True, dest="output_path", help="The path to the output file") 22 | 23 | # @formatter:on 24 | 25 | args = CalculateLossArgs.default_values() 26 | args.from_dict(vars(parser.parse_args())) 27 | return args 28 | 29 | @staticmethod 30 | def default_values() -> 'CalculateLossArgs': 31 | data = [] 32 | 33 | # name, default value, data type, nullable 34 | data.append(("config_path", None, str, True)) 35 | data.append(("output_path", "losses.json", str, False)) 36 | 37 | return CalculateLossArgs(data) 38 | -------------------------------------------------------------------------------- /modules/modelSaver/ChromaEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class ChromaEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: ChromaModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = ChromaEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/HiDreamEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hidream.HiDreamEmbeddingSaver import HiDreamEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class HiDreamEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: HiDreamModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = HiDreamEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/zImage/ZImageLoRASaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.ZImageModel import ZImageModel 2 | from modules.modelSaver.mixin.LoRASaverMixin import LoRASaverMixin 3 | from modules.util.enum.ModelFormat import ModelFormat 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 9 | 10 | 11 | class ZImageLoRASaver( 12 | LoRASaverMixin, 13 | ): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def _get_convert_key_sets(self, model: ZImageModel) -> list[LoraConversionKeySet] | None: 18 | return None 19 | 20 | def _get_state_dict( 21 | self, 22 | model: ZImageModel, 23 | ) -> dict[str, Tensor]: 24 | state_dict = {} 25 | if model.transformer_lora is not None: 26 | state_dict |= model.transformer_lora.state_dict() 27 | if model.lora_state_dict is not None: 28 | state_dict |= model.lora_state_dict 29 | return state_dict 30 | 31 | def save( 32 | self, 33 | model: ZImageModel, 34 | output_model_format: ModelFormat, 35 | output_model_destination: str, 36 | dtype: torch.dtype | None, 37 | ): 38 | self._save(model, output_model_format, output_model_destination, dtype) 39 | -------------------------------------------------------------------------------- /modules/modelSaver/WuerstchenEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class WuerstchenEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: WuerstchenModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype, 25 | ): 26 | embedding_model_saver = WuerstchenEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/PixArtAlphaEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class PixArtAlphaEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: PixArtAlphaModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = PixArtAlphaEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/HunyuanVideoEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class HunyuanVideoEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: HunyuanVideoModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = HunyuanVideoEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/util/args/CaptionUIArgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from modules.util.args.BaseArgs import BaseArgs 5 | 6 | 7 | class CaptionUIArgs(BaseArgs): 8 | dir: str 9 | include_subdirectories: bool 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def parse_args() -> 'CaptionUIArgs': 16 | parser = argparse.ArgumentParser(description="One Trainer Caption UI Script.") 17 | 18 | # @formatter:off 19 | 20 | parser.add_argument("--dir", type=str, required=False, default=None, dest="dir", help="The initial directory to load training data from") 21 | parser.add_argument("--include-subdirectories", action="store_true", required=False, default=False, dest="include_subdirectories", help="Whether to include subdirectories when processing samples") 22 | 23 | # @formatter:on 24 | 25 | args = CaptionUIArgs.default_values() 26 | args.from_dict(vars(parser.parse_args())) 27 | return args 28 | 29 | @staticmethod 30 | def default_values() -> 'CaptionUIArgs': 31 | data = [] 32 | 33 | # name, default value, data type, nullable 34 | data.append(("dir", None, str, True)) 35 | data.append(("include_subdirectories", False, bool, False)) 36 | 37 | return CaptionUIArgs(data) 38 | -------------------------------------------------------------------------------- /modules/modelSaver/SanaFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver 5 | from modules.modelSaver.sana.SanaModelSaver import SanaModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class SanaFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: SanaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | base_model_saver = SanaModelSaver() 28 | embedding_model_saver = SanaEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class StableDiffusionEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: StableDiffusionModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = StableDiffusionEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /scripts/util/import_util.py: -------------------------------------------------------------------------------- 1 | def script_imports(allow_zluda: bool = True): 2 | import logging 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | # Filter out the Triton warning on startup. 8 | # xformers is not installed anymore, but might still exist for some installations. 9 | logging \ 10 | .getLogger("xformers") \ 11 | .addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) 12 | 13 | # Insert ourselves as the highest-priority library path, so our modules are 14 | # always found without any risk of being shadowed by another import path. 15 | # 3 .parent calls to navigate from /scripts/util/import_util.py to the main directory 16 | onetrainer_lib_path = Path(__file__).absolute().parent.parent.parent 17 | sys.path.insert(0, str(onetrainer_lib_path)) 18 | 19 | if allow_zluda and sys.platform.startswith('win'): 20 | from modules.zluda import ZLUDAInstaller 21 | 22 | zluda_path = ZLUDAInstaller.get_path() 23 | 24 | if os.path.exists(zluda_path): 25 | try: 26 | ZLUDAInstaller.load(zluda_path) 27 | print(f'Using ZLUDA in {zluda_path}') 28 | except Exception as e: 29 | print(f'Failed to load ZLUDA: {e}') 30 | 31 | from modules.zluda import ZLUDA 32 | 33 | ZLUDA.initialize() 34 | -------------------------------------------------------------------------------- /modules/modelSaver/FluxFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver 4 | from modules.modelSaver.flux.FluxModelSaver import FluxModelSaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class FluxFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: FluxModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = FluxModelSaver() 28 | embedding_model_saver = FluxEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusion3EmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class StableDiffusion3EmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: StableDiffusion3Model, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = StableDiffusion3EmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/util/loss/masked_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def masked_losses( 6 | losses: Tensor, 7 | mask: Tensor, 8 | unmasked_weight: float, 9 | normalize_masked_area_loss: bool, 10 | ) -> Tensor: 11 | clamped_mask = torch.clamp(mask, unmasked_weight, 1) 12 | 13 | losses *= clamped_mask 14 | 15 | if normalize_masked_area_loss: 16 | losses = losses / clamped_mask.mean(dim=(1, 2, 3), keepdim=True) 17 | 18 | return losses 19 | 20 | 21 | def masked_losses_with_prior( 22 | losses: Tensor, 23 | prior_losses: Tensor | None, 24 | mask: Tensor, 25 | unmasked_weight: float, 26 | normalize_masked_area_loss: bool, 27 | masked_prior_preservation_weight: float, 28 | ) -> Tensor: 29 | clamped_mask = torch.clamp(mask, unmasked_weight, 1) 30 | 31 | losses *= clamped_mask 32 | 33 | if normalize_masked_area_loss: 34 | losses = losses / clamped_mask.mean(dim=(1, 2, 3), keepdim=True) 35 | 36 | if masked_prior_preservation_weight == 0 or prior_losses is None: 37 | return losses 38 | 39 | clamped_mask = (1 - clamped_mask) 40 | prior_losses *= clamped_mask * masked_prior_preservation_weight 41 | 42 | if normalize_masked_area_loss: 43 | prior_losses = prior_losses / clamped_mask.mean(dim=(1, 2, 3), keepdim=True) 44 | 45 | return losses + prior_losses 46 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionXLEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class StableDiffusionXLEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: StableDiffusionXLModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = StableDiffusionXLEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /scripts/generate_captions.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.module.Blip2Model import Blip2Model 6 | from modules.module.BlipModel import BlipModel 7 | from modules.module.WDModel import WDModel 8 | from modules.util.args.GenerateCaptionsArgs import GenerateCaptionsArgs 9 | from modules.util.enum.GenerateCaptionsModel import GenerateCaptionsModel 10 | 11 | import torch 12 | 13 | 14 | def main(): 15 | args = GenerateCaptionsArgs.parse_args() 16 | 17 | model = None 18 | if args.model == GenerateCaptionsModel.BLIP: 19 | model = BlipModel(torch.device(args.device), args.dtype.torch_dtype()) 20 | elif args.model == GenerateCaptionsModel.BLIP2: 21 | model = Blip2Model(torch.device(args.device), args.dtype.torch_dtype()) 22 | elif args.model == GenerateCaptionsModel.WD14_VIT_2: 23 | model = WDModel(torch.device(args.device), args.dtype.torch_dtype()) 24 | 25 | model.caption_folder( 26 | sample_dir=args.sample_dir, 27 | initial_caption=args.initial_caption, 28 | caption_prefix=args.caption_prefix, 29 | caption_postfix=args.caption_postfix, 30 | mode=args.mode, 31 | error_callback=lambda filename: print("Error while processing image " + filename), 32 | include_subdirectories=args.include_subdirectories 33 | ) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /modules/modelLoader/mixin/ModelSpecModelLoaderMixin.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | from abc import ABCMeta 4 | 5 | from modules.util.enum.ModelType import ModelType 6 | from modules.util.modelSpec.ModelSpec import ModelSpec 7 | 8 | from safetensors import safe_open 9 | 10 | 11 | class ModelSpecModelLoaderMixin(metaclass=ABCMeta): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def _default_model_spec_name( 16 | self, 17 | model_type: ModelType, 18 | ) -> str | None: 19 | return None 20 | 21 | def _load_default_model_spec( 22 | self, 23 | model_type: ModelType, 24 | safetensors_file_name: str | None = None, 25 | ) -> ModelSpec: 26 | model_spec = None 27 | 28 | model_spec_name = self._default_model_spec_name(model_type) 29 | if model_spec_name: 30 | with open(model_spec_name, "r", encoding="utf-8") as model_spec_file: 31 | model_spec = ModelSpec.from_dict(json.load(model_spec_file)) 32 | else: 33 | model_spec = ModelSpec() 34 | 35 | if safetensors_file_name: 36 | with contextlib.suppress(Exception), safe_open(safetensors_file_name, framework="pt") as f: 37 | if "modelspec.sai_model_spec" in f.metadata(): 38 | model_spec = ModelSpec.from_dict(f.metadata()) 39 | 40 | return model_spec 41 | -------------------------------------------------------------------------------- /modules/util/convert/rescale_noise_scheduler_to_zero_terminal_snr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusers import DDIMScheduler 4 | 5 | 6 | def rescale_noise_scheduler_to_zero_terminal_snr(noise_scheduler: DDIMScheduler): 7 | """ 8 | From: Common Diffusion Noise Schedules and Sample Steps are Flawed (https://arxiv.org/abs/2305.08891) 9 | 10 | Rescales the 11 | 12 | Args: 13 | noise_scheduler: The noise scheduler to transform 14 | 15 | Returns: 16 | 17 | """ 18 | alphas_cumprod = noise_scheduler.alphas_cumprod 19 | sqrt_alphas_cumprod = alphas_cumprod ** 0.5 20 | 21 | # Store old values. 22 | alphas_cumprod_sqrt_0 = sqrt_alphas_cumprod[0].clone() 23 | alphas_cumprod_sqrt_T = sqrt_alphas_cumprod[-1].clone() 24 | 25 | # Shift so last timestep is zero. 26 | sqrt_alphas_cumprod -= alphas_cumprod_sqrt_T 27 | 28 | # Scale so first timestep is back to old value. 29 | sqrt_alphas_cumprod *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) 30 | 31 | # Convert alphas_cumprod_sqrt to betas 32 | alphas_cumprod = sqrt_alphas_cumprod ** 2 33 | alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] 34 | alphas = torch.cat([alphas_cumprod[0:1], alphas]) 35 | betas = 1 - alphas 36 | 37 | noise_scheduler.betas = betas 38 | noise_scheduler.alphas = alphas 39 | noise_scheduler.alphas_cumprod = alphas_cumprod 40 | 41 | return betas 42 | -------------------------------------------------------------------------------- /modules/modelSaver/ChromaFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver 4 | from modules.modelSaver.chroma.ChromaModelSaver import ChromaModelSaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class ChromaFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: ChromaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = ChromaModelSaver() 28 | embedding_model_saver = ChromaEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/qwen/QwenLoRASaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.QwenModel import QwenModel 2 | from modules.modelSaver.mixin.LoRASaverMixin import LoRASaverMixin 3 | from modules.util.enum.ModelFormat import ModelFormat 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 9 | 10 | 11 | class QwenLoRASaver( 12 | LoRASaverMixin, 13 | ): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def _get_convert_key_sets(self, model: QwenModel) -> list[LoraConversionKeySet] | None: 18 | return None 19 | 20 | def _get_state_dict( 21 | self, 22 | model: QwenModel, 23 | ) -> dict[str, Tensor]: 24 | state_dict = {} 25 | if model.text_encoder_lora is not None: 26 | state_dict |= model.text_encoder_lora.state_dict() 27 | if model.transformer_lora is not None: 28 | state_dict |= model.transformer_lora.state_dict() 29 | if model.lora_state_dict is not None: 30 | state_dict |= model.lora_state_dict 31 | return state_dict 32 | 33 | def save( 34 | self, 35 | model: QwenModel, 36 | output_model_format: ModelFormat, 37 | output_model_destination: str, 38 | dtype: torch.dtype | None, 39 | ): 40 | self._save(model, output_model_format, output_model_destination, dtype) 41 | -------------------------------------------------------------------------------- /resources/model_config/stable_cascade/stable_cascade_prior_1.0b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableCascadeUNet", 3 | "_diffusers_version": "0.27.0.dev0", 4 | "block_out_channels": [ 5 | 1536, 6 | 1536 7 | ], 8 | "block_types_per_layer": [ 9 | [ 10 | "SDCascadeResBlock", 11 | "SDCascadeTimestepBlock", 12 | "SDCascadeAttnBlock" 13 | ], 14 | [ 15 | "SDCascadeResBlock", 16 | "SDCascadeTimestepBlock", 17 | "SDCascadeAttnBlock" 18 | ] 19 | ], 20 | "clip_image_in_channels": 768, 21 | "clip_seq": 4, 22 | "clip_text_in_channels": 1280, 23 | "clip_text_pooled_in_channels": 1280, 24 | "conditioning_dim": 1536, 25 | "down_blocks_repeat_mappers": [ 26 | 1, 27 | 1 28 | ], 29 | "down_num_layers_per_block": [ 30 | 4, 31 | 12 32 | ], 33 | "dropout": [ 34 | 0.1, 35 | 0.1 36 | ], 37 | "effnet_in_channels": null, 38 | "in_channels": 16, 39 | "kernel_size": 3, 40 | "num_attention_heads": [ 41 | 24, 42 | 24 43 | ], 44 | "out_channels": 16, 45 | "patch_size": 1, 46 | "pixel_mapper_in_channels": null, 47 | "self_attn": true, 48 | "switch_level": [ 49 | false 50 | ], 51 | "timestep_conditioning_type": [ 52 | "sca", 53 | "crp" 54 | ], 55 | "timestep_ratio_embedding_dim": 64, 56 | "up_blocks_repeat_mappers": [ 57 | 1, 58 | 1 59 | ], 60 | "up_num_layers_per_block": [ 61 | 12, 62 | 4 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /resources/model_config/stable_cascade/stable_cascade_prior_3.6b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableCascadeUNet", 3 | "_diffusers_version": "0.27.0.dev0", 4 | "block_out_channels": [ 5 | 2048, 6 | 2048 7 | ], 8 | "block_types_per_layer": [ 9 | [ 10 | "SDCascadeResBlock", 11 | "SDCascadeTimestepBlock", 12 | "SDCascadeAttnBlock" 13 | ], 14 | [ 15 | "SDCascadeResBlock", 16 | "SDCascadeTimestepBlock", 17 | "SDCascadeAttnBlock" 18 | ] 19 | ], 20 | "clip_image_in_channels": 768, 21 | "clip_seq": 4, 22 | "clip_text_in_channels": 1280, 23 | "clip_text_pooled_in_channels": 1280, 24 | "conditioning_dim": 2048, 25 | "down_blocks_repeat_mappers": [ 26 | 1, 27 | 1 28 | ], 29 | "down_num_layers_per_block": [ 30 | 8, 31 | 24 32 | ], 33 | "dropout": [ 34 | 0.1, 35 | 0.1 36 | ], 37 | "effnet_in_channels": null, 38 | "in_channels": 16, 39 | "kernel_size": 3, 40 | "num_attention_heads": [ 41 | 32, 42 | 32 43 | ], 44 | "out_channels": 16, 45 | "patch_size": 1, 46 | "pixel_mapper_in_channels": null, 47 | "self_attn": true, 48 | "switch_level": [ 49 | false 50 | ], 51 | "timestep_conditioning_type": [ 52 | "sca", 53 | "crp" 54 | ], 55 | "timestep_ratio_embedding_dim": 64, 56 | "up_blocks_repeat_mappers": [ 57 | 1, 58 | 1 59 | ], 60 | "up_num_layers_per_block": [ 61 | 24, 62 | 8 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusionFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader 3 | from modules.modelLoader.stableDiffusion.StableDiffusionEmbeddingLoader import StableDiffusionEmbeddingLoader 4 | from modules.modelLoader.stableDiffusion.StableDiffusionModelLoader import StableDiffusionModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | StableDiffusionFineTuneModelLoader = make_fine_tune_model_loader( 8 | model_spec_map={ 9 | ModelType.STABLE_DIFFUSION_15: "resources/sd_model_spec/sd_3_2b_1.0.json", 10 | ModelType.STABLE_DIFFUSION_15_INPAINTING: "resources/sd_model_spec/sd_3.5_1.0.json", 11 | ModelType.STABLE_DIFFUSION_20: "resources/sd_model_spec/sd_2.0.json", 12 | ModelType.STABLE_DIFFUSION_20_BASE: "resources/sd_model_spec/sd_2.0.json", 13 | ModelType.STABLE_DIFFUSION_20_INPAINTING: "resources/sd_model_spec/sd_2.0_inpainting.json", 14 | ModelType.STABLE_DIFFUSION_20_DEPTH: "resources/sd_model_spec/sd_2.0_depth.json", 15 | ModelType.STABLE_DIFFUSION_21: "resources/sd_model_spec/sd_2.1.json", 16 | ModelType.STABLE_DIFFUSION_21_BASE: "resources/sd_model_spec/sd_2.1.json", 17 | }, 18 | model_class=StableDiffusionModel, 19 | model_loader_class=StableDiffusionModelLoader, 20 | embedding_loader_class=StableDiffusionEmbeddingLoader, 21 | ) 22 | -------------------------------------------------------------------------------- /training_presets/#chroma Finetune 24GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "lodestones/Chroma1-HD", 3 | "batch_size": 2, 4 | "learning_rate": 1e-5, 5 | "model_type": "CHROMA_1", 6 | "resolution": "512", 7 | "transformer": { 8 | "train": true, 9 | "weight_dtype": "BFLOAT_16" 10 | }, 11 | "text_encoder": { 12 | "train": false, 13 | "weight_dtype": "BFLOAT_16" 14 | }, 15 | "training_method": "FINE_TUNE", 16 | "vae": { 17 | "weight_dtype": "FLOAT_32" 18 | }, 19 | "train_dtype": "BFLOAT_16", 20 | "weight_dtype": "BFLOAT_16", 21 | "output_dtype": "BFLOAT_16", 22 | "timestep_distribution": "INVERTED_PARABOLA", 23 | "noising_weight": 7.7, 24 | "optimizer": { 25 | "optimizer": "ADAFACTOR" 26 | }, 27 | "optimizer_defaults": { 28 | "ADAFACTOR": { 29 | "optimizer": "ADAFACTOR", 30 | "fused_back_pass": true, 31 | "beta1": null, 32 | "clip_threshold": 1.0, 33 | "decay_rate": -0.8, 34 | "eps": 1e-30, 35 | "eps2": 0.001, 36 | "relative_step": false, 37 | "scale_parameter": false, 38 | "stochastic_rounding": true, 39 | "warmup_init": false, 40 | "weight_decay": 0.0 41 | } 42 | }, 43 | "quantization": { 44 | "layer_filter": "transformer_block", 45 | "layer_filter_preset": "blocks" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /modules/modelSaver/WuerstchenFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver 5 | from modules.modelSaver.wuerstchen.WuerstchenModelSaver import WuerstchenModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class WuerstchenFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: WuerstchenModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | base_model_saver = WuerstchenModelSaver() 28 | embedding_model_saver = WuerstchenEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/module/Blip2Model.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseImageCaptionModel import BaseImageCaptionModel, CaptionSample 2 | 3 | import torch 4 | 5 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 6 | 7 | 8 | class Blip2Model(BaseImageCaptionModel): 9 | def __init__(self, device: torch.device, dtype: torch.dtype): 10 | self.device = device 11 | self.dtype = dtype 12 | 13 | self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 14 | 15 | self.model = Blip2ForConditionalGeneration.from_pretrained( 16 | "Salesforce/blip2-opt-2.7b", 17 | torch_dtype=self.dtype 18 | ) 19 | self.model.eval() 20 | self.model.to(self.device) 21 | 22 | def generate_caption( 23 | self, 24 | caption_sample: CaptionSample, 25 | initial_caption: str = "", 26 | caption_prefix: str = "", 27 | caption_postfix: str = "", 28 | ) -> str: 29 | inputs = self.processor(caption_sample.get_image(), initial_caption, return_tensors="pt") 30 | inputs = inputs.to(self.device, self.dtype) 31 | with torch.no_grad(): 32 | outputs = self.model.generate(**inputs) 33 | predicted_caption = self.processor.decode(outputs[0], skip_special_tokens=True) 34 | predicted_caption = (caption_prefix + initial_caption + predicted_caption + caption_postfix).strip() 35 | 36 | return predicted_caption 37 | -------------------------------------------------------------------------------- /modules/module/BlipModel.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseImageCaptionModel import BaseImageCaptionModel, CaptionSample 2 | 3 | import torch 4 | 5 | from transformers import BlipForConditionalGeneration, BlipProcessor 6 | 7 | 8 | class BlipModel(BaseImageCaptionModel): 9 | def __init__(self, device: torch.device, dtype: torch.dtype): 10 | self.device = device 11 | self.dtype = dtype 12 | 13 | self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") 14 | 15 | self.model = BlipForConditionalGeneration.from_pretrained( 16 | "Salesforce/blip-image-captioning-large", 17 | torch_dtype=self.dtype 18 | ) 19 | self.model.eval() 20 | self.model.to(self.device) 21 | 22 | def generate_caption( 23 | self, 24 | caption_sample: CaptionSample, 25 | initial_caption: str = "", 26 | caption_prefix: str = "", 27 | caption_postfix: str = "", 28 | ): 29 | inputs = self.processor(caption_sample.get_image(), initial_caption, return_tensors="pt") 30 | inputs = inputs.to(self.device, self.dtype) 31 | with torch.no_grad(): 32 | outputs = self.model.generate(**inputs) 33 | predicted_caption = self.processor.decode(outputs[0], skip_special_tokens=True) 34 | predicted_caption = (caption_prefix + predicted_caption + caption_postfix).strip() 35 | 36 | return predicted_caption 37 | -------------------------------------------------------------------------------- /modules/modelSaver/PixArtAlphaFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver 5 | from modules.modelSaver.pixartAlpha.PixArtAlphaModelSaver import PixArtAlphaModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class PixArtAlphaFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: PixArtAlphaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | base_model_saver = PixArtAlphaModelSaver() 28 | embedding_model_saver = PixArtAlphaEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | import json 6 | 7 | from modules.util import create 8 | from modules.util.args.TrainArgs import TrainArgs 9 | from modules.util.callbacks.TrainCallbacks import TrainCallbacks 10 | from modules.util.commands.TrainCommands import TrainCommands 11 | from modules.util.config.SecretsConfig import SecretsConfig 12 | from modules.util.config.TrainConfig import TrainConfig 13 | 14 | 15 | def main(): 16 | args = TrainArgs.parse_args() 17 | callbacks = TrainCallbacks() 18 | commands = TrainCommands() 19 | 20 | train_config = TrainConfig.default_values() 21 | with open(args.config_path, "r") as f: 22 | train_config.from_dict(json.load(f)) 23 | 24 | try: 25 | with open("secrets.json" if args.secrets_path is None else args.secrets_path, "r") as f: 26 | secrets_dict=json.load(f) 27 | train_config.secrets = SecretsConfig.default_values().from_dict(secrets_dict) 28 | except FileNotFoundError: 29 | if args.secrets_path is not None: 30 | raise 31 | 32 | trainer = create.create_trainer(train_config, callbacks, commands) 33 | 34 | trainer.start() 35 | 36 | canceled = False 37 | try: 38 | trainer.train() 39 | except KeyboardInterrupt: 40 | canceled = True 41 | 42 | if not canceled or train_config.backup_before_save: 43 | trainer.end() 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /training_presets/#qwen Finetune 16GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Qwen/Qwen-Image", 3 | "batch_size": 2, 4 | "learning_rate": 1e-5, 5 | "model_type": "QWEN", 6 | "resolution": "512", 7 | "gradient_checkpointing": "CPU_OFFLOADED", 8 | "layer_offload_fraction": 0.75, 9 | "dataloader_threads": 1, 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "BFLOAT_16" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "FLOAT_8" 17 | }, 18 | "training_method": "FINE_TUNE", 19 | "vae": { 20 | "weight_dtype": "FLOAT_32" 21 | }, 22 | "train_dtype": "BFLOAT_16", 23 | "weight_dtype": "BFLOAT_16", 24 | "output_dtype": "BFLOAT_16", 25 | "timestep_distribution": "LOGIT_NORMAL", 26 | "optimizer": { 27 | "optimizer": "ADAFACTOR" 28 | }, 29 | "optimizer_defaults": { 30 | "ADAFACTOR": { 31 | "optimizer": "ADAFACTOR", 32 | "fused_back_pass": true, 33 | "beta1": null, 34 | "clip_threshold": 1.0, 35 | "decay_rate": -0.8, 36 | "eps": 1e-30, 37 | "eps2": 0.001, 38 | "relative_step": false, 39 | "scale_parameter": false, 40 | "stochastic_rounding": true, 41 | "warmup_init": false, 42 | "weight_decay": 0.0 43 | } 44 | }, 45 | "layer_filter": "transformer_block", 46 | "layer_filter_preset": "blocks" 47 | } 48 | -------------------------------------------------------------------------------- /training_presets/#qwen Finetune 24GB.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Qwen/Qwen-Image", 3 | "batch_size": 2, 4 | "learning_rate": 1e-5, 5 | "model_type": "QWEN", 6 | "resolution": "512", 7 | "gradient_checkpointing": "CPU_OFFLOADED", 8 | "layer_offload_fraction": 0.55, 9 | "dataloader_threads": 1, 10 | "transformer": { 11 | "train": true, 12 | "weight_dtype": "BFLOAT_16" 13 | }, 14 | "text_encoder": { 15 | "train": false, 16 | "weight_dtype": "FLOAT_8" 17 | }, 18 | "training_method": "FINE_TUNE", 19 | "vae": { 20 | "weight_dtype": "FLOAT_32" 21 | }, 22 | "train_dtype": "BFLOAT_16", 23 | "weight_dtype": "BFLOAT_16", 24 | "output_dtype": "BFLOAT_16", 25 | "timestep_distribution": "LOGIT_NORMAL", 26 | "optimizer": { 27 | "optimizer": "ADAFACTOR" 28 | }, 29 | "optimizer_defaults": { 30 | "ADAFACTOR": { 31 | "optimizer": "ADAFACTOR", 32 | "fused_back_pass": true, 33 | "beta1": null, 34 | "clip_threshold": 1.0, 35 | "decay_rate": -0.8, 36 | "eps": 1e-30, 37 | "eps2": 0.001, 38 | "relative_step": false, 39 | "scale_parameter": false, 40 | "stochastic_rounding": true, 41 | "warmup_init": false, 42 | "weight_decay": 0.0 43 | } 44 | }, 45 | "layer_filter": "transformer_block", 46 | "layer_filter_preset": "blocks" 47 | } 48 | -------------------------------------------------------------------------------- /modules/modelSaver/HunyuanVideoFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver 4 | from modules.modelSaver.hunyuanVideo.HunyuanVideoModelSaver import HunyuanVideoModelSaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class HunyuanVideoFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: HunyuanVideoModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = HunyuanVideoModelSaver() 28 | embedding_model_saver = HunyuanVideoEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/SanaLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver 5 | from modules.modelSaver.sana.SanaLoRASaver import SanaLoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class SanaLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: SanaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | lora_model_saver = SanaLoRASaver() 28 | embedding_model_saver = SanaEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/FluxLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver 4 | from modules.modelSaver.flux.FluxLoRASaver import FluxLoRASaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class FluxLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: FluxModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = FluxLoRASaver() 28 | embedding_model_saver = FluxEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report 3 | title: "[Bug]: " 4 | labels: ["bug"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! 10 | - type: textarea 11 | id: actual-behavior 12 | attributes: 13 | label: What happened? 14 | description: Give a description of what happened? 15 | placeholder: Tell us what you saw! 16 | validations: 17 | required: true 18 | - type: textarea 19 | id: expected-behavior 20 | attributes: 21 | label: What did you expect would happen? 22 | description: Also tell us, what did you expect to happen? 23 | placeholder: Tell us what you saw! 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: logs 28 | attributes: 29 | label: Relevant log output 30 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 31 | render: shell 32 | - type: textarea 33 | id: environment 34 | attributes: 35 | label: Generate and upload debug_report.log 36 | description: Please attach your `debug_report.log` file. This file is generated by double-clicking on `export_debug.bat` in the OneTrainer folder on Windows, or by running `./run-cmd.sh generate_debug_report` on Mac/Linux. 37 | placeholder: Drag and drop the file or click the paperclip icon below to upload. 38 | render: '' 39 | -------------------------------------------------------------------------------- /modules/modelSaver/ChromaLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.ChromaModel import ChromaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver 4 | from modules.modelSaver.chroma.ChromaLoRASaver import ChromaLoRASaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class ChromaLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: ChromaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = ChromaLoRASaver() 28 | embedding_model_saver = ChromaEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelLoader/StableDiffusionEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader 3 | from modules.modelLoader.stableDiffusion.StableDiffusionEmbeddingLoader import StableDiffusionEmbeddingLoader 4 | from modules.modelLoader.stableDiffusion.StableDiffusionModelLoader import StableDiffusionModelLoader 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | StableDiffusionEmbeddingModelLoader = make_embedding_model_loader( 8 | model_spec_map={ 9 | ModelType.STABLE_DIFFUSION_15: "resources/sd_model_spec/sd_1.5-embedding.json", 10 | ModelType.STABLE_DIFFUSION_15_INPAINTING: "resources/sd_model_spec/sd_1.5_inpainting-embedding.json", 11 | ModelType.STABLE_DIFFUSION_20: "resources/sd_model_spec/sd_2.0-embedding.json", 12 | ModelType.STABLE_DIFFUSION_20_BASE: "resources/sd_model_spec/sd_2.0-embedding.json", 13 | ModelType.STABLE_DIFFUSION_20_INPAINTING: "resources/sd_model_spec/sd_2.0_inpainting-embedding.json", 14 | ModelType.STABLE_DIFFUSION_20_DEPTH: "resources/sd_model_spec/sd_2.0_depth-embedding.json", 15 | ModelType.STABLE_DIFFUSION_21: "resources/sd_model_spec/sd_2.1-embedding.json", 16 | ModelType.STABLE_DIFFUSION_21_BASE: "resources/sd_model_spec/sd_2.1-embedding.json", 17 | }, 18 | model_class=StableDiffusionModel, 19 | model_loader_class=StableDiffusionModelLoader, 20 | embedding_loader_class=StableDiffusionEmbeddingLoader, 21 | ) 22 | -------------------------------------------------------------------------------- /modules/modelSetup/mixin/ModelSetupFlowMatchingMixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class ModelSetupFlowMatchingMixin(metaclass=ABCMeta): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self.__sigma = None 12 | self.__one_minus_sigma = None 13 | 14 | def _add_noise_discrete( 15 | self, 16 | scaled_latent_image: Tensor, 17 | latent_noise: Tensor, 18 | timestep: Tensor, 19 | timesteps: Tensor, 20 | ) -> tuple[Tensor, Tensor]: 21 | if self.__sigma is None: 22 | num_timesteps = timesteps.shape[-1] 23 | all_timesteps = torch.arange(start=1, end=num_timesteps + 1, step=1, dtype=torch.int32, device=scaled_latent_image.device) 24 | self.__sigma = all_timesteps / num_timesteps 25 | self.__one_minus_sigma = 1.0 - self.__sigma 26 | 27 | orig_dtype = scaled_latent_image.dtype 28 | 29 | sigmas = self.__sigma[timestep] 30 | one_minus_sigmas = self.__one_minus_sigma[timestep] 31 | 32 | while sigmas.dim() < scaled_latent_image.dim(): 33 | sigmas = sigmas.unsqueeze(-1) 34 | one_minus_sigmas = one_minus_sigmas.unsqueeze(-1) 35 | 36 | scaled_noisy_latent_image = latent_noise.to(dtype=sigmas.dtype) * sigmas \ 37 | + scaled_latent_image.to(dtype=sigmas.dtype) * one_minus_sigmas 38 | 39 | return scaled_noisy_latent_image.to(dtype=orig_dtype), sigmas 40 | -------------------------------------------------------------------------------- /modules/util/enum/GradientReducePrecision.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | 5 | 6 | class GradientReducePrecision(Enum): 7 | WEIGHT_DTYPE = 'WEIGHT_DTYPE' 8 | FLOAT_32 = 'FLOAT_32' 9 | WEIGHT_DTYPE_STOCHASTIC = 'WEIGHT_DTYPE_STOCHASTIC' 10 | FLOAT_32_STOCHASTIC = 'FLOAT_32_STOCHASTIC' 11 | 12 | def torch_dtype(self, weight_dtype: torch.dtype) -> torch.dtype: 13 | match self: 14 | case GradientReducePrecision.WEIGHT_DTYPE: 15 | return weight_dtype 16 | case GradientReducePrecision.FLOAT_32: 17 | return torch.float32 18 | case GradientReducePrecision.WEIGHT_DTYPE_STOCHASTIC: 19 | return weight_dtype 20 | case GradientReducePrecision.FLOAT_32_STOCHASTIC: 21 | return torch.float32 22 | case _: 23 | raise ValueError 24 | 25 | def stochastic_rounding(self, weight_dtype: torch.dtype) -> bool: 26 | match self: 27 | case GradientReducePrecision.WEIGHT_DTYPE: 28 | return False 29 | case GradientReducePrecision.FLOAT_32: 30 | return False 31 | case GradientReducePrecision.WEIGHT_DTYPE_STOCHASTIC: 32 | return weight_dtype == torch.bfloat16 33 | case GradientReducePrecision.FLOAT_32_STOCHASTIC: 34 | return weight_dtype == torch.bfloat16 35 | case _: 36 | raise ValueError 37 | 38 | def __str__(self): 39 | return self.value 40 | -------------------------------------------------------------------------------- /modules/modelSaver/HiDreamLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hidream.HiDreamEmbeddingSaver import HiDreamEmbeddingSaver 4 | from modules.modelSaver.hidream.HiDreamLoRASaver import HiDreamLoRASaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class HiDreamLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: HiDreamModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = HiDreamLoRASaver() 28 | embedding_model_saver = HiDreamEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusion3FineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver 5 | from modules.modelSaver.stableDiffusion3.StableDiffusion3ModelSaver import StableDiffusion3ModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusion3FineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusion3Model, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = StableDiffusion3ModelSaver() 28 | embedding_model_saver = StableDiffusion3EmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver 5 | from modules.modelSaver.stableDiffusion.StableDiffusionModelSaver import StableDiffusionModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusionFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusionModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = StableDiffusionModelSaver() 28 | embedding_model_saver = StableDiffusionEmbeddingSaver() 29 | 30 | base_model_saver.save(model, model_type, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | --------------------------------------------------------------------------------