├── .editorconfig ├── .gitattributes ├── .github ├── FUNDING.yml └── ISSUE_TEMPLATE │ ├── bug.yml │ ├── config.yml │ └── feature.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LAUNCH-SCRIPTS.md ├── LICENSE.txt ├── README.md ├── docs ├── CaptioningAndMasking.md ├── CliTraining.md ├── Contributing.md ├── DockerImage.md ├── EmbeddingTraining.md ├── Overview.md ├── ProjectStructure.md ├── QuickStartGuide.md └── RamOffloading.md ├── embedding_templates ├── .gitignore └── subject.txt ├── export_debug.bat ├── install.bat ├── install.sh ├── lib.include.sh ├── modules ├── cloud │ ├── BaseCloud.py │ ├── BaseFileSync.py │ ├── BaseSSHFileSync.py │ ├── FabricFileSync.py │ ├── LinuxCloud.py │ ├── NativeSCPFileSync.py │ └── RunpodCloud.py ├── dataLoader │ ├── BaseDataLoader.py │ ├── FluxBaseDataLoader.py │ ├── HiDreamBaseDataLoader.py │ ├── HunyuanVideoBaseDataLoader.py │ ├── PixArtAlphaBaseDataLoader.py │ ├── SanaBaseDataLoader.py │ ├── StableDiffusion3BaseDataLoader.py │ ├── StableDiffusionBaseDataLoader.py │ ├── StableDiffusionFineTuneVaeDataLoader.py │ ├── StableDiffusionXLBaseDataLoader.py │ ├── WuerstchenBaseDataLoader.py │ ├── flux │ │ └── ShuffleFluxFillMaskChannels.py │ ├── mixin │ │ ├── DataLoaderMgdsMixin.py │ │ └── DataLoaderText2ImageMixin.py │ └── wuerstchen │ │ └── EncodeWuerstchenEffnet.py ├── model │ ├── BaseModel.py │ ├── FluxModel.py │ ├── HiDreamModel.py │ ├── HunyuanVideoModel.py │ ├── PixArtAlphaModel.py │ ├── SanaModel.py │ ├── StableDiffusion3Model.py │ ├── StableDiffusionModel.py │ ├── StableDiffusionXLModel.py │ ├── WuerstchenModel.py │ └── util │ │ ├── clip_util.py │ │ ├── gemma_util.py │ │ ├── llama_util.py │ │ └── t5_util.py ├── modelLoader │ ├── BaseModelLoader.py │ ├── FluxEmbeddingModelLoader.py │ ├── FluxFineTuneModelLoader.py │ ├── FluxLoRAModelLoader.py │ ├── HiDreamEmbeddingModelLoader.py │ ├── HiDreamFineTuneModelLoader.py │ ├── HiDreamLoRAModelLoader.py │ ├── HunyuanVideoEmbeddingModelLoader.py │ ├── HunyuanVideoFineTuneModelLoader.py │ ├── HunyuanVideoLoRAModelLoader.py │ ├── PixArtAlphaEmbeddingModelLoader.py │ ├── PixArtAlphaFineTuneModelLoader.py │ ├── PixArtAlphaLoRAModelLoader.py │ ├── SanaEmbeddingModelLoader.py │ ├── SanaFineTuneModelLoader.py │ ├── SanaLoRAModelLoader.py │ ├── StableDiffusion3EmbeddingModelLoader.py │ ├── StableDiffusion3FineTuneModelLoader.py │ ├── StableDiffusion3LoRAModelLoader.py │ ├── StableDiffusionEmbeddingModelLoader.py │ ├── StableDiffusionFineTuneModelLoader.py │ ├── StableDiffusionLoRAModelLoader.py │ ├── StableDiffusionXLEmbeddingModelLoader.py │ ├── StableDiffusionXLFineTuneModelLoader.py │ ├── StableDiffusionXLLoRAModelLoader.py │ ├── WuerstchenEmbeddingModelLoader.py │ ├── WuerstchenFineTuneModelLoader.py │ ├── WuerstchenLoRAModelLoader.py │ ├── flux │ │ ├── FluxEmbeddingLoader.py │ │ ├── FluxLoRALoader.py │ │ └── FluxModelLoader.py │ ├── hiDream │ │ ├── HiDreamEmbeddingLoader.py │ │ ├── HiDreamLoRALoader.py │ │ └── HiDreamModelLoader.py │ ├── hunyuanVideo │ │ ├── HunyuanVideoEmbeddingLoader.py │ │ ├── HunyuanVideoLoRALoader.py │ │ └── HunyuanVideoModelLoader.py │ ├── mixin │ │ ├── EmbeddingLoaderMixin.py │ │ ├── HFModelLoaderMixin.py │ │ ├── InternalModelLoaderMixin.py │ │ ├── LoRALoaderMixin.py │ │ ├── ModelSpecModelLoaderMixin.py │ │ └── SDConfigModelLoaderMixin.py │ ├── pixartAlpha │ │ ├── PixArtAlphaEmbeddingLoader.py │ │ ├── PixArtAlphaLoRALoader.py │ │ └── PixArtAlphaModelLoader.py │ ├── sana │ │ ├── SanaEmbeddingLoader.py │ │ ├── SanaLoRALoader.py │ │ └── SanaModelLoader.py │ ├── stableDiffusion │ │ ├── StableDiffusionEmbeddingLoader.py │ │ ├── StableDiffusionLoRALoader.py │ │ └── StableDiffusionModelLoader.py │ ├── stableDiffusion3 │ │ ├── StableDiffusion3EmbeddingLoader.py │ │ ├── StableDiffusion3LoRALoader.py │ │ └── StableDiffusion3ModelLoader.py │ ├── stableDiffusionXL │ │ ├── StableDiffusionXLEmbeddingLoader.py │ │ ├── StableDiffusionXLLoRALoader.py │ │ └── StableDiffusionXLModelLoader.py │ └── wuerstchen │ │ ├── WuerstchenEmbeddingLoader.py │ │ ├── WuerstchenLoRALoader.py │ │ └── WuerstchenModelLoader.py ├── modelSampler │ ├── BaseModelSampler.py │ ├── FluxSampler.py │ ├── HiDreamSampler.py │ ├── HunyuanVideoSampler.py │ ├── PixArtAlphaSampler.py │ ├── SanaSampler.py │ ├── StableDiffusion3Sampler.py │ ├── StableDiffusionSampler.py │ ├── StableDiffusionVaeSampler.py │ ├── StableDiffusionXLSampler.py │ └── WuerstchenSampler.py ├── modelSaver │ ├── BaseModelSaver.py │ ├── FluxEmbeddingModelSaver.py │ ├── FluxFineTuneModelSaver.py │ ├── FluxLoRAModelSaver.py │ ├── HiDreamEmbeddingModelSaver.py │ ├── HiDreamLoRAModelSaver.py │ ├── HunyuanVideoEmbeddingModelSaver.py │ ├── HunyuanVideoFineTuneModelSaver.py │ ├── HunyuanVideoLoRAModelSaver.py │ ├── PixArtAlphaEmbeddingModelSaver.py │ ├── PixArtAlphaFineTuneModelSaver.py │ ├── PixArtAlphaLoRAModelSaver.py │ ├── SanaEmbeddingModelSaver.py │ ├── SanaFineTuneModelSaver.py │ ├── SanaLoRAModelSaver.py │ ├── StableDiffusion3EmbeddingModelSaver.py │ ├── StableDiffusion3FineTuneModelSaver.py │ ├── StableDiffusion3LoRAModelSaver.py │ ├── StableDiffusionEmbeddingModelSaver.py │ ├── StableDiffusionFineTuneModelSaver.py │ ├── StableDiffusionLoRAModelSaver.py │ ├── StableDiffusionXLEmbeddingModelSaver.py │ ├── StableDiffusionXLFineTuneModelSaver.py │ ├── StableDiffusionXLLoRAModelSaver.py │ ├── WuerstchenEmbeddingModelSaver.py │ ├── WuerstchenFineTuneModelSaver.py │ ├── WuerstchenLoRAModelSaver.py │ ├── flux │ │ ├── FluxEmbeddingSaver.py │ │ ├── FluxLoRASaver.py │ │ └── FluxModelSaver.py │ ├── hidream │ │ ├── HiDreamEmbeddingSaver.py │ │ ├── HiDreamLoRASaver.py │ │ └── HiDreamModelSaver.py │ ├── hunyuanVideo │ │ ├── HunyuanVideoEmbeddingSaver.py │ │ ├── HunyuanVideoLoRASaver.py │ │ └── HunyuanVideoModelSaver.py │ ├── mixin │ │ ├── DtypeModelSaverMixin.py │ │ ├── EmbeddingSaverMixin.py │ │ ├── InternalModelSaverMixin.py │ │ └── LoRASaverMixin.py │ ├── pixartAlpha │ │ ├── PixArtAlphaEmbeddingSaver.py │ │ ├── PixArtAlphaLoRASaver.py │ │ └── PixArtAlphaModelSaver.py │ ├── sana │ │ ├── SanaEmbeddingSaver.py │ │ ├── SanaLoRASaver.py │ │ └── SanaModelSaver.py │ ├── stableDiffusion │ │ ├── StableDiffusionEmbeddingSaver.py │ │ ├── StableDiffusionLoRASaver.py │ │ └── StableDiffusionModelSaver.py │ ├── stableDiffusion3 │ │ ├── StableDiffusion3EmbeddingSaver.py │ │ ├── StableDiffusion3LoRASaver.py │ │ └── StableDiffusion3ModelSaver.py │ ├── stableDiffusionXL │ │ ├── StableDiffusionXLEmbeddingSaver.py │ │ ├── StableDiffusionXLLoRASaver.py │ │ └── StableDiffusionXLModelSaver.py │ └── wuerstchen │ │ ├── WuerstchenEmbeddingSaver.py │ │ ├── WuerstchenLoRASaver.py │ │ └── WuerstchenModelSaver.py ├── modelSetup │ ├── BaseFluxSetup.py │ ├── BaseHiDreamSetup.py │ ├── BaseHunyuanVideoSetup.py │ ├── BaseModelSetup.py │ ├── BasePixArtAlphaSetup.py │ ├── BaseSanaSetup.py │ ├── BaseStableDiffusion3Setup.py │ ├── BaseStableDiffusionSetup.py │ ├── BaseStableDiffusionXLSetup.py │ ├── BaseWuerstchenSetup.py │ ├── FluxEmbeddingSetup.py │ ├── FluxFineTuneSetup.py │ ├── FluxLoRASetup.py │ ├── HiDreamEmbeddingSetup.py │ ├── HiDreamFineTuneSetup.py │ ├── HiDreamLoRASetup.py │ ├── HunyuanVideoEmbeddingSetup.py │ ├── HunyuanVideoFineTuneSetup.py │ ├── HunyuanVideoLoRASetup.py │ ├── PixArtAlphaEmbeddingSetup.py │ ├── PixArtAlphaFineTuneSetup.py │ ├── PixArtAlphaLoRASetup.py │ ├── SanaEmbeddingSetup.py │ ├── SanaFineTuneSetup.py │ ├── SanaLoRASetup.py │ ├── StableDiffusion3EmbeddingSetup.py │ ├── StableDiffusion3FineTuneSetup.py │ ├── StableDiffusion3LoRASetup.py │ ├── StableDiffusionEmbeddingSetup.py │ ├── StableDiffusionFineTuneSetup.py │ ├── StableDiffusionFineTuneVaeSetup.py │ ├── StableDiffusionLoRASetup.py │ ├── StableDiffusionXLEmbeddingSetup.py │ ├── StableDiffusionXLFineTuneSetup.py │ ├── StableDiffusionXLLoRASetup.py │ ├── WuerstchenEmbeddingSetup.py │ ├── WuerstchenFineTuneSetup.py │ ├── WuerstchenLoRASetup.py │ └── mixin │ │ ├── ModelSetupDebugMixin.py │ │ ├── ModelSetupDiffusionLossMixin.py │ │ ├── ModelSetupDiffusionMixin.py │ │ ├── ModelSetupEmbeddingMixin.py │ │ ├── ModelSetupFlowMatchingMixin.py │ │ └── ModelSetupNoiseMixin.py ├── module │ ├── AdditionalEmbeddingWrapper.py │ ├── AestheticScoreModel.py │ ├── BaseImageCaptionModel.py │ ├── BaseImageMaskModel.py │ ├── BaseRembgModel.py │ ├── Blip2Model.py │ ├── BlipModel.py │ ├── ClipSegModel.py │ ├── EMAModule.py │ ├── GenerateLossesModel.py │ ├── HPSv2ScoreModel.py │ ├── LoRAModule.py │ ├── MaskByColor.py │ ├── RembgHumanModel.py │ ├── RembgModel.py │ ├── WDModel.py │ └── quantized │ │ ├── LinearFp8.py │ │ ├── LinearNf4.py │ │ └── mixin │ │ ├── QuantizedLinearMixin.py │ │ └── QuantizedModuleMixin.py ├── trainer │ ├── BaseTrainer.py │ ├── CloudTrainer.py │ └── GenericTrainer.py ├── ui │ ├── AdditionalEmbeddingsTab.py │ ├── CaptionUI.py │ ├── CloudTab.py │ ├── ConceptTab.py │ ├── ConceptWindow.py │ ├── ConfigList.py │ ├── ConvertModelUI.py │ ├── GenerateCaptionsWindow.py │ ├── GenerateMasksWindow.py │ ├── LoraTab.py │ ├── ModelTab.py │ ├── OffloadingWindow.py │ ├── OptimizerParamsWindow.py │ ├── ProfilingWindow.py │ ├── SampleFrame.py │ ├── SampleParamsWindow.py │ ├── SampleWindow.py │ ├── SamplingTab.py │ ├── SchedulerParamsWindow.py │ ├── TimestepDistributionWindow.py │ ├── TopBar.py │ ├── TrainUI.py │ ├── TrainingTab.py │ └── VideoToolUI.py ├── util │ ├── CustomGradScaler.py │ ├── DiffusionScheduleCoefficients.py │ ├── LayerOffloadConductor.py │ ├── ModelNames.py │ ├── ModelWeightDtypes.py │ ├── NamedParameterGroup.py │ ├── TimedActionMixin.py │ ├── TrainProgress.py │ ├── args │ │ ├── BaseArgs.py │ │ ├── CalculateLossArgs.py │ │ ├── CaptionUIArgs.py │ │ ├── ConvertModelArgs.py │ │ ├── CreateTrainFilesArgs.py │ │ ├── GenerateCaptionsArgs.py │ │ ├── GenerateMasksArgs.py │ │ ├── SampleArgs.py │ │ ├── TrainArgs.py │ │ └── arg_type_util.py │ ├── bf16_stochastic_rounding.py │ ├── callbacks │ │ └── TrainCallbacks.py │ ├── checkpointing_util.py │ ├── commands │ │ └── TrainCommands.py │ ├── concept_stats.py │ ├── config │ │ ├── BaseConfig.py │ │ ├── CloudConfig.py │ │ ├── ConceptConfig.py │ │ ├── SampleConfig.py │ │ ├── SecretsConfig.py │ │ └── TrainConfig.py │ ├── conv_util.py │ ├── convert │ │ ├── convert_diffusers_to_ckpt_util.py │ │ ├── convert_flux_diffusers_to_ckpt.py │ │ ├── convert_hunyuan_video_diffusers_to_ckpt.py │ │ ├── convert_pixart_diffusers_to_ckpt.py │ │ ├── convert_sd3_diffusers_to_ckpt.py │ │ ├── convert_sd_diffusers_to_ckpt.py │ │ ├── convert_sdxl_diffusers_to_ckpt.py │ │ ├── convert_stable_cascade_ckpt_to_diffusers.py │ │ ├── convert_stable_cascade_diffusers_to_ckpt.py │ │ └── rescale_noise_scheduler_to_zero_terminal_snr.py │ ├── create.py │ ├── dtype_util.py │ ├── enum │ │ ├── AudioFormat.py │ │ ├── BalancingStrategy.py │ │ ├── CloudAction.py │ │ ├── CloudFileSync.py │ │ ├── CloudType.py │ │ ├── ConceptType.py │ │ ├── ConfigPart.py │ │ ├── DataType.py │ │ ├── EMAMode.py │ │ ├── FileType.py │ │ ├── GenerateCaptionsModel.py │ │ ├── GenerateMasksModel.py │ │ ├── GradientCheckpointingMethod.py │ │ ├── ImageFormat.py │ │ ├── LearningRateScaler.py │ │ ├── LearningRateScheduler.py │ │ ├── LossScaler.py │ │ ├── LossWeight.py │ │ ├── ModelFormat.py │ │ ├── ModelType.py │ │ ├── NoiseScheduler.py │ │ ├── Optimizer.py │ │ ├── TimeUnit.py │ │ ├── TimestepDistribution.py │ │ ├── TrainingMethod.py │ │ └── VideoFormat.py │ ├── git_util.py │ ├── image_util.py │ ├── loss │ │ ├── masked_loss.py │ │ └── vb_loss.py │ ├── lr_scheduler_util.py │ ├── memory_util.py │ ├── modelSpec │ │ └── ModelSpec.py │ ├── optimizer │ │ ├── CAME.py │ │ ├── adafactor_extensions.py │ │ ├── adam_extensions.py │ │ └── adamw_extensions.py │ ├── optimizer_util.py │ ├── path_util.py │ ├── quantization_util.py │ ├── time_util.py │ ├── torch_util.py │ └── ui │ │ ├── ToolTip.py │ │ ├── UIState.py │ │ ├── components.py │ │ ├── dialogs.py │ │ └── ui_utils.py └── zluda │ ├── ZLUDA.py │ └── ZLUDAInstaller.py ├── pyproject.toml ├── requirements-cuda.txt ├── requirements-default.txt ├── requirements-dev.txt ├── requirements-global.txt ├── requirements-rocm.txt ├── requirements.txt ├── resources ├── docker │ ├── NVIDIA-UI.Dockerfile │ ├── RunPod-NVIDIA-CLI-start.sh.patch │ └── RunPod-NVIDIA-CLI.Dockerfile ├── icons │ ├── icon.ico │ ├── icon.png │ ├── icon_discord.png │ └── icon_small.png ├── images │ └── OneTrainerGUI.gif ├── model_config │ ├── stable_cascade │ │ ├── stable_cascade_prior_1.0b.json │ │ └── stable_cascade_prior_3.6b.json │ ├── stable_diffusion │ │ ├── v1-inference.yaml │ │ ├── v1-inpainting-inference.yaml │ │ ├── v2-inference-v.yaml │ │ ├── v2-inference.yaml │ │ ├── v2-inpainting-inference.yaml │ │ ├── v2-midas-inference.yaml │ │ └── x4-upscaling.yaml │ └── stable_diffusion_xl │ │ ├── sd_xl_base-inpainting.yaml │ │ ├── sd_xl_base.yaml │ │ └── sd_xl_refiner.yaml └── sd_model_spec │ ├── flux_dev_1.0-embedding.json │ ├── flux_dev_1.0-lora.json │ ├── flux_dev_1.0.json │ ├── flux_dev_fill_1.0-embedding.json │ ├── flux_dev_fill_1.0-lora.json │ ├── flux_dev_fill_1.0.json │ ├── hi_dream_full-embedding.json │ ├── hi_dream_full-lora.json │ ├── hi_dream_full.json │ ├── hunyuan_video-embedding.json │ ├── hunyuan_video-lora.json │ ├── hunyuan_video.json │ ├── pixart_alpha_1.0-embedding.json │ ├── pixart_alpha_1.0-lora.json │ ├── pixart_alpha_1.0.json │ ├── pixart_sigma_1.0-embedding.json │ ├── pixart_sigma_1.0-lora.json │ ├── pixart_sigma_1.0.json │ ├── sana-embedding.json │ ├── sana-lora.json │ ├── sana.json │ ├── sd_1.5-embedding.json │ ├── sd_1.5-lora.json │ ├── sd_1.5.json │ ├── sd_1.5_inpainting-embedding.json │ ├── sd_1.5_inpainting-lora.json │ ├── sd_1.5_inpainting.json │ ├── sd_2.0-embedding.json │ ├── sd_2.0-lora.json │ ├── sd_2.0.json │ ├── sd_2.0_depth-embedding.json │ ├── sd_2.0_depth-lora.json │ ├── sd_2.0_depth.json │ ├── sd_2.0_inpainting-embedding.json │ ├── sd_2.0_inpainting-lora.json │ ├── sd_2.0_inpainting.json │ ├── sd_2.1-embedding.json │ ├── sd_2.1-lora.json │ ├── sd_2.1.json │ ├── sd_3.5_1.0-embedding.json │ ├── sd_3.5_1.0-lora.json │ ├── sd_3.5_1.0.json │ ├── sd_3_2b_1.0-embedding.json │ ├── sd_3_2b_1.0-lora.json │ ├── sd_3_2b_1.0.json │ ├── sd_xl_base_1.0-embedding.json │ ├── sd_xl_base_1.0-lora.json │ ├── sd_xl_base_1.0.json │ ├── sd_xl_base_1.0_Inpainting-embedding.json │ ├── sd_xl_base_1.0_inpainting-lora.json │ ├── sd_xl_base_1.0_inpainting.json │ ├── stable_cascade_1.0-embedding.json │ ├── stable_cascade_1.0-lora.json │ ├── stable_cascade_1.0.json │ ├── wuerstchen_2.0-embedding.json │ ├── wuerstchen_2.0-lora.json │ └── wuerstchen_2.0.json ├── run-cmd.sh ├── scripts ├── README.md ├── calculate_loss.py ├── caption_ui.py ├── convert_model.py ├── convert_model_ui.py ├── create_train_files.py ├── generate_captions.py ├── generate_debug_report.py ├── generate_masks.py ├── install_zluda.py ├── sample.py ├── train.py ├── train_remote.py ├── train_ui.py ├── util │ ├── import_util.py │ └── version_check.py └── video_tool_ui.py ├── start-ui.bat ├── start-ui.sh ├── training_presets ├── #flux LoRA.json ├── #hidream LoRA.json ├── #hunyuan video LoRA.json ├── #pixart alpha 1.0.json ├── #pixart sigma 1.0 LoRA.json ├── #pixart sigma 1.0.json ├── #sana 1.6b.json ├── #sd 1.5 LoRA.json ├── #sd 1.5 embedding.json ├── #sd 1.5 inpaint masked.json ├── #sd 1.5 inpaint.json ├── #sd 1.5 masked.json ├── #sd 1.5.json ├── #sd 2.0 inpaint LoRA.json ├── #sd 2.0 inpaint.json ├── #sd 2.1 LoRA.json ├── #sd 2.1 embedding.json ├── #sd 2.1.json ├── #sd 3 LoRA.json ├── #sd 3.json ├── #sdxl 1.0 LoRA.json ├── #sdxl 1.0 embedding.json ├── #sdxl 1.0 inpaint LoRA.json ├── #sdxl 1.0.json ├── #stable cascade.json ├── #wuerstchen 2.0 LoRA.json ├── #wuerstchen 2.0 embedding.json ├── #wuerstchen 2.0.json └── .gitignore ├── update.bat └── update.sh /.editorconfig: -------------------------------------------------------------------------------- 1 | [*] 2 | charset = utf-8 3 | end_of_line = lf 4 | indent_size = 4 5 | indent_style = space 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | max_line_length = 120 9 | tab_width = 4 10 | ij_continuation_indent_size = 8 11 | ij_formatter_off_tag = @formatter:off 12 | ij_formatter_on_tag = @formatter:on 13 | ij_formatter_tags_enabled = true 14 | ij_smart_tabs = false 15 | ij_visual_guides = none 16 | ij_wrap_on_typing = false 17 | 18 | [*.bat] 19 | end_of_line = crlf 20 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | *.bat text eol=crlf 3 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: Nerogar 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report 3 | title: "[Bug]: " 4 | labels: ["bug"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! 10 | - type: textarea 11 | id: actual-behavior 12 | attributes: 13 | label: What happened? 14 | description: Give a description of what happened? 15 | placeholder: Tell us what you saw! 16 | validations: 17 | required: true 18 | - type: textarea 19 | id: expected-behavior 20 | attributes: 21 | label: What did you expect would happen? 22 | description: Also tell us, what did you expect to happen? 23 | placeholder: Tell us what you saw! 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: logs 28 | attributes: 29 | label: Relevant log output 30 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 31 | render: shell 32 | - type: textarea 33 | id: environment 34 | attributes: 35 | label: Generate and upload debug_report.log 36 | description: Please attach your `debug_report.log` file. This file is generated by double-clicking on `export_debug.bat` in the OneTrainer folder on Windows, or by running `./run-cmd.sh generate_debug_report` on Mac/Linux. 37 | placeholder: Drag and drop the file or click the paperclip icon below to upload. 38 | render: '' 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Discord 4 | url: https://discord.gg/KwgcQd5scF 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Request a feature 3 | title: "[Feat]: " 4 | labels: ["enhancement"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this feature request! 10 | - type: textarea 11 | id: motivation 12 | attributes: 13 | label: Describe your use-case. 14 | description: Give as much detail as possible. Can the functionality be achieved using existing features? 15 | placeholder: Tell us what you're missing! 16 | validations: 17 | required: true 18 | - type: textarea 19 | id: proposed-solution 20 | attributes: 21 | label: What would you like to see as a solution? 22 | description: Describe the proposed solution with a workflow. Be precise. 23 | placeholder: Tell us what you'd like to see! 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: alternative-solutions 28 | attributes: 29 | label: Have you considered alternatives? List them here. 30 | description: Give a sketch of alternative solutions you've considered. 31 | placeholder: How else can it be achieved? 32 | validations: 33 | required: false 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # development 2 | .idea 3 | *.bak 4 | *.pyc 5 | *.swp 6 | *~ 7 | debug*.ipynb 8 | debug*.py 9 | /debug* 10 | 11 | # user data 12 | /workspace* 13 | /models* 14 | /training_concepts 15 | /training_samples 16 | /training_user_settings 17 | /external 18 | /update.var 19 | config.json 20 | secrets.json 21 | 22 | # environments 23 | /.venv* 24 | /venv* 25 | /conda_env* 26 | .python-version 27 | *.egg-info 28 | 29 | # pixi environments 30 | .pixi 31 | pixi.lock 32 | pixi.toml 33 | 34 | # misc files 35 | /src 36 | train.bat 37 | debug_report.log 38 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-case-conflict 7 | - id: check-illegal-windows-names 8 | - id: destroyed-symlinks 9 | - id: fix-byte-order-marker 10 | - id: mixed-line-ending 11 | - id: trailing-whitespace 12 | - id: end-of-file-fixer 13 | - id: check-executables-have-shebangs 14 | - id: check-yaml 15 | 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: v0.11.8 18 | hooks: 19 | # Run the Ruff linter, but not the formatter. 20 | - id: ruff 21 | args: ["--fix"] 22 | types_or: [ python, pyi, jupyter ] 23 | 24 | ci: 25 | autofix_prs: false 26 | -------------------------------------------------------------------------------- /docs/CliTraining.md: -------------------------------------------------------------------------------- 1 | # Training from CLI 2 | 3 | All training functionality is available through the CLI command `python scripts/train.py`. The training configuration is 4 | stored in a json file that is passed to this script. Some options require specifying paths to files with a specific 5 | layout. These files can be created using the create_train_files.py script. You can call the script like 6 | this `python scripts/create_train_files.py -h`. 7 | 8 | To simplify the creation of the training config, you can export your settings from the UI by using the export button. 9 | This will create a single file that contains every setting. 10 | -------------------------------------------------------------------------------- /docs/Contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome in any form. Here are some guidelines to make this easier for everyone. 4 | 5 | ## Discussion 6 | 7 | Discussions can be started in two ways. Either through the "Discussions" tab on GitHub, or on the 8 | [Discord](https://discord.gg/KwgcQd5scF) server. If you have any questions or ideas, these are a great way to help 9 | improve OneTrainer. 10 | 11 | ## Issues 12 | 13 | If you find any problems, or you want to suggest new features, please create an issue. Before creating an issue, please 14 | check if one already exists for the same topic to avoid duplications. 15 | 16 | ## Pull Requests 17 | 18 | Before creating a bigger pull request for a new feature, please consider joining or creating a discussion. This can 19 | avoid situations where multiple people work on the same change. It also helps in keeping the changes aligned with the 20 | general project structure and vision. Please also take a look at 21 | the [project structure documentation](ProjectStructure.md). 22 | For smaller changes or fixes, this is not needed. 23 | -------------------------------------------------------------------------------- /docs/DockerImage.md: -------------------------------------------------------------------------------- 1 | # Docker Image 2 | 3 | A Dockerfile based on the `nvidia/cuda:11.8.0-devel-ubuntu22.04` is provided. 4 | 5 | This image requires `nvidia-driver-525` and `nvidia-docker2` installed on the host. 6 | 7 | ## Building Image 8 | 9 | Build using: 10 | 11 | ``` 12 | docker build -t myuser/onetrainer:latest -f Dockerfile . 13 | ``` 14 | 15 | ## Running Image 16 | 17 | This is an example 18 | 19 | ``` 20 | docker run \ 21 | --gpus all \ 22 | -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix \ 23 | -i \ 24 | --tty \ 25 | --shm-size=512m \ 26 | myuser/onetrainer:latest 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/EmbeddingTraining.md: -------------------------------------------------------------------------------- 1 | # Embedding Training 2 | 3 | To get a general overview of the UI, please read the [Quick Start Guide](QuickStartGuide.md) first. 4 | 5 | Training an embedding requires a few different settings than Fine-Tuning or LoRA training. You can enable Embedding 6 | training in the dropdown in the top right corner. 7 | 8 | ### Concepts 9 | 10 | To train an embedding, you need to use special prompts. Each prompt needs to include `` in the place where 11 | you want to place your trained embedding. For example, if you want to train a style of a painting, your prompt could 12 | be `a painting in the style of `. If you don't want to add a custom prompt for every training image, you can 13 | select "From single text file" as the prompt source of your concept. Then select a text file containing one prompt per 14 | line. An example of such a file can be found in the embedding_templates directory. 15 | 16 | ### Special Embedding Settings 17 | 18 | If you select "Embedding" as your training method, a new tab called "embedding" will appear. Here you can specify: 19 | 20 | - Base embedding: An already trained embedding you want to continue training on. Leave this blank to train a new 21 | embedding 22 | - Token count: The number of tokens to train. A higher token count is better at learning things, but will also take up 23 | more of the available tokens in each prompt when generating an image. 24 | - Initial embedding text: The text to use when initializing a new embedding. Choosing a good text can significantly 25 | speed up training. 26 | -------------------------------------------------------------------------------- /docs/Overview.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | [Quick Start Guide](QuickStartGuide.md) 4 | 5 | [Training from CLI](CliTraining.md) 6 | 7 | [Embedding Training](EmbeddingTraining.md) 8 | 9 | [Captioning and Masking](CaptioningAndMasking.md) 10 | -------------------------------------------------------------------------------- /embedding_templates/.gitignore: -------------------------------------------------------------------------------- 1 | # ignores everything except the builtin template files 2 | * 3 | !.gitignore 4 | !subject.txt 5 | -------------------------------------------------------------------------------- /embedding_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a 2 | a rendering of a 3 | a cropped photo of the 4 | the photo of a 5 | a photo of a clean 6 | a photo of a dirty 7 | a dark photo of the 8 | a photo of my 9 | a photo of the cool 10 | a close-up photo of a 11 | a bright photo of the 12 | a cropped photo of a 13 | a photo of the 14 | a good photo of the 15 | a photo of one 16 | a close-up photo of the 17 | a rendition of the 18 | a photo of the clean 19 | a rendition of a 20 | a photo of a nice 21 | a good photo of a 22 | a photo of the nice 23 | a photo of the small 24 | a photo of the weird 25 | a photo of the large 26 | a photo of a cool 27 | a photo of a small 28 | -------------------------------------------------------------------------------- /export_debug.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Avoid footgun by explictly navigating to the directory containing the batch file 4 | cd /d "%~dp0" 5 | 6 | REM Verify that OneTrainer is our current working directory 7 | if not exist "scripts\train_ui.py" ( 8 | echo Error: train_ui.py does not exist, you have done something very wrong. Reclone the repository. 9 | goto :end 10 | ) 11 | 12 | if not defined PYTHON (set PYTHON=python) 13 | if not defined VENV_DIR (set "VENV_DIR=%~dp0venv") 14 | 15 | :check_venv 16 | dir "%VENV_DIR%" >NUL 2>NUL 17 | if not errorlevel 1 goto :activate_venv 18 | echo venv not found, please run install.bat first 19 | goto :end 20 | 21 | :activate_venv 22 | echo activating venv %VENV_DIR% 23 | set PYTHON="%VENV_DIR%\Scripts\python.exe" 24 | echo Using Python %PYTHON% 25 | 26 | :launch 27 | echo Generating debug report... 28 | %PYTHON% scripts\generate_debug_report.py 29 | if errorlevel 1 ( 30 | echo Error: Debug report generation failed with code %ERRORLEVEL% 31 | ) else ( 32 | echo Now upload the debug report to your Github issue or post in Discord. 33 | ) 34 | 35 | :end 36 | pause 37 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | source "${BASH_SOURCE[0]%/*}/lib.include.sh" 6 | 7 | prepare_runtime_environment 8 | -------------------------------------------------------------------------------- /modules/cloud/BaseFileSync.py: -------------------------------------------------------------------------------- 1 | 2 | import concurrent.futures 3 | from abc import ABCMeta, abstractmethod 4 | from math import ceil 5 | from pathlib import Path 6 | 7 | from modules.util.config.CloudConfig import CloudConfig, CloudSecretsConfig 8 | 9 | 10 | class BaseFileSync(metaclass=ABCMeta): 11 | def __init__(self, config: CloudConfig, secrets: CloudSecretsConfig): 12 | super().__init__() 13 | self.config = config 14 | self.secrets = secrets 15 | 16 | 17 | def sync_up(self,local : Path,remote : Path): 18 | if local.is_dir(): 19 | self.sync_up_dir(local=local,remote=remote,recursive=True) 20 | else: 21 | self.sync_up_file(local=local,remote=remote) 22 | 23 | @abstractmethod 24 | def sync_up_file(self,local : Path,remote : Path): 25 | pass 26 | 27 | @abstractmethod 28 | def sync_up_dir(self,local : Path,remote: Path,recursive: bool): 29 | pass 30 | 31 | @abstractmethod 32 | def sync_down_file(self,local : Path,remote : Path): 33 | pass 34 | 35 | @abstractmethod 36 | def sync_down_dir(self,local : Path,remote : Path,filter=None): 37 | pass 38 | 39 | @staticmethod 40 | def _run_batches(fn,tasks,workers:int,max_batch_size=None): 41 | futures=[] 42 | if len(tasks) == 0: 43 | return 44 | batch_size=ceil(len(tasks) / workers) 45 | if max_batch_size is not None: 46 | batch_size=min(max_batch_size,batch_size) 47 | with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: 48 | for i in range(0,len(tasks),batch_size): 49 | batch=tasks[i:i+batch_size] 50 | futures.append(executor.submit(fn,batch)) 51 | 52 | for future in futures: 53 | if (exception:=future.exception()): 54 | raise exception 55 | -------------------------------------------------------------------------------- /modules/dataLoader/BaseDataLoader.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from modules.dataLoader.mixin.DataLoaderMgdsMixin import DataLoaderMgdsMixin 4 | 5 | from mgds.MGDS import MGDS, TrainDataLoader 6 | 7 | import torch 8 | 9 | 10 | class BaseDataLoader( 11 | DataLoaderMgdsMixin, 12 | metaclass=ABCMeta, 13 | ): 14 | 15 | def __init__( 16 | self, 17 | train_device: torch.device, 18 | temp_device: torch.device, 19 | ): 20 | super().__init__() 21 | 22 | self.train_device = train_device 23 | self.temp_device = temp_device 24 | 25 | @abstractmethod 26 | def get_data_set(self) -> MGDS: 27 | pass 28 | 29 | @abstractmethod 30 | def get_data_loader(self) -> TrainDataLoader: 31 | pass 32 | -------------------------------------------------------------------------------- /modules/dataLoader/flux/ShuffleFluxFillMaskChannels.py: -------------------------------------------------------------------------------- 1 | from mgds.PipelineModule import PipelineModule 2 | from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule 3 | 4 | 5 | class ShuffleFluxFillMaskChannels( 6 | PipelineModule, 7 | RandomAccessPipelineModule, 8 | ): 9 | def __init__(self, in_name: str, out_name: str): 10 | super().__init__() 11 | self.in_name = in_name 12 | self.out_name = out_name 13 | 14 | def length(self) -> int: 15 | return self._get_previous_length(self.in_name) 16 | 17 | def get_inputs(self) -> list[str]: 18 | return [self.in_name] 19 | 20 | def get_outputs(self) -> list[str]: 21 | return [self.out_name] 22 | 23 | def get_item(self, variation: int, index: int, requested_name: str = None) -> dict: 24 | mask = self._get_previous_item(variation, self.in_name, index) 25 | 26 | height, width = mask.shape[1], mask.shape[2] 27 | vae_scale_factor = 8 28 | 29 | # height, 8, width, 8 30 | mask = mask.view( 31 | height // vae_scale_factor, 32 | vae_scale_factor, 33 | width // vae_scale_factor, 34 | vae_scale_factor, 35 | ) 36 | # 8, 8, height, width 37 | mask = mask.permute(1, 3, 0, 2) 38 | # 8*8, height, width 39 | mask = mask.reshape( 40 | vae_scale_factor * vae_scale_factor, 41 | height // vae_scale_factor, 42 | width // vae_scale_factor, 43 | ) 44 | 45 | return { 46 | self.out_name: mask 47 | } 48 | -------------------------------------------------------------------------------- /modules/dataLoader/mixin/DataLoaderMgdsMixin.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import ABCMeta 3 | 4 | from modules.util.config.ConceptConfig import ConceptConfig 5 | from modules.util.config.TrainConfig import TrainConfig 6 | from modules.util.enum.ConceptType import ConceptType 7 | from modules.util.TrainProgress import TrainProgress 8 | 9 | from mgds.MGDS import MGDS 10 | from mgds.PipelineModule import PipelineState 11 | 12 | import torch 13 | 14 | 15 | class DataLoaderMgdsMixin(metaclass=ABCMeta): 16 | 17 | def _create_mgds( 18 | self, 19 | config: TrainConfig, 20 | definition: list, 21 | train_progress: TrainProgress, 22 | is_validation: bool = False, 23 | ): 24 | concepts = config.concepts 25 | if concepts is None: 26 | with open(config.concept_file_name, 'r') as f: 27 | concepts = [ConceptConfig.default_values().from_dict(c) for c in json.load(f)] 28 | 29 | # choose all validation concepts, or none of them, depending on is_validation 30 | concepts = [concept for concept in concepts if (ConceptType(concept.type) == ConceptType.VALIDATION) == is_validation] 31 | 32 | # convert before passing to MGDS 33 | concepts = [c.to_dict() for c in concepts] 34 | 35 | settings = { 36 | "target_resolution": config.resolution, 37 | "target_frames": config.frames, 38 | } 39 | 40 | # Just defaults for now. 41 | ds = MGDS( 42 | torch.device(config.train_device), 43 | concepts, 44 | settings, 45 | definition, 46 | batch_size=config.batch_size, 47 | state=PipelineState(config.dataloader_threads), 48 | initial_epoch=train_progress.epoch, 49 | initial_epoch_sample=train_progress.epoch_sample, 50 | ) 51 | 52 | return ds 53 | -------------------------------------------------------------------------------- /modules/dataLoader/wuerstchen/EncodeWuerstchenEffnet.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | 3 | from modules.model.WuerstchenModel import WuerstchenEfficientNetEncoder 4 | 5 | from mgds.MGDS import PipelineModule 6 | from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule 7 | 8 | import torch 9 | 10 | 11 | class EncodeWuerstchenEffnet( 12 | PipelineModule, 13 | RandomAccessPipelineModule, 14 | ): 15 | def __init__( 16 | self, 17 | in_name: str, 18 | out_name: str, 19 | effnet_encoder: WuerstchenEfficientNetEncoder, 20 | autocast_contexts: list[torch.autocast | None] = None, 21 | dtype: torch.dtype | None = None, 22 | ): 23 | super().__init__() 24 | self.in_name = in_name 25 | self.out_name = out_name 26 | self.effnet_encoder = effnet_encoder 27 | 28 | self.autocast_contexts = [nullcontext()] if autocast_contexts is None else autocast_contexts 29 | self.dtype = dtype 30 | 31 | def length(self) -> int: 32 | return self._get_previous_length(self.in_name) 33 | 34 | def get_inputs(self) -> list[str]: 35 | return [self.in_name] 36 | 37 | def get_outputs(self) -> list[str]: 38 | return [self.out_name] 39 | 40 | def get_item(self, variation: int, index: int, requested_name: str = None) -> dict: 41 | image = self._get_previous_item(variation, self.in_name, index) 42 | 43 | if self.dtype: 44 | image = image.to(device=self.effnet_encoder.device, dtype=self.dtype) 45 | 46 | with self._all_contexts(self.autocast_contexts): 47 | image_embeddings = self.effnet_encoder(image.unsqueeze(0)).squeeze() 48 | 49 | return { 50 | self.out_name: image_embeddings 51 | } 52 | -------------------------------------------------------------------------------- /modules/model/util/clip_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 4 | 5 | 6 | def encode_clip( 7 | text_encoder: CLIPTextModel | CLIPTextModelWithProjection, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | add_output: bool = True, 12 | text_encoder_output: Tensor | None = None, 13 | add_pooled_output: bool = False, 14 | pooled_text_encoder_output: Tensor | None = None, 15 | use_attention_mask: bool = True, 16 | attention_mask: Tensor | None = None, 17 | add_layer_norm: bool = True, 18 | ) -> tuple[Tensor, Tensor]: 19 | if (add_output and text_encoder_output is None) \ 20 | or (add_pooled_output and pooled_text_encoder_output is None) \ 21 | and text_encoder is not None: 22 | 23 | text_encoder_output = text_encoder( 24 | tokens, 25 | attention_mask=attention_mask if use_attention_mask else None, 26 | return_dict=True, 27 | output_hidden_states=True, 28 | ) 29 | 30 | pooled_text_encoder_output = None 31 | if add_pooled_output: 32 | if hasattr(text_encoder_output, "text_embeds"): 33 | pooled_text_encoder_output = text_encoder_output.text_embeds 34 | if hasattr(text_encoder_output, "pooler_output"): 35 | pooled_text_encoder_output = text_encoder_output.pooler_output 36 | 37 | text_encoder_output = text_encoder_output.hidden_states[default_layer - layer_skip] if add_output else None 38 | 39 | if add_layer_norm and text_encoder_output is not None: 40 | final_layer_norm = text_encoder.text_model.final_layer_norm 41 | text_encoder_output = final_layer_norm(text_encoder_output) 42 | 43 | return text_encoder_output, pooled_text_encoder_output 44 | -------------------------------------------------------------------------------- /modules/model/util/gemma_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import Gemma2Model 4 | 5 | 6 | def encode_gemma( 7 | text_encoder: Gemma2Model, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | text_encoder_output: Tensor | None = None, 12 | use_attention_mask: bool = True, 13 | attention_mask: Tensor | None = None, 14 | add_layer_norm: bool = True, 15 | ) -> Tensor: 16 | if text_encoder_output is None and text_encoder is not None: 17 | text_encoder_output = text_encoder( 18 | tokens, 19 | attention_mask=attention_mask if use_attention_mask else None, 20 | output_hidden_states=True, 21 | return_dict=True, 22 | use_cache=False, 23 | ) 24 | hidden_state_output_index = default_layer - layer_skip 25 | text_encoder_output = text_encoder_output.hidden_states[hidden_state_output_index] 26 | if hidden_state_output_index != -1 and add_layer_norm: 27 | text_encoder_output = text_encoder.norm(text_encoder_output) 28 | 29 | return text_encoder_output 30 | -------------------------------------------------------------------------------- /modules/model/util/llama_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import LlamaModel 4 | 5 | 6 | def encode_llama( 7 | text_encoder: LlamaModel, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | text_encoder_output: Tensor | None = None, 12 | use_attention_mask: bool = True, 13 | attention_mask: Tensor | None = None, 14 | crop_start: int | None = None, 15 | ) -> tuple[Tensor, Tensor, Tensor]: 16 | if text_encoder_output is None and text_encoder is not None: 17 | text_encoder_output = text_encoder( 18 | tokens, 19 | attention_mask=attention_mask if use_attention_mask else None, 20 | output_hidden_states=True, 21 | return_dict=True, 22 | use_cache=False, 23 | ) 24 | hidden_state_output_index = default_layer - layer_skip 25 | text_encoder_output = text_encoder_output.hidden_states[hidden_state_output_index] 26 | 27 | if crop_start is not None: 28 | tokens = tokens[:, crop_start:] 29 | text_encoder_output = text_encoder_output[:, crop_start:] 30 | attention_mask = attention_mask[:, crop_start:] 31 | 32 | return text_encoder_output, attention_mask, tokens 33 | -------------------------------------------------------------------------------- /modules/model/util/t5_util.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from transformers import T5EncoderModel 4 | 5 | 6 | def encode_t5( 7 | text_encoder: T5EncoderModel, 8 | tokens: Tensor | None = None, 9 | default_layer: int = -1, 10 | layer_skip: int = 0, 11 | text_encoder_output: Tensor | None = None, 12 | use_attention_mask: bool = True, 13 | attention_mask: Tensor | None = None, 14 | add_layer_norm: bool = True, 15 | ) -> Tensor: 16 | if text_encoder_output is None and text_encoder is not None: 17 | text_encoder_output = text_encoder( 18 | tokens, 19 | attention_mask=attention_mask if use_attention_mask else None, 20 | output_hidden_states=True, 21 | return_dict=True, 22 | ) 23 | hidden_state_output_index = default_layer - layer_skip 24 | text_encoder_output = text_encoder_output.hidden_states[hidden_state_output_index] 25 | if hidden_state_output_index != -1 and add_layer_norm: 26 | text_encoder_output = text_encoder.encoder.final_layer_norm(text_encoder_output) 27 | 28 | return text_encoder_output 29 | -------------------------------------------------------------------------------- /modules/modelLoader/BaseModelLoader.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | import os 4 | from abc import ABCMeta, abstractmethod 5 | 6 | from modules.model.BaseModel import BaseModel 7 | from modules.util.enum.ModelType import ModelType 8 | from modules.util.ModelNames import ModelNames 9 | from modules.util.ModelWeightDtypes import ModelWeightDtypes 10 | from modules.util.TrainProgress import TrainProgress 11 | 12 | import torch 13 | 14 | 15 | class BaseModelLoader(metaclass=ABCMeta): 16 | 17 | def _load_internal_state( 18 | self, 19 | model: BaseModel, 20 | base_model_name: str, 21 | ): 22 | with open(os.path.join(base_model_name, "meta.json"), "r") as meta_file: 23 | meta = json.load(meta_file) 24 | train_progress = TrainProgress( 25 | epoch=meta['train_progress']['epoch'], 26 | epoch_step=meta['train_progress']['epoch_step'], 27 | epoch_sample=meta['train_progress']['epoch_sample'], 28 | global_step=meta['train_progress']['global_step'], 29 | ) 30 | 31 | # optimizer 32 | with contextlib.suppress(FileNotFoundError): 33 | model.optimizer_state_dict = torch.load(os.path.join(base_model_name, "optimizer", "optimizer.pt"), 34 | weights_only=True) 35 | 36 | # ema 37 | with contextlib.suppress(FileNotFoundError): 38 | model.ema_state_dict = torch.load(os.path.join(base_model_name, "ema", "ema.pt"), weights_only=True) 39 | 40 | # meta 41 | model.train_progress = train_progress 42 | 43 | 44 | @abstractmethod 45 | def load( 46 | self, 47 | model_type: ModelType, 48 | model_names: ModelNames, 49 | weight_dtypes: ModelWeightDtypes, 50 | ) -> BaseModel | None: 51 | pass 52 | -------------------------------------------------------------------------------- /modules/modelLoader/HiDreamFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelLoader.BaseModelLoader import BaseModelLoader 3 | from modules.modelLoader.hiDream.HiDreamEmbeddingLoader import HiDreamEmbeddingLoader 4 | from modules.modelLoader.hiDream.HiDreamModelLoader import HiDreamModelLoader 5 | from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin 6 | from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin 7 | from modules.util.enum.ModelType import ModelType 8 | from modules.util.ModelNames import ModelNames 9 | from modules.util.ModelWeightDtypes import ModelWeightDtypes 10 | 11 | 12 | class HiDreamFineTuneModelLoader( 13 | BaseModelLoader, 14 | ModelSpecModelLoaderMixin, 15 | InternalModelLoaderMixin, 16 | ): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def _default_model_spec_name( 21 | self, 22 | model_type: ModelType, 23 | ) -> str | None: 24 | match model_type: 25 | case ModelType.HI_DREAM_FULL: 26 | return "resources/sd_model_spec/hi_dream_full.json" 27 | case _: 28 | return None 29 | 30 | def load( 31 | self, 32 | model_type: ModelType, 33 | model_names: ModelNames, 34 | weight_dtypes: ModelWeightDtypes, 35 | ) -> HiDreamModel | None: 36 | base_model_loader = HiDreamModelLoader() 37 | embedding_loader = HiDreamEmbeddingLoader() 38 | 39 | model = HiDreamModel(model_type=model_type) 40 | 41 | self._load_internal_data(model, model_names.base_model) 42 | model.model_spec = self._load_default_model_spec(model_type) 43 | 44 | base_model_loader.load(model, model_type, model_names, weight_dtypes) 45 | embedding_loader.load(model, model_names.base_model, model_names) 46 | 47 | return model 48 | -------------------------------------------------------------------------------- /modules/modelLoader/SanaEmbeddingModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.BaseModelLoader import BaseModelLoader 3 | from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin 4 | from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin 5 | from modules.modelLoader.sana.SanaEmbeddingLoader import SanaEmbeddingLoader 6 | from modules.modelLoader.sana.SanaModelLoader import SanaModelLoader 7 | from modules.util.enum.ModelType import ModelType 8 | from modules.util.ModelNames import ModelNames 9 | from modules.util.ModelWeightDtypes import ModelWeightDtypes 10 | 11 | 12 | class SanaEmbeddingModelLoader( 13 | BaseModelLoader, 14 | ModelSpecModelLoaderMixin, 15 | InternalModelLoaderMixin, 16 | ): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def _default_model_spec_name( 21 | self, 22 | model_type: ModelType, 23 | ) -> str | None: 24 | match model_type: 25 | case ModelType.SANA: 26 | return "resources/sd_model_spec/sana-embedding.json" 27 | case _: 28 | return None 29 | 30 | def load( 31 | self, 32 | model_type: ModelType, 33 | model_names: ModelNames, 34 | weight_dtypes: ModelWeightDtypes, 35 | ) -> SanaModel | None: 36 | base_model_loader = SanaModelLoader() 37 | embedding_loader = SanaEmbeddingLoader() 38 | 39 | model = SanaModel(model_type=model_type) 40 | self._load_internal_data(model, model_names.embedding.model_name) 41 | model.model_spec = self._load_default_model_spec(model_type) 42 | 43 | if model_names.base_model: 44 | base_model_loader.load(model, model_type, model_names, weight_dtypes) 45 | embedding_loader.load(model, model_names.embedding.model_name, model_names) 46 | 47 | return model 48 | -------------------------------------------------------------------------------- /modules/modelLoader/SanaFineTuneModelLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.BaseModelLoader import BaseModelLoader 3 | from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin 4 | from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin 5 | from modules.modelLoader.sana.SanaEmbeddingLoader import SanaEmbeddingLoader 6 | from modules.modelLoader.sana.SanaModelLoader import SanaModelLoader 7 | from modules.util.enum.ModelType import ModelType 8 | from modules.util.ModelNames import ModelNames 9 | from modules.util.ModelWeightDtypes import ModelWeightDtypes 10 | 11 | 12 | class SanaFineTuneModelLoader( 13 | BaseModelLoader, 14 | ModelSpecModelLoaderMixin, 15 | InternalModelLoaderMixin, 16 | ): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def _default_model_spec_name( 21 | self, 22 | model_type: ModelType, 23 | ) -> str | None: 24 | match model_type: 25 | case ModelType.SANA: 26 | return "resources/sd_model_spec/sana.json" 27 | case _: 28 | return None 29 | 30 | def load( 31 | self, 32 | model_type: ModelType, 33 | model_names: ModelNames, 34 | weight_dtypes: ModelWeightDtypes, 35 | ) -> SanaModel | None: 36 | base_model_loader = SanaModelLoader() 37 | embedding_loader = SanaEmbeddingLoader() 38 | 39 | model = SanaModel(model_type=model_type) 40 | 41 | self._load_internal_data(model, model_names.base_model) 42 | model.model_spec = self._load_default_model_spec(model_type) 43 | 44 | base_model_loader.load(model, model_type, model_names, weight_dtypes) 45 | embedding_loader.load(model, model_names.base_model, model_names) 46 | 47 | return model 48 | -------------------------------------------------------------------------------- /modules/modelLoader/flux/FluxEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class FluxEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: FluxModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/flux/FluxLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.FluxModel import FluxModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class FluxLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_flux_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: FluxModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/hiDream/HiDreamEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class HiDreamEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: HiDreamModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/hiDream/HiDreamLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.HiDreamModel import HiDreamModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_hidream_lora import convert_hidream_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class HiDreamLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_hidream_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: HiDreamModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/hunyuanVideo/HunyuanVideoEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class HunyuanVideoEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: HunyuanVideoModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/hunyuanVideo/HunyuanVideoLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_hunyuan_video_lora import convert_hunyuan_video_lora_key_sets 7 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 8 | 9 | 10 | class HunyuanVideoLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_hunyuan_video_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: HunyuanVideoModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/mixin/InternalModelLoaderMixin.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | import os 4 | from abc import ABCMeta 5 | 6 | from modules.model.BaseModel import BaseModel 7 | from modules.util.TrainProgress import TrainProgress 8 | 9 | import torch 10 | 11 | 12 | class InternalModelLoaderMixin(metaclass=ABCMeta): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _load_internal_data( 17 | self, 18 | model: BaseModel, 19 | model_name: str, 20 | ): 21 | if os.path.exists(os.path.join(model_name, "meta.json")): 22 | # train progress 23 | with open(os.path.join(model_name, "meta.json"), "r") as meta_file: 24 | meta = json.load(meta_file) 25 | train_progress = TrainProgress( 26 | epoch=meta['train_progress']['epoch'], 27 | epoch_step=meta['train_progress']['epoch_step'], 28 | epoch_sample=meta['train_progress']['epoch_sample'], 29 | global_step=meta['train_progress']['global_step'], 30 | ) 31 | 32 | # optimizer 33 | with contextlib.suppress(FileNotFoundError): 34 | model.optimizer_state_dict = torch.load(os.path.join(model_name, "optimizer", "optimizer.pt"), 35 | weights_only=True) 36 | 37 | # ema 38 | with contextlib.suppress(FileNotFoundError): 39 | model.ema_state_dict = torch.load(os.path.join(model_name, "ema", "ema.pt"), weights_only=True) 40 | 41 | # meta 42 | model.train_progress = train_progress 43 | -------------------------------------------------------------------------------- /modules/modelLoader/mixin/ModelSpecModelLoaderMixin.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | from abc import ABCMeta 4 | 5 | from modules.util.enum.ModelType import ModelType 6 | from modules.util.modelSpec.ModelSpec import ModelSpec 7 | 8 | from safetensors import safe_open 9 | 10 | 11 | class ModelSpecModelLoaderMixin(metaclass=ABCMeta): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def _default_model_spec_name( 16 | self, 17 | model_type: ModelType, 18 | ) -> str | None: 19 | return None 20 | 21 | def _load_default_model_spec( 22 | self, 23 | model_type: ModelType, 24 | safetensors_file_name: str | None = None, 25 | ) -> ModelSpec: 26 | model_spec = None 27 | 28 | model_spec_name = self._default_model_spec_name(model_type) 29 | if model_spec_name: 30 | with open(model_spec_name, "r", encoding="utf-8") as model_spec_file: 31 | model_spec = ModelSpec.from_dict(json.load(model_spec_file)) 32 | else: 33 | model_spec = ModelSpec() 34 | 35 | if safetensors_file_name: 36 | with contextlib.suppress(Exception), safe_open(safetensors_file_name, framework="pt") as f: 37 | if "modelspec.sai_model_spec" in f.metadata(): 38 | model_spec = ModelSpec.from_dict(f.metadata()) 39 | 40 | return model_spec 41 | -------------------------------------------------------------------------------- /modules/modelLoader/mixin/SDConfigModelLoaderMixin.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABCMeta 3 | 4 | from modules.util.enum.ModelType import ModelType 5 | 6 | import yaml 7 | 8 | 9 | class SDConfigModelLoaderMixin(metaclass=ABCMeta): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def _default_sd_config_name( 14 | self, 15 | model_type: ModelType, 16 | ) -> str | None: 17 | return None 18 | 19 | def _get_sd_config_name( 20 | self, 21 | model_type: ModelType, 22 | base_model_name: str | None = None, 23 | ) -> str | None: 24 | yaml_name = None 25 | 26 | if base_model_name: 27 | new_yaml_name = os.path.splitext(base_model_name)[0] + '.yaml' 28 | if os.path.exists(new_yaml_name): 29 | yaml_name = new_yaml_name 30 | 31 | if not yaml_name: 32 | new_yaml_name = os.path.splitext(base_model_name)[0] + '.yml' 33 | if os.path.exists(new_yaml_name): 34 | yaml_name = new_yaml_name 35 | 36 | if not yaml_name: 37 | new_yaml_name = self._default_sd_config_name(model_type) 38 | if new_yaml_name: 39 | yaml_name = new_yaml_name 40 | 41 | return yaml_name 42 | 43 | def _load_sd_config( 44 | self, 45 | model_type: ModelType, 46 | base_model_name: str | None = None, 47 | ) -> dict | None: 48 | yaml_name = self._get_sd_config_name(model_type, base_model_name) 49 | 50 | if yaml_name: 51 | with open(yaml_name, "r") as f: 52 | return yaml.safe_load(f) 53 | else: 54 | return None 55 | -------------------------------------------------------------------------------- /modules/modelLoader/pixartAlpha/PixArtAlphaEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class PixArtAlphaEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: PixArtAlphaModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/pixartAlpha/PixArtAlphaLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_pixart_lora import convert_pixart_lora_key_sets 8 | 9 | 10 | class PixArtAlphaLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_pixart_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: PixArtAlphaModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/sana/SanaEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class SanaEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: SanaModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/sana/SanaLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.SanaModel import SanaModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | 8 | 9 | class SanaLoRALoader( 10 | LoRALoaderMixin 11 | ): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 16 | return None # TODO: not yet implemented 17 | 18 | def load( 19 | self, 20 | model: SanaModel, 21 | model_names: ModelNames, 22 | ): 23 | return self._load(model, model_names) 24 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion/StableDiffusionEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class StableDiffusionEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: StableDiffusionModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion/StableDiffusionLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.StableDiffusionModel import StableDiffusionModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_sd_lora import convert_sd_lora_key_sets 8 | 9 | 10 | class StableDiffusionLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_sd_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: StableDiffusionModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion3/StableDiffusion3EmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class StableDiffusion3EmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: StableDiffusion3Model, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusion3/StableDiffusion3LoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_sd3_lora import convert_sd3_lora_key_sets 8 | 9 | 10 | class StableDiffusion3LoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_sd3_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: StableDiffusion3Model, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusionXL/StableDiffusionXLEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class StableDiffusionXLEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: StableDiffusionXLModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/stableDiffusionXL/StableDiffusionXLLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_sdxl_lora import convert_sdxl_lora_key_sets 8 | 9 | 10 | class StableDiffusionXLLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | return convert_sdxl_lora_key_sets() 18 | 19 | def load( 20 | self, 21 | model: StableDiffusionXLModel, 22 | model_names: ModelNames, 23 | ): 24 | return self._load(model, model_names) 25 | -------------------------------------------------------------------------------- /modules/modelLoader/wuerstchen/WuerstchenEmbeddingLoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelLoader.mixin.EmbeddingLoaderMixin import EmbeddingLoaderMixin 3 | from modules.util.ModelNames import ModelNames 4 | 5 | 6 | class WuerstchenEmbeddingLoader( 7 | EmbeddingLoaderMixin 8 | ): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def load( 13 | self, 14 | model: WuerstchenModel, 15 | directory: str, 16 | model_names: ModelNames, 17 | ): 18 | self._load(model, directory, model_names) 19 | -------------------------------------------------------------------------------- /modules/modelLoader/wuerstchen/WuerstchenLoRALoader.py: -------------------------------------------------------------------------------- 1 | from modules.model.BaseModel import BaseModel 2 | from modules.model.WuerstchenModel import WuerstchenModel 3 | from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin 4 | from modules.util.ModelNames import ModelNames 5 | 6 | from omi_model_standards.convert.lora.convert_lora_util import LoraConversionKeySet 7 | from omi_model_standards.convert.lora.convert_stable_cascade_lora import convert_stable_cascade_lora_key_sets 8 | 9 | 10 | class WuerstchenLoRALoader( 11 | LoRALoaderMixin 12 | ): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: 17 | if model.model_type.is_stable_cascade(): 18 | return convert_stable_cascade_lora_key_sets() 19 | return None 20 | 21 | def load( 22 | self, 23 | model: WuerstchenModel, 24 | model_names: ModelNames, 25 | ): 26 | return self._load(model, model_names) 27 | -------------------------------------------------------------------------------- /modules/modelSaver/BaseModelSaver.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from modules.model.BaseModel import BaseModel 4 | from modules.util.enum.ModelFormat import ModelFormat 5 | from modules.util.enum.ModelType import ModelType 6 | 7 | import torch 8 | 9 | 10 | class BaseModelSaver(metaclass=ABCMeta): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | @abstractmethod 15 | def save( 16 | self, 17 | model: BaseModel, 18 | model_type: ModelType, 19 | output_model_format: ModelFormat, 20 | output_model_destination: str, 21 | dtype: torch.dtype | None, 22 | ): 23 | pass 24 | -------------------------------------------------------------------------------- /modules/modelSaver/FluxEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class FluxEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: FluxModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = FluxEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/FluxFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver 4 | from modules.modelSaver.flux.FluxModelSaver import FluxModelSaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class FluxFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: FluxModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = FluxModelSaver() 28 | embedding_model_saver = FluxEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/FluxLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.FluxModel import FluxModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver 4 | from modules.modelSaver.flux.FluxLoRASaver import FluxLoRASaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class FluxLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: FluxModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = FluxLoRASaver() 28 | embedding_model_saver = FluxEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/HiDreamEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hidream.HiDreamEmbeddingSaver import HiDreamEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class HiDreamEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: HiDreamModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = HiDreamEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/HiDreamLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HiDreamModel import HiDreamModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hidream.HiDreamEmbeddingSaver import HiDreamEmbeddingSaver 4 | from modules.modelSaver.hidream.HiDreamLoRASaver import HiDreamLoRASaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class HiDreamLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: HiDreamModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = HiDreamLoRASaver() 28 | embedding_model_saver = HiDreamEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/HunyuanVideoEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver 4 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class HunyuanVideoEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: HunyuanVideoModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = HunyuanVideoEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/HunyuanVideoFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver 4 | from modules.modelSaver.hunyuanVideo.HunyuanVideoModelSaver import HunyuanVideoModelSaver 5 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class HunyuanVideoFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: HunyuanVideoModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = HunyuanVideoModelSaver() 28 | embedding_model_saver = HunyuanVideoEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/HunyuanVideoLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.model.HunyuanVideoModel import HunyuanVideoModel 3 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 4 | from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver 5 | from modules.modelSaver.hunyuanVideo.HunyuanVideoLoRASaver import HunyuanVideoLoRASaver 6 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 7 | from modules.util.enum.ModelFormat import ModelFormat 8 | from modules.util.enum.ModelType import ModelType 9 | 10 | import torch 11 | 12 | 13 | class HunyuanVideoLoRAModelSaver( 14 | BaseModelSaver, 15 | InternalModelSaverMixin, 16 | ): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def save( 21 | self, 22 | model: HunyuanVideoModel, 23 | model_type: ModelType, 24 | output_model_format: ModelFormat, 25 | output_model_destination: str, 26 | dtype: torch.dtype | None, 27 | ): 28 | lora_model_saver = HunyuanVideoLoRASaver() 29 | embedding_model_saver = HunyuanVideoEmbeddingSaver() 30 | 31 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 32 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 33 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 34 | 35 | if output_model_format == ModelFormat.INTERNAL: 36 | self._save_internal_data(model, output_model_destination) 37 | -------------------------------------------------------------------------------- /modules/modelSaver/PixArtAlphaEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class PixArtAlphaEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: PixArtAlphaModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = PixArtAlphaEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/PixArtAlphaFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver 5 | from modules.modelSaver.pixartAlpha.PixArtAlphaModelSaver import PixArtAlphaModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class PixArtAlphaFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: PixArtAlphaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | base_model_saver = PixArtAlphaModelSaver() 28 | embedding_model_saver = PixArtAlphaEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/PixArtAlphaLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.PixArtAlphaModel import PixArtAlphaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver 5 | from modules.modelSaver.pixartAlpha.PixArtAlphaLoRASaver import PixArtAlphaLoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class PixArtAlphaLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: PixArtAlphaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | lora_model_saver = PixArtAlphaLoRASaver() 28 | embedding_model_saver = PixArtAlphaEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/SanaEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class SanaEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: SanaModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = SanaEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/SanaFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver 5 | from modules.modelSaver.sana.SanaModelSaver import SanaModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class SanaFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: SanaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | base_model_saver = SanaModelSaver() 28 | embedding_model_saver = SanaEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/SanaLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.SanaModel import SanaModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver 5 | from modules.modelSaver.sana.SanaLoRASaver import SanaLoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class SanaLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: SanaModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | lora_model_saver = SanaLoRASaver() 28 | embedding_model_saver = SanaEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusion3EmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class StableDiffusion3EmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: StableDiffusion3Model, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = StableDiffusion3EmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusion3FineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver 5 | from modules.modelSaver.stableDiffusion3.StableDiffusion3ModelSaver import StableDiffusion3ModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusion3FineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusion3Model, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = StableDiffusion3ModelSaver() 28 | embedding_model_saver = StableDiffusion3EmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusion3LoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusion3Model import StableDiffusion3Model 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver 5 | from modules.modelSaver.stableDiffusion3.StableDiffusion3LoRASaver import StableDiffusion3LoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusion3LoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusion3Model, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = StableDiffusion3LoRASaver() 28 | embedding_model_saver = StableDiffusion3EmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class StableDiffusionEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: StableDiffusionModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = StableDiffusionEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver 5 | from modules.modelSaver.stableDiffusion.StableDiffusionModelSaver import StableDiffusionModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusionFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusionModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = StableDiffusionModelSaver() 28 | embedding_model_saver = StableDiffusionEmbeddingSaver() 29 | 30 | base_model_saver.save(model, model_type, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionModel import StableDiffusionModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver 5 | from modules.modelSaver.stableDiffusion.StableDiffusionLoRASaver import StableDiffusionLoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusionLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusionModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = StableDiffusionLoRASaver() 28 | embedding_model_saver = StableDiffusionEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionXLEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class StableDiffusionXLEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: StableDiffusionXLModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype | None, 25 | ): 26 | embedding_model_saver = StableDiffusionXLEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionXLFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver 5 | from modules.modelSaver.stableDiffusionXL.StableDiffusionXLModelSaver import StableDiffusionXLModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusionXLFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusionXLModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | base_model_saver = StableDiffusionXLModelSaver() 28 | embedding_model_saver = StableDiffusionXLEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/StableDiffusionXLLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.StableDiffusionXLModel import StableDiffusionXLModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver 5 | from modules.modelSaver.stableDiffusionXL.StableDiffusionXLLoRASaver import StableDiffusionXLLoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class StableDiffusionXLLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: StableDiffusionXLModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype | None, 26 | ): 27 | lora_model_saver = StableDiffusionXLLoRASaver() 28 | embedding_model_saver = StableDiffusionXLEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/WuerstchenEmbeddingModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver 5 | from modules.util.enum.ModelFormat import ModelFormat 6 | from modules.util.enum.ModelType import ModelType 7 | 8 | import torch 9 | 10 | 11 | class WuerstchenEmbeddingModelSaver( 12 | BaseModelSaver, 13 | InternalModelSaverMixin, 14 | ): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def save( 19 | self, 20 | model: WuerstchenModel, 21 | model_type: ModelType, 22 | output_model_format: ModelFormat, 23 | output_model_destination: str, 24 | dtype: torch.dtype, 25 | ): 26 | embedding_model_saver = WuerstchenEmbeddingSaver() 27 | 28 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 29 | embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) 30 | 31 | if output_model_format == ModelFormat.INTERNAL: 32 | self._save_internal_data(model, output_model_destination) 33 | -------------------------------------------------------------------------------- /modules/modelSaver/WuerstchenFineTuneModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver 5 | from modules.modelSaver.wuerstchen.WuerstchenModelSaver import WuerstchenModelSaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class WuerstchenFineTuneModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: WuerstchenModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | base_model_saver = WuerstchenModelSaver() 28 | embedding_model_saver = WuerstchenEmbeddingSaver() 29 | 30 | base_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 32 | 33 | if output_model_format == ModelFormat.INTERNAL: 34 | self._save_internal_data(model, output_model_destination) 35 | -------------------------------------------------------------------------------- /modules/modelSaver/WuerstchenLoRAModelSaver.py: -------------------------------------------------------------------------------- 1 | from modules.model.WuerstchenModel import WuerstchenModel 2 | from modules.modelSaver.BaseModelSaver import BaseModelSaver 3 | from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin 4 | from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver 5 | from modules.modelSaver.wuerstchen.WuerstchenLoRASaver import WuerstchenLoRASaver 6 | from modules.util.enum.ModelFormat import ModelFormat 7 | from modules.util.enum.ModelType import ModelType 8 | 9 | import torch 10 | 11 | 12 | class WuerstchenLoRAModelSaver( 13 | BaseModelSaver, 14 | InternalModelSaverMixin, 15 | ): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def save( 20 | self, 21 | model: WuerstchenModel, 22 | model_type: ModelType, 23 | output_model_format: ModelFormat, 24 | output_model_destination: str, 25 | dtype: torch.dtype, 26 | ): 27 | lora_model_saver = WuerstchenLoRASaver() 28 | embedding_model_saver = WuerstchenEmbeddingSaver() 29 | 30 | lora_model_saver.save(model, output_model_format, output_model_destination, dtype) 31 | if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: 32 | embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) 33 | 34 | if output_model_format == ModelFormat.INTERNAL: 35 | self._save_internal_data(model, output_model_destination) 36 | -------------------------------------------------------------------------------- /modules/modelSaver/mixin/EmbeddingSaverMixin.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABCMeta, abstractmethod 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from safetensors.torch import save_file 10 | 11 | 12 | class EmbeddingSaverMixin(metaclass=ABCMeta): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | @abstractmethod 17 | def _to_state_dict( 18 | self, 19 | embedding: Any | None, 20 | embedding_state_dict: dict[str, Tensor] | None, 21 | dtype: torch.dtype | None, 22 | ): 23 | pass 24 | 25 | def _save_safetensors( 26 | self, 27 | embedding: Any | None, 28 | embedding_state_dict: dict[str, Tensor] | None, 29 | destination: str, 30 | dtype: torch.dtype | None, 31 | ): 32 | os.makedirs(Path(destination).parent.absolute(), exist_ok=True) 33 | 34 | state_dict = self._to_state_dict( 35 | embedding, 36 | embedding_state_dict, 37 | dtype, 38 | ) 39 | 40 | save_file(state_dict, destination) 41 | 42 | def _save_internal( 43 | self, 44 | embedding: Any | None, 45 | embedding_state: dict[str, Tensor] | None, 46 | embedding_uuid: str, 47 | destination: str, 48 | ): 49 | safetensors_embedding_name = os.path.join( 50 | destination, 51 | "embeddings", 52 | f"{embedding_uuid}.safetensors", 53 | ) 54 | self._save_safetensors( 55 | embedding, 56 | embedding_state, 57 | safetensors_embedding_name, 58 | None, 59 | ) 60 | -------------------------------------------------------------------------------- /modules/modelSaver/mixin/InternalModelSaverMixin.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from abc import ABCMeta 4 | 5 | from modules.model.BaseModel import BaseModel 6 | 7 | import torch 8 | 9 | 10 | class InternalModelSaverMixin(metaclass=ABCMeta): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def _save_internal_data( 15 | self, 16 | model: BaseModel, 17 | destination: str, 18 | ): 19 | # optimizer 20 | os.makedirs(os.path.join(destination, "optimizer"), exist_ok=True) 21 | optimizer_state_dict = model.optimizer.state_dict() 22 | optimizer_state_dict["param_group_mapping"] = model.param_group_mapping 23 | optimizer_state_dict["param_group_optimizer_mapping"] = \ 24 | [str(model.train_config.optimizer.optimizer) for _ in model.param_group_mapping] 25 | 26 | torch.save(optimizer_state_dict, os.path.join(destination, "optimizer", "optimizer.pt")) 27 | 28 | # ema 29 | if model.ema: 30 | os.makedirs(os.path.join(destination, "ema"), exist_ok=True) 31 | torch.save(model.ema.state_dict(), os.path.join(destination, "ema", "ema.pt")) 32 | 33 | # meta 34 | with open(os.path.join(destination, "meta.json"), "w") as meta_file: 35 | json.dump({ 36 | 'train_progress': { 37 | 'epoch': model.train_progress.epoch, 38 | 'epoch_step': model.train_progress.epoch_step, 39 | 'epoch_sample': model.train_progress.epoch_sample, 40 | 'global_step': model.train_progress.global_step, 41 | }, 42 | }, meta_file) 43 | -------------------------------------------------------------------------------- /modules/modelSetup/mixin/ModelSetupFlowMatchingMixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class ModelSetupFlowMatchingMixin(metaclass=ABCMeta): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self.__sigma = None 12 | self.__one_minus_sigma = None 13 | 14 | def _add_noise_discrete( 15 | self, 16 | scaled_latent_image: Tensor, 17 | latent_noise: Tensor, 18 | timestep: Tensor, 19 | timesteps: Tensor, 20 | ) -> tuple[Tensor, Tensor]: 21 | if self.__sigma is None: 22 | num_timesteps = timesteps.shape[-1] 23 | all_timesteps = torch.arange(start=1, end=num_timesteps + 1, step=1, dtype=torch.int32, device=scaled_latent_image.device) 24 | self.__sigma = all_timesteps / num_timesteps 25 | self.__one_minus_sigma = 1.0 - self.__sigma 26 | 27 | orig_dtype = scaled_latent_image.dtype 28 | 29 | sigmas = self.__sigma[timestep] 30 | one_minus_sigmas = self.__one_minus_sigma[timestep] 31 | 32 | while sigmas.dim() < scaled_latent_image.dim(): 33 | sigmas = sigmas.unsqueeze(-1) 34 | one_minus_sigmas = one_minus_sigmas.unsqueeze(-1) 35 | 36 | scaled_noisy_latent_image = latent_noise.to(dtype=sigmas.dtype) * sigmas \ 37 | + scaled_latent_image.to(dtype=sigmas.dtype) * one_minus_sigmas 38 | 39 | return scaled_noisy_latent_image.to(dtype=orig_dtype), sigmas 40 | -------------------------------------------------------------------------------- /modules/module/Blip2Model.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseImageCaptionModel import BaseImageCaptionModel, CaptionSample 2 | 3 | import torch 4 | 5 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 6 | 7 | 8 | class Blip2Model(BaseImageCaptionModel): 9 | def __init__(self, device: torch.device, dtype: torch.dtype): 10 | self.device = device 11 | self.dtype = dtype 12 | 13 | self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 14 | 15 | self.model = Blip2ForConditionalGeneration.from_pretrained( 16 | "Salesforce/blip2-opt-2.7b", 17 | torch_dtype=self.dtype 18 | ) 19 | self.model.eval() 20 | self.model.to(self.device) 21 | 22 | def generate_caption( 23 | self, 24 | caption_sample: CaptionSample, 25 | initial_caption: str = "", 26 | caption_prefix: str = "", 27 | caption_postfix: str = "", 28 | ) -> str: 29 | inputs = self.processor(caption_sample.get_image(), initial_caption, return_tensors="pt") 30 | inputs = inputs.to(self.device, self.dtype) 31 | with torch.no_grad(): 32 | outputs = self.model.generate(**inputs) 33 | predicted_caption = self.processor.decode(outputs[0], skip_special_tokens=True) 34 | predicted_caption = (caption_prefix + initial_caption + predicted_caption + caption_postfix).strip() 35 | 36 | return predicted_caption 37 | -------------------------------------------------------------------------------- /modules/module/BlipModel.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseImageCaptionModel import BaseImageCaptionModel, CaptionSample 2 | 3 | import torch 4 | 5 | from transformers import BlipForConditionalGeneration, BlipProcessor 6 | 7 | 8 | class BlipModel(BaseImageCaptionModel): 9 | def __init__(self, device: torch.device, dtype: torch.dtype): 10 | self.device = device 11 | self.dtype = dtype 12 | 13 | self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") 14 | 15 | self.model = BlipForConditionalGeneration.from_pretrained( 16 | "Salesforce/blip-image-captioning-large", 17 | torch_dtype=self.dtype 18 | ) 19 | self.model.eval() 20 | self.model.to(self.device) 21 | 22 | def generate_caption( 23 | self, 24 | caption_sample: CaptionSample, 25 | initial_caption: str = "", 26 | caption_prefix: str = "", 27 | caption_postfix: str = "", 28 | ): 29 | inputs = self.processor(caption_sample.get_image(), initial_caption, return_tensors="pt") 30 | inputs = inputs.to(self.device, self.dtype) 31 | with torch.no_grad(): 32 | outputs = self.model.generate(**inputs) 33 | predicted_caption = self.processor.decode(outputs[0], skip_special_tokens=True) 34 | predicted_caption = (caption_prefix + predicted_caption + caption_postfix).strip() 35 | 36 | return predicted_caption 37 | -------------------------------------------------------------------------------- /modules/module/RembgHumanModel.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseRembgModel import BaseRembgModel 2 | 3 | import torch 4 | 5 | 6 | class RembgHumanModel(BaseRembgModel): 7 | def __init__(self, device: torch.device, dtype: torch.dtype): 8 | super().__init__( 9 | model_filename="u2net_human_seg.onnx", 10 | model_path="https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx", 11 | model_md5="md5:c09ddc2e0104f800e3e1bb4652583d1f", 12 | device=device, 13 | dtype=dtype, 14 | ) 15 | -------------------------------------------------------------------------------- /modules/module/RembgModel.py: -------------------------------------------------------------------------------- 1 | from modules.module.BaseRembgModel import BaseRembgModel 2 | 3 | import torch 4 | 5 | 6 | class RembgModel(BaseRembgModel): 7 | def __init__(self, device: torch.device, dtype: torch.dtype): 8 | super().__init__( 9 | model_filename="u2net.onnx", 10 | model_path="https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx", 11 | model_md5="md5:60024c5c889badc19c04ad937298a77b", 12 | device=device, 13 | dtype=dtype, 14 | ) 15 | -------------------------------------------------------------------------------- /modules/module/quantized/mixin/QuantizedLinearMixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class QuantizedLinearMixin(metaclass=ABCMeta): 7 | @abstractmethod 8 | def original_weight_shape(self) -> tuple[int, ...]: 9 | pass 10 | 11 | @abstractmethod 12 | def unquantized_weight(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: 13 | pass 14 | -------------------------------------------------------------------------------- /modules/module/quantized/mixin/QuantizedModuleMixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class QuantizedModuleMixin(metaclass=ABCMeta): 7 | @abstractmethod 8 | def quantize(self, device: torch.device | None = None): 9 | pass 10 | -------------------------------------------------------------------------------- /modules/ui/SampleParamsWindow.py: -------------------------------------------------------------------------------- 1 | from modules.ui.SampleFrame import SampleFrame 2 | from modules.util.config.SampleConfig import SampleConfig 3 | from modules.util.ui import components 4 | from modules.util.ui.ui_utils import set_window_icon 5 | from modules.util.ui.UIState import UIState 6 | 7 | import customtkinter as ctk 8 | 9 | 10 | class SampleParamsWindow(ctk.CTkToplevel): 11 | def __init__(self, parent, sample: SampleConfig, ui_state: UIState, *args, **kwargs): 12 | super().__init__(parent, *args, **kwargs) 13 | 14 | self.sample = sample 15 | self.ui_state = ui_state 16 | 17 | self.title("Sample") 18 | self.geometry("800x500") 19 | self.resizable(True, True) 20 | 21 | self.grid_rowconfigure(0, weight=1) 22 | self.grid_rowconfigure(1, weight=0) 23 | self.grid_columnconfigure(0, weight=1) 24 | 25 | frame = SampleFrame(self, self.sample, self.ui_state) 26 | frame.grid(row=0, column=0, padx=0, pady=0, sticky="nsew") 27 | 28 | components.button(self, 1, 0, "ok", self.__ok) 29 | 30 | self.wait_visibility() 31 | self.grab_set() 32 | self.focus_set() 33 | self.after(200, lambda: set_window_icon(self)) 34 | 35 | 36 | def __ok(self): 37 | self.destroy() 38 | -------------------------------------------------------------------------------- /modules/util/ModelNames.py: -------------------------------------------------------------------------------- 1 | class EmbeddingName: 2 | def __init__( 3 | self, 4 | uuid: str, 5 | model_name: str, 6 | ): 7 | self.uuid = uuid 8 | self.model_name = model_name 9 | 10 | 11 | class ModelNames: 12 | def __init__( 13 | self, 14 | base_model: str = "", 15 | prior_model: str = "", 16 | effnet_encoder_model: str = "", 17 | decoder_model: str = "", 18 | text_encoder_4: str = "", 19 | vae_model: str = "", 20 | lora: str = "", 21 | embedding: EmbeddingName | None = None, 22 | additional_embeddings: list[EmbeddingName] | None = None, 23 | include_text_encoder: bool = True, 24 | include_text_encoder_2: bool = True, 25 | include_text_encoder_3: bool = True, 26 | include_text_encoder_4: bool = True, 27 | ): 28 | self.base_model = base_model 29 | self.prior_model = prior_model 30 | self.effnet_encoder_model = effnet_encoder_model 31 | self.decoder_model = decoder_model 32 | self.text_encoder_4 = text_encoder_4 33 | self.vae_model = vae_model 34 | self.lora = lora 35 | self.embedding = embedding 36 | self.additional_embeddings = [] if additional_embeddings is None else additional_embeddings 37 | self.include_text_encoder = include_text_encoder 38 | self.include_text_encoder_2 = include_text_encoder_2 39 | self.include_text_encoder_3 = include_text_encoder_3 40 | self.include_text_encoder_4 = include_text_encoder_4 41 | 42 | def all_embedding(self): 43 | if self.embedding is not None: 44 | return self.additional_embeddings + [self.embedding] 45 | else: 46 | return self.additional_embeddings 47 | -------------------------------------------------------------------------------- /modules/util/TrainProgress.py: -------------------------------------------------------------------------------- 1 | class TrainProgress: 2 | def __init__( 3 | self, 4 | epoch: int = 0, 5 | epoch_step: int = 0, 6 | epoch_sample: int = 0, 7 | global_step: int = 0, 8 | ): 9 | self.epoch = epoch 10 | self.epoch_step = epoch_step 11 | self.epoch_sample = epoch_sample 12 | self.global_step = global_step 13 | 14 | def next_step(self, batch_size: int): 15 | self.epoch_step += 1 16 | self.epoch_sample += batch_size 17 | self.global_step += 1 18 | 19 | def next_epoch(self): 20 | self.epoch_step = 0 21 | self.epoch_sample = 0 22 | self.epoch += 1 23 | 24 | def filename_string(self): 25 | return f"{self.global_step}-{self.epoch}-{self.epoch_step}" 26 | -------------------------------------------------------------------------------- /modules/util/args/CalculateLossArgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from modules.util.args.BaseArgs import BaseArgs 5 | 6 | 7 | class CalculateLossArgs(BaseArgs): 8 | config_path: str 9 | output_path: str 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def parse_args() -> 'CalculateLossArgs': 16 | parser = argparse.ArgumentParser(description="One Trainer Loss Calculation Script.") 17 | 18 | # @formatter:off 19 | 20 | parser.add_argument("--config-path", type=str, required=True, dest="config_path", help="The path to the config file") 21 | parser.add_argument("--output-path", type=str, required=True, dest="output_path", help="The path to the output file") 22 | 23 | # @formatter:on 24 | 25 | args = CalculateLossArgs.default_values() 26 | args.from_dict(vars(parser.parse_args())) 27 | return args 28 | 29 | @staticmethod 30 | def default_values() -> 'CalculateLossArgs': 31 | data = [] 32 | 33 | # name, default value, data type, nullable 34 | data.append(("config_path", None, str, True)) 35 | data.append(("output_path", "losses.json", str, False)) 36 | 37 | return CalculateLossArgs(data) 38 | -------------------------------------------------------------------------------- /modules/util/args/CaptionUIArgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from modules.util.args.BaseArgs import BaseArgs 5 | 6 | 7 | class CaptionUIArgs(BaseArgs): 8 | dir: str 9 | include_subdirectories: bool 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def parse_args() -> 'CaptionUIArgs': 16 | parser = argparse.ArgumentParser(description="One Trainer Caption UI Script.") 17 | 18 | # @formatter:off 19 | 20 | parser.add_argument("--dir", type=str, required=False, default=None, dest="dir", help="The initial directory to load training data from") 21 | parser.add_argument("--include-subdirectories", action="store_true", required=False, default=False, dest="include_subdirectories", help="Whether to include subdirectories when processing samples") 22 | 23 | # @formatter:on 24 | 25 | args = CaptionUIArgs.default_values() 26 | args.from_dict(vars(parser.parse_args())) 27 | return args 28 | 29 | @staticmethod 30 | def default_values() -> 'CaptionUIArgs': 31 | data = [] 32 | 33 | # name, default value, data type, nullable 34 | data.append(("dir", None, str, True)) 35 | data.append(("include_subdirectories", False, bool, False)) 36 | 37 | return CaptionUIArgs(data) 38 | -------------------------------------------------------------------------------- /modules/util/args/CreateTrainFilesArgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from modules.util.args.BaseArgs import BaseArgs 5 | 6 | 7 | class CreateTrainFilesArgs(BaseArgs): 8 | config_output_destination: str 9 | concepts_output_destination: str 10 | samples_output_destination: str 11 | 12 | def __init__(self, data: list[(str, Any, type, bool)]): 13 | super().__init__(data) 14 | 15 | @staticmethod 16 | def parse_args() -> 'CreateTrainFilesArgs': 17 | parser = argparse.ArgumentParser(description="One Trainer Create Train Files Script.") 18 | 19 | # @formatter:off 20 | 21 | parser.add_argument("--config-output-destination", type=str, required=False, default=None, dest="config_output_destination", help="The destination filename to save a default config file") 22 | parser.add_argument("--concepts-output-destination", type=str, required=False, default=None, dest="concepts_output_destination", help="The destination filename to save a default concepts file") 23 | parser.add_argument("--samples-output-destination", type=str, required=False, default=None, dest="samples_output_destination", help="The destination filename to save a default samples file") 24 | 25 | # @formatter:on 26 | 27 | args = CreateTrainFilesArgs.default_values() 28 | args.from_dict(vars(parser.parse_args())) 29 | return args 30 | 31 | 32 | @staticmethod 33 | def default_values(): 34 | data = [] 35 | 36 | data.append(("config_output_destination", None, str, True)) 37 | data.append(("concepts_output_destination", None, str, True)) 38 | data.append(("samples_output_destination", None, str, True)) 39 | 40 | return CreateTrainFilesArgs(data) 41 | -------------------------------------------------------------------------------- /modules/util/args/TrainArgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from modules.util.args.BaseArgs import BaseArgs 5 | 6 | 7 | class TrainArgs(BaseArgs): 8 | config_path: str 9 | secrets_path: str 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def parse_args() -> 'TrainArgs': 16 | parser = argparse.ArgumentParser(description="One Trainer Training Script.") 17 | 18 | # @formatter:off 19 | 20 | parser.add_argument("--config-path", type=str, required=True, dest="config_path", help="The path to the config file") 21 | parser.add_argument("--secrets-path", type=str, required=False, dest="secrets_path", help="The path to the secrets file") 22 | parser.add_argument("--callback-path", type=str, required=False, dest="callback_path", help="The path to the callback pickle file") 23 | parser.add_argument("--command-path", type=str, required=False, dest="command_path", help="The path to the command pickle file") 24 | 25 | # @formatter:on 26 | 27 | args = TrainArgs.default_values() 28 | args.from_dict(vars(parser.parse_args())) 29 | return args 30 | 31 | @staticmethod 32 | def default_values() -> 'TrainArgs': 33 | data = [] 34 | 35 | # name, default value, data type, nullable 36 | data.append(("config_path", None, str, True)) 37 | data.append(("secrets_path", None, str, True)) 38 | data.append(("callback_path", None, str, True)) 39 | data.append(("command_path", None, str, True)) 40 | 41 | return TrainArgs(data) 42 | -------------------------------------------------------------------------------- /modules/util/args/arg_type_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def torch_device(device_name: str) -> torch.device: 5 | return torch.device(device_name) 6 | 7 | def nullable_bool(bool_value: str) -> bool: 8 | return bool_value.lower() == "true" 9 | -------------------------------------------------------------------------------- /modules/util/config/SecretsConfig.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from modules.util.config.BaseConfig import BaseConfig 4 | from modules.util.config.CloudConfig import CloudSecretsConfig 5 | 6 | 7 | class SecretsConfig(BaseConfig): 8 | huggingface_token: str 9 | cloud: CloudSecretsConfig 10 | 11 | def __init__(self, data: list[(str, Any, type, bool)]): 12 | super().__init__(data) 13 | 14 | @staticmethod 15 | def default_values() -> 'SecretsConfig': 16 | data = [] 17 | 18 | # name, default value, data type, nullable 19 | data.append(("huggingface_token", "", str, False)) 20 | 21 | # cloud 22 | cloud = CloudSecretsConfig.default_values() 23 | data.append(("cloud", cloud, CloudSecretsConfig, False)) 24 | 25 | return SecretsConfig(data) 26 | -------------------------------------------------------------------------------- /modules/util/conv_util.py: -------------------------------------------------------------------------------- 1 | from modules.module.LoRAModule import LoRAModuleWrapper 2 | 3 | from torch import nn 4 | 5 | 6 | def apply_circular_padding_to_conv2d(module: nn.Module | LoRAModuleWrapper): 7 | for m in module.modules(): 8 | if isinstance(m, nn.Conv2d): 9 | m.padding_mode = 'circular' 10 | -------------------------------------------------------------------------------- /modules/util/convert/rescale_noise_scheduler_to_zero_terminal_snr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusers import DDIMScheduler 4 | 5 | 6 | def rescale_noise_scheduler_to_zero_terminal_snr(noise_scheduler: DDIMScheduler): 7 | """ 8 | From: Common Diffusion Noise Schedules and Sample Steps are Flawed (https://arxiv.org/abs/2305.08891) 9 | 10 | Rescales the 11 | 12 | Args: 13 | noise_scheduler: The noise scheduler to transform 14 | 15 | Returns: 16 | 17 | """ 18 | alphas_cumprod = noise_scheduler.alphas_cumprod 19 | sqrt_alphas_cumprod = alphas_cumprod ** 0.5 20 | 21 | # Store old values. 22 | alphas_cumprod_sqrt_0 = sqrt_alphas_cumprod[0].clone() 23 | alphas_cumprod_sqrt_T = sqrt_alphas_cumprod[-1].clone() 24 | 25 | # Shift so last timestep is zero. 26 | sqrt_alphas_cumprod -= alphas_cumprod_sqrt_T 27 | 28 | # Scale so first timestep is back to old value. 29 | sqrt_alphas_cumprod *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) 30 | 31 | # Convert alphas_cumprod_sqrt to betas 32 | alphas_cumprod = sqrt_alphas_cumprod ** 2 33 | alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] 34 | alphas = torch.cat([alphas_cumprod[0:1], alphas]) 35 | betas = 1 - alphas 36 | 37 | noise_scheduler.betas = betas 38 | noise_scheduler.alphas = alphas 39 | noise_scheduler.alphas_cumprod = alphas_cumprod 40 | 41 | return betas 42 | -------------------------------------------------------------------------------- /modules/util/enum/AudioFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AudioFormat(Enum): 5 | MP3 = 'MP3' 6 | 7 | def __str__(self): 8 | return self.value 9 | 10 | def extension(self) -> str: 11 | match self: 12 | case AudioFormat.MP3: 13 | return ".mp3" 14 | case _: 15 | return "" 16 | -------------------------------------------------------------------------------- /modules/util/enum/BalancingStrategy.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BalancingStrategy(Enum): 5 | REPEATS = 'REPEATS' 6 | SAMPLES = 'SAMPLES' 7 | 8 | def __str__(self): 9 | return self.value 10 | -------------------------------------------------------------------------------- /modules/util/enum/CloudAction.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class CloudAction(Enum): 5 | NONE = 'NONE' 6 | STOP = 'STOP' 7 | DELETE = 'DELETE' 8 | def __str__(self): 9 | return self.value 10 | -------------------------------------------------------------------------------- /modules/util/enum/CloudFileSync.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class CloudFileSync(Enum): 5 | FABRIC_SFTP = 'FABRIC_SFTP' 6 | NATIVE_SCP = 'NATIVE_SCP' 7 | def __str__(self): 8 | return self.value 9 | -------------------------------------------------------------------------------- /modules/util/enum/CloudType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class CloudType(Enum): 5 | RUNPOD = 'RUNPOD' 6 | LINUX = 'LINUX' 7 | def __str__(self): 8 | return self.value 9 | -------------------------------------------------------------------------------- /modules/util/enum/ConceptType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ConceptType(Enum): 5 | STANDARD = 'STANDARD' 6 | VALIDATION = 'VALIDATION' 7 | PRIOR_PREDICTION = 'PRIOR_PREDICTION' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /modules/util/enum/ConfigPart.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ConfigPart(Enum): 5 | NONE = 'NONE' 6 | SETTINGS = 'SETTINGS' 7 | ALL = 'ALL' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /modules/util/enum/DataType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | 5 | 6 | class DataType(Enum): 7 | NONE = 'NONE' 8 | FLOAT_8 = 'FLOAT_8' 9 | FLOAT_16 = 'FLOAT_16' 10 | FLOAT_32 = 'FLOAT_32' 11 | BFLOAT_16 = 'BFLOAT_16' 12 | TFLOAT_32 = 'TFLOAT_32' 13 | INT_8 = 'INT_8' 14 | NFLOAT_4 = 'NFLOAT_4' 15 | 16 | def __str__(self): 17 | return self.value 18 | 19 | def torch_dtype( 20 | self, 21 | supports_quantization: bool = True, 22 | ): 23 | if self.is_quantized() and not supports_quantization: 24 | return torch.float16 25 | 26 | match self: 27 | case DataType.FLOAT_16: 28 | return torch.float16 29 | case DataType.FLOAT_32: 30 | return torch.float32 31 | case DataType.BFLOAT_16: 32 | return torch.bfloat16 33 | case DataType.TFLOAT_32: 34 | return torch.float32 35 | case _: 36 | return None 37 | 38 | def enable_tf(self): 39 | return self == DataType.TFLOAT_32 40 | 41 | def is_quantized(self): 42 | return self in [DataType.FLOAT_8, 43 | DataType.INT_8, 44 | DataType.NFLOAT_4] 45 | 46 | def quantize_fp8(self): 47 | return self == DataType.FLOAT_8 48 | 49 | def quantize_int8(self): 50 | return self == DataType.INT_8 51 | 52 | def quantize_nf4(self): 53 | return self == DataType.NFLOAT_4 54 | -------------------------------------------------------------------------------- /modules/util/enum/EMAMode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class EMAMode(Enum): 5 | OFF = 'OFF' 6 | GPU = 'GPU' 7 | CPU = 'CPU' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /modules/util/enum/FileType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class FileType(Enum): 5 | IMAGE = 'IMAGE' 6 | VIDEO = 'VIDEO' 7 | AUDIO = 'AUDIO' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /modules/util/enum/GenerateCaptionsModel.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GenerateCaptionsModel(Enum): 5 | BLIP = 'BLIP' 6 | BLIP2 = 'BLIP2' 7 | WD14_VIT_2 = 'WD14_VIT_2' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /modules/util/enum/GenerateMasksModel.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GenerateMasksModel(Enum): 5 | CLIPSEG = 'CLIPSEG' 6 | REMBG = 'REMBG' 7 | REMBG_HUMAN = 'REMBG_HUMAN' 8 | COLOR = 'COLOR' 9 | 10 | def __str__(self): 11 | return self.value 12 | -------------------------------------------------------------------------------- /modules/util/enum/GradientCheckpointingMethod.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class GradientCheckpointingMethod(Enum): 5 | OFF = 'OFF' 6 | ON = 'ON' 7 | CPU_OFFLOADED = 'CPU_OFFLOADED' 8 | 9 | def __str__(self): 10 | return self.value 11 | 12 | def enabled(self): 13 | return self == GradientCheckpointingMethod.ON \ 14 | or self == GradientCheckpointingMethod.CPU_OFFLOADED 15 | 16 | def offload(self): 17 | return self == GradientCheckpointingMethod.CPU_OFFLOADED 18 | -------------------------------------------------------------------------------- /modules/util/enum/ImageFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ImageFormat(Enum): 5 | PNG = 'PNG' 6 | JPG = 'JPG' 7 | 8 | def __str__(self): 9 | return self.value 10 | 11 | def extension(self) -> str: 12 | match self: 13 | case ImageFormat.PNG: 14 | return ".png" 15 | case ImageFormat.JPG: 16 | return ".jpg" 17 | case _: 18 | return "" 19 | 20 | def pil_format(self) -> str: 21 | match self: 22 | case ImageFormat.PNG: 23 | return "PNG" 24 | case ImageFormat.JPG: 25 | return "JPEG" 26 | case _: 27 | return "" 28 | -------------------------------------------------------------------------------- /modules/util/enum/LearningRateScaler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LearningRateScaler(Enum): 5 | NONE = 'NONE' 6 | BATCH = 'BATCH' 7 | GRADIENT_ACCUMULATION = 'GRADIENT_ACCUMULATION' 8 | BOTH = 'BOTH' 9 | 10 | def __str__(self): 11 | return self.value 12 | -------------------------------------------------------------------------------- /modules/util/enum/LearningRateScheduler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LearningRateScheduler(Enum): 5 | CONSTANT = 'CONSTANT' 6 | LINEAR = 'LINEAR' 7 | COSINE = 'COSINE' 8 | COSINE_WITH_RESTARTS = 'COSINE_WITH_RESTARTS' 9 | COSINE_WITH_HARD_RESTARTS = 'COSINE_WITH_HARD_RESTARTS' 10 | REX = 'REX' 11 | ADAFACTOR = 'ADAFACTOR' 12 | CUSTOM = 'CUSTOM' 13 | 14 | def __str__(self): 15 | return self.value 16 | -------------------------------------------------------------------------------- /modules/util/enum/LossScaler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LossScaler(Enum): 5 | NONE = 'NONE' 6 | BATCH = 'BATCH' 7 | GRADIENT_ACCUMULATION = 'GRADIENT_ACCUMULATION' 8 | BOTH = 'BOTH' 9 | 10 | def __str__(self): 11 | return self.value 12 | -------------------------------------------------------------------------------- /modules/util/enum/LossWeight.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class LossWeight(Enum): 5 | CONSTANT = 'CONSTANT' 6 | P2 = 'P2' 7 | MIN_SNR_GAMMA = 'MIN_SNR_GAMMA' 8 | DEBIASED_ESTIMATION = 'DEBIASED_ESTIMATION' 9 | SIGMA = 'SIGMA' 10 | 11 | def __str__(self): 12 | return self.value 13 | -------------------------------------------------------------------------------- /modules/util/enum/ModelFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ModelFormat(Enum): 5 | DIFFUSERS = 'DIFFUSERS' 6 | CKPT = 'CKPT' 7 | SAFETENSORS = 'SAFETENSORS' 8 | LEGACY_SAFETENSORS = 'LEGACY_SAFETENSORS' 9 | 10 | INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training 11 | 12 | def __str__(self): 13 | return self.value 14 | 15 | 16 | def file_extension(self) -> str: 17 | match self: 18 | case ModelFormat.DIFFUSERS: 19 | return '' 20 | case ModelFormat.CKPT: 21 | return '.ckpt' 22 | case ModelFormat.SAFETENSORS: 23 | return '.safetensors' 24 | case ModelFormat.LEGACY_SAFETENSORS: 25 | return '.safetensors' 26 | case _: 27 | return '' 28 | 29 | def is_single_file(self) -> bool: 30 | return self.file_extension() != '' 31 | -------------------------------------------------------------------------------- /modules/util/enum/NoiseScheduler.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class NoiseScheduler(Enum): 5 | DDIM = 'DDIM' 6 | 7 | EULER = 'EULER' 8 | EULER_A = 'EULER_A' 9 | DPMPP = 'DPMPP' 10 | DPMPP_SDE = 'DPMPP_SDE' 11 | UNIPC = 'UNIPC' 12 | 13 | EULER_KARRAS = 'EULER_KARRAS' 14 | DPMPP_KARRAS = 'DPMPP_KARRAS' 15 | DPMPP_SDE_KARRAS = 'DPMPP_SDE_KARRAS' 16 | UNIPC_KARRAS = 'UNIPC_KARRAS' 17 | 18 | def __str__(self): 19 | return self.value 20 | -------------------------------------------------------------------------------- /modules/util/enum/TimeUnit.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TimeUnit(Enum): 5 | EPOCH = 'EPOCH' 6 | STEP = 'STEP' 7 | SECOND = 'SECOND' 8 | MINUTE = 'MINUTE' 9 | HOUR = 'HOUR' 10 | 11 | NEVER = 'NEVER' 12 | ALWAYS = 'ALWAYS' 13 | 14 | def __str__(self): 15 | return self.value 16 | 17 | def is_time_unit(self) -> bool: 18 | return self == TimeUnit.SECOND \ 19 | or self == TimeUnit.MINUTE \ 20 | or self == TimeUnit.HOUR 21 | -------------------------------------------------------------------------------- /modules/util/enum/TimestepDistribution.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TimestepDistribution(Enum): 5 | UNIFORM = 'UNIFORM' 6 | SIGMOID = 'SIGMOID' 7 | LOGIT_NORMAL = 'LOGIT_NORMAL' 8 | HEAVY_TAIL = 'HEAVY_TAIL' 9 | COS_MAP = 'COS_MAP' 10 | 11 | def __str__(self): 12 | return self.value 13 | -------------------------------------------------------------------------------- /modules/util/enum/TrainingMethod.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TrainingMethod(Enum): 5 | FINE_TUNE = 'FINE_TUNE' 6 | LORA = 'LORA' 7 | EMBEDDING = 'EMBEDDING' 8 | FINE_TUNE_VAE = 'FINE_TUNE_VAE' 9 | 10 | def __str__(self): 11 | return self.value 12 | -------------------------------------------------------------------------------- /modules/util/enum/VideoFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class VideoFormat(Enum): 5 | PNG_IMAGE_SEQUENCE = 'PNG_IMAGE_SEQUENCE' 6 | JPG_IMAGE_SEQUENCE = 'JPG_IMAGE_SEQUENCE' 7 | MP4 = 'MP4' 8 | 9 | def __str__(self): 10 | return self.value 11 | 12 | def extension(self) -> str: 13 | match self: 14 | case VideoFormat.PNG_IMAGE_SEQUENCE: 15 | return ".png" 16 | case VideoFormat.JPG_IMAGE_SEQUENCE: 17 | return ".jpg" 18 | case VideoFormat.MP4: 19 | return ".mp4" 20 | case _: 21 | return "" 22 | 23 | def pil_format(self) -> str: 24 | match self: 25 | case VideoFormat.PNG_IMAGE_SEQUENCE: 26 | return "PNG" 27 | case VideoFormat.JPG_IMAGE_SEQUENCE: 28 | return "JPEG" 29 | case _: 30 | return "" 31 | -------------------------------------------------------------------------------- /modules/util/git_util.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def get_git_branch() -> str: 5 | try: 6 | return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode('ascii').strip() 7 | except subprocess.CalledProcessError: 8 | return "git not installed" 9 | 10 | 11 | def get_git_revision() -> str: 12 | try: 13 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() 14 | except subprocess.CalledProcessError: 15 | return "git not installed" 16 | -------------------------------------------------------------------------------- /modules/util/image_util.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | 3 | 4 | def load_image(path: str, convert_mode: str = 'RGB') -> Image.Image: 5 | image = Image.open(path) 6 | image = ImageOps.exif_transpose(image) 7 | if convert_mode: 8 | image = image.convert(convert_mode) 9 | return image 10 | -------------------------------------------------------------------------------- /modules/util/loss/masked_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def masked_losses( 6 | losses: Tensor, 7 | mask: Tensor, 8 | unmasked_weight: float, 9 | normalize_masked_area_loss: bool, 10 | ) -> Tensor: 11 | clamped_mask = torch.clamp(mask, unmasked_weight, 1) 12 | 13 | losses *= clamped_mask 14 | 15 | if normalize_masked_area_loss: 16 | losses = losses / clamped_mask.mean(dim=(1, 2, 3), keepdim=True) 17 | 18 | return losses 19 | 20 | 21 | def masked_losses_with_prior( 22 | losses: Tensor, 23 | prior_losses: Tensor | None, 24 | mask: Tensor, 25 | unmasked_weight: float, 26 | normalize_masked_area_loss: bool, 27 | masked_prior_preservation_weight: float, 28 | ) -> Tensor: 29 | clamped_mask = torch.clamp(mask, unmasked_weight, 1) 30 | 31 | losses *= clamped_mask 32 | 33 | if normalize_masked_area_loss: 34 | losses = losses / clamped_mask.mean(dim=(1, 2, 3), keepdim=True) 35 | 36 | if masked_prior_preservation_weight == 0: 37 | return losses 38 | 39 | clamped_mask = (1 - clamped_mask) 40 | prior_losses *= clamped_mask * masked_prior_preservation_weight 41 | 42 | if normalize_masked_area_loss: 43 | prior_losses = prior_losses / clamped_mask.mean(dim=(1, 2, 3), keepdim=True) 44 | 45 | return losses + prior_losses 46 | -------------------------------------------------------------------------------- /modules/util/memory_util.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import torch 4 | 5 | 6 | class TorchMemoryRecorder: 7 | 8 | def __init__(self, filename: str = "memory.pickle", enabled: bool = True): 9 | self.filename = filename 10 | self.enabled = enabled and platform.system() == 'Linux' 11 | 12 | def __enter__(self): 13 | if self.enabled: 14 | torch.cuda.memory._record_memory_history() 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | if self.enabled: 18 | try: 19 | torch.cuda.memory._dump_snapshot(filename=self.filename) 20 | print(f"dumped memory snapshot to {self.filename}") 21 | except Exception: 22 | print(f"could not dump memory snapshot {self.filename}") 23 | 24 | torch.cuda.memory._record_memory_history(enabled=None) 25 | -------------------------------------------------------------------------------- /modules/util/path_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | from typing import Any 4 | 5 | 6 | def safe_filename( 7 | text: str, 8 | allow_spaces: bool = True, 9 | max_length: int | None = 32, 10 | ): 11 | legal_chars = [' ', '.', '_', '-', '#'] 12 | if not allow_spaces: 13 | text = text.replace(' ', '_') 14 | 15 | text = ''.join(filter(lambda x: str.isalnum(x) or x in legal_chars, text)).strip() 16 | 17 | if max_length is not None: 18 | text = text[0: max_length] 19 | 20 | return text.strip() 21 | 22 | 23 | def canonical_join(base_path: str, *paths: str): 24 | # Creates a canonical path name that can be used for comparisons. 25 | # Also, Windows does understand / instead of \, so these paths can be used as usual. 26 | 27 | joined = os.path.join(base_path, *paths) 28 | return joined.replace('\\', '/') 29 | 30 | 31 | def write_json_atomic(path: str, obj: Any): 32 | with open(path + ".write", "w") as f: 33 | json.dump(obj, f, indent=4) 34 | os.replace(path + ".write", path) 35 | 36 | 37 | SUPPORTED_IMAGE_EXTENSIONS = {'.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.webp'} 38 | SUPPORTED_VIDEO_EXTENSIONS = {'.webm', '.mkv', '.flv', '.avi', '.mov', '.wmv', '.mp4', '.mpeg', '.m4v'} 39 | 40 | 41 | def supported_image_extensions() -> set[str]: 42 | return SUPPORTED_IMAGE_EXTENSIONS 43 | 44 | 45 | def is_supported_image_extension(extension: str) -> bool: 46 | return extension.lower() in SUPPORTED_IMAGE_EXTENSIONS 47 | 48 | 49 | def supported_video_extensions() -> set[str]: 50 | return SUPPORTED_VIDEO_EXTENSIONS 51 | 52 | 53 | def is_supported_video_extension(extension: str) -> bool: 54 | return extension.lower() in SUPPORTED_VIDEO_EXTENSIONS 55 | -------------------------------------------------------------------------------- /modules/util/time_util.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | def get_string_timestamp(): 5 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 6 | -------------------------------------------------------------------------------- /modules/util/ui/ToolTip.py: -------------------------------------------------------------------------------- 1 | import customtkinter as ctk 2 | 3 | 4 | class ToolTip: 5 | """ 6 | create a tooltip for a given widget 7 | """ 8 | 9 | def __init__(self, widget, text='widget info', x_position=20, wide=False): 10 | self.widget = widget 11 | self.text = text 12 | self.x_position = x_position 13 | 14 | self.waittime = 500 # miliseconds 15 | self.wraplength = 180 if not wide else 350 # pixels 16 | self.widget.bind("", self.enter) 17 | self.widget.bind("", self.leave) 18 | self.widget.bind("", self.leave) 19 | self.id = None 20 | self.tw = None 21 | 22 | def enter(self, event=None): 23 | self.schedule() 24 | 25 | def leave(self, event=None): 26 | self.unschedule() 27 | self.hidetip() 28 | 29 | def schedule(self): 30 | self.unschedule() 31 | self.id = self.widget.after(self.waittime, self.showtip) 32 | 33 | def unschedule(self): 34 | id = self.id 35 | self.id = None 36 | if id: 37 | self.widget.after_cancel(id) 38 | 39 | def showtip(self, event=None): 40 | x = y = 0 41 | x, y, cx, cy = self.widget.bbox("insert") 42 | x += self.widget.winfo_rootx() + 25 43 | y += self.widget.winfo_rooty() + self.x_position 44 | # creates a toplevel window 45 | self.tw = ctk.CTkToplevel(self.widget) 46 | # Leaves only the label and removes the app window 47 | self.tw.wm_overrideredirect(True) 48 | self.tw.wm_geometry(f"+{x}+{y}") 49 | label = ctk.CTkLabel(self.tw, text=self.text, justify='left', wraplength=self.wraplength) 50 | label.pack(padx=8, pady=8) 51 | 52 | def hidetip(self): 53 | tw = self.tw 54 | self.tw = None 55 | if tw: 56 | tw.destroy() 57 | -------------------------------------------------------------------------------- /modules/util/ui/dialogs.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import customtkinter as ctk 4 | 5 | 6 | class StringInputDialog(ctk.CTkToplevel): 7 | def __init__( 8 | self, 9 | parent, 10 | title: str, 11 | question: str, 12 | callback: Callable[[str], None], 13 | default_value: str = None, 14 | validate_callback: Callable[[str], bool] = None, 15 | *args, **kwargs 16 | ): 17 | super().__init__(parent, *args, **kwargs) 18 | self.parent = parent 19 | 20 | self.callback = callback 21 | self.validate_callback = validate_callback 22 | 23 | self.grid_columnconfigure(0, weight=1) 24 | self.grid_columnconfigure(1, weight=1) 25 | 26 | self.title(title) 27 | self.geometry("300x120") 28 | self.resizable(False, False) 29 | self.wait_visibility() 30 | self.grab_set() 31 | self.focus_set() 32 | 33 | self.question_label = ctk.CTkLabel(self, text=question) 34 | self.question_label.grid(row=0, column=0, columnspan=2, sticky="we", padx=5, pady=5) 35 | 36 | self.entry = ctk.CTkEntry(self, width=150) 37 | self.entry.grid(row=1, column=0, columnspan=2, sticky="we", padx=10, pady=5) 38 | 39 | self.ok_button = ctk.CTkButton(self, width=30, text="ok", command=self.ok) 40 | self.ok_button.grid(row=2, column=0, sticky="we", padx=10, pady=5) 41 | 42 | self.ok_button = ctk.CTkButton(self, width=30, text="cancel", command=self.cancel) 43 | self.ok_button.grid(row=2, column=1, sticky="we", padx=10, pady=5) 44 | 45 | if default_value is not None: 46 | self.entry.insert(0, default_value) 47 | 48 | def ok(self): 49 | if self.validate_callback is None or self.validate_callback(self.entry.get()): 50 | self.callback(self.entry.get()) 51 | self.destroy() 52 | 53 | def cancel(self): 54 | self.destroy() 55 | -------------------------------------------------------------------------------- /modules/zluda/ZLUDA.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.util.config.TrainConfig import TrainConfig 3 | 4 | import torch 5 | from torch._prims_common import DeviceLikeType 6 | 7 | 8 | def is_zluda(device: DeviceLikeType): 9 | device = torch.device(device) 10 | if device.type == "cpu": 11 | return False 12 | return torch.cuda.get_device_name(device).endswith("[ZLUDA]") 13 | 14 | 15 | def test(device: DeviceLikeType) -> Exception | None: 16 | device = torch.device(device) 17 | try: 18 | ten1 = torch.randn((2, 4,), device=device) 19 | ten2 = torch.randn((4, 8,), device=device) 20 | out = torch.mm(ten1, ten2) 21 | assert out.sum().is_nonzero() 22 | return None 23 | except Exception as e: 24 | return e 25 | 26 | 27 | def initialize(): 28 | torch.backends.cudnn.enabled = False 29 | torch.backends.cuda.enable_flash_sdp(False) 30 | torch.backends.cuda.enable_math_sdp(True) 31 | torch.backends.cuda.enable_mem_efficient_sdp(False) 32 | if hasattr(torch.backends.cuda, "enable_cudnn_sdp"): 33 | torch.backends.cuda.enable_cudnn_sdp(False) 34 | 35 | 36 | def initialize_devices(config: TrainConfig): 37 | if not is_zluda(config.train_device) and not is_zluda(config.temp_device): 38 | return 39 | devices = [config.train_device, config.temp_device,] 40 | for i in range(2): 41 | device = torch.device(devices[i]) 42 | result = test(device) 43 | if result is not None: 44 | print(f'ZLUDA device failed to pass basic operation test: index={device.index}, device_name={torch.cuda.get_device_name(device)}') 45 | print(result) 46 | devices[i] = 'cpu' 47 | config.train_device, config.temp_device = devices 48 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | requires-python = ">=3.10" 3 | 4 | [tool.ruff] 5 | extend-exclude = [ 6 | # Exclude all third-party dependencies and environments. 7 | # NOTE: Conda installs mgds+diffusers into "src/" in the project directory. 8 | ".venv*", 9 | "venv*", 10 | "conda_env*", 11 | "src", 12 | # Do not lint the universal Python 2/3 version checker. 13 | "scripts/util/version_check.py", 14 | ] 15 | line-length = 120 16 | 17 | [tool.ruff.lint] 18 | select = ["F", "E", "W", "I", "B", "UP", "YTT", "BLE", "C4", "T10", "ISC", "ICN", "PIE", "PYI", "RSE", "RET", "SIM", "PGH", "FLY", "NPY", "PERF"] 19 | ignore = ["BLE001", "E402", "E501", "B024", "PGH003", "RET504", "RET505", "SIM102", "UP015", "PYI041"] 20 | 21 | [tool.ruff.lint.isort.sections] 22 | hf = ["diffusers*", "transformers*"] 23 | mgds = ["mgds"] 24 | torch = ["torch*"] 25 | 26 | [tool.ruff.format] 27 | quote-style = "double" 28 | docstring-code-format = true 29 | 30 | [tool.ruff.lint.isort] 31 | known-first-party = ["modules"] 32 | section-order = [ 33 | "future", 34 | "standard-library", 35 | "first-party", 36 | "mgds", 37 | "torch", 38 | "hf", 39 | "third-party", 40 | "local-folder", 41 | ] 42 | -------------------------------------------------------------------------------- /requirements-cuda.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | --extra-index-url https://download.pytorch.org/whl/cu124 3 | torch==2.6.0+cu124 4 | torchvision==0.21.0+cu124 5 | onnxruntime-gpu==1.20.1 6 | 7 | # optimizers 8 | bitsandbytes==0.45.2 # bitsandbytes for 8-bit optimizers and weight quantization 9 | -------------------------------------------------------------------------------- /requirements-default.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | torch==2.6.0 3 | torchvision==0.21.0 4 | onnxruntime==1.20.1 5 | 6 | # optimizers 7 | # TODO 8 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Git commit hook handler. 2 | # NOTE: It's a user-wide tool installed outside venv. Don't pin exact versions. 3 | # SEE: https://pre-commit.com/ 4 | pre-commit>=4.0.1 5 | -------------------------------------------------------------------------------- /requirements-global.txt: -------------------------------------------------------------------------------- 1 | # base requirements 2 | numpy==2.2.2 3 | opencv-python==4.11.0.86 4 | pillow==11.1.0 5 | imagesize==1.4.1 #for concept statistics 6 | tqdm==4.67.1 7 | PyYAML==6.0.2 8 | huggingface-hub==0.28.1 9 | scipy==1.15.1; sys_platform != 'win32' 10 | matplotlib==3.10.0 11 | av==14.1.0 # using an older version. only update once torchvision has been updated as well. 12 | yt-dlp #no pinned version, frequently updated for compatibility with sites 13 | scenedetect==0.6.6 14 | 15 | # pytorch 16 | accelerate==1.3.0 17 | safetensors==0.5.2 18 | tensorboard==2.18.0 19 | pytorch-lightning==2.5.0.post0 20 | 21 | # diffusion models 22 | -e git+https://github.com/huggingface/diffusers.git@5873377#egg=diffusers 23 | transformers==4.48.3 24 | sentencepiece==0.2.0 # transitive dependency of transformers for tokenizer loading 25 | omegaconf==2.3.0 # needed to load stable diffusion from single ckpt files 26 | invisible-watermark==0.2.0 # needed for the SDXL pipeline 27 | 28 | # model conversion 29 | -e git+https://github.com/Open-Model-Initiative/OMI-Model-Standards.git@4ad235c#egg=omi_model_standards 30 | 31 | # other models 32 | pooch==1.8.2 33 | open-clip-torch==2.30.0 34 | 35 | # data loader 36 | -e git+https://github.com/Nerogar/mgds.git@11ff4aa#egg=mgds 37 | 38 | # optimizers 39 | dadaptation==3.2 # dadaptation optimizers 40 | lion-pytorch==0.2.3 # lion optimizer 41 | prodigyopt==1.1.2 # prodigy optimizer 42 | schedulefree==1.4.0 # schedule-free optimizers 43 | pytorch_optimizer==3.4.0 # pytorch optimizers 44 | prodigy-plus-schedule-free==1.9.1 # prodigy+schedulefree optimizer 45 | 46 | # Profiling 47 | scalene==1.5.51 48 | 49 | # ui 50 | customtkinter==5.2.2 51 | 52 | # cloud 53 | runpod==1.7.7 54 | fabric==3.2.2 55 | 56 | # debug 57 | psutil==6.1.1 58 | requests==2.32.3 59 | -------------------------------------------------------------------------------- /requirements-rocm.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | --extra-index-url https://download.pytorch.org/whl/rocm6.2.4 3 | torch==2.6.0+rocm6.2.4 4 | torchvision==0.21.0+rocm6.2.4 5 | onnxruntime==1.20.1 6 | 7 | # optimizers 8 | # TODO 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements-cuda.txt 2 | -r requirements-global.txt 3 | -------------------------------------------------------------------------------- /resources/docker/NVIDIA-UI.Dockerfile: -------------------------------------------------------------------------------- 1 | # Note: as of April 2025, this Dockerfile is outdated and requires adjustments 2 | 3 | # Inspiration for setup @ https://dev.to/ordigital/nvidia-525-cuda-118-python-310-pytorch-gpu-docker-image-1l4a 4 | FROM docker.io/nvidia/cuda:11.8.0-devel-ubuntu22.04 5 | 6 | ENV PYTHONUNBUFFERED=1 7 | 8 | # SYSTEM 9 | RUN apt-get update --yes --quiet && DEBIAN_FRONTEND=noninteractive apt-get install --yes --quiet --no-install-recommends \ 10 | software-properties-common \ 11 | build-essential apt-utils \ 12 | wget curl vim git ca-certificates kmod \ 13 | nvidia-driver-525 \ 14 | && rm -rf /var/lib/apt/lists/* 15 | 16 | # PYTHON 3.10 17 | RUN add-apt-repository --yes ppa:deadsnakes/ppa && apt-get update --yes --quiet 18 | RUN DEBIAN_FRONTEND=noninteractive apt-get install --yes --quiet --no-install-recommends \ 19 | python3.10 \ 20 | python3.10-dev \ 21 | python3.10-distutils \ 22 | python3.10-lib2to3 \ 23 | python3.10-gdbm \ 24 | python3.10-tk \ 25 | pip 26 | 27 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 999 \ 28 | && update-alternatives --config python3 && ln -s /usr/bin/python3 /usr/bin/python 29 | 30 | RUN pip install --upgrade pip 31 | 32 | # Create and set the working directory 33 | RUN mkdir -p /OneTrainer 34 | WORKDIR /OneTrainer 35 | 36 | # Copy the current directory's contents to the container image 37 | COPY . /OneTrainer 38 | WORKDIR /OneTrainer 39 | 40 | # Install requirements 41 | RUN python3 --version 42 | RUN python3 -m pip install -r requirements.txt 43 | 44 | # Run the training UI 45 | CMD ["python3", "scripts/train_ui.py"] 46 | -------------------------------------------------------------------------------- /resources/docker/RunPod-NVIDIA-CLI-start.sh.patch: -------------------------------------------------------------------------------- 1 | 83a84,88 2 | > link_onetrainer() { 3 | > if [[ ! -f /workspace/OneTrainer ]]; then 4 | > ln -snf /OneTrainer /workspace/OneTrainer 5 | > fi 6 | > } 7 | 93a100 8 | > link_onetrainer 9 | -------------------------------------------------------------------------------- /resources/docker/RunPod-NVIDIA-CLI.Dockerfile: -------------------------------------------------------------------------------- 1 | #To build, run 2 | # docker build -t . -f RunPod-NVIDIA-CLI.Dockerfile 3 | # docker tag /: 4 | # docker push /: 5 | 6 | FROM runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04 7 | #the base image is barely used, pytorch 2.4 is the wrong version. However, by using 8 | #a base image that is popular on RunPod, the base image likely is already available 9 | #in the image cache of a pod, and no download is necessary 10 | 11 | WORKDIR / 12 | RUN git clone https://github.com/Nerogar/OneTrainer 13 | RUN cd OneTrainer \ 14 | && export OT_PLATFORM_REQUIREMENTS=requirements-cuda.txt \ 15 | && export OT_LAZY_UPDATES=true \ 16 | && ./install.sh \ 17 | && pip cache purge \ 18 | && rm -r ~/.cache/pip 19 | RUN apt-get update --yes \ 20 | && apt-get install --yes --no-install-recommends \ 21 | joe \ 22 | less \ 23 | gh \ 24 | iputils-ping \ 25 | nano \ 26 | && apt-get autoremove -y \ 27 | && apt-get clean \ 28 | && rm -rf /var/lib/apt/lists/* 29 | RUN pip install nvitop \ 30 | && pip cache purge \ 31 | && rm -rf ~/.cache/pip 32 | COPY RunPod-NVIDIA-CLI-start.sh.patch /start.sh.patch 33 | RUN patch /start.sh < /start.sh.patch 34 | -------------------------------------------------------------------------------- /resources/icons/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/16e938d3c9ca97b0450e8811fcf0cc2f884f7daa/resources/icons/icon.ico -------------------------------------------------------------------------------- /resources/icons/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/16e938d3c9ca97b0450e8811fcf0cc2f884f7daa/resources/icons/icon.png -------------------------------------------------------------------------------- /resources/icons/icon_discord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/16e938d3c9ca97b0450e8811fcf0cc2f884f7daa/resources/icons/icon_discord.png -------------------------------------------------------------------------------- /resources/icons/icon_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/16e938d3c9ca97b0450e8811fcf0cc2f884f7daa/resources/icons/icon_small.png -------------------------------------------------------------------------------- /resources/images/OneTrainerGUI.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nerogar/OneTrainer/16e938d3c9ca97b0450e8811fcf0cc2f884f7daa/resources/images/OneTrainerGUI.gif -------------------------------------------------------------------------------- /resources/model_config/stable_cascade/stable_cascade_prior_1.0b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableCascadeUNet", 3 | "_diffusers_version": "0.27.0.dev0", 4 | "block_out_channels": [ 5 | 1536, 6 | 1536 7 | ], 8 | "block_types_per_layer": [ 9 | [ 10 | "SDCascadeResBlock", 11 | "SDCascadeTimestepBlock", 12 | "SDCascadeAttnBlock" 13 | ], 14 | [ 15 | "SDCascadeResBlock", 16 | "SDCascadeTimestepBlock", 17 | "SDCascadeAttnBlock" 18 | ] 19 | ], 20 | "clip_image_in_channels": 768, 21 | "clip_seq": 4, 22 | "clip_text_in_channels": 1280, 23 | "clip_text_pooled_in_channels": 1280, 24 | "conditioning_dim": 1536, 25 | "down_blocks_repeat_mappers": [ 26 | 1, 27 | 1 28 | ], 29 | "down_num_layers_per_block": [ 30 | 4, 31 | 12 32 | ], 33 | "dropout": [ 34 | 0.1, 35 | 0.1 36 | ], 37 | "effnet_in_channels": null, 38 | "in_channels": 16, 39 | "kernel_size": 3, 40 | "num_attention_heads": [ 41 | 24, 42 | 24 43 | ], 44 | "out_channels": 16, 45 | "patch_size": 1, 46 | "pixel_mapper_in_channels": null, 47 | "self_attn": true, 48 | "switch_level": [ 49 | false 50 | ], 51 | "timestep_conditioning_type": [ 52 | "sca", 53 | "crp" 54 | ], 55 | "timestep_ratio_embedding_dim": 64, 56 | "up_blocks_repeat_mappers": [ 57 | 1, 58 | 1 59 | ], 60 | "up_num_layers_per_block": [ 61 | 12, 62 | 4 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /resources/model_config/stable_cascade/stable_cascade_prior_3.6b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableCascadeUNet", 3 | "_diffusers_version": "0.27.0.dev0", 4 | "block_out_channels": [ 5 | 2048, 6 | 2048 7 | ], 8 | "block_types_per_layer": [ 9 | [ 10 | "SDCascadeResBlock", 11 | "SDCascadeTimestepBlock", 12 | "SDCascadeAttnBlock" 13 | ], 14 | [ 15 | "SDCascadeResBlock", 16 | "SDCascadeTimestepBlock", 17 | "SDCascadeAttnBlock" 18 | ] 19 | ], 20 | "clip_image_in_channels": 768, 21 | "clip_seq": 4, 22 | "clip_text_in_channels": 1280, 23 | "clip_text_pooled_in_channels": 1280, 24 | "conditioning_dim": 2048, 25 | "down_blocks_repeat_mappers": [ 26 | 1, 27 | 1 28 | ], 29 | "down_num_layers_per_block": [ 30 | 8, 31 | 24 32 | ], 33 | "dropout": [ 34 | 0.1, 35 | 0.1 36 | ], 37 | "effnet_in_channels": null, 38 | "in_channels": 16, 39 | "kernel_size": 3, 40 | "num_attention_heads": [ 41 | 32, 42 | 32 43 | ], 44 | "out_channels": 16, 45 | "patch_size": 1, 46 | "pixel_mapper_in_channels": null, 47 | "self_attn": true, 48 | "switch_level": [ 49 | false 50 | ], 51 | "timestep_conditioning_type": [ 52 | "sca", 53 | "crp" 54 | ], 55 | "timestep_ratio_embedding_dim": 64, 56 | "up_blocks_repeat_mappers": [ 57 | 1, 58 | 1 59 | ], 60 | "up_num_layers_per_block": [ 61 | 24, 62 | 8 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-dev/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-dev/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-dev", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_fill_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-fill-dev/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev Fill 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_fill_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-fill-dev/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev Fill 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/flux_dev_fill_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "Flux.1-fill-dev", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "FluxDev Fill 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hi_dream_full-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hidream-i1/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HiDream I1 Full Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hi_dream_full-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hidream-i1/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HiDream I1 Full LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hi_dream_full.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hidream-i1", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HiDream I1 Full" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hunyuan_video-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hunyuan-video/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HunyuanVideo Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hunyuan_video-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hunyuan-video/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HunyuanVideo LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/hunyuan_video.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "hunyuan-video", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "HunyuanVideo" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_alpha_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-alpha/embedding", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-alpha", 5 | "modelspec.title": "PixArt Alpha 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_alpha_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-alpha/lora", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-alpha", 5 | "modelspec.title": "PixArt Alpha 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_alpha_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-alpha", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-alpha", 5 | "modelspec.title": "PixArt Alpha 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_sigma_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-sigma/embedding", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-sigma", 5 | "modelspec.title": "PixArt Sigma 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_sigma_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-sigma/lora", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-sigma", 5 | "modelspec.title": "PixArt Sigma 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/pixart_sigma_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "pixart-sigma", 4 | "modelspec.implementation": "https://github.com/PixArt-alpha/PixArt-sigma", 5 | "modelspec.title": "PixArt Sigma 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sana-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "sana/embedding", 4 | "modelspec.implementation": "https://github.com/NVlabs/Sana", 5 | "modelspec.title": "Sana Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sana-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "sana/lora", 4 | "modelspec.implementation": "https://github.com/NVlabs/Sana", 5 | "modelspec.title": "Sana LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sana.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "sana", 4 | "modelspec.implementation": "https://github.com/NVlabs/Sana", 5 | "modelspec.title": "Sana" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1", 4 | "modelspec.implementation": "https://github.com/CompVis/stable-diffusion", 5 | "modelspec.title": "Stable Diffusion 1.5" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5_inpainting-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1-inpainting/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 Inpainting Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5_inpainting-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1-inpainting/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 1.5 Inpainting LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_1.5_inpainting.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v1-inpainting", 4 | "modelspec.implementation": "https://github.com/CompVis/stable-diffusion", 5 | "modelspec.title": "Stable Diffusion 1.5 Inpainting" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_depth-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-depth/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Depth Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_depth-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-depth/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Depth LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_depth.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-depth", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.0 Depth" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_inpainting-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-inpainting/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Inpainting Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_inpainting-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-inpainting/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.0 Inpainting LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.0_inpainting.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2-inpainting", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.0 Inpainting" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.1-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.1 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.1-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 2.1 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_2.1.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v2", 4 | "modelspec.implementation": "https://github.com/Stability-AI/StableDiffusion", 5 | "modelspec.title": "Stable Diffusion 2.1" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3.5_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3.5/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3.5 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3.5_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3.5/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3.5 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3.5_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3.5", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3.5" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3_2b_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3-medium/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3 Medium Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3_2b_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3-medium/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3 Medium LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_3_2b_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-v3-medium", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion 3 Medium" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0_Inpainting-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base-inpainting/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Inpainting Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0_inpainting-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base-inpainting/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Inpainting LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/sd_xl_base_1.0_inpainting.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-diffusion-xl-v1-base-inpainting", 4 | "modelspec.implementation": "https://github.com/Stability-AI/generative-models", 5 | "modelspec.title": "Stable Diffusion XL 1.0 Base Inpainting" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/stable_cascade_1.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-cascade-v1-prior/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Cascade 1.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/stable_cascade_1.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-cascade-v1-prior/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Cascade 1.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/stable_cascade_1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "stable-cascade-v1-prior", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Stable Cascade 1.0" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/wuerstchen_2.0-embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "wuerstchen-v2-prior/embedding", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Wuerstchen 2.0 Embedding" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/wuerstchen_2.0-lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "wuerstchen-v2-prior/lora", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Wuerstchen 2.0 LoRA" 6 | } 7 | -------------------------------------------------------------------------------- /resources/sd_model_spec/wuerstchen_2.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelspec.sai_model_spec": "1.0.0", 3 | "modelspec.architecture": "wuerstchen-v2-prior", 4 | "modelspec.implementation": "https://github.com/huggingface/diffusers", 5 | "modelspec.title": "Wuerstchen 2.0" 6 | } 7 | -------------------------------------------------------------------------------- /run-cmd.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | source "${BASH_SOURCE[0]%/*}/lib.include.sh" 6 | 7 | # Fetch and validate the name of the target script. 8 | if [[ -z "${1}" ]]; then 9 | print_error "You must provide the name of the script to execute, such as \"train\"." 10 | exit 1 11 | fi 12 | 13 | OT_CUSTOM_SCRIPT_FILE="scripts/${1}.py" 14 | if [[ ! -f "${OT_CUSTOM_SCRIPT_FILE}" ]]; then 15 | print_error "Custom script file \"${OT_CUSTOM_SCRIPT_FILE}\" does not exist." 16 | exit 1 17 | fi 18 | 19 | prepare_runtime_environment 20 | 21 | # Remove $1 (name of the script) and pass all remaining arguments to the script. 22 | shift 23 | run_python_in_active_env "${OT_CUSTOM_SCRIPT_FILE}" "$@" 24 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # OneTrainer Scripts 2 | 3 | For an overview of these scripts, see the 4 | [CLI Mode subsection](../README.md#cli-mode) in the OneTrainer README. 5 | -------------------------------------------------------------------------------- /scripts/calculate_loss.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | import json 6 | 7 | from modules.module.GenerateLossesModel import GenerateLossesModel 8 | from modules.util.args.CalculateLossArgs import CalculateLossArgs 9 | from modules.util.config.TrainConfig import TrainConfig 10 | 11 | 12 | def main(): 13 | args = CalculateLossArgs.parse_args() 14 | 15 | train_config = TrainConfig.default_values() 16 | with open(args.config_path, "r") as f: 17 | train_config.from_dict(json.load(f)) 18 | 19 | trainer = GenerateLossesModel(train_config, args.output_path) 20 | trainer.start() 21 | 22 | 23 | if __name__ == '__main__': 24 | main() 25 | -------------------------------------------------------------------------------- /scripts/caption_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.CaptionUI import CaptionUI 6 | from modules.util.args.CaptionUIArgs import CaptionUIArgs 7 | 8 | 9 | def main(): 10 | args = CaptionUIArgs.parse_args() 11 | 12 | ui = CaptionUI(None, args.dir, args.include_subdirectories) 13 | ui.mainloop() 14 | 15 | 16 | if __name__ == '__main__': 17 | main() 18 | -------------------------------------------------------------------------------- /scripts/convert_model.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from uuid import uuid4 6 | 7 | from modules.util import create 8 | from modules.util.args.ConvertModelArgs import ConvertModelArgs 9 | from modules.util.enum.TrainingMethod import TrainingMethod 10 | from modules.util.ModelNames import EmbeddingName, ModelNames 11 | 12 | 13 | def main(): 14 | args = ConvertModelArgs.parse_args() 15 | 16 | model_loader = create.create_model_loader(model_type=args.model_type, training_method=args.training_method) 17 | model_saver = create.create_model_saver(model_type=args.model_type, training_method=args.training_method) 18 | 19 | print("Loading model " + args.input_name) 20 | if args.training_method in [TrainingMethod.FINE_TUNE]: 21 | model = model_loader.load( 22 | model_type=args.model_type, 23 | model_names=ModelNames( 24 | base_model=args.input_name, 25 | ), 26 | weight_dtypes=args.weight_dtypes(), 27 | ) 28 | elif args.training_method in [TrainingMethod.LORA, TrainingMethod.EMBEDDING]: 29 | model = model_loader.load( 30 | model_type=args.model_type, 31 | model_names=ModelNames( 32 | lora=args.input_name, 33 | embedding=EmbeddingName(str(uuid4()), args.input_name), 34 | ), 35 | weight_dtypes=args.weight_dtypes(), 36 | ) 37 | else: 38 | raise Exception("could not load model: " + args.input_name) 39 | 40 | print("Saving model " + args.output_model_destination) 41 | model_saver.save( 42 | model=model, 43 | model_type=args.model_type, 44 | output_model_format=args.output_model_format, 45 | output_model_destination=args.output_model_destination, 46 | dtype=args.output_dtype.torch_dtype(), 47 | ) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /scripts/convert_model_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.ConvertModelUI import ConvertModelUI 6 | 7 | 8 | def main(): 9 | ui = ConvertModelUI(None) 10 | ui.mainloop() 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /scripts/create_train_files.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | import json 6 | import os 7 | from pathlib import Path 8 | 9 | from modules.util.args.CreateTrainFilesArgs import CreateTrainFilesArgs 10 | from modules.util.config.ConceptConfig import ConceptConfig 11 | from modules.util.config.SampleConfig import SampleConfig 12 | from modules.util.config.TrainConfig import TrainConfig 13 | 14 | 15 | def main(): 16 | args = CreateTrainFilesArgs.parse_args() 17 | 18 | print(args.to_dict()) 19 | 20 | if args.config_output_destination: 21 | print("config") 22 | data = TrainConfig.default_values().to_dict() 23 | os.makedirs(Path(path=args.config_output_destination).parent.absolute(), exist_ok=True) 24 | 25 | with open(args.config_output_destination, "w") as f: 26 | json.dump(data, f, indent=4) 27 | 28 | if args.concepts_output_destination: 29 | print("concepts") 30 | data = [ConceptConfig.default_values().to_dict()] 31 | os.makedirs(Path(path=args.concepts_output_destination).parent.absolute(), exist_ok=True) 32 | 33 | with open(args.concepts_output_destination, "w") as f: 34 | json.dump(data, f, indent=4) 35 | 36 | if args.samples_output_destination: 37 | print("samples") 38 | data = [SampleConfig.default_values().to_dict()] 39 | os.makedirs(Path(path=args.samples_output_destination).parent.absolute(), exist_ok=True) 40 | 41 | with open(args.samples_output_destination, "w") as f: 42 | json.dump(data, f, indent=4) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/generate_captions.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.module.Blip2Model import Blip2Model 6 | from modules.module.BlipModel import BlipModel 7 | from modules.module.WDModel import WDModel 8 | from modules.util.args.GenerateCaptionsArgs import GenerateCaptionsArgs 9 | from modules.util.enum.GenerateCaptionsModel import GenerateCaptionsModel 10 | 11 | import torch 12 | 13 | 14 | def main(): 15 | args = GenerateCaptionsArgs.parse_args() 16 | 17 | model = None 18 | if args.model == GenerateCaptionsModel.BLIP: 19 | model = BlipModel(torch.device(args.device), args.dtype.torch_dtype()) 20 | elif args.model == GenerateCaptionsModel.BLIP2: 21 | model = Blip2Model(torch.device(args.device), args.dtype.torch_dtype()) 22 | elif args.model == GenerateCaptionsModel.WD14_VIT_2: 23 | model = WDModel(torch.device(args.device), args.dtype.torch_dtype()) 24 | 25 | model.caption_folder( 26 | sample_dir=args.sample_dir, 27 | initial_caption=args.initial_caption, 28 | caption_prefix=args.caption_prefix, 29 | caption_postfix=args.caption_postfix, 30 | mode=args.mode, 31 | error_callback=lambda filename: print("Error while processing image " + filename), 32 | include_subdirectories=args.include_subdirectories 33 | ) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /scripts/generate_masks.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.module.ClipSegModel import ClipSegModel 6 | from modules.module.MaskByColor import MaskByColor 7 | from modules.module.RembgHumanModel import RembgHumanModel 8 | from modules.module.RembgModel import RembgModel 9 | from modules.util.args.GenerateMasksArgs import GenerateMasksArgs 10 | from modules.util.enum.GenerateMasksModel import GenerateMasksModel 11 | 12 | import torch 13 | 14 | 15 | def main(): 16 | args = GenerateMasksArgs.parse_args() 17 | 18 | model = None 19 | if args.model == GenerateMasksModel.CLIPSEG: 20 | model = ClipSegModel(torch.device(args.device), args.dtype.torch_dtype()) 21 | elif args.model == GenerateMasksModel.REMBG: 22 | model = RembgModel(torch.device(args.device), args.dtype.torch_dtype()) 23 | elif args.model == GenerateMasksModel.REMBG_HUMAN: 24 | model = RembgHumanModel(torch.device(args.device), args.dtype.torch_dtype()) 25 | elif args.model == GenerateMasksModel.COLOR: 26 | model = MaskByColor(torch.device(args.device), args.dtype.torch_dtype()) 27 | 28 | model.mask_folder( 29 | sample_dir=args.sample_dir, 30 | prompts=args.prompts, 31 | mode=args.mode, 32 | threshold=args.threshold, 33 | smooth_pixels=args.smooth_pixels, 34 | expand_pixels=args.expand_pixels, 35 | alpha=args.alpha, 36 | error_callback=lambda filename: print("Error while processing image " + filename), 37 | include_subdirectories=args.include_subdirectories 38 | ) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /scripts/install_zluda.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports(allow_zluda=False) 4 | 5 | import sys 6 | 7 | from modules.zluda import ZLUDAInstaller 8 | 9 | if __name__ == '__main__': 10 | try: 11 | zluda_path = ZLUDAInstaller.get_path() 12 | ZLUDAInstaller.install(zluda_path) 13 | ZLUDAInstaller.make_copy(zluda_path) 14 | except Exception as e: 15 | print(f'Failed to install ZLUDA: {e}') 16 | sys.exit(1) 17 | 18 | print(f'ZLUDA installed: {zluda_path}') 19 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | import json 6 | 7 | from modules.trainer.GenericTrainer import GenericTrainer 8 | from modules.util.args.TrainArgs import TrainArgs 9 | from modules.util.callbacks.TrainCallbacks import TrainCallbacks 10 | from modules.util.commands.TrainCommands import TrainCommands 11 | from modules.util.config.SecretsConfig import SecretsConfig 12 | from modules.util.config.TrainConfig import TrainConfig 13 | 14 | 15 | def main(): 16 | args = TrainArgs.parse_args() 17 | callbacks = TrainCallbacks() 18 | commands = TrainCommands() 19 | 20 | train_config = TrainConfig.default_values() 21 | with open(args.config_path, "r") as f: 22 | train_config.from_dict(json.load(f)) 23 | 24 | try: 25 | with open("secrets.json" if args.secrets_path is None else args.secrets_path, "r") as f: 26 | secrets_dict=json.load(f) 27 | train_config.secrets = SecretsConfig.default_values().from_dict(secrets_dict) 28 | except FileNotFoundError: 29 | if args.secrets_path is not None: 30 | raise 31 | 32 | trainer = GenericTrainer(train_config, callbacks, commands) 33 | 34 | trainer.start() 35 | 36 | canceled = False 37 | try: 38 | trainer.train() 39 | except KeyboardInterrupt: 40 | canceled = True 41 | 42 | if not canceled or train_config.backup_before_save: 43 | trainer.end() 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/train_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.TrainUI import TrainUI 6 | 7 | 8 | def main(): 9 | ui = TrainUI() 10 | ui.mainloop() 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /scripts/util/import_util.py: -------------------------------------------------------------------------------- 1 | def script_imports(allow_zluda: bool = True): 2 | import logging 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | # Filter out the Triton warning on startup. 8 | # xformers is not installed anymore, but might still exist for some installations. 9 | logging \ 10 | .getLogger("xformers") \ 11 | .addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) 12 | 13 | # Insert ourselves as the highest-priority library path, so our modules are 14 | # always found without any risk of being shadowed by another import path. 15 | # 3 .parent calls to navigate from /scripts/util/import_util.py to the main directory 16 | onetrainer_lib_path = Path(__file__).absolute().parent.parent.parent 17 | sys.path.insert(0, str(onetrainer_lib_path)) 18 | 19 | if allow_zluda and sys.platform.startswith('win'): 20 | from modules.zluda import ZLUDAInstaller 21 | 22 | zluda_path = ZLUDAInstaller.get_path() 23 | 24 | if os.path.exists(zluda_path): 25 | try: 26 | ZLUDAInstaller.load(zluda_path) 27 | print(f'Using ZLUDA in {zluda_path}') 28 | except Exception as e: 29 | print(f'Failed to load ZLUDA: {e}') 30 | 31 | from modules.zluda import ZLUDA 32 | 33 | ZLUDA.initialize() 34 | -------------------------------------------------------------------------------- /scripts/util/version_check.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | # Python Version Check. 5 | # IMPORTANT: All code below must be backwards-compatible with Python 2+. 6 | 7 | 8 | def exit_err(msg): 9 | sys.stderr.write("Error: " + msg + "\n") 10 | sys.stderr.flush() 11 | sys.exit(1) 12 | 13 | 14 | def str_to_tuple(data): 15 | return tuple(map(lambda x: int(x, 10), data.split("."))) 16 | 17 | 18 | def tuple_to_str(data): 19 | return ".".join(map(str, data)) 20 | 21 | 22 | def exit_wrong_version(msg, min_ver, too_high_ver): 23 | exit_err( 24 | "Your Python version is %s: %s. Must be >= %s and < %s." 25 | % (msg, sys.version, tuple_to_str(min_ver), tuple_to_str(too_high_ver)) 26 | ) 27 | 28 | 29 | if len(sys.argv) < 3: 30 | exit_err("Version check requires 2 arguments: [min_ver] [too_high_ver]") 31 | 32 | min_ver = str_to_tuple(sys.argv[1]) 33 | too_high_ver = str_to_tuple(sys.argv[2]) 34 | 35 | # Specifically exclude Python 3.11.0 as Scalene does NOT support it https://pypi.org/project/scalene/ 36 | if sys.version_info[:3] == (3, 11, 0): 37 | exit_err("Python 3.11.0 specifically is not supported (due to Scalene). Please use a different Python version.") 38 | 39 | if sys.version_info < min_ver: 40 | exit_wrong_version("too low", min_ver, too_high_ver) 41 | 42 | if sys.version_info >= too_high_ver: 43 | exit_wrong_version("too high", min_ver, too_high_ver) 44 | -------------------------------------------------------------------------------- /scripts/video_tool_ui.py: -------------------------------------------------------------------------------- 1 | from util.import_util import script_imports 2 | 3 | script_imports() 4 | 5 | from modules.ui.VideoToolUI import VideoToolUI 6 | 7 | 8 | def main(): 9 | ui = VideoToolUI(None) 10 | ui.mainloop() 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /start-ui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Avoid footgun by explictly navigating to the directory containing the batch file 4 | cd /d "%~dp0" 5 | 6 | REM Verify that OneTrainer is our current working directory 7 | if not exist "scripts\train_ui.py" ( 8 | echo Error: train_ui.py does not exist, you have done something very wrong. Reclone the repository. 9 | goto :end 10 | ) 11 | 12 | if not defined PYTHON ( 13 | where python >NUL 2>NUL 14 | if errorlevel 1 ( 15 | echo Error: Python is not installed or not in PATH 16 | goto :end 17 | ) 18 | set PYTHON=python 19 | ) 20 | if not defined VENV_DIR (set "VENV_DIR=%~dp0venv") 21 | 22 | :check_venv 23 | dir "%VENV_DIR%" > NUL 2> NUL 24 | if not errorlevel 1 goto :activate_venv 25 | echo venv not found, please run install.bat first 26 | goto :end 27 | 28 | :activate_venv 29 | echo activating venv %VENV_DIR% 30 | if not exist "%VENV_DIR%\Scripts\python.exe" ( 31 | echo Error: Python executable not found in virtual environment 32 | goto :end 33 | ) 34 | set PYTHON="%VENV_DIR%\Scripts\python.exe" 35 | if defined PROFILE (set PYTHON=%PYTHON% -m scalene --off --cpu --gpu --profile-all --no-browser) 36 | echo Using Python %PYTHON% 37 | 38 | :launch 39 | echo Starting UI... 40 | %PYTHON% scripts\train_ui.py 41 | if errorlevel 1 ( 42 | echo Error: UI script exited with code %ERRORLEVEL% 43 | ) 44 | 45 | :end 46 | pause 47 | -------------------------------------------------------------------------------- /start-ui.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | source "${BASH_SOURCE[0]%/*}/lib.include.sh" 6 | 7 | prepare_runtime_environment 8 | 9 | run_python_in_active_env "scripts/train_ui.py" "$@" 10 | -------------------------------------------------------------------------------- /training_presets/#flux LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "black-forest-labs/FLUX.1-dev", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "FLUX_DEV_1", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "768", 10 | "prior": { 11 | "train": true, 12 | "weight_dtype": "NFLOAT_4" 13 | }, 14 | "text_encoder": { 15 | "train": false 16 | }, 17 | "text_encoder_2": { 18 | "train": false, 19 | "weight_dtype": "NFLOAT_4" 20 | }, 21 | "training_method": "LORA", 22 | "vae": { 23 | "weight_dtype": "FLOAT_32" 24 | }, 25 | "train_dtype": "BFLOAT_16", 26 | "weight_dtype": "BFLOAT_16", 27 | "timestep_distribution": "LOGIT_NORMAL", 28 | "dynamic_timestep_shifting": false 29 | } 30 | -------------------------------------------------------------------------------- /training_presets/#hidream LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "HiDream-ai/HiDream-I1-Full", 4 | "batch_size": 4, 5 | "gradient_checkpointing": "CPU_OFFLOADED", 6 | "layer_offload_fraction": 0.5, 7 | "dataloader_threads": 1, 8 | "learning_rate": 0.0003, 9 | "model_type": "HI_DREAM_FULL", 10 | "output_model_destination": "models/lora.safetensors", 11 | "output_model_format": "SAFETENSORS", 12 | "resolution": "512", 13 | "timestep_distribution": "LOGIT_NORMAL", 14 | "dynamic_timestep_shifting": false, 15 | "train_dtype": "BFLOAT_16", 16 | "training_method": "LORA", 17 | "weight_dtype": "FLOAT_8", 18 | "text_encoder": { 19 | "train": false 20 | }, 21 | "text_encoder_2": { 22 | "train": false 23 | }, 24 | "text_encoder_3": { 25 | "train": false 26 | }, 27 | "text_encoder_4": { 28 | "model_name": "meta-llama/Llama-3.1-8B-Instruct", 29 | "train": false 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /training_presets/#hunyuan video LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "hunyuanvideo-community/HunyuanVideo", 4 | "batch_size": 4, 5 | "gradient_checkpointing": "CPU_OFFLOADED", 6 | "layer_offload_fraction": 0.5, 7 | "dataloader_threads": 1, 8 | "learning_rate": 0.0003, 9 | "model_type": "HUNYUAN_VIDEO", 10 | "output_model_destination": "models/lora.safetensors", 11 | "output_model_format": "SAFETENSORS", 12 | "resolution": "512", 13 | "timestep_distribution": "LOGIT_NORMAL", 14 | "dynamic_timestep_shifting": false, 15 | "train_dtype": "BFLOAT_16", 16 | "training_method": "LORA", 17 | "weight_dtype": "FLOAT_8", 18 | "text_encoder": { 19 | "train": false 20 | }, 21 | "text_encoder_2": { 22 | "train": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /training_presets/#pixart alpha 1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "PixArt-alpha/PixArt-XL-2-1024-MS", 3 | "model_type": "PIXART_ALPHA", 4 | "output_model_destination": "models/model", 5 | "output_model_format": "DIFFUSERS", 6 | "resolution": "1024", 7 | "text_encoder": { 8 | "train": false, 9 | "weight_dtype": "BFLOAT_16" 10 | }, 11 | "training_method": "FINE_TUNE" 12 | } 13 | -------------------------------------------------------------------------------- /training_presets/#pixart sigma 1.0 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "PixArt-sigma/PixArt-XL-2-1024-MS", 3 | "batch_size": 4, 4 | "model_type": "PIXART_SIGMA", 5 | "epochs": 50, 6 | "learning_rate": 0.00001, 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "training_method": "LORA", 14 | "lora_rank": 16, 15 | "lora_alpha": 1.0, 16 | "vae": { 17 | "weight_dtype": "FLOAT_32" 18 | }, 19 | "weight_dtype": "BFLOAT_16" 20 | } 21 | -------------------------------------------------------------------------------- /training_presets/#pixart sigma 1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "PixArt-sigma/PixArt-XL-2-1024-MS", 3 | "batch_size": 8, 4 | "model_type": "PIXART_SIGMA", 5 | "epochs": 100, 6 | "learning_rate": 0.00002, 7 | "output_model_destination": "models/model", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "training_method": "FINE_TUNE", 14 | "vae": { 15 | "weight_dtype": "FLOAT_32" 16 | }, 17 | "weight_dtype": "BFLOAT_16" 18 | } 19 | -------------------------------------------------------------------------------- /training_presets/#sana 1.6b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", 3 | "model_type": "SANA", 4 | "output_model_destination": "models/model", 5 | "output_model_format": "DIFFUSERS", 6 | "resolution": "1024", 7 | "text_encoder": { 8 | "train": false, 9 | "weight_dtype": "BFLOAT_16" 10 | }, 11 | "training_method": "FINE_TUNE", 12 | "train_dtype": "BFLOAT_16" 13 | } 14 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_15", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "512", 10 | "training_method": "LORA", 11 | "vae": { 12 | "weight_dtype": "FLOAT_32" 13 | }, 14 | "weight_dtype": "FLOAT_16" 15 | } 16 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 4 | "latent_caching": false, 5 | "learning_rate": 0.0003, 6 | "learning_rate_warmup_steps": 20, 7 | "model_type": "STABLE_DIFFUSION_15", 8 | "output_model_destination": "models/embedding.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "sample_after": 1, 12 | "training_method": "EMBEDDING", 13 | "vae": { 14 | "weight_dtype": "FLOAT_32" 15 | }, 16 | "weight_dtype": "FLOAT_16" 17 | } 18 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 inpaint masked.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-inpainting", 3 | "batch_size": 4, 4 | "masked_training": true, 5 | "model_type": "STABLE_DIFFUSION_15_INPAINTING", 6 | "normalize_masked_area_loss": true, 7 | "output_model_destination": "models/model.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "512", 10 | "training_method": "FINE_TUNE", 11 | "vae": { 12 | "weight_dtype": "FLOAT_32" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 inpaint.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-inpainting", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_15_INPAINTING", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "vae": { 10 | "weight_dtype": "FLOAT_32" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5 masked.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 3 | "batch_size": 4, 4 | "masked_training": true, 5 | "max_noising_strength": 0.8, 6 | "model_type": "STABLE_DIFFUSION_15", 7 | "normalize_masked_area_loss": true, 8 | "output_model_destination": "models/model.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "training_method": "FINE_TUNE", 12 | "vae": { 13 | "weight_dtype": "FLOAT_32" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /training_presets/#sd 1.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stable-diffusion-v1-5/stable-diffusion-v1-5", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_15", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "unet": { 10 | "train": true 11 | }, 12 | "vae": { 13 | "weight_dtype": "FLOAT_32" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /training_presets/#sd 2.0 inpaint LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-2-inpainting", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_20_INPAINTING", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "512", 10 | "training_method": "LORA", 11 | "vae": { 12 | "weight_dtype": "FLOAT_16" 13 | }, 14 | "weight_dtype": "FLOAT_16" 15 | } 16 | -------------------------------------------------------------------------------- /training_presets/#sd 2.0 inpaint.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-diffusion-2-inpainting", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_20_INPAINTING", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "vae": { 10 | "weight_dtype": "FLOAT_32" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /training_presets/#sd 2.1 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-2-1", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_21", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "512", 10 | "training_method": "LORA", 11 | "vae": { 12 | "weight_dtype": "FLOAT_16" 13 | }, 14 | "weight_dtype": "FLOAT_16" 15 | } 16 | -------------------------------------------------------------------------------- /training_presets/#sd 2.1 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-2-1", 4 | "latent_caching": false, 5 | "learning_rate": 0.0003, 6 | "learning_rate_warmup_steps": 20, 7 | "model_type": "STABLE_DIFFUSION_21", 8 | "output_model_destination": "models/embedding.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "512", 11 | "sample_after": 1, 12 | "training_method": "EMBEDDING", 13 | "vae": { 14 | "weight_dtype": "FLOAT_32" 15 | }, 16 | "weight_dtype": "FLOAT_16" 17 | } 18 | -------------------------------------------------------------------------------- /training_presets/#sd 2.1.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-diffusion-2-1", 3 | "batch_size": 4, 4 | "model_type": "STABLE_DIFFUSION_21", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "resolution": "512", 8 | "training_method": "FINE_TUNE", 9 | "vae": { 10 | "weight_dtype": "FLOAT_32" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /training_presets/#sd 3 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-3-medium-diffusers", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_3", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "text_encoder_2": { 14 | "train": false 15 | }, 16 | "text_encoder_3": { 17 | "train": false 18 | }, 19 | "training_method": "LORA", 20 | "vae": { 21 | "weight_dtype": "FLOAT_32" 22 | }, 23 | "weight_dtype": "FLOAT_16", 24 | "timestep_distribution": "LOGIT_NORMAL" 25 | } 26 | -------------------------------------------------------------------------------- /training_presets/#sd 3.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-diffusion-3-medium-diffusers", 3 | "model_type": "STABLE_DIFFUSION_3", 4 | "output_dtype": "FLOAT_16", 5 | "output_model_destination": "models/model.safetensors", 6 | "output_model_format": "SAFETENSORS", 7 | "prior": { 8 | "weight_dtype": "FLOAT_32" 9 | }, 10 | "resolution": "1024", 11 | "text_encoder": { 12 | "train": false 13 | }, 14 | "text_encoder_2": { 15 | "train": false 16 | }, 17 | "text_encoder_3": { 18 | "train": false 19 | }, 20 | "training_method": "FINE_TUNE", 21 | "vae": { 22 | "weight_dtype": "FLOAT_32" 23 | }, 24 | "weight_dtype": "FLOAT_16", 25 | "timestep_distribution": "LOGIT_NORMAL" 26 | } 27 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_XL_10_BASE", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "text_encoder_2": { 14 | "train": false 15 | }, 16 | "training_method": "LORA", 17 | "vae": { 18 | "weight_dtype": "FLOAT_32" 19 | }, 20 | "weight_dtype": "FLOAT_16" 21 | } 22 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0", 4 | "latent_caching": false, 5 | "learning_rate": 0.0003, 6 | "learning_rate_warmup_steps": 20, 7 | "model_type": "STABLE_DIFFUSION_XL_10_BASE", 8 | "output_model_destination": "models/embedding.safetensors", 9 | "output_model_format": "SAFETENSORS", 10 | "resolution": "1024", 11 | "sample_after": 1, 12 | "training_method": "EMBEDDING", 13 | "vae": { 14 | "weight_dtype": "FLOAT_32" 15 | }, 16 | "weight_dtype": "FLOAT_16" 17 | } 18 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0 inpaint LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 4 | "batch_size": 4, 5 | "learning_rate": 0.0003, 6 | "model_type": "STABLE_DIFFUSION_XL_10_BASE_INPAINTING", 7 | "output_model_destination": "models/lora.safetensors", 8 | "output_model_format": "SAFETENSORS", 9 | "resolution": "1024", 10 | "text_encoder": { 11 | "train": false 12 | }, 13 | "text_encoder_2": { 14 | "train": false 15 | }, 16 | "training_method": "LORA", 17 | "vae": { 18 | "weight_dtype": "FLOAT_32" 19 | }, 20 | "weight_dtype": "FLOAT_16" 21 | } 22 | -------------------------------------------------------------------------------- /training_presets/#sdxl 1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0", 3 | "model_type": "STABLE_DIFFUSION_XL_10_BASE", 4 | "output_model_destination": "models/model.safetensors", 5 | "output_model_format": "SAFETENSORS", 6 | "resolution": "1024", 7 | "text_encoder": { 8 | "train": false 9 | }, 10 | "text_encoder_2": { 11 | "train": false 12 | }, 13 | "training_method": "FINE_TUNE", 14 | "vae": { 15 | "weight_dtype": "FLOAT_32" 16 | }, 17 | "weight_dtype": "BFLOAT_16" 18 | } 19 | -------------------------------------------------------------------------------- /training_presets/#stable cascade.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "stabilityai/stable-cascade-prior", 3 | "decoder": { 4 | "weight_dtype": "FLOAT_16", 5 | "model_name": "stabilityai/stable-cascade" 6 | }, 7 | "decoder_vqgan": { 8 | "weight_dtype": "FLOAT_16" 9 | }, 10 | "effnet_encoder": { 11 | "model_name": "TODO: MANUALLY DOWNLOAD EFFNET WEIGHTS", 12 | "weight_dtype": "FLOAT_16" 13 | }, 14 | "model_type": "STABLE_CASCADE_1", 15 | "optimizer": { 16 | "__version": 0, 17 | "optimizer": "ADAFACTOR", 18 | "beta1": null, 19 | "clip_threshold": 1.0, 20 | "decay_rate": -0.8, 21 | "eps": 1e-30, 22 | "eps2": 0.001, 23 | "relative_step": false, 24 | "scale_parameter": false, 25 | "stochastic_rounding": true, 26 | "warmup_init": false, 27 | "weight_decay": 0.0 28 | }, 29 | "optimizer_defaults": { 30 | "ADAFACTOR": { 31 | "__version": 0, 32 | "optimizer": "ADAFACTOR", 33 | "beta1": null, 34 | "clip_threshold": 1.0, 35 | "decay_rate": -0.8, 36 | "eps": 1e-30, 37 | "eps2": 0.001, 38 | "relative_step": false, 39 | "scale_parameter": false, 40 | "stochastic_rounding": true, 41 | "warmup_init": false, 42 | "weight_decay": 0.0 43 | } 44 | }, 45 | "output_dtype": "BFLOAT_16", 46 | "output_model_destination": "models/model", 47 | "output_model_format": "SAFETENSORS", 48 | "resolution": "1024", 49 | "training_method": "FINE_TUNE", 50 | "weight_dtype": "BFLOAT_16", 51 | "loss_weight_fn": "P2", 52 | "loss_weight_strength": 1.0 53 | } 54 | -------------------------------------------------------------------------------- /training_presets/#wuerstchen 2.0 LoRA.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "warp-ai/wuerstchen-prior", 4 | "decoder": { 5 | "model_name": "warp-ai/wuerstchen" 6 | }, 7 | "effnet_encoder": { 8 | "model_name": "warp-ai/EfficientNetEncoder" 9 | }, 10 | "learning_rate": 0.0003, 11 | "model_type": "WUERSTCHEN_2", 12 | "output_model_destination": "models/lora.safetensors", 13 | "output_model_format": "SAFETENSORS", 14 | "resolution": "1024", 15 | "training_method": "LORA", 16 | "weight_dtype": "FLOAT_16", 17 | "loss_weight_fn": "P2", 18 | "loss_weight_strength": 1.0 19 | } 20 | -------------------------------------------------------------------------------- /training_presets/#wuerstchen 2.0 embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "backup_after": 10, 3 | "base_model_name": "warp-ai/wuerstchen-prior", 4 | "decoder": { 5 | "model_name": "warp-ai/wuerstchen" 6 | }, 7 | "effnet_encoder": { 8 | "model_name": "warp-ai/EfficientNetEncoder" 9 | }, 10 | "latent_caching": false, 11 | "learning_rate": 0.0003, 12 | "learning_rate_warmup_steps": 20, 13 | "model_type": "WUERSTCHEN_2", 14 | "output_model_destination": "models/embedding.safetensors", 15 | "output_model_format": "SAFETENSORS", 16 | "resolution": "1024", 17 | "sample_after": 1, 18 | "training_method": "EMBEDDING", 19 | "weight_dtype": "FLOAT_16", 20 | "loss_weight_fn": "P2", 21 | "loss_weight_strength": 1.0 22 | } 23 | -------------------------------------------------------------------------------- /training_presets/#wuerstchen 2.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model_name": "warp-ai/wuerstchen-prior", 3 | "decoder": { 4 | "weight_dtype": "FLOAT_16", 5 | "model_name": "warp-ai/wuerstchen" 6 | }, 7 | "decoder_text_encoder": { 8 | "weight_dtype": "FLOAT_16" 9 | }, 10 | "decoder_vqgan": { 11 | "weight_dtype": "FLOAT_16" 12 | }, 13 | "effnet_encoder": { 14 | "model_name": "warp-ai/EfficientNetEncoder", 15 | "weight_dtype": "FLOAT_16" 16 | }, 17 | "model_type": "WUERSTCHEN_2", 18 | "optimizer": { 19 | "__version": 0, 20 | "optimizer": "ADAFACTOR", 21 | "beta1": null, 22 | "clip_threshold": 1.0, 23 | "decay_rate": -0.8, 24 | "eps": 1e-30, 25 | "eps2": 0.001, 26 | "relative_step": false, 27 | "scale_parameter": false, 28 | "stochastic_rounding": true, 29 | "warmup_init": false, 30 | "weight_decay": 0.0 31 | }, 32 | "optimizer_defaults": { 33 | "ADAFACTOR": { 34 | "__version": 0, 35 | "optimizer": "ADAFACTOR", 36 | "beta1": null, 37 | "clip_threshold": 1.0, 38 | "decay_rate": -0.8, 39 | "eps": 1e-30, 40 | "eps2": 0.001, 41 | "relative_step": false, 42 | "scale_parameter": false, 43 | "stochastic_rounding": true, 44 | "warmup_init": false, 45 | "weight_decay": 0.0 46 | } 47 | }, 48 | "output_model_destination": "models/model", 49 | "output_model_format": "DIFFUSERS", 50 | "resolution": "1024", 51 | "training_method": "FINE_TUNE", 52 | "loss_weight_fn": "P2", 53 | "loss_weight_strength": 1.0 54 | } 55 | -------------------------------------------------------------------------------- /training_presets/.gitignore: -------------------------------------------------------------------------------- 1 | # ignores everything except the builtin config files starting with # 2 | * 3 | !.gitignore 4 | !\#* 5 | 6 | # #.json is an exception (from the exception). This file stores the ui state from a previous run. 7 | \#.json 8 | -------------------------------------------------------------------------------- /update.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Change our working dir to the root of the project. 6 | cd -- "$(dirname -- "${BASH_SOURCE[0]}")" 7 | 8 | # Pull the latest changes via Git. 9 | echo "[OneTrainer] Updating OneTrainer to latest version from Git repository..." 10 | git pull 11 | 12 | # Load the newest version of the function library. 13 | source "lib.include.sh" 14 | 15 | # Prepare runtime and upgrade all dependencies to latest compatible version. 16 | prepare_runtime_environment upgrade 17 | --------------------------------------------------------------------------------