├── .github └── workflows │ └── python-tests.yaml ├── .gitignore ├── Dockerfile ├── INSTALL.md ├── LICENSE ├── OPTIONS.md ├── README.md ├── TUTORIAL.md ├── config ├── caption_filter_list.txt.example ├── config.json.example ├── lycoris_config.json.example ├── multidatabackend-multiresolution.json.example ├── multidatabackend.json.example └── user_prompt_library.json.example ├── configure.py ├── convert_sd_checkpoint.py ├── convert_sdxl_checkpoint.py ├── docker-start.sh ├── documentation ├── CONTROLNET.md ├── DATALOADER.md ├── DEEPFLOYD.md ├── DEEPSPEED.md ├── DISTRIBUTED.md ├── DOCKER.md ├── DREAMBOOTH.md ├── LYCORIS.md ├── MIXTURE_OF_EXPERTS.md ├── QUICKSTART.md ├── data_presets │ ├── README.md │ ├── preset.md │ ├── preset_dalle3.md │ ├── preset_midjourney.md │ ├── preset_nijijourney.md │ └── preset_pexels.md ├── evaluation │ ├── CLIP_SCORES.md │ └── EVAL_LOSS.md └── quickstart │ ├── AURAFLOW.md │ ├── FLUX.md │ ├── HIDREAM.md │ ├── KOLORS.md │ ├── LTXVIDEO.md │ ├── OMNIGEN.md │ ├── SANA.md │ ├── SD3.md │ ├── SDXL.md │ ├── SIGMA.md │ └── WAN.md ├── filter_list.txt ├── helpers ├── caching │ ├── memory.py │ ├── text_embeds.py │ └── vae.py ├── configuration │ ├── cmd_args.py │ ├── env_file.py │ ├── json_file.py │ ├── loader.py │ └── toml_file.py ├── data_backend │ ├── aws.py │ ├── base.py │ ├── csv_url_list.py │ ├── factory.py │ └── local.py ├── image_manipulation │ ├── brightness.py │ ├── cropping.py │ ├── load.py │ └── training_sample.py ├── legacy │ └── pipeline.py ├── log_format.py ├── metadata │ └── backends │ │ ├── base.py │ │ ├── discovery.py │ │ └── parquet.py ├── models │ ├── all.py │ ├── auraflow │ │ ├── model.py │ │ ├── pipeline.py │ │ └── transformer.py │ ├── common.py │ ├── deepfloyd │ │ └── model.py │ ├── flux │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── model.py │ │ ├── pipeline.py │ │ └── transformer.py │ ├── hidream │ │ ├── model.py │ │ ├── pipeline.py │ │ ├── schedule.py │ │ └── transformer.py │ ├── kolors │ │ ├── model.py │ │ └── pipeline.py │ ├── ltxvideo │ │ ├── __init__.py │ │ └── model.py │ ├── omnigen │ │ ├── collator.py │ │ └── model.py │ ├── pixart │ │ ├── model.py │ │ └── pipeline.py │ ├── sana │ │ ├── __init__.py │ │ ├── model.py │ │ ├── pipeline.py │ │ └── transformer.py │ ├── sd1x │ │ ├── model.py │ │ └── pipeline.py │ ├── sd3 │ │ ├── __init__.py │ │ ├── expanded.py │ │ ├── model.py │ │ ├── pipeline.py │ │ └── transformer.py │ ├── sdxl │ │ ├── model.py │ │ └── pipeline.py │ └── wan │ │ ├── __init__.py │ │ ├── model.py │ │ ├── pipeline.py │ │ └── transformer.py ├── multiaspect │ ├── dataset.py │ ├── image.py │ ├── sampler.py │ ├── state.py │ └── video.py ├── prompt_expander │ └── __init__.py ├── prompts.py ├── publishing │ ├── huggingface.py │ └── metadata.py ├── training │ ├── __init__.py │ ├── adapter.py │ ├── collate.py │ ├── custom_schedule.py │ ├── deepspeed.py │ ├── default_settings │ │ ├── __init__.py │ │ └── safety_check.py │ ├── diffusion_model.py │ ├── ema.py │ ├── error_handling.py │ ├── evaluation.py │ ├── exceptions.py │ ├── gradient_checkpointing_interval.py │ ├── min_snr_gamma.py │ ├── model_freeze.py │ ├── multi_process.py │ ├── optimizer_param.py │ ├── optimizers │ │ ├── adamw_bfloat16 │ │ │ ├── __init__.py │ │ │ └── stochastic │ │ │ │ └── __init__.py │ │ ├── adamw_schedulefree │ │ │ └── __init__.py │ │ └── soap │ │ │ └── __init__.py │ ├── peft_init.py │ ├── quantisation │ │ ├── __init__.py │ │ ├── peft_workarounds.py │ │ ├── quanto_workarounds.py │ │ └── torchao_workarounds.py │ ├── save_hooks.py │ ├── state_tracker.py │ ├── trainer.py │ ├── validation.py │ └── wrappers.py └── webhooks │ ├── config.py │ ├── handler.py │ └── mixin.py ├── inference.py ├── inference_comparison.py ├── install ├── apple │ ├── poetry.lock │ ├── poetry.toml │ └── pyproject.toml ├── github │ ├── poetry.lock │ └── pyproject.toml └── rocm │ ├── poetry.lock │ ├── poetry.toml │ └── pyproject.toml ├── notebook.ipynb ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── service_worker.py ├── simpletuner_sdk ├── api_state.py ├── configuration.py ├── interface.py ├── thread_keeper │ └── __init__.py └── training_host.py ├── tests ├── __init__.py ├── helpers │ └── data.py ├── test_collate.py ├── test_cropping.py ├── test_custom_schedules.py ├── test_dataset.py ├── test_ema.py ├── test_image.py ├── test_metadata_backend.py ├── test_model_card.py ├── test_prompthandler.py ├── test_sampler.py ├── test_state.py ├── test_trainer.py ├── test_training_sample.py ├── test_vae.py └── test_webhooks.py ├── toolkit ├── README.md ├── captioning │ ├── caption_backend_server.php │ ├── caption_with_blip.py │ ├── caption_with_blip3.py │ ├── caption_with_cogvlm.py │ ├── caption_with_cogvlm_remote.py │ ├── caption_with_florence.py │ ├── caption_with_gemini.py │ ├── caption_with_gemma.py │ ├── caption_with_gpt4.py │ ├── caption_with_internvl.py │ ├── caption_with_llava.py │ ├── classes │ │ ├── Authorization.php │ │ ├── BackendController.php │ │ └── S3Uploader.php │ ├── composer.json │ └── composer.lock ├── datasets │ ├── README.md │ ├── analyze_aspect_ratios_json.py │ ├── analyze_laion_data.py │ ├── check_latent_corruption.py │ ├── clear_s3_bucket.py │ ├── controlnet │ │ └── create_canny_edge.py │ ├── crop.py │ ├── csv_to_s3.py │ ├── dataset_from_kellyc.py │ ├── dataset_from_laion.py │ ├── dataset_from_pixilart.py │ ├── discord_scrape.py │ ├── enhance_with_controlnet.py │ ├── folder_to_parquet.py │ ├── masked_loss │ │ ├── generate_dataset_masks.py │ │ ├── generate_dataset_masks_via_huggingface.py │ │ └── requirements.txt │ ├── random_recrop_for_json_image_metadata.py │ ├── retrieve_s3_bucket.py │ └── update_parquet.py └── inference │ ├── inference_ddpm.py │ ├── inference_karras.py │ ├── inference_sigma.py │ ├── inference_snr_test.py │ ├── sigma │ └── __init__.py │ ├── tile_images.py │ ├── tile_samplers.py │ └── tile_shortnames.py ├── train.py └── train.sh /.github/workflows/python-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Python Unit Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Maximize build space 17 | uses: AdityaGarg8/remove-unwanted-software@v4.1 18 | with: 19 | remove-android: 'true' 20 | 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: 3.11 27 | 28 | - name: Install Poetry 29 | run: python -m pip install --upgrade pip poetry 30 | 31 | - name: Install Dependencies 32 | run: poetry -C install/apple install 33 | 34 | - name: Run Tests 35 | run: poetry -C ./ -P install/apple run python -m unittest discover tests/ 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.code-workspace 2 | multidatabackend*json 3 | # Python and virtual environment files 4 | output/ 5 | temp/ 6 | env.sh 7 | *.pem 8 | */config/config.json 9 | user_prompt_library.json 10 | __pycache__/ 11 | .venv/ 12 | cache/ 13 | *.pyc 14 | *.pyo 15 | *.pyd 16 | *.pyc 17 | *.pyo 18 | *.pyd 19 | *.so 20 | *.dylib 21 | *.egg-info/ 22 | dist/ 23 | build/ 24 | *.egg 25 | *.pyc 26 | *.pyo 27 | *.pyd 28 | *.so 29 | *.dylib 30 | *.egg-info/ 31 | dist/ 32 | build/ 33 | *.egg 34 | venv/ 35 | *.log 36 | 37 | # IDE files 38 | .vscode/ 39 | .idea/ 40 | *.swp 41 | *.swo 42 | *.swn 43 | 44 | # OS generated files 45 | .DS_Store 46 | Thumbs.db 47 | 48 | files.tbz 49 | */config/auth.json 50 | work 51 | multidatabackend.json 52 | multidatabackend_sd2x.json 53 | 54 | config/*.json 55 | config/*.env 56 | wandb/ 57 | cache/ 58 | vae_cache/ 59 | vendor/ 60 | inference/ 61 | webhooks.json 62 | untracked/ 63 | config/*/ 64 | config/*.toml 65 | static 66 | templates 67 | api_state.json 68 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # SimpleTuner needs CU141 2 | FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 3 | 4 | # /workspace is the default volume for Runpod & other hosts 5 | WORKDIR /workspace 6 | 7 | # Update apt-get 8 | RUN apt-get update -y 9 | 10 | # Prevents different commands from being stuck by waiting 11 | # on user input during build 12 | ENV DEBIAN_FRONTEND noninteractive 13 | 14 | # Install libg dependencies 15 | RUN apt install libgl1-mesa-glx -y 16 | RUN apt-get install 'ffmpeg'\ 17 | 'libsm6'\ 18 | 'libxext6' -y 19 | 20 | # Install misc unix libraries 21 | RUN apt-get install -y --no-install-recommends openssh-server \ 22 | openssh-client \ 23 | git \ 24 | git-lfs \ 25 | wget \ 26 | curl \ 27 | tmux \ 28 | tldr \ 29 | nvtop \ 30 | vim \ 31 | rsync \ 32 | net-tools \ 33 | less \ 34 | iputils-ping \ 35 | 7zip \ 36 | zip \ 37 | unzip \ 38 | htop \ 39 | inotify-tools 40 | 41 | # Set up git to support LFS, and to store credentials; useful for Huggingface Hub 42 | RUN git config --global credential.helper store && \ 43 | git lfs install 44 | 45 | # Install Python VENV 46 | RUN apt-get install -y python3.10-venv 47 | 48 | # Ensure SSH access. Not needed for Runpod but is required on Vast and other Docker hosts 49 | EXPOSE 22/tcp 50 | 51 | # Python 52 | RUN apt-get update -y && apt-get install -y python3 python3-pip 53 | RUN python3 -m pip install pip --upgrade 54 | 55 | # HF 56 | ENV HF_HOME=/workspace/huggingface 57 | 58 | RUN pip3 install "huggingface_hub[cli]" 59 | 60 | # WanDB 61 | RUN pip3 install wandb 62 | 63 | # Clone SimpleTuner 64 | RUN git clone https://github.com/bghira/SimpleTuner --branch release 65 | # RUN git clone https://github.com/bghira/SimpleTuner --branch main # Uncomment to use latest (possibly unstable) version 66 | 67 | # Install SimpleTuner 68 | RUN pip3 install poetry 69 | RUN cd SimpleTuner && python3 -m venv .venv && poetry install --no-root 70 | RUN chmod +x SimpleTuner/train.sh 71 | 72 | # Copy start script with exec permissions 73 | COPY --chmod=755 docker-start.sh /start.sh 74 | 75 | # Dummy entrypoint 76 | ENTRYPOINT [ "/start.sh" ] 77 | -------------------------------------------------------------------------------- /config/caption_filter_list.txt.example: -------------------------------------------------------------------------------- 1 | The image does not provide a clear view of the entire scene, making it challenging to accurately caption. However, based on the visible elements, a possible caption could be: 2 | The image does not explicitly depict any LGBT themes. It shows two hands holding a shell against a beach backdrop. A caption like 3 | The image does not depict a South American scene. It shows a 4 | The image does not depict a South American 5 | The photo does not prominently feature any anatomical features. It primarily showcases 6 | The image showcases 7 | The image features 8 | The image captures 9 | The image depicts 10 | The photo showcases 11 | The photo features 12 | The photo captures 13 | The image does not have a caption provided in the image itself. However, based on the content, a suitable caption might be: 14 | The image does not have a clear caption as it is an experimental photo. However, one could describe it as 15 | The image does not have a clear subject or context .* for a precise caption. 16 | Caption: 17 | The image does not have a clear caption .* describe it as 18 | The image does not have a clear caption .* could be: 19 | The image does not have .* appears to be 20 | The image does not have a clear subject or context .* appears to be 21 | The image does not have a clear subject or context .* could be: 22 | The image does not have a direct caption .*: 23 | The image does not require a caption .* ' 24 | smoking vaping 25 | ^there is 26 | araffe 27 | arafed 28 | ^someone is -------------------------------------------------------------------------------- /config/config.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "--resume_from_checkpoint": "latest", 3 | "--data_backend_config": "config/multidatabackend.json", 4 | "--aspect_bucket_rounding": 2, 5 | "--seed": 42, 6 | "--minimum_image_size": 0, 7 | "--output_dir": "output/models", 8 | "--lora_type": "lycoris", 9 | "--lycoris_config": "config/lycoris_config.json", 10 | "--max_train_steps": 10000, 11 | "--num_train_epochs": 0, 12 | "--checkpointing_steps": 500, 13 | "--checkpoints_total_limit": 5, 14 | "--hub_model_id": "simpletuner-lora", 15 | "--push_to_hub": "true", 16 | "--push_checkpoints_to_hub": "true", 17 | "--tracker_project_name": "lora-training", 18 | "--tracker_run_name": "simpletuner-lora", 19 | "--report_to": "wandb", 20 | "--model_type": "lora", 21 | "--pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", 22 | "--model_family": "flux", 23 | "--train_batch_size": 1, 24 | "--gradient_checkpointing": "true", 25 | "--caption_dropout_probability": 0.1, 26 | "--resolution_type": "pixel_area", 27 | "--resolution": 1024, 28 | "--validation_seed": 42, 29 | "--validation_steps": 500, 30 | "--validation_resolution": "1024x1024", 31 | "--validation_guidance": 3.0, 32 | "--validation_guidance_rescale": "0.0", 33 | "--validation_num_inference_steps": "20", 34 | "--validation_prompt": "A photo-realistic image of a cat", 35 | "--mixed_precision": "bf16", 36 | "--optimizer": "adamw_bf16", 37 | "--learning_rate": "1e-4", 38 | "--lr_scheduler": "polynomial", 39 | "--lr_warmup_steps": 100, 40 | "--validation_torch_compile": "false", 41 | "--disable_benchmark": "false" 42 | } -------------------------------------------------------------------------------- /config/lycoris_config.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "algo": "lokr", 3 | "multiplier": 1.0, 4 | "full_matrix": true, 5 | "linear_alpha": 1, 6 | "factor": 16, 7 | "apply_preset": { 8 | "target_module": [ 9 | "Attention", 10 | "FeedForward" 11 | ], 12 | "module_algo_map": { 13 | "Attention": { 14 | "factor": 16 15 | }, 16 | "FeedForward": { 17 | "factor": 8 18 | } 19 | } 20 | } 21 | } -------------------------------------------------------------------------------- /config/multidatabackend-multiresolution.json.example: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "dreambooth-512", 4 | "type": "local", 5 | "instance_data_dir": "/home/user/simpletuner/data/dreambooth", 6 | "crop": false, "crop_style": "random", 7 | "minimum_image_size": 128, 8 | "resolution": 512, 9 | "resolution_type": "pixel_area", "repeats": 10, 10 | "metadata_backend": "discovery", 11 | "caption_strategy": "textfile", 12 | "cache_dir_vae": "cache/vae-512" 13 | }, 14 | { 15 | "id": "dreambooth-1024", 16 | "type": "local", 17 | "instance_data_dir": "/home/user/simpletuner/data/dreambooth", 18 | "crop": false, "crop_style": "random", 19 | "minimum_image_size": 128, 20 | "resolution": 1024, 21 | "resolution_type": "pixel_area", "repeats": 10, 22 | "metadata_backend": "discovery", 23 | "caption_strategy": "textfile", 24 | "cache_dir_vae": "cache/vae-1024" 25 | }, 26 | { 27 | "id": "dreambooth-512-crop", 28 | "type": "local", 29 | "instance_data_dir": "/home/user/simpletuner/data/dreambooth", 30 | "crop": true, "crop_style": "random", 31 | "minimum_image_size": 128, 32 | "resolution": 512, 33 | "resolution_type": "pixel_area", "repeats": 10, 34 | "metadata_backend": "discovery", 35 | "caption_strategy": "textfile", 36 | "cache_dir_vae": "cache/vae-512-crop" 37 | }, 38 | { 39 | "id": "dreambooth-1024-crop", 40 | "type": "local", 41 | "instance_data_dir": "/home/user/simpletuner/data/dreambooth", 42 | "crop": true, "crop_style": "random", 43 | "minimum_image_size": 128, 44 | "resolution": 1024, 45 | "resolution_type": "pixel_area", "repeats": 10, 46 | "metadata_backend": "discovery", 47 | "caption_strategy": "textfile", 48 | "cache_dir_vae": "cache/vae-1024-crop" 49 | }, 50 | { 51 | "id": "text-embed-cache", 52 | "dataset_type": "text_embeds", 53 | "default": true, 54 | "type": "local", 55 | "cache_dir": "cache/text" 56 | } 57 | ] 58 | -------------------------------------------------------------------------------- /config/multidatabackend.json.example: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "something-special-to-remember-by", 4 | "type": "local", 5 | "instance_data_dir": "/path/to/data/tree", 6 | "crop": false, 7 | "crop_style": "random|center|corner", 8 | "crop_aspect": "square|preserve", 9 | "minimum_image_size": 1024, 10 | "maximum_image_size": 1536, 11 | "target_downsample_size": 1024, 12 | "resolution": 1024, 13 | "resolution_type": "pixel_area|area|pixel", 14 | "prepend_instance_prompt": false, 15 | "instance_prompt": "cat girls", 16 | "only_instance_prompt": false, 17 | "caption_strategy": "filename", 18 | "cache_dir_vae": "/path/to/vaecache", 19 | "vae_cache_clear_each_epoch": true, 20 | "probability": 1.0, 21 | "repeats": 5, 22 | "text_embeds": "alt-embed-cache", 23 | "skip_file_discovery": "vae,aspect,text,metadata", 24 | "preserve_data_backend_cache": true 25 | }, 26 | { 27 | "id": "another-special-name-for-another-backend", 28 | "type": "aws", 29 | "aws_bucket_name": "something-yummy", 30 | "aws_region_name": null, 31 | "aws_endpoint_url": "https://foo.bar/", 32 | "aws_access_key_id": "wpz-764e9734523434", 33 | "aws_secret_access_key": "xyz-sdajkhfhakhfjd", 34 | "aws_data_prefix": "", 35 | "cache_dir_vae": "s3prefix/for/vaecache", 36 | "vae_cache_clear_each_epoch": true, 37 | "repeats": 2, 38 | "ignore_epochs": false 39 | }, 40 | { 41 | "id": "an example backend for text embeds.", 42 | "dataset_type": "text_embeds", 43 | "default": true, 44 | "type": "aws", 45 | "aws_bucket_name": "textembeds-something-yummy", 46 | "aws_region_name": null, 47 | "aws_endpoint_url": "https://foo.bar/", 48 | "aws_access_key_id": "wpz-764e9734523434", 49 | "aws_secret_access_key": "xyz-sdajkhfhakhfjd", 50 | "aws_data_prefix": "", 51 | "cache_dir": "" 52 | }, 53 | { 54 | "id": "alt-embed-cache", 55 | "dataset_type": "text_embeds", 56 | "default": false, 57 | "type": "local", 58 | "cache_dir": "/path/to/textembed_cache" 59 | } 60 | ] -------------------------------------------------------------------------------- /config/user_prompt_library.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "shortname_here": "your prompt to validate on", 3 | "another_shortname": "another prompt to validate on" 4 | } -------------------------------------------------------------------------------- /docker-start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Export useful ENV variables, including all Runpod specific vars, to /etc/rp_environment 4 | # This file can then later be sourced in a login shell 5 | echo "Exporting environment variables..." 6 | printenv | 7 | grep -E '^RUNPOD_|^PATH=|^HF_HOME=|^HF_TOKEN=|^HUGGING_FACE_HUB_TOKEN=|^WANDB_API_KEY=|^WANDB_TOKEN=|^_=' | 8 | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >>/etc/rp_environment 9 | 10 | # Add it to Bash login script 11 | echo 'source /etc/rp_environment' >>~/.bashrc 12 | 13 | # Vast.ai uses $SSH_PUBLIC_KEY 14 | if [[ $SSH_PUBLIC_KEY ]]; then 15 | PUBLIC_KEY="${SSH_PUBLIC_KEY}" 16 | fi 17 | 18 | # Runpod uses $PUBLIC_KEY 19 | if [[ $PUBLIC_KEY ]]; then 20 | mkdir -p ~/.ssh 21 | chmod 700 ~/.ssh 22 | echo "${PUBLIC_KEY}" >>~/.ssh/authorized_keys 23 | chmod 700 -R ~/.ssh 24 | fi 25 | 26 | # Start SSH server 27 | service ssh start 28 | 29 | # Login to HF 30 | if [[ -n "${HF_TOKEN:-$HUGGING_FACE_HUB_TOKEN}" ]]; then 31 | huggingface-cli login --token "${HF_TOKEN:-$HUGGING_FACE_HUB_TOKEN}" --add-to-git-credential 32 | else 33 | echo "HF_TOKEN or HUGGING_FACE_HUB_TOKEN not set; skipping login" 34 | fi 35 | 36 | # Login to WanDB 37 | if [[ -n "${WANDB_API_KEY:-$WANDB_TOKEN}" ]]; then 38 | wandb login "${WANDB_API_KEY:-$WANDB_TOKEN}" 39 | else 40 | echo "WANDB_API_KEY or WANDB_TOKEN not set; skipping login" 41 | fi 42 | 43 | # 🫡 44 | sleep infinity 45 | -------------------------------------------------------------------------------- /documentation/LYCORIS.md: -------------------------------------------------------------------------------- 1 | # LyCORIS 2 | 3 | ## Background 4 | 5 | [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) is an extensive suite of parameter-efficient fine-tuning (PEFT) methods that allow you to finetune models while using less VRAM and produces smaller distributable weights. 6 | 7 | ## Using LyCORIS 8 | 9 | To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file. 10 | 11 | The following will go into your `config.json`: 12 | ```json 13 | { 14 | "model_type": "lora", 15 | "lora_type": "lycoris", 16 | "lycoris_config": "config/lycoris_config.json", 17 | "validation_lycoris_strength": 1.0, 18 | ...the rest of your settings... 19 | } 20 | ``` 21 | 22 | 23 | The LyCORIS configuration file is in the format: 24 | 25 | ```json 26 | { 27 | "algo": "lokr", 28 | "multiplier": 1.0, 29 | "linear_dim": 10000, 30 | "linear_alpha": 1, 31 | "factor": 10, 32 | "apply_preset": { 33 | "target_module": [ 34 | "Attention", 35 | "FeedForward" 36 | ], 37 | "module_algo_map": { 38 | "Attention": { 39 | "factor": 10 40 | }, 41 | "FeedForward": { 42 | "factor": 4 43 | } 44 | } 45 | } 46 | } 47 | ``` 48 | 49 | ### Fields 50 | 51 | Optional fields: 52 | - apply_preset for LycorisNetwork.apply_preset 53 | - any keyword arguments specific to the selected algorithm, at the end. 54 | 55 | Mandatory fields: 56 | - multiplier, which should be set to 1.0 only unless you know what to expect 57 | - linear_dim 58 | - linear_alpha 59 | 60 | For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs). 61 | 62 | ## Potential problems 63 | 64 | When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory. 65 | 66 | This seems to be potentially exacerbated when using SageAttention (with `--sageattention_usage=training`), making it all but guaranteed that the model will immediately fail. 67 | 68 | The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks. 69 | 70 | ## LyCORIS Inference Example 71 | 72 | Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference. 73 | 74 | ```py 75 | import torch 76 | 77 | from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL 78 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel 79 | from diffusers.pipelines.flux.pipeline_flux import FluxPipeline 80 | from transformers import AutoModelForCausalLM, CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast 81 | 82 | from lycoris import create_lycoris_from_weights 83 | 84 | device = "cuda" if torch.cuda.is_available() else "cpu" 85 | dtype = torch.bfloat16 86 | bfl_repo = "black-forest-labs/FLUX.1-dev" 87 | 88 | scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler") 89 | text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) 90 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) 91 | text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype) 92 | tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype) 93 | vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype) 94 | transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer") 95 | 96 | lycoris_safetensors_path = 'pytorch_lora_weights.safetensors' 97 | lycoris_strength = 1.0 98 | wrapper, _ = create_lycoris_from_weights(lycoris_strength, lycoris_safetensors_path, transformer) 99 | wrapper.merge_to() # using apply_to() will be slower. 100 | 101 | transformer.to(device, dtype=dtype) 102 | 103 | pipe = FluxPipeline( 104 | scheduler=scheduler, 105 | text_encoder=text_encoder, 106 | tokenizer=tokenizer, 107 | text_encoder_2=text_encoder_2, 108 | tokenizer_2=tokenizer_2, 109 | vae=vae, 110 | transformer=transformer, 111 | ) 112 | 113 | pipe.enable_sequential_cpu_offload() 114 | 115 | with torch.inference_mode(): 116 | image = pipe( 117 | prompt="a pokemon that looks like a pizza is eating a popsicle", 118 | width=1280, 119 | height=768, 120 | num_inference_steps=15, 121 | generator=generator, 122 | guidance_scale=3.5, 123 | ).images[0] 124 | image.save('image.png') 125 | 126 | # optionally, save a merged pipeline containing the LyCORIS baked-in: 127 | pipe.save_pretrained('/path/to/output/pipeline') 128 | ``` 129 | -------------------------------------------------------------------------------- /documentation/QUICKSTART.md: -------------------------------------------------------------------------------- 1 | # Quickstart Guide 2 | 3 | > ⚠️ These tutorials are a work-in-progress. They contain full end-to-end instructions for a basic training session. 4 | 5 | **Note**: For more advanced configurations, see the [tutorial](/TUTORIAL.md), [dataloader configuration guide](/documentation/DATALOADER.md), and the [options breakdown](/OPTIONS.md) pages. 6 | 7 | ## PixArt Sigma (1K, 2K & 4K) 8 | 9 | For a fun and lightweight model, see [this quickstart guide](/documentation/quickstart/SIGMA.md) 10 | 11 | ## NVLabs Sana (1024px, currently) 12 | 13 | Probably the fastest model currently; see [this quickstart guide](/documentation/quickstart/SANA.md) 14 | 15 | ## Kwai Kolors 16 | 17 | An SDXL-like U-net based architecture that uses a language model called ChatGLM for its text parsing can be found [here](/documentation/quickstart/KOLORS.md) 18 | 19 | ## Stable Diffusion 3 20 | 21 | For personalisation of the Stable Diffusion 3 model family, see [this quickstart guide](/documentation/quickstart/SD3.md) 22 | 23 | ## Flux.1 24 | 25 | For training of the enormous monster known as Flux, see [its specific quickstart guide](/documentation/quickstart/FLUX.md) -------------------------------------------------------------------------------- /documentation/data_presets/README.md: -------------------------------------------------------------------------------- 1 | # Dataset configuration presets 2 | 3 | For various large-scale datasets on Hugging Face Hub, configuration details are provided here to give a head start on making things work. 4 | 5 | To add a new preset, use [this template](/documentation/data_presets/preset.md) to submit a new pull-request. 6 | 7 | - [DALLE-3 1M](/documentation/data_presets/preset_dalle3.md) 8 | - [ptx0/photo-concept-bucket](/documentation/data_presets/preset_pexels.md) 9 | - [Midjourney v6 520k](/documentation/data_presets/preset_midjourney.md) 10 | - [Nijijourney v6 520k](/documentation/data_presets/preset_nijijourney.md) -------------------------------------------------------------------------------- /documentation/data_presets/preset.md: -------------------------------------------------------------------------------- 1 | # Dataset Name 2 | 3 | ## Details 4 | 5 | - **Hub link**: ... 6 | - **Description**: ... 7 | - **Caption format(s)**: ... 8 | 9 | ## (optional) Required preprocessing steps 10 | 11 | If your dataset requires any steps to prepare it for use in SimpleTuner, those should be listed here, along with any code snippets that would help. 12 | 13 | ## Dataloader configuration example 14 | 15 | Here, you'll place the `multidatabackend.json` contents that will work for this dataset. 16 | 17 | ```json 18 | ... 19 | ``` 20 | -------------------------------------------------------------------------------- /documentation/data_presets/preset_midjourney.md: -------------------------------------------------------------------------------- 1 | # Midjourney v6 520k 2 | 3 | ## Details 4 | 5 | - **Hub link**: [terminusresearch/midjourney-v6-520k-raw](https://huggingface.co/datasets/terminusresearch/midjourney-v6-520k-raw) 6 | - **Description**: ~520,000 high quality outputs where any Japanese user prompts have been re-captioned with GPT-3.5-Turbo. 7 | - **Caption format(s)**: Parquet 8 | 9 | ## Required storage 10 | 11 | This dataset contains all image data, and as such, it will be difficult to extract without adequate disk space. **Ensure you have at least 1.5TB of disk space available to extract it.** 12 | 13 | T5-XXL text embeds for this model will consume ~520GB even with `--compress_disk_cache` enabled. 14 | The VAE embeds will consume just under 80 to 100GB of space, depending on the model being trained and the resolution of the embeds. 15 | 16 | 17 | ## Download 18 | 19 | ```bash 20 | huggingface-cli download --repo-type=dataset terminusresearch/midjourney-v6-520k-raw --local-dir=midjourney-v6-520k-raw 21 | ``` 22 | 23 | This will simultaneously download the chunked tar segments from Hugging Face Hub. 24 | 25 | ## Extract 26 | 27 | ```bash 28 | cd midjourney-v6-520k-raw 29 | cat *.tar | tar x 30 | ``` 31 | 32 | This will create a folder containing all of the samples inside the current directory. 33 | 34 | ## Dataloader configuration example 35 | 36 | ```json 37 | { 38 | "id": "midjourney-v6-520k-raw", 39 | "type": "local", 40 | "cache_dir_vae": "cache/vae-mj-520k/", 41 | "crop": true, 42 | "crop_aspect": "square", 43 | "resolution": 1.0, 44 | "maximum_image_size": 1.0, 45 | "minimum_image_size": 0.75, 46 | "target_downsample_size": 1.00, 47 | "resolution_type": "area", 48 | "caption_strategy": "parquet", 49 | "metadata_backend": "parquet", 50 | "parquet": { 51 | "path": "/path/to/midjourney-v6-520k-raw/train.parquet", 52 | "caption_column": "gpt_caption", 53 | "filename_column": "id", 54 | "width_column": "width", 55 | "height_column": "height", 56 | "identifier_includes_extension": false 57 | } 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /documentation/data_presets/preset_nijijourney.md: -------------------------------------------------------------------------------- 1 | # Niji v6 520k 2 | 3 | ## Details 4 | 5 | - **Hub link**: [terminusresearch/nijijourney-v6-520k-raw](https://huggingface.co/datasets/terminusresearch/nijijourney-v6-520k-raw) 6 | - **Description**: ~520,000 high quality outputs where any Japanese user prompts have been re-captioned with GPT-3.5-Turbo. 7 | - **Caption format(s)**: Parquet 8 | 9 | ## Required storage 10 | 11 | This dataset contains all image data, and as such, it will be difficult to extract without adequate disk space. **Ensure you have at least 1.5TB of disk space available to extract it.** 12 | 13 | T5-XXL text embeds for this model will consume ~520GB even with `--compress_disk_cache` enabled. 14 | The VAE embeds will consume just under 80 to 100GB of space, depending on the model being trained and the resolution of the embeds. 15 | 16 | ## Download 17 | 18 | ```bash 19 | huggingface-cli download --repo-type=dataset terminusresearch/nijijourney-v6-520k-raw --local-dir=nijijourney-v6-520k-raw 20 | ``` 21 | 22 | This will simultaneously download the chunked tar segments from Hugging Face Hub. 23 | 24 | ## Extract 25 | 26 | ```bash 27 | cd nijijourney-v6-520k-raw 28 | cat *.tar | tar x 29 | ``` 30 | 31 | This will create a folder containing all of the samples inside the current directory. 32 | 33 | ## Dataloader configuration example 34 | 35 | ```json 36 | { 37 | "id": "nijijourney-v6-520k-raw", 38 | "type": "local", 39 | "cache_dir_vae": "cache/vae-nj-520k/", 40 | "crop": true, 41 | "crop_aspect": "square", 42 | "resolution": 1.0, 43 | "maximum_image_size": 1.0, 44 | "minimum_image_size": 0.75, 45 | "target_downsample_size": 1.00, 46 | "resolution_type": "area", 47 | "caption_strategy": "parquet", 48 | "metadata_backend": "parquet", 49 | "parquet": { 50 | "path": "/path/to/nijijourney-v6-520k-raw/train.parquet", 51 | "caption_column": "gpt_caption", 52 | "filename_column": "id", 53 | "width_column": "width", 54 | "height_column": "height", 55 | "identifier_includes_extension": false 56 | } 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /documentation/data_presets/preset_pexels.md: -------------------------------------------------------------------------------- 1 | # Photo concept bucket 2 | 3 | ## Details 4 | 5 | - **Hub link**: [ptx0/photo-concept-bucket](https://huggingface.co/datasets/ptx0/photo-concept-bucket) 6 | - **Description**: ~567,000 high quality photographs across dense concept distribution, captioned with CogVLM. 7 | - **Caption format(s)**: Parquet 8 | 9 | ## Required preprocessing steps 10 | 11 | As the photo-concept-bucket repository does not include image data, this must be retrieved by you directly from the Pexels server. 12 | 13 | An example script for downloading the dataset is provided, but you must ensure you are following the terms and conditions of the Pexels service, at the time of consumption. 14 | 15 | To download the captions and URL list: 16 | 17 | ```bash 18 | huggingface-cli download --repo-type=dataset ptx0/photo-concept-bucket --local-dir=/home/user/training/photo-concept-bucket 19 | ``` 20 | 21 | Place this file into `/home/user/training/photo-concept-bucket`: 22 | 23 | `download.py` 24 | ```py 25 | from concurrent.futures import ThreadPoolExecutor 26 | import pyarrow.parquet as pq 27 | import os 28 | import requests 29 | from PIL import Image 30 | from io import BytesIO 31 | 32 | # Load the Parquet file 33 | parquet_file = 'photo-concept-bucket.parquet' 34 | df = pq.read_table(parquet_file).to_pandas() 35 | 36 | # Define the output directory 37 | output_dir = 'train' 38 | os.makedirs(output_dir, exist_ok=True) 39 | 40 | def resize_for_condition_image(input_image: Image, resolution: int): 41 | input_image = input_image.convert("RGB") 42 | W, H = input_image.size 43 | k = float(resolution) / min(H, W) 44 | H *= k 45 | W *= k 46 | H = int(round(H / 64.0)) * 64 47 | W = int(round(W / 64.0)) * 64 48 | img = input_image.resize((W, H), resample=Image.LANCZOS) 49 | return img 50 | 51 | def download_and_save(row): 52 | img_url = row['url'] 53 | caption = row['cogvlm_caption'] 54 | img_id = row['id'] 55 | 56 | try: 57 | # Download the image 58 | img_response = requests.get(img_url) 59 | if img_response.status_code == 200: 60 | img = Image.open(BytesIO(img_response.content)) 61 | img_path = os.path.join(output_dir, f"{img_id}.png") 62 | img.save(img_path) 63 | 64 | # Write the caption to a text file 65 | caption_path = os.path.join(output_dir, f"{img_id}.txt") 66 | with open(caption_path, 'w') as caption_file: 67 | caption_file.write(caption) 68 | except Exception as e: 69 | print(f"Failed to download or save data for id {img_id}: {e}") 70 | 71 | # Run the download in parallel 72 | with ThreadPoolExecutor() as executor: 73 | executor.map(download_and_save, [row for _, row in df.iterrows()]) 74 | ``` 75 | 76 | This script will simultaneously download the images from Pexels and write their captions into the `train/` directory as a txt file. 77 | 78 | > ⚠️ This dataset is extremely large, and will consume more than 7TB of local disk space to retrieve as-is. It's recommended that you add a resize step to this retrieval, if you don't wish to store the whole 20 megapixel dataset. 79 | 80 | ## Dataloader configuration example 81 | 82 | ```json 83 | { 84 | "id": "photo-concept-bucket", 85 | "type": "local", 86 | "instance_data_dir": "/home/user/training/photo-concept-bucket/train", 87 | "crop": true, 88 | "crop_aspect": "square", 89 | "crop_style": "center", 90 | "resolution": 1.0, 91 | "minimum_image_size": 1.0, 92 | "maximum_image_size": 1.5, 93 | "target_downsample_size": 1.25, 94 | "resolution_type": "area", 95 | "cache_dir_vae": "/home/user/training/photo-concept-bucket/cache/vae", 96 | "caption_strategy": "parquet", 97 | "metadata_backend": "parquet", 98 | "parquet": { 99 | "path": "/home/user/training/photo-concept-bucket/photo-concept-bucket.parquet", 100 | "caption_column": "cogvlm_caption", 101 | "fallback_caption_column": "tags", 102 | "filename_column": "id", 103 | "width_column": "width", 104 | "height_column": "height" 105 | } 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /documentation/evaluation/CLIP_SCORES.md: -------------------------------------------------------------------------------- 1 | # CLIP score tracking 2 | 3 | CLIP scores are loosely related to measurement of a model's ability to follow prompts; it is not at all related to image quality/fidelity. 4 | 5 | The `clip/mean` score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts). 6 | 7 | CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. However, a model with a very low average CLIP score around `0.18` to `0.22` will probably be pretty poorly-performing. 8 | 9 | Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt and have high image quality; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. This is why such a large number of prompts are typically used to measure model performance - _and even then_.. 10 | 11 | On its own, CLIP scores do not take long to calculate; however, the number of prompts required for meaningful evaluation can make it take an incredibly long time. 12 | 13 | Since it doesn't take much to run, it doesn't hurt to include CLIP evaluation in small training runs. Perhaps you will discover a pattern of the outputs where it makes sense to abandon a training run or adjust other hyperparameters such as the learning rate. 14 | 15 | To include a standard prompt library for evaluation, `--validation_prompt_library` can be provided and then we will generate a somewhat relative benchmark between training runs. 16 | 17 | In `config.json`: 18 | 19 | ```json 20 | { 21 | ... 22 | "evaluation_type": "clip", 23 | "pretrained_evaluation_model_name_or_path": "openai/clip-vit-large-patch14-336", 24 | "report_to": "tensorboard", # or wandb 25 | ... 26 | } 27 | ``` 28 | 29 | ## Compatibility 30 | 31 | SageAttention is currently not compatible with CLIP score tracking. One or the other must be disabled. -------------------------------------------------------------------------------- /documentation/evaluation/EVAL_LOSS.md: -------------------------------------------------------------------------------- 1 | An experimental feature in SimpleTuner implements the ideas behind ["Demystifying SD fine-tuning"](https://github.com/spacepxl/demystifying-sd-finetuning) to provide a stable loss value for evaluation. 2 | 3 | Due to its experimental nature, it may cause problems or lack functionality / integration that a fully finalised feature might have. 4 | 5 | It is fine to use this feature in production, but beware of the potential for bugs or changes in future versions. 6 | 7 | Example dataloader: 8 | 9 | ```json 10 | [ 11 | { 12 | "id": "something-special-to-remember-by", 13 | "crop": false, 14 | "type": "local", 15 | "instance_data_dir": "/datasets/pseudo-camera-10k/train", 16 | "minimum_image_size": 512, 17 | "maximum_image_size": 1536, 18 | "target_downsample_size": 512, 19 | "resolution": 512, 20 | "resolution_type": "pixel_area", 21 | "caption_strategy": "filename", 22 | "cache_dir_vae": "cache/vae/sana", 23 | "vae_cache_clear_each_epoch": false, 24 | "skip_file_discovery": "" 25 | }, 26 | { 27 | "id": "sana-eval", 28 | "type": "local", 29 | "dataset_type": "eval", 30 | "instance_data_dir": "/datasets/test_datasets/squares", 31 | "resolution": 1024, 32 | "minimum_image_size": 1024, 33 | "maximum_image_size": 1024, 34 | "target_downsample_size": 1024, 35 | "resolution_type": "pixel_area", 36 | "cache_dir_vae": "cache/vae/sana-eval", 37 | "caption_strategy": "filename" 38 | }, 39 | { 40 | "id": "text-embed-cache", 41 | "dataset_type": "text_embeds", 42 | "default": true, 43 | "type": "local", 44 | "cache_dir": "cache/text/sana" 45 | } 46 | ] 47 | ``` 48 | 49 | - Eval image datasets can be configured exactly like a normal image dataset. 50 | - The evaluation dataset is **not** used for training. 51 | - It's recommended to use images that represent concepts outside of your training set. 52 | 53 | To configure and enable evaluation loss calculations: 54 | 55 | ```json 56 | { 57 | "--eval_steps_interval": 10, 58 | "--num_eval_images": 1, 59 | "--report_to": "wandb", 60 | } 61 | ``` 62 | 63 | > **Note**: Weights & Biases (wandb) is currently required for the full evaluation charting functionality. Other trackers only receive the single mean value. -------------------------------------------------------------------------------- /filter_list.txt: -------------------------------------------------------------------------------- 1 | The image does not provide a clear view of the entire scene, making it challenging to accurately caption. However, based on the visible elements, a possible caption could be: 2 | The image does not .* depict .* caption like 3 | The image does not depict .* South America.* shows a 4 | The image does not depict a South American 5 | The photo does not prominently feature any anatomical features. It primarily showcases 6 | The image showcases 7 | The image features 8 | The image captures 9 | The image depicts 10 | The photo showcases 11 | The photo features 12 | The photo captures 13 | The image does not have a caption provided in the image itself. However, based on the content, a suitable caption might be: 14 | The image does not have a clear caption as it is an experimental photo. However, one could describe it as 15 | The image does not have a clear subject or context .* for a precise caption. 16 | Caption: 17 | The image does not have a clear caption .* describe it as 18 | The image does not have a clear caption .* could be: 19 | The image does not have .* appears to be 20 | The image does not have a clear subject or context .* appears to be 21 | The image does not have a clear subject or context .* could be: 22 | The image does not have a direct caption .*: 23 | The image does not require a caption .* ' 24 | smoking vaping 25 | ^there is 26 | araffe 27 | arafed 28 | ^someone is 29 | ^The image is 30 | ^setting. 31 | The caption for .* could be: 32 | The image does not .* showcases 33 | The image does not .* shows -------------------------------------------------------------------------------- /helpers/caching/memory.py: -------------------------------------------------------------------------------- 1 | def reclaim_memory(): 2 | import gc 3 | import torch 4 | 5 | if torch.cuda.is_available(): 6 | gc.collect() 7 | torch.cuda.empty_cache() 8 | torch.cuda.ipc_collect() 9 | 10 | if torch.backends.mps.is_available(): 11 | torch.mps.empty_cache() 12 | torch.mps.synchronize() 13 | gc.collect() 14 | -------------------------------------------------------------------------------- /helpers/configuration/json_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | # Set up logging 6 | from helpers.training.multi_process import _get_rank 7 | 8 | logger = logging.getLogger("SimpleTuner") 9 | if _get_rank() > 0: 10 | logger.setLevel(logging.WARNING) 11 | else: 12 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 13 | 14 | 15 | def normalize_args(args_dict): 16 | """ 17 | Normalize arguments, ensuring they have '--' at the start if necessary. 18 | 19 | :param args_dict: A dictionary of arguments that may or may not have '--' prefixes. 20 | :return: A normalized dictionary of arguments. 21 | """ 22 | normalized = [] 23 | for key, value in args_dict.items(): 24 | # Add -- prefix if not present 25 | if (type(value) is bool and value) or value == "true": 26 | if not key.startswith("--"): 27 | normalized_key = f"--{key}" 28 | else: 29 | normalized_key = key 30 | elif type(value) is bool and not value or value == "false": 31 | logger.warning(f"Skipping false argument: {key}") 32 | continue 33 | else: 34 | if not key.startswith("--"): 35 | normalized_key = f"--{key}={value}" 36 | else: 37 | normalized_key = f"{key}={value}" 38 | normalized.append(normalized_key) 39 | return normalized 40 | 41 | 42 | def load_json_config(): 43 | """ 44 | Load configuration from a JSON file that directly specifies command-line arguments. 45 | 46 | :param json_path: The path to the JSON file. 47 | :return: A dictionary containing the configuration. 48 | """ 49 | config_json_path = "config/config.json" 50 | env = os.environ.get( 51 | "SIMPLETUNER_ENVIRONMENT", 52 | os.environ.get("SIMPLETUNER_ENV", os.environ.get("ENV", None)), 53 | ) 54 | if env and env != "default": 55 | config_json_path = f"config/{env}/config.json" 56 | 57 | if not os.path.isfile(config_json_path): 58 | raise ValueError(f"JSON configuration file not found: {config_json_path}") 59 | 60 | with open(config_json_path, "r") as file: 61 | try: 62 | config = json.load(file) 63 | logger.info(f"[CONFIG.JSON] Loaded configuration from {config_json_path}") 64 | return normalize_args(config) 65 | except json.JSONDecodeError as e: 66 | raise ValueError(f"Failed to parse JSON file {config_json_path}: {e}") 67 | -------------------------------------------------------------------------------- /helpers/configuration/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from helpers.configuration import toml_file, json_file, env_file, cmd_args 4 | from helpers.training.state_tracker import StateTracker 5 | import sys 6 | 7 | logger = logging.getLogger("SimpleTuner") 8 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 9 | 10 | helpers = { 11 | "json": json_file.load_json_config, 12 | "toml": toml_file.load_toml_config, 13 | "env": env_file.load_env_config, 14 | "cmd": cmd_args.parse_cmdline_args, 15 | } 16 | 17 | default_config_paths = { 18 | "json": "config.json", 19 | "toml": "config.toml", 20 | "env": "config.env", 21 | } 22 | 23 | 24 | def attach_env_to_path_if_not_present(backend: str, env: str = None): 25 | backend_cfg_path = default_config_paths.get(backend) 26 | if env and env != "default": 27 | return f"config/{env}/{backend_cfg_path}" 28 | return f"config/{backend_cfg_path}" 29 | 30 | 31 | def load_config(args: dict = None, exit_on_error: bool = False): 32 | # Check if help is requested; bypass configuration loading if true 33 | if "-h" in sys.argv or "--help" in sys.argv: 34 | return helpers["cmd"]() 35 | 36 | mapped_config = args 37 | if mapped_config is None or not mapped_config: 38 | config_backend = os.environ.get( 39 | "SIMPLETUNER_CONFIG_BACKEND", 40 | os.environ.get("CONFIG_BACKEND", os.environ.get("CONFIG_TYPE", "env")), 41 | ).lower() 42 | config_env = os.environ.get( 43 | "SIMPLETUNER_ENVIRONMENT", 44 | os.environ.get("SIMPLETUNER_ENV", os.environ.get("ENV", "default")), 45 | ) 46 | config_backend_path = "config" 47 | if config_env and config_env != "default" and config_env is not None: 48 | config_backend_path = os.path.join("config", config_env) 49 | StateTracker.set_config_path(config_backend_path) 50 | logger.info("Using {} configuration backend.".format(config_backend)) 51 | mapped_config = helpers[config_backend]() 52 | if config_backend == "cmd": 53 | return mapped_config 54 | 55 | # Other configs need to be passed through parse_cmdline_args to be made whole and have complete defaults and safety checks applied. 56 | configuration = helpers["cmd"]( 57 | input_args=mapped_config, exit_on_error=exit_on_error 58 | ) 59 | 60 | return configuration 61 | -------------------------------------------------------------------------------- /helpers/configuration/toml_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import toml 3 | import logging 4 | 5 | # Set up logging 6 | from helpers.training.multi_process import _get_rank 7 | 8 | logger = logging.getLogger("SimpleTuner") 9 | if _get_rank() > 0: 10 | logger.setLevel(logging.WARNING) 11 | else: 12 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 13 | 14 | 15 | def normalize_args(args_dict): 16 | """ 17 | Normalize arguments, ensuring they have '--' at the start if necessary. 18 | 19 | :param args_dict: A dictionary of arguments that may or may not have '--' prefixes. 20 | :return: A normalized dictionary of arguments. 21 | """ 22 | normalized = [] 23 | for key, value in args_dict.items(): 24 | # Add -- prefix if not present 25 | if type(value) is bool and value or value == "true": 26 | if not key.startswith("--"): 27 | normalized_key = f"--{key}" 28 | else: 29 | normalized_key = key 30 | elif type(value) is bool and not value or value == "false": 31 | logger.warning(f"Skipping false argument: {key}") 32 | continue 33 | else: 34 | print(f"Value: {value}, type: {type(value)}") 35 | if not key.startswith("--"): 36 | normalized_key = f"--{key}={value}" 37 | else: 38 | normalized_key = f"{key}={value}" 39 | normalized.append(normalized_key) 40 | return normalized 41 | 42 | 43 | def load_toml_config(): 44 | """ 45 | Load configuration from a TOML file that directly specifies command-line arguments. 46 | 47 | :param toml_path: The path to the TOML file. 48 | :return: A dictionary containing the configuration. 49 | """ 50 | config_toml_path = "config/config.toml" 51 | env = os.environ.get( 52 | "SIMPLETUNER_ENVIRONMENT", 53 | os.environ.get("SIMPLETUNER_ENV", os.environ.get("ENV", None)), 54 | ) 55 | if env and env != "default": 56 | config_toml_path = f"config/{env}/config.toml" 57 | 58 | if not os.path.isfile(config_toml_path): 59 | raise ValueError(f"Can not find config file: {config_toml_path}") 60 | 61 | with open(config_toml_path, "r") as file: 62 | try: 63 | config = toml.load(file) 64 | logger.info(f"[CONFIG.TOML] Loaded configuration from {config_toml_path}") 65 | toml_config = config 66 | except toml.TomlDecodeError as e: 67 | logger.error(f"Failed to parse TOML file {config_toml_path}: {e}") 68 | toml_config = {} 69 | normalized_config = normalize_args(toml_config) 70 | logger.info( 71 | f"[CONFIG] Loaded and normalized TOML configuration: {normalized_config}" 72 | ) 73 | 74 | return normalized_config 75 | -------------------------------------------------------------------------------- /helpers/data_backend/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from io import BytesIO 3 | import gzip 4 | import torch 5 | 6 | 7 | class BaseDataBackend(ABC): 8 | @abstractmethod 9 | def read(self, identifier): 10 | """ 11 | Read data based on the identifier. 12 | """ 13 | pass 14 | 15 | @abstractmethod 16 | def write(self, identifier, data): 17 | """ 18 | Write data to the specified identifier. 19 | """ 20 | pass 21 | 22 | @abstractmethod 23 | def delete(self, identifier): 24 | """ 25 | Delete data associated with the identifier. 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def exists(self, identifier): 31 | """ 32 | Check if the identifier exists. 33 | """ 34 | pass 35 | 36 | @abstractmethod 37 | def open_file(self, identifier, mode): 38 | """ 39 | Open the identifier (file or object) in the specified mode. 40 | """ 41 | pass 42 | 43 | @abstractmethod 44 | def list_files(self, file_extensions: list, instance_data_dir: str = None) -> tuple: 45 | """ 46 | List all files matching the pattern. 47 | """ 48 | pass 49 | 50 | @abstractmethod 51 | def read_image(self, filepath: str, delete_problematic_images: bool = False): 52 | """ 53 | Read an image from the backend and return a PIL Image. 54 | """ 55 | pass 56 | 57 | @abstractmethod 58 | def read_image_batch(self, filepaths: str, delete_problematic_images: bool = False): 59 | """ 60 | Read a batch of images from the backend and return a list of PIL Images. 61 | """ 62 | pass 63 | 64 | @abstractmethod 65 | def create_directory(self, directory_path): 66 | """ 67 | Creates a directory in the backend. 68 | """ 69 | pass 70 | 71 | @abstractmethod 72 | def torch_load(self, filename): 73 | """ 74 | Reads content from the backend and loads it with torch. 75 | """ 76 | pass 77 | 78 | @abstractmethod 79 | def torch_save(self, data, filename): 80 | """ 81 | Saves the data using torch to the backend. 82 | """ 83 | pass 84 | 85 | @abstractmethod 86 | def write_batch(self, identifiers, files): 87 | """ 88 | Write a batch of files to the specified identifiers. 89 | """ 90 | pass 91 | 92 | def _decompress_torch(self, gzip_data: BytesIO): 93 | """ 94 | We've read the gzip from disk. Just decompress it. 95 | """ 96 | # bytes object might not have seek. workaround: 97 | if not hasattr(gzip_data, "seek"): 98 | gzip_data = BytesIO(gzip_data) 99 | gzip_data.seek(0) 100 | with gzip.GzipFile(fileobj=gzip_data, mode="rb") as file: 101 | decompressed_data = file.read() 102 | return BytesIO(decompressed_data) 103 | 104 | def _compress_torch(self, data): 105 | """ 106 | Compress the torch data before writing it to disk. 107 | """ 108 | output_data_container = BytesIO() 109 | torch.save(data, output_data_container) 110 | output_data_container.seek(0) 111 | 112 | with BytesIO() as compressed_output: 113 | with gzip.GzipFile(fileobj=compressed_output, mode="wb") as file: 114 | file.write(output_data_container.getvalue()) 115 | return compressed_output.getvalue() 116 | -------------------------------------------------------------------------------- /helpers/image_manipulation/brightness.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def calculate_luminance(img: Image.Image): 8 | if isinstance(img, np.ndarray): 9 | np_img = img 10 | elif isinstance(img, Image.Image): 11 | np_img = np.asarray(img.convert("RGB")) 12 | else: 13 | raise ValueError( 14 | f"Unexpected image type for luminance calculation: {type(img)}" 15 | ) 16 | r, g, b = np_img[:, :, 0], np_img[:, :, 1], np_img[:, :, 2] 17 | luminance = 0.299 * r + 0.587 * g + 0.114 * b 18 | avg_luminance = np.mean(luminance) 19 | return avg_luminance 20 | 21 | 22 | def worker_batch_luminance(imgs: list): 23 | return [calculate_luminance(img) for img in imgs] 24 | 25 | 26 | def calculate_batch_luminance(imgs: list): 27 | num_processes = multiprocessing.cpu_count() 28 | with multiprocessing.Pool(num_processes) as pool: 29 | # Splitting images into batches for each process 30 | img_batches = [imgs[i::num_processes] for i in range(num_processes)] 31 | results = pool.map(worker_batch_luminance, img_batches) 32 | 33 | # Flatten the results and calculate average luminance 34 | all_luminance_values = [lum for sublist in results for lum in sublist] 35 | return sum(all_luminance_values) / len(all_luminance_values) 36 | -------------------------------------------------------------------------------- /helpers/log_format.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from colorama import Fore, Back, Style, init 4 | 5 | 6 | class ColorizedFormatter(logging.Formatter): 7 | level_colors = { 8 | logging.DEBUG: Fore.CYAN, 9 | logging.INFO: Fore.GREEN, 10 | logging.WARNING: Fore.YELLOW, 11 | logging.ERROR: Fore.RED, 12 | logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT, 13 | } 14 | 15 | def format(self, record): 16 | level_color = self.level_colors.get(record.levelno, "") 17 | reset_color = Style.RESET_ALL 18 | message = super().format(record) 19 | return f"{level_color}{message}{reset_color}" 20 | 21 | 22 | # Initialize colorama 23 | init(autoreset=True) 24 | 25 | # Create a logger 26 | logger = logging.getLogger() 27 | logger.setLevel(logging.DEBUG) # Set lowest level to capture everything 28 | 29 | # Create handlers 30 | console_handler = logging.StreamHandler() 31 | console_handler.setLevel( 32 | logging.INFO 33 | ) # Change to ERROR if you want to suppress INFO messages too 34 | console_handler.setFormatter( 35 | ColorizedFormatter("%(asctime)s [%(levelname)s] %(message)s") 36 | ) 37 | 38 | # blank out the existing debug.log, if exists 39 | if os.path.exists("debug.log"): 40 | with open("debug.log", "w"): 41 | pass 42 | 43 | # Create a file handler 44 | file_handler = logging.FileHandler("debug.log") 45 | file_handler.setLevel(logging.DEBUG) # Capture debug and above 46 | file_handler.setFormatter( 47 | logging.Formatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s") 48 | ) 49 | 50 | # Remove existing handlers 51 | for handler in logger.handlers[:]: 52 | logger.removeHandler(handler) 53 | 54 | # Add handlers to the logger 55 | logger.addHandler(console_handler) 56 | logger.addHandler(file_handler) 57 | 58 | forward_logger = logging.getLogger("diffusers.models.unet_2d_condition") 59 | forward_logger.setLevel(logging.WARNING) 60 | 61 | pil_logger = logging.getLogger("PIL") 62 | pil_logger.setLevel(logging.INFO) 63 | pil_logger = logging.getLogger("PIL.Image") 64 | pil_logger.setLevel("ERROR") 65 | pil_logger = logging.getLogger("PIL.PngImagePlugin") 66 | pil_logger.setLevel("ERROR") 67 | transformers_logger = logging.getLogger("transformers.configuration_utils") 68 | transformers_logger.setLevel("ERROR") 69 | transformers_logger = logging.getLogger("transformers.processing_utils") 70 | transformers_logger.setLevel("ERROR") 71 | diffusers_logger = logging.getLogger("diffusers.configuration_utils") 72 | diffusers_logger.setLevel("ERROR") 73 | diffusers_utils_logger = logging.getLogger("diffusers.pipelines.pipeline_utils") 74 | diffusers_utils_logger.setLevel("ERROR") 75 | torchdistlogger = logging.getLogger("torch.distributed.nn.jit.instantiator") 76 | torchdistlogger.setLevel("WARNING") 77 | torch_utils_logger = logging.getLogger("diffusers.utils.torch_utils") 78 | torch_utils_logger.setLevel("ERROR") 79 | 80 | import warnings 81 | 82 | # Suppress specific PIL warning 83 | warnings.filterwarnings( 84 | "ignore", 85 | category=UserWarning, 86 | module="PIL", 87 | message="Palette images with Transparency expressed in bytes should be converted to RGBA images", 88 | ) 89 | warnings.filterwarnings( 90 | "ignore", 91 | category=FutureWarning, 92 | module="transformers.deepspeed", 93 | message="transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations", 94 | ) 95 | 96 | # Ignore torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. 97 | warnings.filterwarnings( 98 | "ignore", 99 | category=DeprecationWarning, 100 | module="torch.utils._pytree", 101 | message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.", 102 | ) 103 | 104 | warnings.filterwarnings( 105 | "ignore", 106 | ) 107 | warnings.filterwarnings( 108 | "ignore", 109 | message=".*is deprecated.*", 110 | ) 111 | -------------------------------------------------------------------------------- /helpers/models/all.py: -------------------------------------------------------------------------------- 1 | from helpers.models.sd3.model import SD3 2 | from helpers.models.deepfloyd.model import DeepFloydIF 3 | from helpers.models.sana.model import Sana 4 | from helpers.models.sdxl.model import SDXL 5 | from helpers.models.kolors.model import Kolors 6 | from helpers.models.flux.model import Flux 7 | from helpers.models.wan.model import Wan 8 | from helpers.models.ltxvideo.model import LTXVideo 9 | from helpers.models.sd1x.model import StableDiffusion1, StableDiffusion2 10 | from helpers.models.pixart.model import PixartSigma 11 | from helpers.models.hidream.model import HiDream 12 | from helpers.models.omnigen.model import OmniGen 13 | from helpers.models.auraflow.model import Auraflow 14 | 15 | model_families = { 16 | "sd1x": StableDiffusion1, 17 | "sd2x": StableDiffusion2, 18 | "sd3": SD3, 19 | "deepfloyd": DeepFloydIF, 20 | "sana": Sana, 21 | "sdxl": SDXL, 22 | "kolors": Kolors, 23 | "flux": Flux, 24 | "wan": Wan, 25 | "ltxvideo": LTXVideo, 26 | "pixart_sigma": PixartSigma, 27 | "omnigen": OmniGen, 28 | "hidream": HiDream, 29 | "auraflow": Auraflow, 30 | } 31 | 32 | 33 | def get_all_model_flavours() -> list: 34 | """ 35 | Returns a list of all model flavours available in the model families. 36 | """ 37 | flavours = [] 38 | for model_family, model_implementation in model_families.items(): 39 | flavours.extend(list(model_implementation.get_flavour_choices())) 40 | return flavours 41 | 42 | 43 | def get_model_flavour_choices(key_to_find: str = None): 44 | flavours = "" 45 | for model_family, model_implementation in model_families.items(): 46 | if key_to_find is not None and model_family == key_to_find: 47 | return model_implementation.get_flavour_choices() 48 | flavours += f""" 49 | {model_family}: {model_implementation.get_flavour_choices()} 50 | """ 51 | 52 | return flavours 53 | -------------------------------------------------------------------------------- /helpers/models/flux/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | from helpers.models.flux.pipeline import FluxPipeline 5 | from helpers.training import steps_remaining_in_epoch 6 | from diffusers.pipelines.flux.pipeline_flux import ( 7 | calculate_shift as calculate_shift_flux, 8 | ) 9 | 10 | 11 | def update_flux_schedule_to_fast(args, noise_scheduler_to_copy): 12 | if args.flux_fast_schedule and args.model_family.lower() == "flux": 13 | # 4-step noise schedule [0.7, 0.1, 0.1, 0.1] from SD3-Turbo paper 14 | for i in range(0, 250): 15 | noise_scheduler_to_copy.sigmas[i] = 1.0 16 | for i in range(250, 500): 17 | noise_scheduler_to_copy.sigmas[i] = 0.3 18 | for i in range(500, 750): 19 | noise_scheduler_to_copy.sigmas[i] = 0.2 20 | for i in range(750, 1000): 21 | noise_scheduler_to_copy.sigmas[i] = 0.1 22 | return noise_scheduler_to_copy 23 | 24 | 25 | def pack_latents(latents, batch_size, num_channels_latents, height, width): 26 | latents = latents.view( 27 | batch_size, num_channels_latents, height // 2, 2, width // 2, 2 28 | ) 29 | latents = latents.permute(0, 2, 4, 1, 3, 5) 30 | latents = latents.reshape( 31 | batch_size, (height // 2) * (width // 2), num_channels_latents * 4 32 | ) 33 | 34 | return latents 35 | 36 | 37 | def unpack_latents(latents, height, width, vae_scale_factor): 38 | batch_size, num_patches, channels = latents.shape 39 | 40 | height = height // vae_scale_factor 41 | width = width // vae_scale_factor 42 | 43 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 44 | latents = latents.permute(0, 3, 1, 4, 2, 5) 45 | 46 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 47 | 48 | return latents 49 | 50 | 51 | def prepare_latent_image_ids(batch_size, height, width, device, dtype): 52 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 53 | latent_image_ids[..., 1] = ( 54 | latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 55 | ) 56 | latent_image_ids[..., 2] = ( 57 | latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 58 | ) 59 | 60 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( 61 | latent_image_ids.shape 62 | ) 63 | 64 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) 65 | latent_image_ids = latent_image_ids.reshape( 66 | batch_size, 67 | latent_image_id_height * latent_image_id_width, 68 | latent_image_id_channels, 69 | ) 70 | 71 | return latent_image_ids.to(device=device, dtype=dtype)[0] 72 | -------------------------------------------------------------------------------- /helpers/models/omnigen/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines.omnigen.processor_omnigen import OmniGenCollator 3 | 4 | 5 | class OmniGenTrainingCollator(OmniGenCollator): 6 | """ 7 | A specialized collator that works with pre-cached latents instead of raw pixels. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | pad_token_id: int = 2, 13 | hidden_size: int = 3072, 14 | keep_raw_resolution: bool = True, 15 | ): 16 | super().__init__(pad_token_id, hidden_size) 17 | self.keep_raw_resolution = keep_raw_resolution 18 | 19 | def __call__(self, features): 20 | # Extract text processing part from mllm_inputs 21 | mllm_inputs = [f[0] for f in features] 22 | 23 | # Extract pre-computed latents instead of raw images 24 | output_latents = [f[1] for f in features] # These are already latents 25 | output_latents = torch.stack(output_latents, dim=0) 26 | 27 | # Process text inputs normally 28 | target_img_size = [ 29 | [x.shape[-2] * 8, x.shape[-1] * 8] for x in output_latents 30 | ] # Convert latent size to image size 31 | ( 32 | all_padded_input_ids, 33 | all_position_ids, 34 | all_attention_mask, 35 | all_padding_images, 36 | all_pixel_values, 37 | all_image_sizes, 38 | ) = self.process_mllm_input(mllm_inputs, target_img_size) 39 | 40 | # Handle input image latents if needed 41 | input_latents = None 42 | if len(all_pixel_values) > 0: 43 | # If we have input images that would normally go through VAE, 44 | # they should already be pre-encoded too 45 | input_latents = ( 46 | torch.cat(all_pixel_values, dim=0) 47 | if not isinstance(all_pixel_values[0], list) 48 | else all_pixel_values 49 | ) 50 | 51 | # Return the processed data with latents instead of pixel values 52 | data = { 53 | "input_ids": all_padded_input_ids, 54 | "attention_mask": all_attention_mask, 55 | "position_ids": all_position_ids, 56 | "input_img_latents": input_latents, # Renamed to match transformer forward params 57 | "input_image_sizes": all_image_sizes, 58 | "padding_images": all_padding_images, 59 | "output_latents": output_latents, # These are now latents, not images 60 | } 61 | return data 62 | -------------------------------------------------------------------------------- /helpers/models/sana/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bghira/SimpleTuner/ba0a415594f524fc6de51205e2da075a47ea37ee/helpers/models/sana/__init__.py -------------------------------------------------------------------------------- /helpers/models/sd3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bghira/SimpleTuner/ba0a415594f524fc6de51205e2da075a47ea37ee/helpers/models/sd3/__init__.py -------------------------------------------------------------------------------- /helpers/models/wan/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution 3 | 4 | """ 5 | This module provides utility functions for normalizing latent representations (WAN latents) 6 | and computing the posterior distribution for a variational autoencoder (VAE) using a diagonal 7 | Gaussian distribution. The posterior is computed from latent tensors that encode both the mean 8 | (mu) and log-variance (logvar) parameters of the Gaussian. 9 | 10 | Functions: 11 | normalize_wan_latents(latents, latents_mean, latents_std): 12 | Normalizes the latent tensor using given mean and standard deviation values. 13 | 14 | compute_wan_posterior(latents, latents_mean, latents_std): 15 | Computes the posterior distribution from the latent tensor by splitting it into mean 16 | and log-variance components, normalizing each, and then constructing a DiagonalGaussianDistribution. 17 | """ 18 | 19 | 20 | def normalize_wan_latents( 21 | latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor 22 | ) -> torch.Tensor: 23 | """ 24 | Normalize latent representations (WAN latents) using provided mean and standard deviation. 25 | 26 | This function reshapes the provided mean and standard deviation tensors so they can be broadcasted 27 | across the dimensions of the `latents` tensor. It then normalizes the latents by subtracting 28 | the mean and scaling by the standard deviation. 29 | 30 | Args: 31 | latents (torch.Tensor): The input latent tensor to be normalized. 32 | Expected shape is (batch_size, channels, ...). 33 | latents_mean (torch.Tensor): The mean tensor for normalization. 34 | Expected to have shape (channels,). 35 | latents_std (torch.Tensor): The standard deviation tensor for normalization. 36 | Expected to have shape (channels,). 37 | 38 | Returns: 39 | torch.Tensor: The normalized latent tensor, with the same shape as the input `latents`. 40 | """ 41 | # Reshape latents_mean to (1, channels, 1, 1, 1) to allow broadcasting across batch and spatial dimensions. 42 | latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) 43 | 44 | # Reshape latents_std similarly to ensure it is broadcastable and resides on the same device as latents. 45 | latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) 46 | 47 | # Convert latents to float (if not already) and apply normalization: 48 | # Subtract the mean and then multiply by the standard deviation. 49 | print(f"Shapes: {latents.shape}, {latents_mean.shape}, {latents_std.shape}") 50 | latents = ((latents.float() - latents_mean) * latents_std).to(latents) 51 | 52 | return latents 53 | 54 | 55 | def compute_wan_posterior( 56 | latents: torch.Tensor, latents_mean: list, latents_std: list 57 | ) -> DiagonalGaussianDistribution: 58 | """ 59 | Compute the WAN posterior distribution from latent representations. 60 | 61 | This function splits the input `latents` tensor along the channel dimension into two halves: 62 | one for the mean (mu) and one for the log-variance (logvar) of a Gaussian distribution. 63 | Both components are normalized using the `normalize_wan_latents` function. The normalized 64 | parameters are then concatenated and used to instantiate a DiagonalGaussianDistribution, 65 | representing the approximate posterior q(z|x) in a VAE framework. 66 | 67 | Args: 68 | latents (torch.Tensor): A tensor containing concatenated latent representations. 69 | It is assumed that the first half of the channels corresponds 70 | to the mean (mu) and the second half corresponds to the log-variance (logvar). 71 | latents_mean (torch.Tensor): The mean tensor for normalization. 72 | latents_std (torch.Tensor): The standard deviation tensor for normalization. 73 | 74 | Returns: 75 | DiagonalGaussianDistribution: A diagonal Gaussian distribution representing the 76 | computed posterior distribution. 77 | """ 78 | latents_mean = torch.tensor(latents_mean) 79 | latents_std = 1.0 / torch.tensor(latents_std) 80 | latents = normalize_wan_latents(latents, latents_mean, latents_std) 81 | 82 | # Construct the posterior distribution using the DiagonalGaussianDistribution. 83 | # This distribution represents a diagonal covariance Gaussian, parameterized by [mu, logvar]. 84 | posterior = DiagonalGaussianDistribution(latents) 85 | 86 | return posterior 87 | -------------------------------------------------------------------------------- /helpers/multiaspect/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from helpers.training.state_tracker import StateTracker 3 | from helpers.multiaspect.image import MultiaspectImage 4 | from helpers.image_manipulation.training_sample import TrainingSample 5 | import logging 6 | import os 7 | 8 | logger = logging.getLogger("MultiAspectDataset") 9 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 10 | 11 | 12 | class MultiAspectDataset(Dataset): 13 | """ 14 | A multi-aspect dataset requires special consideration and handling. 15 | This class implements bucketed data loading for precomputed text embeddings. 16 | This class does not do any image transforms, as those are handled by VAECache. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | id: str, 22 | datasets: list, 23 | print_names: bool = False, 24 | is_regularisation_data: bool = False, 25 | is_i2v_data: bool = False, 26 | ): 27 | self.id = id 28 | self.datasets = datasets 29 | self.print_names = print_names 30 | self.is_regularisation_data = is_regularisation_data 31 | self.is_i2v_data = is_i2v_data 32 | 33 | def __len__(self): 34 | # Sum the length of all data backends: 35 | return sum([len(dataset) for dataset in self.datasets]) 36 | 37 | def __getitem__(self, image_tuple): 38 | output_data = { 39 | "training_samples": [], 40 | "conditioning_samples": [], 41 | "is_regularisation_data": self.is_regularisation_data, 42 | "is_i2v_data": self.is_i2v_data, 43 | } 44 | first_aspect_ratio = None 45 | for sample in image_tuple: 46 | if type(sample) is TrainingSample: 47 | image_metadata = sample.image_metadata 48 | else: 49 | image_metadata = sample 50 | if "target_size" in image_metadata: 51 | calculated_aspect_ratio = ( 52 | MultiaspectImage.calculate_image_aspect_ratio( 53 | image_metadata["target_size"] 54 | ) 55 | ) 56 | if first_aspect_ratio is None: 57 | first_aspect_ratio = calculated_aspect_ratio 58 | elif first_aspect_ratio != calculated_aspect_ratio: 59 | raise ValueError( 60 | f"Aspect ratios must be the same for all images in a batch. Expected: {first_aspect_ratio}, got: {calculated_aspect_ratio}" 61 | ) 62 | if "deepfloyd" not in StateTracker.get_args().model_family and ( 63 | image_metadata["original_size"] is None 64 | or image_metadata["target_size"] is None 65 | ): 66 | raise Exception( 67 | f"Metadata was unavailable for image: {image_metadata['image_path']}. Ensure --skip_file_discovery=metadata is not set." 68 | ) 69 | 70 | if self.print_names: 71 | logger.info( 72 | f"Dataset is now using image: {image_metadata['image_path']}" 73 | ) 74 | 75 | if type(sample) is TrainingSample: 76 | output_data["conditioning_samples"].append(sample) 77 | continue 78 | else: 79 | output_data["training_samples"].append(image_metadata) 80 | 81 | if "instance_prompt_text" not in image_metadata: 82 | raise ValueError( 83 | f"Instance prompt text must be provided in image metadata. Image metadata: {image_metadata}" 84 | ) 85 | return output_data 86 | -------------------------------------------------------------------------------- /helpers/multiaspect/state.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | from multiprocessing.managers import DictProxy 5 | 6 | logger = logging.getLogger("BucketStateManager") 7 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 8 | 9 | 10 | class BucketStateManager: 11 | def __init__(self, id: str): 12 | self.id = id 13 | 14 | def mangle_state_path(self, state_path): 15 | # When saving the state, it goes into the checkpoint dir. 16 | # However, we need to save a single state for each data backend. 17 | # Thus, we split the state_path from its extension, add self.id to the end of the name, and rejoin: 18 | if self.id in os.path.basename(state_path): 19 | return state_path 20 | filename, ext = os.path.splitext(state_path) 21 | return f"{filename}-{self.id}{ext}" 22 | 23 | def load_seen_images(self, state_path: str): 24 | if os.path.exists(state_path): 25 | with open(state_path, "r") as f: 26 | return json.load(f) 27 | else: 28 | return {} 29 | 30 | def save_seen_images(self, seen_images, state_path: str): 31 | with open(state_path, "w") as f: 32 | json.dump(seen_images, f) 33 | 34 | def deep_convert_dict(self, d): 35 | if isinstance(d, dict): 36 | return {key: self.deep_convert_dict(value) for key, value in d.items()} 37 | elif isinstance(d, list): 38 | return [self.deep_convert_dict(value) for value in d] 39 | elif isinstance(d, DictProxy): 40 | return self.deep_convert_dict(dict(d)) 41 | else: 42 | return d 43 | 44 | def save_state(self, state: dict, state_path: str): 45 | if state_path is None: 46 | raise ValueError("state_path must be specified") 47 | state_path = self.mangle_state_path(state_path) 48 | logger.debug(f"Saving trainer state to {state_path}") 49 | final_state = self.deep_convert_dict(state) 50 | with open(state_path, "w") as f: 51 | json.dump(final_state, f) 52 | 53 | def load_state(self, state_path: str): 54 | if state_path is None: 55 | raise ValueError("state_path must be specified") 56 | state_path = self.mangle_state_path(state_path) 57 | if os.path.exists(state_path): 58 | with open(state_path, "r") as f: 59 | return json.load(f) 60 | else: 61 | logger.debug(f"load_state found no file: {state_path}") 62 | return {} 63 | -------------------------------------------------------------------------------- /helpers/multiaspect/video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def resize_video_frames( 6 | video_frames: np.ndarray, dsize=None, fx=None, fy=None 7 | ) -> np.ndarray: 8 | """ 9 | Resize each frame in a video (NumPy array with shape (num_frames, height, width, channels)). 10 | You can either provide a fixed destination size (dsize) or scaling factors (fx and fy). 11 | """ 12 | resized_frames = [] 13 | for frame in video_frames: 14 | # Optionally, add a check to make sure frame is valid. 15 | if frame is None or frame.size == 0: 16 | continue 17 | resized_frame = cv2.resize(frame, dsize=dsize, fx=fx, fy=fy) 18 | resized_frames.append(resized_frame) 19 | 20 | if not resized_frames: 21 | raise ValueError( 22 | "No frames were resized. Check your video data and resize parameters." 23 | ) 24 | 25 | return np.stack(resized_frames, axis=0) 26 | -------------------------------------------------------------------------------- /helpers/training/__init__.py: -------------------------------------------------------------------------------- 1 | quantised_precision_levels = [ 2 | "no_change", 3 | "int8-quanto", 4 | "int4-quanto", 5 | "int2-quanto", 6 | "int8-torchao", 7 | ] 8 | import torch 9 | 10 | if torch.cuda.is_available(): 11 | quantised_precision_levels.extend( 12 | [ 13 | "nf4-bnb", 14 | # "fp4-bnb", 15 | # "fp8-bnb", 16 | "fp8-quanto", 17 | "fp8uz-quanto", 18 | ] 19 | ) 20 | primary_device = torch.cuda.get_device_properties(0) 21 | if primary_device.major >= 8: 22 | # Hopper! Or blackwell+. 23 | quantised_precision_levels.append("fp8-torchao") 24 | 25 | try: 26 | import pillow_jxl 27 | except ModuleNotFoundError: 28 | pass 29 | from PIL import Image 30 | 31 | supported_extensions = Image.registered_extensions() 32 | image_file_extensions = set( 33 | ext.lower().lstrip(".") for ext, img_format in supported_extensions.items() 34 | if img_format in Image.OPEN 35 | ) 36 | 37 | video_file_extensions = set(["mp4", "avi", "gif", "mov", "webm"]) 38 | 39 | lycoris_defaults = { 40 | "lora": { 41 | "algo": "lora", 42 | "multiplier": 1.0, 43 | "linear_dim": 64, 44 | "linear_alpha": 32, 45 | "apply_preset": { 46 | "target_module": ["Attention", "FeedForward"], 47 | "module_algo_map": { 48 | "Attention": {"factor": 16}, 49 | "FeedForward": {"factor": 8}, 50 | }, 51 | }, 52 | }, 53 | "loha": { 54 | "algo": "loha", 55 | "multiplier": 1.0, 56 | "linear_dim": 32, 57 | "linear_alpha": 16, 58 | "apply_preset": { 59 | "target_module": ["Attention", "FeedForward"], 60 | "module_algo_map": { 61 | "Attention": {"factor": 16}, 62 | "FeedForward": {"factor": 8}, 63 | }, 64 | }, 65 | }, 66 | "lokr": { 67 | "algo": "lokr", 68 | "multiplier": 1.0, 69 | "linear_dim": 10000, # Full dimension 70 | "linear_alpha": 1, # Ignored in full dimension 71 | "factor": 16, 72 | "apply_preset": { 73 | "target_module": ["Attention", "FeedForward"], 74 | "module_algo_map": { 75 | "Attention": {"factor": 16}, 76 | "FeedForward": {"factor": 8}, 77 | }, 78 | }, 79 | }, 80 | "full": { 81 | "algo": "full", 82 | "multiplier": 1.0, 83 | "linear_dim": 1024, # Example full matrix size 84 | "linear_alpha": 512, 85 | "apply_preset": { 86 | "target_module": ["Attention", "FeedForward"], 87 | }, 88 | }, 89 | "ia3": { 90 | "algo": "ia3", 91 | "multiplier": 1.0, 92 | "linear_dim": None, # No network arguments 93 | "linear_alpha": None, 94 | "apply_preset": { 95 | "target_module": ["Attention", "FeedForward"], 96 | }, 97 | }, 98 | "dylora": { 99 | "algo": "dylora", 100 | "multiplier": 1.0, 101 | "linear_dim": 128, 102 | "linear_alpha": 64, 103 | "block_size": 1, # Update one row/col per step 104 | "apply_preset": { 105 | "target_module": ["Attention", "FeedForward"], 106 | "module_algo_map": { 107 | "Attention": {"factor": 16}, 108 | "FeedForward": {"factor": 8}, 109 | }, 110 | }, 111 | }, 112 | "diag-oft": { 113 | "algo": "diag-oft", 114 | "multiplier": 1.0, 115 | "linear_dim": 64, # Block size 116 | "constraint": False, 117 | "rescaled": False, 118 | "apply_preset": { 119 | "target_module": ["Attention", "FeedForward"], 120 | "module_algo_map": { 121 | "Attention": {"factor": 16}, 122 | "FeedForward": {"factor": 8}, 123 | }, 124 | }, 125 | }, 126 | "boft": { 127 | "algo": "boft", 128 | "multiplier": 1.0, 129 | "linear_dim": 64, # Block size 130 | "constraint": False, 131 | "rescaled": False, 132 | "apply_preset": { 133 | "target_module": ["Attention", "FeedForward"], 134 | "module_algo_map": { 135 | "Attention": {"factor": 16}, 136 | "FeedForward": {"factor": 8}, 137 | }, 138 | }, 139 | }, 140 | } 141 | 142 | 143 | def steps_remaining_in_epoch(current_step: int, steps_per_epoch: int) -> int: 144 | """ 145 | Calculate the number of steps remaining in the current epoch. 146 | 147 | Args: 148 | current_step (int): The current step within the epoch. 149 | steps_per_epoch (int): Total number of steps in the epoch. 150 | 151 | Returns: 152 | int: Number of steps remaining in the current epoch. 153 | """ 154 | remaining_steps = steps_per_epoch - (current_step % steps_per_epoch) 155 | return remaining_steps 156 | -------------------------------------------------------------------------------- /helpers/training/adapter.py: -------------------------------------------------------------------------------- 1 | import peft 2 | import torch 3 | import safetensors.torch 4 | 5 | 6 | def determine_adapter_target_modules(args, unet, transformer): 7 | if unet is not None: 8 | return ["to_k", "to_q", "to_v", "to_out.0"] 9 | elif transformer is not None: 10 | target_modules = ["to_k", "to_q", "to_v", "to_out.0"] 11 | 12 | if args.model_family.lower() == "flux" and args.flux_lora_target == "all": 13 | # target_modules = mmdit layers here 14 | target_modules = [ 15 | "to_k", 16 | "to_q", 17 | "to_v", 18 | "add_k_proj", 19 | "add_q_proj", 20 | "add_v_proj", 21 | "to_out.0", 22 | "to_add_out", 23 | ] 24 | elif args.flux_lora_target == "context": 25 | # i think these are the text input layers. 26 | target_modules = [ 27 | "add_k_proj", 28 | "add_q_proj", 29 | "add_v_proj", 30 | "to_add_out", 31 | ] 32 | elif args.flux_lora_target == "context+ffs": 33 | # i think these are the text input layers. 34 | target_modules = [ 35 | "add_k_proj", 36 | "add_q_proj", 37 | "add_v_proj", 38 | "to_add_out", 39 | "ff_context.net.0.proj", 40 | "ff_context.net.2", 41 | ] 42 | elif args.flux_lora_target == "all+ffs": 43 | target_modules = [ 44 | "to_k", 45 | "to_q", 46 | "to_v", 47 | "add_k_proj", 48 | "add_q_proj", 49 | "add_v_proj", 50 | "to_out.0", 51 | "to_add_out", 52 | "ff.net.0.proj", 53 | "ff.net.2", 54 | "ff_context.net.0.proj", 55 | "ff_context.net.2", 56 | "proj_mlp", 57 | "proj_out", 58 | ] 59 | elif args.flux_lora_target == "ai-toolkit": 60 | # from ostris' ai-toolkit, possibly required to continue finetuning one. 61 | target_modules = [ 62 | "to_q", 63 | "to_k", 64 | "to_v", 65 | "add_q_proj", 66 | "add_k_proj", 67 | "add_v_proj", 68 | "to_out.0", 69 | "to_add_out", 70 | "ff.net.0.proj", 71 | "ff.net.2", 72 | "ff_context.net.0.proj", 73 | "ff_context.net.2", 74 | "norm.linear", 75 | "norm1.linear", 76 | "norm1_context.linear", 77 | "proj_mlp", 78 | "proj_out", 79 | ] 80 | elif args.flux_lora_target == "tiny": 81 | # From TheLastBen 82 | # https://www.reddit.com/r/StableDiffusion/comments/1f523bd/good_flux_loras_can_be_less_than_45mb_128_dim/ 83 | target_modules = [ 84 | "single_transformer_blocks.7.proj_out", 85 | "single_transformer_blocks.20.proj_out", 86 | ] 87 | elif args.flux_lora_target == "nano": 88 | # From TheLastBen 89 | # https://www.reddit.com/r/StableDiffusion/comments/1f523bd/good_flux_loras_can_be_less_than_45mb_128_dim/ 90 | target_modules = [ 91 | "single_transformer_blocks.7.proj_out", 92 | ] 93 | 94 | return target_modules 95 | 96 | 97 | @torch.no_grad() 98 | def load_lora_weights(dictionary, filename, loraKey="default", use_dora=False): 99 | additional_keys = set() 100 | state_dict = safetensors.torch.load_file(filename) 101 | for prefix, model in dictionary.items(): 102 | lora_layers = { 103 | (prefix + "." + x): y 104 | for (x, y) in model.named_modules() 105 | if isinstance(y, peft.tuners.lora.layer.Linear) 106 | } 107 | missing_keys = set( 108 | [x + ".lora_A.weight" for x in lora_layers.keys()] 109 | + [x + ".lora_B.weight" for x in lora_layers.keys()] 110 | + ( 111 | [x + ".lora_magnitude_vector.weight" for x in lora_layers.keys()] 112 | if use_dora 113 | else [] 114 | ) 115 | ) 116 | for k, v in state_dict.items(): 117 | if "lora_A" in k: 118 | kk = k.replace(".lora_A.weight", "") 119 | if kk in lora_layers: 120 | lora_layers[kk].lora_A[loraKey].weight.copy_(v) 121 | missing_keys.remove(k) 122 | else: 123 | additional_keys.add(k) 124 | elif "lora_B" in k: 125 | kk = k.replace(".lora_B.weight", "") 126 | if kk in lora_layers: 127 | lora_layers[kk].lora_B[loraKey].weight.copy_(v) 128 | missing_keys.remove(k) 129 | else: 130 | additional_keys.add(k) 131 | elif ".alpha" in k or ".lora_alpha" in k: 132 | kk = k.replace(".lora_alpha", "").replace(".alpha", "") 133 | if kk in lora_layers: 134 | lora_layers[kk].lora_alpha[loraKey] = v 135 | elif ".lora_magnitude_vector" in k: 136 | kk = k.replace(".lora_magnitude_vector.weight", "") 137 | if kk in lora_layers: 138 | lora_layers[kk].lora_magnitude_vector[loraKey].weight.copy_(v) 139 | missing_keys.remove(k) 140 | else: 141 | additional_keys.add(k) 142 | return (additional_keys, missing_keys) 143 | -------------------------------------------------------------------------------- /helpers/training/default_settings/__init__.py: -------------------------------------------------------------------------------- 1 | CURRENT_VERSION = 2 2 | 3 | LATEST_DEFAULTS = {1: {"hash_filenames": False}, 2: {"hash_filenames": True}} 4 | 5 | 6 | def default(setting: str, current_version: int = None, default_value=None): 7 | if current_version <= 0 or current_version is None: 8 | current_version = CURRENT_VERSION 9 | if current_version in LATEST_DEFAULTS: 10 | return LATEST_DEFAULTS[current_version].get(setting, default_value) 11 | return default_value 12 | 13 | 14 | def latest_config_version(): 15 | return CURRENT_VERSION 16 | -------------------------------------------------------------------------------- /helpers/training/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from accelerate.logging import get_logger 3 | from helpers.models.common import get_model_config_path 4 | 5 | logger = get_logger(__name__, log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 6 | 7 | target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") 8 | logger.setLevel(target_level) 9 | 10 | 11 | def determine_subfolder(folder_value: str = None): 12 | if folder_value is None or str(folder_value).lower() == "none": 13 | return None 14 | return str(folder_value) 15 | -------------------------------------------------------------------------------- /helpers/training/error_handling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from accelerate.logging import get_logger 4 | 5 | logger = get_logger(__name__, log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 6 | 7 | target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") 8 | logger.setLevel(target_level) 9 | 10 | 11 | def validate_deepspeed_compat_from_args(accelerator, args): 12 | if "lora" in args.model_type: 13 | logger.error( 14 | "LoRA can not be trained with DeepSpeed. Please disable DeepSpeed via 'accelerate config' before reattempting." 15 | ) 16 | sys.exit(1) 17 | if ( 18 | "gradient_accumulation_steps" 19 | in accelerator.state.deepspeed_plugin.deepspeed_config 20 | ): 21 | args.gradient_accumulation_steps = ( 22 | accelerator.state.deepspeed_plugin.deepspeed_config[ 23 | "gradient_accumulation_steps" 24 | ] 25 | ) 26 | logger.info( 27 | f"Updated gradient_accumulation_steps to the value provided by DeepSpeed: {args.gradient_accumulation_steps}" 28 | ) 29 | -------------------------------------------------------------------------------- /helpers/training/evaluation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from torchmetrics.functional.multimodal import clip_score 3 | from torchvision import transforms 4 | import torch, logging, os 5 | import numpy as np 6 | from PIL import Image 7 | from helpers.training.state_tracker import StateTracker 8 | 9 | logger = logging.getLogger("ModelEvaluator") 10 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 11 | 12 | model_evaluator_map = { 13 | "clip": "CLIPModelEvaluator", 14 | } 15 | 16 | 17 | class ModelEvaluator: 18 | def __init__(self, pretrained_model_name_or_path): 19 | raise NotImplementedError( 20 | "Subclasses is incomplete, no __init__ method was found." 21 | ) 22 | 23 | def evaluate(self, images, prompts): 24 | raise NotImplementedError("Subclasses should implement the evaluate() method.") 25 | 26 | @staticmethod 27 | def from_config(args): 28 | """Instantiate a ModelEvaluator from the training config, if set to do so.""" 29 | if not StateTracker.get_accelerator().is_main_process: 30 | return None 31 | if ( 32 | args.evaluation_type is not None 33 | and args.evaluation_type.lower() != "" 34 | and args.evaluation_type.lower() != "none" 35 | ): 36 | model_evaluator = model_evaluator_map[args.evaluation_type] 37 | return globals()[model_evaluator]( 38 | args.pretrained_evaluation_model_name_or_path 39 | ) 40 | 41 | return None 42 | 43 | 44 | class CLIPModelEvaluator(ModelEvaluator): 45 | def __init__( 46 | self, pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" 47 | ): 48 | self.clip_score_fn = partial( 49 | clip_score, model_name_or_path=pretrained_model_name_or_path 50 | ) 51 | self.preprocess = transforms.Compose([transforms.ToTensor()]) 52 | 53 | def evaluate(self, images, prompts): 54 | # Preprocess images 55 | images_tensor = torch.stack([self.preprocess(img) * 255 for img in images]) 56 | # Compute CLIP scores 57 | result = self.clip_score_fn(images_tensor, prompts).detach().cpu() 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /helpers/training/exceptions.py: -------------------------------------------------------------------------------- 1 | class MultiDatasetExhausted(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /helpers/training/gradient_checkpointing_interval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import checkpoint as original_checkpoint 3 | 4 | 5 | # Global variables to keep track of the checkpointing state 6 | _checkpoint_call_count = 0 7 | _checkpoint_interval = 4 # You can set this to any interval you prefer 8 | 9 | 10 | def reset_checkpoint_counter(): 11 | """Resets the checkpoint call counter. Call this at the beginning of the forward pass.""" 12 | global _checkpoint_call_count 13 | _checkpoint_call_count = 0 14 | 15 | 16 | def set_checkpoint_interval(n): 17 | """Sets the interval at which checkpointing is skipped.""" 18 | global _checkpoint_interval 19 | _checkpoint_interval = n 20 | 21 | 22 | def checkpoint_wrapper(function, *args, use_reentrant=True, **kwargs): 23 | """Wrapper function for torch.utils.checkpoint.checkpoint.""" 24 | global _checkpoint_call_count, _checkpoint_interval 25 | _checkpoint_call_count += 1 26 | 27 | if ( 28 | _checkpoint_interval > 0 29 | and (_checkpoint_call_count % _checkpoint_interval) == 0 30 | ): 31 | # Use the original checkpoint function 32 | return original_checkpoint( 33 | function, *args, use_reentrant=use_reentrant, **kwargs 34 | ) 35 | else: 36 | # Skip checkpointing: execute the function directly 37 | # Do not pass 'use_reentrant' to the function 38 | return function(*args, **kwargs) 39 | 40 | 41 | # Monkeypatch torch.utils.checkpoint.checkpoint 42 | torch.utils.checkpoint.checkpoint = checkpoint_wrapper 43 | -------------------------------------------------------------------------------- /helpers/training/min_snr_gamma.py: -------------------------------------------------------------------------------- 1 | # From Diffusers repository: examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py 2 | 3 | 4 | def compute_snr(timesteps, noise_scheduler, use_soft_min: bool = False, sigma_data=1.0): 5 | """ 6 | Computes SNR using two different methods based on the `use_soft_min` flag. 7 | 8 | Args: 9 | timesteps (torch.Tensor): The timesteps at which SNR is computed. 10 | noise_scheduler (NoiseScheduler): An object that contains the alpha_cumprod values. 11 | use_soft_min (bool): If True, use the _weighting_soft_min_snr method to compute SNR. 12 | sigma_data (torch.Tensor or None): The standard deviation of the data used in the soft min weighting method. 13 | 14 | Returns: 15 | torch.Tensor: The computed SNR values. 16 | """ 17 | alphas_cumprod = noise_scheduler.alphas_cumprod 18 | sqrt_alphas_cumprod = alphas_cumprod**0.5 19 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 20 | 21 | # Expand the tensors. 22 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ 23 | timesteps 24 | ].float() 25 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 26 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 27 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 28 | 29 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( 30 | device=timesteps.device 31 | )[timesteps].float() 32 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 33 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 34 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 35 | 36 | # Choose the method to compute SNR 37 | if use_soft_min: 38 | if sigma_data is None: 39 | raise ValueError( 40 | "sigma_data must be provided when using soft min SNR calculation." 41 | ) 42 | snr = (sigma * sigma_data) ** 2 / (sigma**2 + sigma_data**2) ** 2 43 | else: 44 | # Default SNR computation 45 | snr = (alpha / sigma) ** 2 46 | 47 | return snr 48 | -------------------------------------------------------------------------------- /helpers/training/multi_process.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def _get_rank(): 5 | if dist.is_available() and dist.is_initialized(): 6 | return dist.get_rank() 7 | else: 8 | return 0 9 | 10 | 11 | def rank_info(): 12 | try: 13 | return f"(Rank: {_get_rank()}) " 14 | except: 15 | return "" 16 | 17 | 18 | def should_log(): 19 | return _get_rank() == 0 20 | -------------------------------------------------------------------------------- /helpers/training/optimizers/adamw_bfloat16/stochastic/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, FloatTensor 3 | 4 | 5 | def swap_first_and_last_dims(tensor: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Swap the first dimension with the last dimension of a tensor. 8 | 9 | Args: 10 | tensor (torch.Tensor): The input tensor of any shape. 11 | 12 | Returns: 13 | torch.Tensor: A tensor with the first dimension swapped with the last. 14 | """ 15 | # Get the total number of dimensions 16 | num_dims = len(tensor.shape) 17 | 18 | # Create a new order of dimensions 19 | new_order = list(range(1, num_dims)) + [0] 20 | 21 | # Permute the tensor according to the new order 22 | return tensor.permute(*new_order) 23 | 24 | 25 | def swap_back_first_and_last_dims(tensor: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Swap back the first dimension with the last dimension of a tensor 28 | to its original shape after a swap. 29 | 30 | Args: 31 | tensor (torch.Tensor): The tensor that had its first and last dimensions swapped. 32 | 33 | Returns: 34 | torch.Tensor: A tensor with its original shape restored. 35 | """ 36 | # Get the total number of dimensions 37 | num_dims = len(tensor.shape) 38 | 39 | # Create a new order to reverse the previous swapping 40 | new_order = [num_dims - 1] + list(range(0, num_dims - 1)) 41 | 42 | # Permute the tensor according to the new order 43 | return tensor.permute(*new_order) 44 | 45 | 46 | def copy_stochastic_(target: Tensor, source: Tensor): 47 | """ 48 | copies source into target using stochastic rounding 49 | 50 | Args: 51 | target: the target tensor with dtype=bfloat16 52 | source: the target tensor with dtype=float32 53 | """ 54 | # create a random 16 bit integer 55 | result = torch.randint_like( 56 | source, 57 | dtype=torch.int32, 58 | low=0, 59 | high=(1 << 16), 60 | ) 61 | 62 | # add the random number to the lower 16 bit of the mantissa 63 | result.add_(source.view(dtype=torch.int32)) 64 | 65 | # mask off the lower 16 bit of the mantissa 66 | result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 67 | 68 | # copy the higher 16 bit into the target tensor 69 | target.copy_(result.view(dtype=torch.float32)) 70 | 71 | del result 72 | 73 | 74 | def add_stochastic_(_input: Tensor, other: Tensor, alpha: float = 1.0): 75 | """ 76 | Adds other to input using stochastic rounding. 77 | 78 | There is a hack to fix a bug on MPS where uneven final dimensions cause 79 | a crash. 80 | 81 | Args: 82 | _input: the input tensor with dtype=bfloat16 83 | other: the other tensor 84 | alpha: a multiplier for other 85 | """ 86 | _input_original = _input 87 | if _input.device.type == "mps": 88 | _input = _input.to(dtype=torch.float32) 89 | 90 | if other.dtype == torch.float32: 91 | result = other.clone() 92 | else: 93 | result = other.to(dtype=torch.float32) 94 | 95 | if _input.device.type == "mps": 96 | result.add_(_input, alpha=torch.tensor(alpha, dtype=torch.float32)) 97 | else: 98 | result.add_(_input, alpha=alpha) 99 | 100 | copy_stochastic_(_input, result) 101 | 102 | if _input.device.type == "mps": 103 | _input_original.copy_(_input.view(dtype=torch.float32)) 104 | 105 | 106 | def addcdiv_stochastic_( 107 | _input: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1.0 108 | ): 109 | """ 110 | adds (tensor1 / tensor2 * value) to input using stochastic rounding 111 | 112 | Args: 113 | _input: the input tensor with dtype=bfloat16 114 | tensor1: the numerator tensor 115 | tensor2: the denominator tensor 116 | value: a multiplier for tensor1/tensor2 117 | """ 118 | if _input.dtype == torch.float32: 119 | result = _input.clone() 120 | else: 121 | result = _input.to(dtype=torch.float32) 122 | 123 | result.addcdiv_(tensor1, tensor2, value=value) 124 | copy_stochastic_(_input, result) 125 | -------------------------------------------------------------------------------- /helpers/training/peft_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def approximate_normal_tensor(inp, target, scale=1.0): 5 | device = inp.device 6 | tensor = torch.randn_like(target).to(device) 7 | desired_norm = inp.norm().to(device) 8 | desired_mean = inp.mean().to(device) 9 | desired_std = inp.std().to(device) 10 | 11 | current_norm = tensor.norm() 12 | tensor = tensor * (desired_norm / current_norm) 13 | current_std = tensor.std() 14 | tensor = tensor * (desired_std / current_std) 15 | tensor = tensor - tensor.mean() + desired_mean 16 | tensor.mul_(scale) 17 | 18 | target.copy_(tensor) 19 | 20 | 21 | def init_lokr_network_with_perturbed_normal(lycoris, scale=1e-3): 22 | with torch.no_grad(): 23 | for lora in lycoris.loras: 24 | lora.lokr_w1.fill_(1.0) 25 | approximate_normal_tensor(lora.org_weight, lora.lokr_w2, scale=scale) 26 | -------------------------------------------------------------------------------- /helpers/training/quantisation/quanto_workarounds.py: -------------------------------------------------------------------------------- 1 | import torch, optimum 2 | 3 | if torch.cuda.is_available(): 4 | # the marlin fp8 kernel needs some help with dtype casting for some reason 5 | # see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201 6 | from optimum.quanto.library.extensions.cuda import ext as quanto_ext 7 | 8 | # Save the original operator 9 | original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin 10 | 11 | def fp8_marlin_gemm_wrapper( 12 | a: torch.Tensor, 13 | b_q_weight: torch.Tensor, 14 | b_scales: torch.Tensor, 15 | workspace: torch.Tensor, 16 | num_bits: int, 17 | size_m: int, 18 | size_n: int, 19 | size_k: int, 20 | ) -> torch.Tensor: 21 | # Ensure 'a' has the correct dtype 22 | a = a.to(b_scales.dtype) 23 | # Call the original operator 24 | return original_gemm_f16f8_marlin( 25 | a, 26 | b_q_weight, 27 | b_scales, 28 | workspace, 29 | num_bits, 30 | size_m, 31 | size_n, 32 | size_k, 33 | ) 34 | 35 | # Monkey-patch the operator 36 | torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper 37 | 38 | class TinyGemmQBitsLinearFunction( 39 | optimum.quanto.tensor.function.QuantizedLinearFunction 40 | ): 41 | @staticmethod 42 | def forward(ctx, input, other, bias): 43 | ctx.save_for_backward(input, other) 44 | if type(input) is not torch.Tensor: 45 | input = input.dequantize() 46 | in_features = input.shape[-1] 47 | out_features = other.shape[0] 48 | output_shape = input.shape[:-1] + (out_features,) 49 | output = torch._weight_int4pack_mm( 50 | input.view(-1, in_features).to(dtype=other.dtype), 51 | other._data._data, 52 | other._group_size, 53 | other._scale_shift, 54 | ) 55 | output = output.view(output_shape) 56 | if bias is not None: 57 | output = output + bias 58 | return output 59 | 60 | from optimum.quanto.tensor.weights import tinygemm 61 | 62 | tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction 63 | 64 | 65 | class WeightQBytesLinearFunction( 66 | optimum.quanto.tensor.function.QuantizedLinearFunction 67 | ): 68 | @staticmethod 69 | def forward(ctx, input, other, bias=None): 70 | ctx.save_for_backward(input, other) 71 | if isinstance(input, optimum.quanto.tensor.QBytesTensor): 72 | output = torch.ops.quanto.qbytes_mm( 73 | input._data, other._data, input._scale * other._scale 74 | ) 75 | else: 76 | in_features = input.shape[-1] 77 | out_features = other.shape[0] 78 | output_shape = input.shape[:-1] + (out_features,) 79 | output = torch.ops.quanto.qbytes_mm( 80 | input.reshape(-1, in_features), other._data, other._scale 81 | ) 82 | output = output.view(output_shape) 83 | if bias is not None: 84 | output = output + bias 85 | return output 86 | 87 | 88 | optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = ( 89 | WeightQBytesLinearFunction 90 | ) 91 | 92 | 93 | def reshape_qlf_backward(ctx, gO): 94 | # another one where we need .reshape instead of .view 95 | input_gO = other_gO = bias_gO = None 96 | input, other = ctx.saved_tensors 97 | out_features, in_features = other.shape 98 | if ctx.needs_input_grad[0]: 99 | # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B 100 | input_gO = torch.matmul(gO, other) 101 | if ctx.needs_input_grad[1]: 102 | # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A 103 | other_gO = torch.matmul( 104 | gO.reshape(-1, out_features).t(), 105 | input.to(gO.dtype).reshape(-1, in_features), 106 | ) 107 | if ctx.needs_input_grad[2]: 108 | # Bias gradient is the sum on all dimensions but the last one 109 | dim = tuple(range(gO.ndim - 1)) 110 | bias_gO = gO.sum(dim) 111 | return input_gO, other_gO, bias_gO 112 | 113 | 114 | optimum.quanto.tensor.function.QuantizedLinearFunction.backward = reshape_qlf_backward 115 | -------------------------------------------------------------------------------- /helpers/training/quantisation/torchao_workarounds.py: -------------------------------------------------------------------------------- 1 | # import torchao, torch 2 | 3 | # from torch import Tensor 4 | # from typing import Optional 5 | # from torchao.prototype.quantized_training.int8 import Int8QuantizedTrainingLinearWeight 6 | 7 | 8 | # class _Int8WeightOnlyLinear(torch.autograd.Function): 9 | # @staticmethod 10 | # def forward( 11 | # ctx, 12 | # input: Tensor, 13 | # weight: Int8QuantizedTrainingLinearWeight, 14 | # bias: Optional[Tensor] = None, 15 | # ): 16 | # ctx.save_for_backward(input, weight) 17 | # ctx.bias = bias is not None 18 | 19 | # # NOTE: we have to .T before .to(input.dtype) for torch.compile() mixed matmul to work 20 | # out = (input @ weight.int_data.T.to(input.dtype)) * weight.scale 21 | # out = out + bias if bias is not None else out 22 | # return out 23 | 24 | # @staticmethod 25 | # def backward(ctx, grad_output): 26 | # input, weight = ctx.saved_tensors 27 | 28 | # grad_input = (grad_output * weight.scale) @ weight.int_data.to( 29 | # grad_output.dtype 30 | # ) 31 | # # print(f"dtypes: grad_output {grad_output.dtype}, input {input.dtype}, weight {weight.dtype}") 32 | # # here is the patch: we will cast the input to the grad_output dtype. 33 | # grad_weight = grad_output.view(-1, weight.shape[0]).T @ input.to( 34 | # grad_output.dtype 35 | # ).reshape(-1, weight.shape[1]) 36 | # grad_bias = grad_output.view(-1, weight.shape[0]).sum(0) if ctx.bias else None 37 | # return grad_input, grad_weight, grad_bias 38 | 39 | 40 | # torchao.prototype.quantized_training.int8._Int8WeightOnlyLinear = _Int8WeightOnlyLinear 41 | -------------------------------------------------------------------------------- /helpers/training/wrappers.py: -------------------------------------------------------------------------------- 1 | from diffusers.utils.torch_utils import is_compiled_module 2 | 3 | 4 | def unwrap_model(accelerator, model): 5 | model = accelerator.unwrap_model(model) 6 | model = model._orig_mod if is_compiled_module(model) else model 7 | return model 8 | 9 | 10 | def gather_dict_of_tensors_shapes(tensors: dict) -> dict: 11 | if "prompt_embeds" in tensors and isinstance(tensors["prompt_embeds"], list): 12 | # some models like HiDream return a list of batched tensors.. 13 | return {k: [x.shape for x in v] for k, v in tensors.items()} 14 | else: 15 | return {k: v.shape for k, v in tensors.items()} 16 | 17 | 18 | def move_dict_of_tensors_to_device(tensors: dict, device) -> dict: 19 | """ 20 | Move a dictionary of tensors to a specified device, including dictionaries of nested tensors in lists (HiDream outputs). 21 | 22 | Args: 23 | tensors (dict): Dictionary of tensors to move. 24 | device (torch.device): The device to move the tensors to. 25 | 26 | Returns: 27 | dict: Dictionary of tensors moved to the specified device. 28 | """ 29 | if "prompt_embeds" in tensors and isinstance(tensors["prompt_embeds"], list): 30 | return {k: [x.to(device) for x in v] for k, v in tensors.items()} 31 | else: 32 | return {k: v.to(device) for k, v in tensors.items()} 33 | -------------------------------------------------------------------------------- /helpers/webhooks/config.py: -------------------------------------------------------------------------------- 1 | from json import load 2 | 3 | supported_webhooks = ["discord", "raw"] 4 | 5 | 6 | def check_discord_webhook_config(config: dict) -> bool: 7 | if "webhook_type" not in config or config["webhook_type"] != "discord": 8 | return False 9 | if "webhook_url" not in config: 10 | raise ValueError("Discord webhook config is missing 'webhook_url' value.") 11 | return True 12 | 13 | 14 | def check_raw_webhook_config(config: dict) -> bool: 15 | if config.get("webhook_type") != "raw": 16 | return False 17 | missing_fields = [] 18 | required_fields = ["callback_url"] 19 | for config_field in required_fields: 20 | if not config.get(config_field): 21 | missing_fields.append(config_field) 22 | if missing_fields: 23 | raise ValueError(f"Missing fields on webhook config: {missing_fields}") 24 | return False 25 | 26 | 27 | class WebhookConfig: 28 | def __init__(self, config_path: str): 29 | self.config_path = config_path 30 | self.values = self.load_config() 31 | if ( 32 | "webhook_type" not in self.values 33 | or self.values["webhook_type"] not in supported_webhooks 34 | ): 35 | raise ValueError( 36 | f"Invalid webhook type specified in config. Supported values: {supported_webhooks}" 37 | ) 38 | if check_discord_webhook_config(self.values): 39 | self.webhook_type = "discord" 40 | elif check_raw_webhook_config(self.values): 41 | self.webhook_type = "raw" 42 | 43 | def load_config(self): 44 | with open(self.config_path, "r") as f: 45 | return load(f) 46 | 47 | def get_config(self): 48 | return self.values 49 | 50 | def __getattr__(self, name): 51 | return self.values.get(name, None) 52 | -------------------------------------------------------------------------------- /helpers/webhooks/mixin.py: -------------------------------------------------------------------------------- 1 | from helpers.webhooks.handler import WebhookHandler 2 | from helpers.training.state_tracker import StateTracker 3 | from helpers.training.multi_process import _get_rank as get_rank 4 | 5 | current_rank = get_rank() 6 | 7 | 8 | class WebhookMixin: 9 | webhook_handler: WebhookHandler = None 10 | 11 | def set_webhook_handler(self, webhook_handler: WebhookHandler): 12 | self.webhook_handler = webhook_handler 13 | 14 | def send_progress_update(self, type: str, progress: int, total: int, current: int): 15 | if total == 1: 16 | return 17 | if int(current_rank) != 0: 18 | return 19 | progress = { 20 | "message_type": "progress_update", 21 | "message": { 22 | "progress_type": type, 23 | "progress": progress, 24 | "total_elements": total, 25 | "current_estimated_index": current, 26 | }, 27 | } 28 | 29 | self.webhook_handler.send_raw( 30 | progress, "progress_update", job_id=StateTracker.get_job_id() 31 | ) 32 | -------------------------------------------------------------------------------- /install/apple/poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = false -------------------------------------------------------------------------------- /install/apple/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "simpletuner" 3 | version = "1.1.0" 4 | description = "Stable Diffusion 2.x and XL tuner." 5 | authors = ["bghira"] 6 | license = "AGPLv3" 7 | readme = "README.md" 8 | package-mode = false 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.10,<3.13" 12 | torch = "2.7.0" 13 | torchvision = "^0.22.0" 14 | diffusers = {git = "https://github.com/huggingface/diffusers"} 15 | transformers = "^4.51.3" 16 | datasets = "^3.0.0" 17 | wandb = "^0.19.10" 18 | requests = "^2.32.3" 19 | pillow = "^11.2.1" 20 | opencv-python = "^4.10.0.84" 21 | accelerate = "^1.6.0" 22 | safetensors = "^0.4.5" 23 | compel = "^2.0.1" 24 | clip-interrogator = "^0.6.0" 25 | open-clip-torch = "^2.26.1" 26 | iterutils = "^0.1.6" 27 | scipy = "^1.11.1" 28 | boto3 = "^1.35.83" 29 | pandas = "^2.2.3" 30 | botocore = "^1.35.83" 31 | urllib3 = "<1.27" 32 | torchsde = "^0.2.6" 33 | torchmetrics = "^1.1.1" 34 | colorama = "^0.4.6" 35 | numpy = "^2.2.0" 36 | peft = "^0.15.2" 37 | tensorboard = "^2.18.0" 38 | regex = "^2023.12.25" 39 | huggingface-hub = "^0.30.2" 40 | optimum-quanto = {git = "https://github.com/huggingface/optimum-quanto"} 41 | torch-optimi = "^0.2.1" 42 | lycoris-lora = {git = "https://github.com/kohakublueleaf/lycoris", rev = "dev"} 43 | fastapi = {extras = ["all"], version = "^0.115.12"} 44 | deepspeed = "^0.16.1" 45 | sentencepiece = "^0.2.0" 46 | torchao = "^0.9.0" 47 | torchaudio = "^2.7.0" 48 | atomicwrites = "^1.4.1" 49 | beautifulsoup4 = "^4.12.3" 50 | prodigy-plus-schedule-free = "^1.9.1" 51 | imageio-ffmpeg = "^0.6.0" 52 | imageio = {extras = ["pyav"], version = "^2.37.0"} 53 | aiohttp = "^3.11.18" 54 | aiohappyeyeballs = "^2.6.1" 55 | pytz = "^2025.2" 56 | setuptools = "^79.0.1" 57 | tzdata = "^2025.2" 58 | pydantic = "^2.11.3" 59 | 60 | [tool.poetry.group.jxl.dependencies] 61 | pillow-jxl-plugin = "^1.3.1" 62 | 63 | [build-system] 64 | requires = ["poetry-core"] 65 | build-backend = "poetry.core.masonry.api" 66 | 67 | [[tool.poetry.source]] 68 | priority = "supplemental" 69 | name = "pytorch-nightly" 70 | url = "https://download.pytorch.org/whl/nightly/cpu" 71 | -------------------------------------------------------------------------------- /install/github/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "simpletuner" 3 | version = "1.1.0" 4 | description = "Stable Diffusion 2.x and XL tuner." 5 | authors = ["bghira"] 6 | license = "AGPLv3" 7 | readme = "README.md" 8 | package-mode = false 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.10,<3.13" 12 | torch = { "version" = ">=2.4.1", "source" = "pytorch" } 13 | torchvision = "^0.19.0" 14 | diffusers = "^0.30.3" 15 | transformers = "^4.44.2" 16 | datasets = "^3.0.0" 17 | wandb = "^0.18.1" 18 | requests = "^2.32.3" 19 | pillow = "^10.4.0" 20 | opencv-python = "^4.10.0.84" 21 | accelerate = "^0.34.2" 22 | safetensors = "^0.4.5" 23 | compel = "^2.0.1" 24 | clip-interrogator = "^0.6.0" 25 | open-clip-torch = "^2.26.1" 26 | iterutils = "^0.1.6" 27 | scipy = "^1.11.1" 28 | boto3 = "^1.35.24" 29 | pandas = "^2.2.3" 30 | botocore = "^1.35.24" 31 | urllib3 = "<1.27" 32 | torchsde = "^0.2.5" 33 | torchmetrics = "^1.1.1" 34 | colorama = "^0.4.6" 35 | numpy = "1.26" 36 | peft = "^0.12.0" 37 | tensorboard = "^2.17.1" 38 | regex = "^2023.12.25" 39 | huggingface-hub = "^0.23.3" 40 | optimum-quanto = {git = "https://github.com/huggingface/optimum-quanto"} 41 | torch-optimi = "^0.2.1" 42 | lycoris-lora = {git = "https://github.com/kohakublueleaf/lycoris", rev = "dev"} 43 | fastapi = {extras = ["standard"], version = "^0.115.0"} 44 | deepspeed = "^0.15.1" 45 | sentencepiece = "^0.2.0" 46 | torchao = "^0.5.0" 47 | 48 | 49 | [build-system] 50 | requires = ["poetry-core"] 51 | build-backend = "poetry.core.masonry.api" 52 | 53 | [[tool.poetry.source]] 54 | priority = "supplemental" 55 | name = "pytorch" 56 | url = "https://download.pytorch.org/whl/cpu" 57 | 58 | [[tool.poetry.source]] 59 | priority = "supplemental" 60 | name = "pytorch-nightly" 61 | url = "https://download.pytorch.org/whl/nightly/cpu" 62 | -------------------------------------------------------------------------------- /install/rocm/poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = false -------------------------------------------------------------------------------- /install/rocm/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "simpletuner" 3 | version = "1.1.0" 4 | description = "Stable Diffusion 2.x and XL tuner." 5 | authors = ["bghira"] 6 | license = "AGPLv3" 7 | readme = "README.md" 8 | #package-mode = false 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.12" 12 | torch = { url = "https://download.pytorch.org/whl/test/rocm6.3/torch-2.7.0%2Brocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl" } 13 | torchvision = "*" 14 | torchaudio = "*" 15 | pytorch_triton_rocm = "*" 16 | accelerate = "^1.2.1" 17 | boto3 = "^1.35.83" 18 | botocore = "^1.35.83" 19 | clip-interrogator = "^0.6.0" 20 | colorama = "^0.4.6" 21 | compel = "^2" 22 | datasets = "^3.0.0" 23 | deepspeed = "^0.16.1" 24 | diffusers = {git = "https://github.com/huggingface/diffusers"} 25 | iterutils = "^0.1.6" 26 | numpy = "^2.2.0" 27 | open-clip-torch = "^2.26.1" 28 | opencv-python = "^4.10.0.84" 29 | pandas = "^2.2.3" 30 | peft = "^0.14.0" 31 | pillow = "^11.0.0" 32 | requests = "^2.32.3" 33 | safetensors = "^0.4.5" 34 | scipy = "^1" 35 | tensorboard = "^2.18.0" 36 | torchsde = "^0.2.6" 37 | transformers = "^4.49.0" 38 | urllib3 = "<1.27" 39 | wandb = "^0.19.1" 40 | sentencepiece = "^0.2.0" 41 | optimum-quanto = {git = "https://github.com/huggingface/optimum-quanto"} 42 | lycoris-lora = {git = "https://github.com/kohakublueleaf/lycoris", rev = "dev"} 43 | torch-optimi = "^0.2.1" 44 | fastapi = {extras = ["standard"], version = "^0.115.0"} 45 | atomicwrites = "^1.4.1" 46 | torchao = "^0.7.0" 47 | beautifulsoup4 = "^4.12.3" 48 | prodigy-plus-schedule-free = "^1.9.0" 49 | huggingface-hub = "^0.29.1" 50 | imageio-ffmpeg = "^0.6.0" 51 | imageio = {extras = ["pyav"], version = "^2.37.0"} 52 | torchmetrics = "^1.7.1" 53 | 54 | [tool.poetry.group.jxl.dependencies] 55 | pillow-jxl-plugin = "^1.3.1" 56 | 57 | [build-system] 58 | requires = ["poetry-core"] 59 | build-backend = "poetry.core.masonry.api" 60 | 61 | [[tool.poetry.source]] 62 | secondary = true 63 | name = "pytorch-rocm" 64 | url = "https://download.pytorch.org/whl/test/rocm6.3" 65 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = false -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "simpletuner" 3 | version = "1.1.0" 4 | description = "Stable Diffusion 2.x and XL tuner." 5 | authors = ["bghira"] 6 | license = "AGPLv3" 7 | readme = "README.md" 8 | package-mode = false 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.10,<3.13" 12 | torch = "^2.6.0" 13 | torchvision = "^0.21.0" 14 | diffusers = {git = "https://github.com/huggingface/diffusers"} 15 | transformers = "^4.51.1" 16 | datasets = "^3.0.1" 17 | bitsandbytes = "^0.45.0" 18 | wandb = "^0.19.4" 19 | requests = "^2.32.3" 20 | pillow = "^11.0.0" 21 | opencv-python = "^4.10.0.84" 22 | deepspeed = "^0.16.2" 23 | accelerate = "^1.5.2" 24 | safetensors = "^0.4.5" 25 | compel = "^2.0.1" 26 | clip-interrogator = "^0.6.0" 27 | open-clip-torch = "^2.26.1" 28 | iterutils = "^0.1.6" 29 | scipy = "^1.11.1" 30 | boto3 = "^1.35.83" 31 | pandas = "^2.2.3" 32 | botocore = "^1.35.83" 33 | urllib3 = "<1.27" 34 | torchaudio = "^2.4.1" 35 | triton-library = "^1.0.0rc4" 36 | torchsde = "^0.2.6" 37 | torchmetrics = "^1.1.1" 38 | colorama = "^0.4.6" 39 | numpy = "^2.2.0" 40 | peft = "^0.14.0" 41 | tensorboard = "^2.18.0" 42 | triton = {version = "^3.1.0", source = "pytorch"} 43 | sentencepiece = "^0.2.0" 44 | optimum-quanto = {git = "https://github.com/huggingface/optimum-quanto"} 45 | lycoris-lora = {git = "https://github.com/kohakublueleaf/lycoris", rev = "dev"} 46 | torch-optimi = "^0.2.1" 47 | toml = "^0.10.2" 48 | fastapi = {extras = ["standard"], version = "^0.115.0"} 49 | torchao = "^0.10.0" 50 | lm-eval = "^0.4.4" 51 | nvidia-cudnn-cu12 = "*" 52 | nvidia-nccl-cu12 = "*" 53 | atomicwrites = "^1.4.1" 54 | beautifulsoup4 = "^4.12.3" 55 | prodigy-plus-schedule-free = "^1.9.0" 56 | tokenizers = "^0.21.0" 57 | huggingface-hub = "^0.30.2" 58 | imageio-ffmpeg = "^0.6.0" 59 | imageio = {extras = ["pyav"], version = "^2.37.0"} 60 | 61 | [tool.poetry.group.jxl.dependencies] 62 | pillow-jxl-plugin = "^1.3.1" 63 | 64 | [build-system] 65 | requires = ["poetry-core", "setuptools", "wheel", "torch"] 66 | build-backend = "poetry.core.masonry.api" 67 | 68 | [[tool.poetry.source]] 69 | priority = "supplemental" 70 | name = "pytorch" 71 | url = "https://download.pytorch.org/whl/cu124" 72 | -------------------------------------------------------------------------------- /service_worker.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | 3 | # from simpletuner_sdk import parse_api_args 4 | from simpletuner_sdk.configuration import Configuration 5 | from simpletuner_sdk.training_host import TrainingHost 6 | from fastapi.staticfiles import StaticFiles 7 | import logging, os 8 | 9 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 10 | logger = logging.getLogger("SimpleTunerAPI") 11 | 12 | config_controller = Configuration() 13 | training_host = TrainingHost() 14 | 15 | app = FastAPI() 16 | app.mount("/static", StaticFiles(directory="static"), name="static") 17 | 18 | ##################################################### 19 | # configuration controller for argument handling # 20 | ##################################################### 21 | app.include_router(config_controller.router) 22 | 23 | ##################################################### 24 | # traininghost controller for training job mgmt # 25 | ##################################################### 26 | app.include_router(training_host.router) 27 | 28 | if os.path.exists("templates/ui.template"): 29 | from simpletuner_sdk.interface import WebInterface 30 | 31 | app.include_router(WebInterface().router) 32 | -------------------------------------------------------------------------------- /simpletuner_sdk/api_state.py: -------------------------------------------------------------------------------- 1 | """ 2 | API State tracker to persist and resume tracker states and configurations. 3 | 4 | If a server were to crash during a training job, we can immediately reload the system state 5 | and continue training from the last checkpoint. 6 | """ 7 | 8 | import os 9 | import json 10 | import logging 11 | 12 | logger = logging.getLogger("SimpleTunerSDK") 13 | logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING")) 14 | 15 | 16 | class APIState: 17 | state = {} 18 | state_file = "api_state.json" 19 | trainer = None 20 | 21 | @classmethod 22 | def load_state(cls): 23 | if os.path.exists(cls.state_file): 24 | with open(cls.state_file, "r") as f: 25 | cls.state = json.load(f) 26 | logger.info(f"Loaded state from {cls.state_file}: {cls.state}") 27 | else: 28 | logger.info(f"No state file found at {cls.state_file}") 29 | 30 | @classmethod 31 | def save_state(cls): 32 | with open(cls.state_file, "w") as f: 33 | json.dump(cls.state, f) 34 | logger.info(f"Saved state to {cls.state_file}: {cls.state}") 35 | 36 | @classmethod 37 | def get_state(cls, key=None): 38 | if not key: 39 | return cls.state 40 | return cls.state.get(key) 41 | 42 | @classmethod 43 | def set_state(cls, key, value): 44 | cls.state[key] = value 45 | cls.save_state() 46 | 47 | @classmethod 48 | def delete_state(cls, key): 49 | if key in cls.state: 50 | del cls.state[key] 51 | cls.save_state() 52 | 53 | @classmethod 54 | def clear_state(cls): 55 | cls.state = {} 56 | cls.save_state() 57 | 58 | @classmethod 59 | def set_job(cls, job_id, job: dict): 60 | cls.set_state("current_job", job) 61 | cls.set_state("current_job_id", job_id) 62 | cls.set_state("status", "running") 63 | 64 | @classmethod 65 | def get_job(cls): 66 | return { 67 | "job_id": cls.get_state("current_job_id"), 68 | "job": cls.get_state("job"), 69 | } 70 | 71 | @classmethod 72 | def cancel_job(cls): 73 | cls.delete_state("current_job") 74 | cls.delete_state("current_job_id") 75 | cls.set_state("status", "cancelled") 76 | 77 | @classmethod 78 | def set_trainer(cls, trainer): 79 | cls.trainer = trainer 80 | 81 | @classmethod 82 | def get_trainer(cls): 83 | return cls.trainer 84 | -------------------------------------------------------------------------------- /simpletuner_sdk/interface.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Request 2 | from fastapi.templating import Jinja2Templates 3 | from fastapi.responses import HTMLResponse 4 | import os 5 | 6 | 7 | class WebInterface: 8 | def __init__(self): 9 | self.router = APIRouter(prefix="/web") 10 | self.router.add_api_route( 11 | "/", self.get_page, methods=["GET"], response_class=HTMLResponse 12 | ) 13 | self.templates = Jinja2Templates( 14 | directory="templates" 15 | ) # Define the directory where your templates are stored 16 | self.template_file = "ui.template" # This should correspond to an HTML file in the templates directory 17 | 18 | async def get_page(self, request: Request): 19 | """ 20 | Retrieve ui.template from disk and display it to the user. 21 | If the template file does not exist, display a default message. 22 | """ 23 | if os.path.exists(f"templates/{self.template_file}"): 24 | # Serve the template if it exists 25 | return self.templates.TemplateResponse( 26 | self.template_file, {"request": request} 27 | ) 28 | else: 29 | # Default HTML if template is missing 30 | return HTMLResponse( 31 | content=""" 32 | 33 | 34 |
35 |Welcome to SimpleTuner. This installation does not include a compatible web interface. Sorry.
40 | 41 | 42 | """, 43 | status_code=200, 44 | ) 45 | -------------------------------------------------------------------------------- /simpletuner_sdk/thread_keeper/__init__.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, Future 2 | from typing import Dict 3 | import threading 4 | 5 | # We can only really have one thread going at a time anyway. 6 | executor = ThreadPoolExecutor(max_workers=1) 7 | # But, we've designed this for a future where multiple background threads might be managed. 8 | thread_registry: Dict[str, Future] = {} 9 | # So we don't zig while we zag. 10 | lock = threading.Lock() 11 | 12 | 13 | def submit_job(job_id: str, func, *args, **kwargs): 14 | with lock: 15 | if ( 16 | job_id in thread_registry 17 | and get_thread_status(job_id, with_lock=False).lower() == "running" 18 | ): 19 | raise Exception(f"Job with ID {job_id} is already running or pending.") 20 | # Remove the completed or cancelled future from the registry 21 | thread_registry.pop(job_id, None) 22 | # Submit the new job 23 | future = executor.submit(func, *args, **kwargs) 24 | thread_registry[job_id] = future 25 | 26 | 27 | def get_thread_status(job_id: str, with_lock: bool = True) -> str: 28 | if with_lock: 29 | with lock: 30 | future = thread_registry.get(job_id) 31 | if not future: 32 | return "No such job." 33 | if future.running(): 34 | return "Running" 35 | elif future.done(): 36 | if future.exception(): 37 | return f"Failed: {future.exception()}" 38 | return "Completed" 39 | return "Pending" 40 | else: 41 | future = thread_registry.get(job_id) 42 | if not future: 43 | return "No such job." 44 | if future.running(): 45 | return "Running" 46 | elif future.done(): 47 | if future.exception(): 48 | return f"Failed: {future.exception()}" 49 | return "Completed" 50 | return "Pending" 51 | 52 | 53 | def terminate_thread(job_id: str) -> bool: 54 | with lock: 55 | future = thread_registry.get(job_id) 56 | if not future: 57 | print(f"Thread {job_id} not found") 58 | return False 59 | # Attempt to cancel the future if it hasn't started running 60 | cancelled = future.cancel() 61 | if cancelled: 62 | del thread_registry[job_id] 63 | return cancelled 64 | 65 | 66 | def list_threads(): 67 | return {job_id: get_thread_status(job_id) for job_id in thread_registry} 68 | -------------------------------------------------------------------------------- /simpletuner_sdk/training_host.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from pydantic import BaseModel 3 | from simpletuner_sdk.api_state import APIState 4 | from simpletuner_sdk.thread_keeper import list_threads, terminate_thread 5 | 6 | 7 | class TrainingHost: 8 | def __init__(self): 9 | self.router = APIRouter(prefix="/training") 10 | self.router.add_api_route("/", self.get_job, methods=["GET"]) 11 | self.router.add_api_route("/", self.get_job, methods=["GET"]) 12 | self.router.add_api_route("/state", self.get_host_state, methods=["GET"]) 13 | self.router.add_api_route("/cancel", self.cancel_job, methods=["POST"]) 14 | 15 | def get_host_state(self): 16 | """ 17 | Get the current host status using APIState 18 | """ 19 | return {"result": APIState.get_state(), "job_list": list_threads()} 20 | 21 | def get_job(self): 22 | """ 23 | Returns just the currently active job from APIState 24 | """ 25 | return {"result": APIState.get_job()} 26 | 27 | def cancel_job(self): 28 | """ 29 | Cancel the currently active job 30 | """ 31 | trainer = APIState.get_trainer() 32 | if not trainer: 33 | return {"status": False, "result": "No job to cancel"} 34 | trainer.abort() 35 | is_terminated = terminate_thread(job_id=APIState.get_state("current_job_id")) 36 | APIState.set_trainer(None) 37 | APIState.cancel_job() 38 | 39 | return {"result": "Job marked for cancellation."} 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bghira/SimpleTuner/ba0a415594f524fc6de51205e2da075a47ea37ee/tests/__init__.py -------------------------------------------------------------------------------- /tests/helpers/data.py: -------------------------------------------------------------------------------- 1 | # Mock data_backend for unit testing 2 | 3 | 4 | class MockDataBackend: 5 | @staticmethod 6 | def read(image_path_str): 7 | # Dummy read method for testing 8 | return b"fake_image_data" 9 | -------------------------------------------------------------------------------- /tests/test_collate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | import numpy as np 4 | import torch 5 | 6 | from helpers.training.collate import ( 7 | collate_fn, 8 | ) # Adjust this import according to your project structure 9 | from helpers.training.state_tracker import StateTracker # Adjust imports as needed 10 | 11 | 12 | class TestCollateFn(unittest.TestCase): 13 | def setUp(self): 14 | # Set up any common variables or mocks used in multiple tests 15 | self.mock_batch = [ 16 | { 17 | "training_samples": [ 18 | { 19 | "image_path": "fake_path_1.png", 20 | "instance_prompt_text": "caption 1", 21 | "luminance": 0.5, 22 | "original_size": (100, 100), 23 | "image_data": MagicMock(), 24 | "crop_coordinates": [0, 0, 100, 100], 25 | "data_backend_id": "foo", 26 | "aspect_ratio": 1.0, 27 | } 28 | ], 29 | "conditioning_samples": [], 30 | }, 31 | # Add more examples as needed 32 | ] 33 | # Mock StateTracker.get_args() to return a mock object with required attributes 34 | StateTracker.set_args( 35 | MagicMock(caption_dropout_probability=0.5, controlnet=False, flux=False) 36 | ) 37 | fake_accelerator = MagicMock(device="cpu") 38 | StateTracker.set_accelerator(fake_accelerator) 39 | 40 | @patch("helpers.training.collate.compute_latents") 41 | @patch("helpers.training.collate.compute_prompt_embeddings") 42 | @patch("helpers.training.collate.gather_conditional_sdxl_size_features") 43 | def test_collate_fn(self, mock_gather, mock_compute_embeds, mock_compute_latents): 44 | # Mock the responses from the compute functions 45 | mock_compute_latents.return_value = torch.randn( 46 | 2, 512 47 | ) # Adjust dimensions as needed 48 | mock_compute_embeds.return_value = { 49 | "prompt_embeds": torch.randn(2, 768), 50 | "pooled_prompt_embeds": torch.randn(2, 768), 51 | } # Example embeddings 52 | mock_gather.return_value = torch.tensor( 53 | [[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]] 54 | ) 55 | mock_compute_latents.to = MagicMock(return_value=mock_compute_latents) 56 | 57 | # Call collate_fn with a mock batch 58 | with patch("helpers.training.state_tracker.StateTracker.get_data_backend"): 59 | # Mock get_data_backend() to return a mock object with required attributes 60 | StateTracker.get_data_backend.return_value = MagicMock( 61 | compute_embeddings_for_legacy_prompts=MagicMock() 62 | ) 63 | result = collate_fn(self.mock_batch) 64 | 65 | # Assert that the results are as expected 66 | self.assertIn("latent_batch", result) 67 | self.assertIn("prompt_embeds", result) 68 | self.assertIn("add_text_embeds", result) 69 | self.assertIn("batch_time_ids", result) 70 | self.assertIn("batch_luminance", result) 71 | 72 | # Check that the conditioning dropout was correctly applied (random elements should be zeros) 73 | # This can be tricky since the dropout is random; you may want to set a fixed random seed or test the structure more than values 74 | 75 | # You can add more test methods to cover different aspects like different dropout probabilities, edge cases, etc. 76 | 77 | 78 | if __name__ == "__main__": 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /tests/test_cropping.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PIL import Image 3 | from helpers.multiaspect.image import ( 4 | MultiaspectImage, 5 | ) # Adjust import according to your project structure 6 | 7 | 8 | class TestCropping(unittest.TestCase): 9 | def setUp(self): 10 | # Creating a sample image for testing 11 | self.sample_image = Image.new("RGB", (500, 300), "white") 12 | 13 | def test_crop_corner(self): 14 | target_width, target_height = 300, 200 15 | from helpers.image_manipulation.cropping import CornerCropping 16 | 17 | cropper = CornerCropping(self.sample_image) 18 | cropped_image, (top, left) = cropper.set_intermediary_size( 19 | target_width + 10, target_height + 10 20 | ).crop(target_width, target_height) 21 | 22 | # Check if cropped coordinates are within original image bounds 23 | self.assertTrue(0 <= left < self.sample_image.width) 24 | self.assertTrue(0 <= top < self.sample_image.height) 25 | self.assertTrue(left + target_width <= self.sample_image.width) 26 | self.assertTrue(top + target_height <= self.sample_image.height) 27 | self.assertEqual(cropped_image.size, (target_width, target_height)) 28 | 29 | def test_crop_center(self): 30 | from helpers.image_manipulation.cropping import CenterCropping 31 | 32 | cropper = CenterCropping(self.sample_image) 33 | target_width, target_height = 300, 200 34 | cropper.set_intermediary_size(target_width + 10, target_height + 10) 35 | cropped_image, (left, top) = cropper.crop(target_width, target_height) 36 | 37 | # Similar checks as above 38 | self.assertTrue(0 <= left < self.sample_image.width) 39 | self.assertTrue(0 <= top < self.sample_image.height) 40 | self.assertTrue(left + target_width <= self.sample_image.width) 41 | self.assertTrue(top + target_height <= self.sample_image.height) 42 | self.assertEqual(cropped_image.size, (target_width, target_height)) 43 | 44 | def test_crop_random(self): 45 | from helpers.image_manipulation.cropping import RandomCropping 46 | 47 | target_width, target_height = 300, 200 48 | cropped_image, (top, left) = ( 49 | RandomCropping(self.sample_image) 50 | .set_intermediary_size(target_width + 10, target_height + 10) 51 | .crop(target_width, target_height) 52 | ) 53 | 54 | # Similar checks as above 55 | self.assertTrue(0 <= left < self.sample_image.width) 56 | self.assertTrue(0 <= top < self.sample_image.height) 57 | self.assertTrue(left + target_width <= self.sample_image.width) 58 | self.assertTrue(top + target_height <= self.sample_image.height) 59 | self.assertEqual(cropped_image.size, (target_width, target_height)) 60 | 61 | # Add additional tests for other methods as necessary 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /tests/test_custom_schedules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | import torch 4 | import torch.optim as optim 5 | from helpers.training.custom_schedule import ( 6 | get_polynomial_decay_schedule_with_warmup, 7 | enforce_zero_terminal_snr, 8 | patch_scheduler_betas, 9 | segmented_timestep_selection, 10 | ) 11 | 12 | 13 | class TestPolynomialDecayWithWarmup(unittest.TestCase): 14 | def test_polynomial_decay_schedule_with_warmup(self): 15 | optimizer = optim.SGD([torch.randn(2, 2, requires_grad=True)], lr=0.1) 16 | scheduler = get_polynomial_decay_schedule_with_warmup( 17 | optimizer, num_warmup_steps=10, num_training_steps=100 18 | ) 19 | 20 | # Test warmup 21 | ranges = [0, 0.01, 0.02, 0.03, 0.04, 0.05] 22 | first_lr = round(scheduler.get_last_lr()[0], 2) 23 | for step in range(len(ranges)): 24 | last_lr = round(scheduler.get_last_lr()[0], 2) 25 | optimizer.step() 26 | scheduler.step() 27 | self.assertAlmostEqual(last_lr, ranges[-1], places=3) 28 | 29 | # Test decay 30 | for step in range(len(ranges), 100): 31 | optimizer.step() 32 | scheduler.step() 33 | # Implement your decay formula here to check 34 | expected_lr = 1e-7 35 | self.assertAlmostEqual(scheduler.get_last_lr()[0], expected_lr, places=4) 36 | 37 | def test_enforce_zero_terminal_snr(self): 38 | betas = torch.tensor([0.9, 0.8, 0.7]) 39 | new_betas = enforce_zero_terminal_snr(betas) 40 | final_beta = new_betas[-1] 41 | self.assertEqual(final_beta, 1.0) 42 | 43 | def test_patch_scheduler_betas(self): 44 | # Create a dummy scheduler with betas attribute 45 | class DummyScheduler: 46 | def __init__(self): 47 | self.betas = torch.tensor([0.9, 0.8, 0.7]) 48 | 49 | scheduler = DummyScheduler() 50 | # Check value before. 51 | final_beta = scheduler.betas[-1] 52 | self.assertEqual(final_beta, 0.7) 53 | 54 | patch_scheduler_betas(scheduler) 55 | 56 | final_beta = scheduler.betas[-1] 57 | self.assertEqual(final_beta, 1.0) 58 | 59 | def test_inverted_schedule(self): 60 | with patch( 61 | "helpers.training.state_tracker.StateTracker.get_args", 62 | return_value=MagicMock( 63 | refiner_training=True, 64 | refiner_training_invert_schedule=True, 65 | refiner_training_strength=0.35, 66 | ), 67 | ): 68 | weights = torch.ones(1000) # Uniform weights 69 | selected_timesteps = segmented_timestep_selection( 70 | 1000, 71 | 10, 72 | weights, 73 | config=MagicMock( 74 | refiner_training=True, 75 | refiner_training_invert_schedule=True, 76 | refiner_training_strength=0.35, 77 | ), 78 | use_refiner_range=False, 79 | ) 80 | self.assertTrue( 81 | all(350 <= t <= 999 for t in selected_timesteps), 82 | f"Selected timesteps: {selected_timesteps}", 83 | ) 84 | 85 | def test_normal_schedule(self): 86 | with patch( 87 | "helpers.training.state_tracker.StateTracker.get_args", 88 | return_value=MagicMock( 89 | refiner_training=True, 90 | refiner_training_invert_schedule=False, 91 | refiner_training_strength=0.35, 92 | ), 93 | ): 94 | weights = torch.ones(1000) # Uniform weights 95 | selected_timesteps = segmented_timestep_selection( 96 | 1000, 97 | 10, 98 | weights, 99 | use_refiner_range=False, 100 | config=MagicMock( 101 | refiner_training=True, 102 | refiner_training_invert_schedule=False, 103 | refiner_training_strength=0.35, 104 | ), 105 | ) 106 | self.assertTrue(all(0 <= t < 350 for t in selected_timesteps)) 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /tests/test_ema.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import tempfile 4 | import os 5 | from helpers.training.ema import EMAModel 6 | 7 | 8 | class TestEMAModel(unittest.TestCase): 9 | def setUp(self): 10 | # Set up a simple model and its parameters 11 | self.model = torch.nn.Linear(10, 5) # Simple linear model 12 | self.args = type( 13 | "Args", 14 | (), 15 | {"ema_update_interval": None, "ema_device": "cpu", "ema_cpu_only": True}, 16 | ) 17 | self.accelerator = None # For simplicity, assuming no accelerator in tests 18 | self.ema_model = EMAModel( 19 | args=self.args, 20 | accelerator=self.accelerator, 21 | parameters=self.model.parameters(), 22 | decay=0.999, 23 | min_decay=0.999, # Force decay to be 0.999 24 | update_after_step=-1, # Ensure decay is used from step 1 25 | use_ema_warmup=False, # Disable EMA warmup 26 | foreach=False, 27 | ) 28 | 29 | def test_ema_initialization(self): 30 | """Test that the EMA model initializes correctly.""" 31 | self.assertEqual( 32 | len(self.ema_model.shadow_params), len(list(self.model.parameters())) 33 | ) 34 | for shadow_param, model_param in zip( 35 | self.ema_model.shadow_params, self.model.parameters() 36 | ): 37 | self.assertTrue(torch.equal(shadow_param, model_param)) 38 | 39 | def test_ema_step(self): 40 | """Test that the EMA model updates correctly after a step.""" 41 | # Perform a model parameter update 42 | optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) 43 | dummy_input = torch.randn(1, 10) # Adjust to match input size 44 | dummy_output = self.model(dummy_input) 45 | loss = dummy_output.sum() # A dummy loss function 46 | loss.backward() 47 | optimizer.step() 48 | 49 | # Save a copy of the model parameters after the update but before the EMA update. 50 | model_params = [p.clone() for p in self.model.parameters()] 51 | # Save a copy of the shadow parameters before the EMA update. 52 | shadow_params_before = [p.clone() for p in self.ema_model.shadow_params] 53 | 54 | # Perform an EMA update 55 | self.ema_model.step(self.model.parameters(), global_step=1) 56 | decay = self.ema_model.cur_decay_value # This should be 0.999 57 | 58 | # Verify that the decay used is as expected 59 | self.assertAlmostEqual( 60 | decay, 0.999, places=6, msg="Decay value is not as expected." 61 | ) 62 | 63 | # Verify shadow parameters have changed 64 | for shadow_param, shadow_param_before in zip( 65 | self.ema_model.shadow_params, shadow_params_before 66 | ): 67 | self.assertFalse( 68 | torch.equal(shadow_param, shadow_param_before), 69 | "Shadow parameters did not update correctly.", 70 | ) 71 | 72 | # Compute and check expected shadow parameter values 73 | for shadow_param, shadow_param_before, model_param in zip( 74 | self.ema_model.shadow_params, shadow_params_before, self.model.parameters() 75 | ): 76 | expected_shadow = decay * shadow_param_before + (1 - decay) * model_param 77 | self.assertTrue( 78 | torch.allclose(shadow_param, expected_shadow, atol=1e-6), 79 | f"Shadow parameter does not match expected value.", 80 | ) 81 | 82 | def test_save_and_load_state_dict(self): 83 | with tempfile.TemporaryDirectory() as temp_dir: 84 | temp_path = os.path.join(temp_dir, "ema_model_state.pth") 85 | 86 | # Save the state 87 | self.ema_model.save_state_dict(temp_path) 88 | 89 | # Create a new EMA model and load the state 90 | new_ema_model = EMAModel( 91 | args=self.args, 92 | accelerator=self.accelerator, 93 | parameters=self.model.parameters(), 94 | decay=0.999, 95 | ) 96 | new_ema_model.load_state_dict(temp_path) 97 | 98 | # Check that the new EMA model's shadow parameters match the saved state 99 | for shadow_param, new_shadow_param in zip( 100 | self.ema_model.shadow_params, new_ema_model.shadow_params 101 | ): 102 | self.assertTrue(torch.equal(shadow_param, new_shadow_param)) 103 | 104 | 105 | if __name__ == "__main__": 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /tests/test_state.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from helpers.multiaspect.state import BucketStateManager 3 | 4 | 5 | class TestBucketStateManager(unittest.TestCase): 6 | def setUp(self): 7 | pass # TODO: Add setup code if needed 8 | 9 | def test_example(self): 10 | # TODO: Write test cases 11 | self.assertEqual(True, True) 12 | 13 | 14 | if __name__ == "__main__": 15 | unittest.main() 16 | -------------------------------------------------------------------------------- /tests/test_vae.py: -------------------------------------------------------------------------------- 1 | from hashlib import sha256 2 | from helpers.caching.vae import VAECache 3 | 4 | import unittest 5 | from PIL import Image 6 | import numpy as np 7 | from helpers.image_manipulation.training_sample import TrainingSample 8 | from helpers.training.state_tracker import StateTracker 9 | from unittest.mock import MagicMock 10 | 11 | 12 | class TestVaeCache(unittest.TestCase): 13 | def test_filename_mapping(self): 14 | # Test cases 15 | test_cases = [ 16 | # 0 Filepath ends with .pt (no change expected in the path) 17 | {"image_path": "/data/image1.pt", "cache_path": "/data/image1.pt"}, 18 | # 1 Normal filepath 19 | {"image_path": "/data/image1.png", "cache_path": "cache/image1.pt"}, 20 | # 2, 3 Nested subdirectories 21 | { 22 | "image_path": "/data/subdir1/subdir2/image2.jpg", 23 | "cache_path": "cache/subdir1/subdir2/image2.pt", 24 | }, 25 | { 26 | "image_path": "data/subdir1/subdir2/image2.jpg", 27 | "cache_path": "cache/subdir1/subdir2/image2.pt", 28 | "instance_dir": "data", 29 | }, 30 | # 4 No instance_data_dir, direct cache dir placement 31 | { 32 | "image_path": "/anotherdir/image3.png", 33 | "cache_path": "cache/image3.pt", 34 | "instance_dir": None, 35 | }, 36 | # 5 Instance data directory is None 37 | { 38 | "image_path": "/data/image4.png", 39 | "cache_path": "cache/image4.pt", 40 | "instance_dir": None, 41 | }, 42 | # 6 Filepath in root directory 43 | {"image_path": "/image5.png", "cache_path": "cache/image5.pt"}, 44 | # 7 Hash filenames enabled 45 | { 46 | "image_path": "/data/image6.png", 47 | "cache_path": "cache/" + sha256("image6".encode()).hexdigest() + ".pt", 48 | "should_hash": True, 49 | }, 50 | # 8 Invalid cache_dir 51 | {"image_path": "/data/image7.png", "cache_path": "cache/image7.pt"}, 52 | ] 53 | 54 | # Running test cases 55 | for i, test_case in enumerate(test_cases, 1): 56 | filepath = test_case["image_path"] 57 | # expected = os.path.abspath(test_case['cache_path']) 58 | expected = test_case["cache_path"] 59 | cache_dir = test_case.get("cache_dir", "cache") 60 | instance_dir = test_case.get("instance_dir", "/data") 61 | should_hash = test_case.get("should_hash", False) 62 | vae_cache = VAECache( 63 | id="test-cache", 64 | vae=None, 65 | accelerator=None, 66 | metadata_backend=None, 67 | image_data_backend=None, 68 | hash_filenames=should_hash, 69 | instance_data_dir=instance_dir, 70 | cache_dir=cache_dir, 71 | model=MagicMock(), 72 | ) 73 | generated = vae_cache.generate_vae_cache_filename(filepath)[0] 74 | self.assertEqual( 75 | generated, expected, f"Test {i} failed: {generated} != {expected}" 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_webhooks.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | from helpers.webhooks.handler import WebhookHandler 4 | from helpers.webhooks.config import WebhookConfig 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class TestWebhookHandler(unittest.TestCase): 10 | def setUp(self): 11 | # Create a mock for the WebhookConfig 12 | self.mock_config_instance = MagicMock(spec=WebhookConfig) 13 | self.mock_config_instance.webhook_url = "http://example.com/webhook" 14 | self.mock_config_instance.webhook_type = "discord" 15 | self.mock_config_instance.log_level = "info" 16 | self.mock_config_instance.message_prefix = "TestPrefix" 17 | self.mock_config_instance.values = { 18 | "webhook_url": "http://example.com/webhook", 19 | "webhook_type": "discord", 20 | "log_level": "info", 21 | "message_prefix": "TestPrefix", 22 | } 23 | 24 | # Mock the accelerator object 25 | self.mock_accelerator = MagicMock() 26 | self.mock_accelerator.is_main_process = True 27 | 28 | # Instantiate the handler with the mocked config 29 | self.handler = WebhookHandler( 30 | config_path="dummy_path", 31 | accelerator=self.mock_accelerator, 32 | project_name="TestProject", 33 | mock_webhook_config=self.mock_config_instance, 34 | args=MagicMock(framerate=99), 35 | ) 36 | 37 | @patch("requests.post") 38 | def test_send_message_info_level(self, mock_post): 39 | # Test sending a simple info level message 40 | message = "Test message" 41 | self.handler.send(message, message_level="info") 42 | mock_post.assert_called_once() 43 | # Capture the call arguments 44 | args, kwargs = mock_post.call_args 45 | # Assuming the message is sent in 'data' parameter 46 | self.assertIn("data", kwargs) 47 | self.assertIn(message, kwargs["data"].get("content")) 48 | 49 | @patch("requests.post") 50 | def test_debug_message_wont_send(self, mock_post): 51 | # Test that debug logs don't send when the log level is info 52 | self.handler.send("Test message", message_level="debug") 53 | mock_post.assert_not_called() 54 | 55 | @patch("requests.post") 56 | def test_do_not_send_lower_than_configured_level(self, mock_post): 57 | # Set a higher log level and test 58 | self.handler.log_level = 1 # Error level 59 | self.handler.send("Test message", message_level="info") 60 | mock_post.assert_not_called() 61 | 62 | @patch("requests.post") 63 | def test_send_with_images(self, mock_post): 64 | # Test sending messages with images 65 | image = Image.new("RGB", (60, 30), color="red") 66 | message = "Test message with image" 67 | self.handler.send(message, images=[image], message_level="info") 68 | args, kwargs = mock_post.call_args 69 | self.assertIn("files", kwargs) 70 | self.assertEqual(len(kwargs["files"]), 1) 71 | # Check that the message is in the 'data' parameter 72 | content = kwargs.get("data", {}).get("content", "") 73 | self.assertIn(self.mock_config_instance.values.get("message_prefix"), content) 74 | self.assertIn("data", kwargs, f"Check data for contents: {kwargs}") 75 | self.assertIn(message, content) 76 | 77 | @patch("requests.post") 78 | def test_response_storage(self, mock_post): 79 | # Mock response object 80 | mock_response = MagicMock() 81 | mock_response.headers = {"Content-Type": "application/json"} 82 | mock_post.return_value = mock_response 83 | 84 | self.handler.send("Test message", message_level="info", store_response=True) 85 | self.assertEqual(self.handler.stored_response, mock_response.headers) 86 | # Also check that the message is sent 87 | args, kwargs = mock_post.call_args 88 | content = kwargs.get("data", {}).get("content", "") 89 | self.assertIn(self.mock_config_instance.values.get("message_prefix"), content) 90 | self.assertIn("Test message", content) 91 | 92 | 93 | if __name__ == "__main__": 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /toolkit/README.md: -------------------------------------------------------------------------------- 1 | SimpleTuner contains some ad-hoc tools for generating and managing the training data and checkpoints. 2 | 3 | #### Captioning 4 | 5 | When captioning a dataset, relying on a single caption model is a bad practice as it pins the model to whatever the chosen caption model knows. 6 | 7 | A variety of caption options are provided: 8 | 9 | * `caption_with_blip.py` - This is the original BLIP / BLIP2 captioning script which leverages the `interrogate` python library to run the (by default Flan-T5) BLIP model. 10 | * `caption_with_blip3.py` - Built on top of the Phi LLM, BLIP3 aka XGEN-MM is an excellent option for captioning, relatively lightweight and yet very powerful. 11 | * `caption_with_cogvlm_remote.py` - A script used by the volunteer cluster run via the Terminus Research Group 12 | * `caption_with_cogvlm.py` - If you want CogVLM captioning, use this - though there's some potentially erratic results from Cog where it might repeat words. 13 | * `caption_with_gemini.py` - Set `GEMINI_API_KEY` in your environment from one obtained via [Google AI](https://ai.google.dev) and you can caption images for free using Gemini Pro Vision. 14 | * `caption_with_llava.py` - Use Llava 1.5 or 1.6 and run pretty much the same way the CogVLM script does, albeit in a different style. 15 | * `caption_with_internvl.py` - Uses InternVL2 by default to caption images direclty into parquet tables for use by SimpleTuner. 16 | 17 | 18 | #### Datasets 19 | 20 | * `csv_to_s3.py` - given a folder of CSV webdataset as inputs, download/caption/transform images before stuffing them into an S3 bucket. 21 | * `clear_s3_bucket.py` - Just a convenient way to clear an S3 bucket that's been used with this tool. 22 | * `dataset_from_kellyc.py` - If you use the KellyC browser extension for image scraping, this will build a dataset from the URL list it saves. 23 | * `dataset_from_csv.py` - Download a chunk of data to local storage from a single csv dataset document. 24 | * `dataset_from_laion.py` - A variant of the above script. 25 | * `analyze_laion_data.py` - After downloading a lot of LAION's data, you can use this to throw a lot of it away. 26 | * `analyze_aspect_ratios_json.py` - Use the output from `analyze_laion_data.py` to nuke images that do not fit our aspect goals. 27 | * `check_latent_corruption.py` - Scan and remove any images that will not load properly. 28 | * `update_parquet.py` - A scaffold for updating the contents of a parquet file. 29 | * `folder_to_parquet.py` - Import a folder of images into a parquet file. 30 | * `discord_scrape.py` - Scrape the Midjourney server into a local folder and/or parquet files. 31 | * `enhance_with_controlnet.py` - An incomplete script which aims to demonstrate improving a dataset using ControlNet Tile before training. 32 | 33 | #### Inference 34 | 35 | * `inference.py` - Generate validation results from the prompts catalogue (`prompts.py`) using DDIMScheduler. 36 | * `inference_ddpm.py` - Use DDPMScheduler to assemble a checkpoint from a base model configuration and run through validation prompts. 37 | * `inference_karras.py` - Use the Karras sigmas with DPM 2M Karras. Useful for testing what might happen in Automatic1111. 38 | * `tile_shortnames.py` - Tile the outputs from the above scripts into strips. 39 | 40 | * `inference_snr_test.py` - Generate a large number of CFG range images, and catalogue the results for tiling. 41 | * `tile_images.py` - Generate large image tiles to compare CFG results for zero SNR training / inference tuning. -------------------------------------------------------------------------------- /toolkit/captioning/caption_backend_server.php: -------------------------------------------------------------------------------- 1 | PDO::ERRMODE_EXCEPTION, 19 | PDO::ATTR_DEFAULT_FETCH_MODE => PDO::FETCH_ASSOC, 20 | ]); 21 | } catch (PDOException $e) { 22 | die('Could not connect to the database (' . $dsn . '): ' . $e->getMessage()); 23 | } 24 | 25 | // Create the `dataset` table if it does not exist 26 | $pdo->exec("CREATE TABLE IF NOT EXISTS dataset ( 27 | data_id int AUTO_INCREMENT PRIMARY KEY, 28 | URL varchar(255) NOT NULL, 29 | pending tinyint(1) NOT NULL DEFAULT '0', 30 | result longtext, 31 | submitted_at datetime DEFAULT NULL, 32 | attempts int DEFAULT '0', 33 | error text, 34 | client_id varchar(255) DEFAULT NULL, 35 | updated_at datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 36 | job_group varchar(255) DEFAULT NULL 37 | )"); 38 | 39 | $aws_config = json_decode(file_get_contents('/var/www/.aws.json'), true); 40 | $s3_uploader = new S3Uploader( 41 | $aws_config['aws_bucket_name'], 42 | $aws_config['aws_region'], 43 | $aws_config['aws_access_key_id'], 44 | $aws_config['aws_secret_access_key'], 45 | $aws_config['aws_endpoint_url'], 46 | $aws_config['vae_cache_prefix'], 47 | $aws_config['text_cache_prefix'], 48 | $aws_config['image_data_prefix'] 49 | ); 50 | 51 | $backendController = new BackendController($pdo, $s3_uploader); 52 | $result = $backendController->handleRequest(); 53 | 54 | echo json_encode($result); -------------------------------------------------------------------------------- /toolkit/captioning/caption_with_blip.py: -------------------------------------------------------------------------------- 1 | import os, torch, logging, re, random 2 | try: 3 | import pillow_jxl 4 | except ModuleNotFoundError: 5 | pass 6 | from PIL import Image 7 | from clip_interrogator import Config, Interrogator, LabelTable, load_list 8 | from clip_interrogator import clip_interrogator 9 | clip_interrogator.CAPTION_MODELS.update({ 10 | 'unography': 'unography/blip-large-long-cap', # 1.9GB 11 | }) 12 | print(f"Models supported: {clip_interrogator.CAPTION_MODELS}") 13 | 14 | # Directory where the images are located 15 | input_directory_path = "/Volumes/datasets/photo-concept-bucket/image_data" 16 | output_dir = "/Volumes/datasets/photo-concept-bucket/image_data_captioned" 17 | caption_strategy = "text" 18 | 19 | 20 | def content_to_filename(content): 21 | """ 22 | Function to convert content to filename by stripping everything after '--', 23 | replacing non-alphanumeric characters and spaces, converting to lowercase, 24 | removing leading/trailing underscores, and limiting filename length to 128. 25 | """ 26 | # Split on '--' and take the first part 27 | content = content.split("--", 1)[0] 28 | 29 | # Remove URLs 30 | cleaned_content = re.sub(r"https*://\S*", "", content) 31 | 32 | # Replace non-alphanumeric characters and spaces, convert to lowercase, remove leading/trailing underscores 33 | cleaned_content = re.sub(r"[^a-zA-Z0-9 ]", "", cleaned_content) 34 | cleaned_content = cleaned_content.replace(" ", "_").lower().strip("_") 35 | 36 | # If cleaned_content is empty after removing URLs, generate a random filename 37 | if cleaned_content == "": 38 | cleaned_content = f"midjourney_{random.randint(0, 1000000)}" 39 | 40 | # Limit filename length to 128 41 | cleaned_content = ( 42 | cleaned_content[:128] if len(cleaned_content) > 128 else cleaned_content 43 | ) 44 | 45 | return cleaned_content + ".png" 46 | 47 | 48 | def interrogator( 49 | clip_model_name="ViT-H-14/laion2b_s32b_b79k", blip_model="unography" 50 | ): 51 | # Create an Interrogator instance with the latest CLIP model for Stable Diffusion 2.1 52 | conf = Config( 53 | clip_model_name=clip_model_name, clip_offload=True, caption_offload=True, caption_max_length=170, device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 54 | ) 55 | conf.caption_model_name = blip_model 56 | ci = Interrogator(conf) 57 | return ci 58 | 59 | 60 | def load_terms(filename, interrogator_instance): 61 | # Load your list of terms 62 | table = LabelTable(load_list(filename), "terms", interrogator_instance) 63 | logging.debug(f"Loaded {len(table)} terms from {filename}") 64 | return table 65 | 66 | 67 | def process_directory(image_dir="images", terms_file=None, active_interrogator=None): 68 | if active_interrogator is None: 69 | active_interrogator = interrogator() 70 | if terms_file is not None: 71 | table = load_terms(terms_file, active_interrogator) 72 | 73 | for filename in os.listdir(image_dir): 74 | full_filepath = os.path.join(image_dir, filename) 75 | if os.path.isdir(full_filepath): 76 | process_directory(full_filepath, terms_file, active_interrogator) 77 | elif filename.lower().endswith((".jpg", ".png")): 78 | try: 79 | image = Image.open(full_filepath).convert("RGB") 80 | if terms_file: 81 | best_match = table.rank( 82 | active_interrogator.image_to_features(image), top_count=1 83 | )[0] 84 | else: 85 | best_match = active_interrogator.generate_caption(image) 86 | 87 | logging.info(f"Best match for {filename}: {best_match}") 88 | 89 | # Save based on caption strategy 90 | new_filename = ( 91 | content_to_filename(best_match) 92 | if caption_strategy == "filename" 93 | else filename 94 | ) 95 | new_filepath = os.path.join(output_dir, new_filename) 96 | 97 | if caption_strategy == "text": 98 | with open(new_filepath + ".txt", "w") as f: 99 | f.write(best_match) 100 | else: 101 | # Ensure no overwriting 102 | counter = 1 103 | while os.path.exists(new_filepath): 104 | new_filepath = os.path.join( 105 | output_dir, 106 | f"{new_filename.rsplit('.', 1)[0]}_{counter}.{new_filename.rsplit('.', 1)[1]}", 107 | ) 108 | counter += 1 109 | 110 | image.save(new_filepath) 111 | image.close() 112 | 113 | 114 | except Exception as e: 115 | logging.error(f"Error processing {filename}: {str(e)}") 116 | 117 | 118 | if __name__ == "__main__": 119 | logging.basicConfig(level=logging.INFO) 120 | 121 | # Ensure output directory exists 122 | if not os.path.exists(output_dir): 123 | os.makedirs(output_dir) 124 | 125 | process_directory(input_directory_path) 126 | -------------------------------------------------------------------------------- /toolkit/captioning/classes/Authorization.php: -------------------------------------------------------------------------------- 1 | client_id = $_REQUEST['client_id']; 25 | $this->secret = $_REQUEST['secret']; 26 | $this->user_config_path = $user_config_path; 27 | $this->load_user_database(); 28 | if ($test_authorization) $this->authorize(); 29 | } 30 | 31 | /** 32 | * Load the user database from disk. 33 | * 34 | * @return Authorization 35 | */ 36 | private function load_user_database() { 37 | // Load the user database from the file: 38 | try { 39 | $this->users = json_decode(file_get_contents($this->user_config_path), true); 40 | return $this; 41 | } catch (Exception $e) { 42 | error_log($e->getMessage()); 43 | http_response_code(500); 44 | echo 'Internal server error.'; 45 | exit; 46 | } 47 | } 48 | 49 | public function authorize() { 50 | // Check if client_id and secret are valid: 51 | if (!in_array($this->client_id, array_keys($this->users)) || $this->secret !== $this->users[$this->client_id]) { 52 | http_response_code(403); 53 | echo 'Unauthorized.'; 54 | exit; 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /toolkit/captioning/classes/BackendController.php: -------------------------------------------------------------------------------- 1 | pdo = $pdo; 25 | $this->s3_uploader = $s3_uploader; 26 | $this->getParameters(); 27 | } 28 | 29 | public function getParameters() { 30 | // Action handling 31 | $this->action = $_REQUEST['action'] ?? ''; 32 | $this->job_type = $_REQUEST['job_type'] ?? ''; 33 | $this->error = $_REQUEST['error'] ?? ''; 34 | } 35 | 36 | public function handleRequest() { 37 | return $this->{$this->action}(); 38 | } 39 | 40 | public function list_jobs() { 41 | try { 42 | $limit = 500; // Number of rows to fetch and randomize in PHP 43 | $count = $_GET['count'] ?? 1; // Number of rows to actually return 44 | 45 | // Fetch the rows 46 | if ($this->job_type === 'vae') { 47 | $total_jobs = $this->pdo->query('SELECT COUNT(*) FROM dataset')->fetchColumn(); 48 | $remaining_jobs = $this->pdo->query('SELECT COUNT(*) FROM dataset WHERE pending = 0')->fetchColumn(); 49 | $stmt = $this->pdo->prepare('SELECT * FROM dataset WHERE pending = 0 LIMIT ?'); 50 | } elseif ($this->job_type === 'dataset_upload') { 51 | $total_jobs = $this->pdo->query('SELECT COUNT(*) FROM dataset')->fetchColumn(); 52 | $remaining_jobs = $this->pdo->query('SELECT COUNT(*) FROM dataset WHERE upload_pending = 0 AND result IS NULL')->fetchColumn(); 53 | $stmt = $this->pdo->prepare('SELECT * FROM dataset WHERE result IS NULL LIMIT ?'); 54 | } 55 | $stmt->bindValue(1, $limit, PDO::PARAM_INT); 56 | $stmt->execute(); 57 | $jobs = $stmt->fetchAll(); 58 | 59 | // Shuffle the array in PHP 60 | shuffle($jobs); 61 | 62 | // Slice the array to get only the number of rows specified by $count 63 | $jobs = array_slice($jobs, 0, $count); 64 | 65 | // Update the database for the selected jobs 66 | foreach ($jobs as $idx => $job) { 67 | if ($this->job_type === 'vae') { 68 | $updateStmt = $this->pdo->prepare('UPDATE dataset SET pending = 1, submitted_at = NOW(), attempts = attempts + 1 WHERE data_id = ?'); 69 | } elseif ($this->job_type === 'dataset_upload') { 70 | $updateStmt = $this->pdo->prepare('UPDATE dataset SET upload_pending = 1 WHERE data_id = ?'); 71 | } 72 | $updateStmt->execute([$job['data_id']]); 73 | $jobs[$idx]['total_jobs'] = $total_jobs; 74 | $jobs[$idx]['remaining_jobs'] = $remaining_jobs; // Update remaining jobs count 75 | $jobs[$idx]['completed_jobs'] = $total_jobs - $remaining_jobs; 76 | $jobs[$idx]['job_type'] = $this->job_type; 77 | } 78 | 79 | // Return the selected jobs 80 | return $jobs; 81 | } catch (\Throwable $ex) { 82 | echo 'An error occurred: ' . $ex->getMessage(); 83 | } 84 | } 85 | 86 | public function submit_job() { 87 | try { 88 | $dataId = $_REQUEST['job_id'] ?? ''; 89 | $result = $_REQUEST['result'] ?? ''; 90 | $status = $_REQUEST['status'] ?? 'success'; 91 | if (!$result || !$dataId) { 92 | echo 'Job ID and result are required'; 93 | exit; 94 | } 95 | if ($status == 'error' && !$this->error) { 96 | echo "Error message required for status 'error'"; 97 | exit; 98 | } 99 | 100 | if ($status !== 'error') { 101 | $stmt = $this->pdo->prepare('SELECT data_id FROM dataset WHERE data_id = ?'); 102 | $stmt->execute([$dataId]); 103 | $filename = $stmt->fetchColumn(); 104 | if (!$filename) { 105 | echo 'Job ID not found'; 106 | exit; 107 | } 108 | if ($this->job_type === 'vae') { 109 | if (!in_array('result_file', array_keys($_FILES))) { 110 | echo 'Result files are required for VAE tasks.'; 111 | echo 'Provided files: ' . json_encode($_FILES); 112 | exit; 113 | } 114 | $result = $this->s3_uploader->uploadVAECache($_FILES['result_file']['tmp_name'], $filename . '.pt'); 115 | $updateStmt = $this->pdo->prepare('UPDATE dataset SET client_id = ?, error = ? WHERE data_id = ?'); 116 | $updateStmt->execute([$this->client_id, $this->error, $dataId]); 117 | } elseif ($this->job_type === 'dataset_upload') { 118 | if (in_array('image_file', $_FILES)) $result = $this->s3_uploader->uploadImage($_FILES['image_file']['tmp_name'], $filename . '.png'); 119 | $updateStmt = $this->pdo->prepare('UPDATE dataset SET result = ?, upload_pending = 1 WHERE data_id = ?'); 120 | $updateStmt->execute([$result, $dataId]); 121 | } elseif ($this->job_type === 'text') { 122 | $result = $this->s3_uploader->uploadTextCache($_FILES['result_file']['tmp_name'], $filename); 123 | } else { 124 | echo 'Invalid job type: ' . $this->job_type . ' - must be "vae" or "text"'; 125 | exit; 126 | } 127 | } 128 | return ['status' => 'success', 'result' => 'Job submitted successfully']; 129 | } catch (\Throwable $ex) { 130 | echo 'An error occurred for FILES ' . json_encode($_FILES) . ': ' . $ex->getMessage() . ', traceback: ' . $ex->getTraceAsString(); 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /toolkit/captioning/classes/S3Uploader.php: -------------------------------------------------------------------------------- 1 | s3Client = new S3Client([ 22 | 'version' => 'latest', 23 | 'region' => $region, 24 | 'credentials' => [ 25 | 'key' => $key, 26 | 'secret' => $secret, 27 | ], 28 | 'endpoint' => $endpoint, 29 | ]); 30 | $this->bucket = $bucket; 31 | $this->vae_cache_prefix = $vae_cache_prefix; 32 | $this->text_cache_prefix = $text_cache_prefix; 33 | $this->image_data_prefix = $image_data_prefix; 34 | } 35 | 36 | /** 37 | * Upload a VAE cache file to S3, under the embed prefix as a .pt file. 38 | * 39 | * A Client worker will POST this to us, we need to accept and forward to S3. 40 | */ 41 | public function uploadVAECache($file, $key) { 42 | return $this->uploadFile($file, $this->vae_cache_prefix .'/'. $key); 43 | } 44 | 45 | /** 46 | * Upload an image file to S3, under the image data prefix as a .png file. 47 | * 48 | * A Client worker will POST this to us, we need to accept and forward to S3. 49 | */ 50 | public function uploadImage($file, $key) { 51 | return $this->uploadFile($file, $this->image_data_prefix .'/'. $key); 52 | } 53 | 54 | /** 55 | * Upload a text cache file to S3, under the embed prefix as a .pt file. 56 | * 57 | * A Client worker will POST this to us, we need to accept and forward to S3. 58 | */ 59 | public function uploadTextCache($file, $key) { 60 | return $this->uploadFile($file, $this->text_cache_prefix .'/'. $key); 61 | } 62 | 63 | public function uploadFile($file, $key) { 64 | try { 65 | $result = $this->s3Client->putObject([ 66 | 'Bucket' => $this->bucket, 67 | 'Key' => $key, 68 | 'SourceFile' => $file, 69 | ]); 70 | return $result['ObjectURL']; 71 | } catch (AwsException $e) { 72 | // Output error message if fails 73 | error_log($e->getMessage()); 74 | return null; 75 | } 76 | } 77 | 78 | public function uploadContent($content, $key, $contentType = 'text/plain') { 79 | try { 80 | $result = $this->s3Client->putObject([ 81 | 'Bucket' => $this->bucket, 82 | 'Key' => $key, 83 | 'Body' => $content, 84 | 'ContentType' => $contentType, 85 | ]); 86 | return $result['ObjectURL']; 87 | } catch (AwsException $e) { 88 | // Output error message if fails 89 | error_log($e->getMessage()); 90 | return null; 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /toolkit/captioning/composer.json: -------------------------------------------------------------------------------- 1 | { 2 | "require": { 3 | "aws/aws-sdk-php": "^3.300" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /toolkit/datasets/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bghira/SimpleTuner/ba0a415594f524fc6de51205e2da075a47ea37ee/toolkit/datasets/README.md -------------------------------------------------------------------------------- /toolkit/datasets/analyze_aspect_ratios_json.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import threading, logging 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | try: 6 | import pillow_jxl 7 | except ModuleNotFoundError: 8 | pass 9 | from PIL import Image 10 | 11 | # Allowed bucket values: 12 | allowed = [1.0, 1.5, 0.67, 0.75, 1.78] 13 | 14 | # Load from JSON. 15 | with open("aspect_ratios.json", "r") as f: 16 | aspect_ratios = json.load(f) 17 | 18 | new_bucket = {} 19 | for bucket, indices in aspect_ratios.items(): 20 | logging.info(f"{bucket}: {len(indices)}") 21 | if float(bucket) in allowed: 22 | logging.info(f"Bucket {bucket} in {allowed}") 23 | new_bucket[bucket] = aspect_ratios[bucket] 24 | 25 | least_amount = None 26 | for bucket, indices in aspect_ratios.items(): 27 | if least_amount is None or len(indices) < least_amount: 28 | least_amount = len(indices) 29 | 30 | # We don't want to limit square image training. 31 | # buckets_to_skip = [ 1.0 ] 32 | # for bucket, files in aspect_ratios.items(): 33 | # if float(bucket) not in buckets_to_skip and len(files) > least_amount: 34 | # logging.info(f'We have to reduce the number of items in the bucket: {bucket}') 35 | # # 'files' is a list of full file paths. we need to delete them randomly until the value of least_amount is reached. 36 | 37 | 38 | # # Get a random sample of files to delete. 39 | # import random 40 | # random.shuffle(files) 41 | # files_to_delete = files[least_amount:] 42 | # logging.info(f'Files to delete: {len(files_to_delete)}') 43 | # for file in files_to_delete: 44 | # import os 45 | # os.remove(file) 46 | def _resize_for_condition_image(self, input_image: Image.Image, resolution: int): 47 | input_image = input_image.convert("RGB") 48 | W, H = input_image.size 49 | k = float(resolution) / min(H, W) 50 | H *= k 51 | W *= k 52 | H = int(round(H / 64.0)) * 64 53 | W = int(round(W / 64.0)) * 64 54 | img = input_image.resize((W, H), resample=Image.LANCZOS) 55 | return img 56 | 57 | 58 | def process_file(file): 59 | image = Image.open(file).convert("RGB") 60 | width, height = image.size 61 | if width < 900 or height < 900: 62 | logging.info( 63 | f"Image does not meet minimum size requirements: {file}, size {image.size}" 64 | ) 65 | os.remove(file) 66 | else: 67 | logging.info( 68 | f"Image meets minimum size requirements for conditioning: {file}, size {image.size}" 69 | ) 70 | image = _resize_for_condition_image(image, 1024) 71 | image.save(file) 72 | 73 | 74 | def process_bucket(bucket, files): 75 | logging.info(f"Processing bucket {bucket}: {len(files)} files") 76 | with ThreadPoolExecutor(max_workers=32) as executor: 77 | executor.map(process_file, files) 78 | 79 | 80 | if __name__ == "__main__": 81 | # Load aspect ratios from the JSON file 82 | with open("aspect_ratios.json", "r") as f: 83 | aspect_ratios = json.load(f) 84 | 85 | threads = [] 86 | for bucket, files in aspect_ratios.items(): 87 | thread = threading.Thread(target=process_bucket, args=(bucket, files)) 88 | threads.append(thread) 89 | thread.start() 90 | 91 | for thread in threads: 92 | thread.join() 93 | -------------------------------------------------------------------------------- /toolkit/datasets/analyze_laion_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Walk through a LAION dataset and analyze it. 3 | """ 4 | 5 | import os 6 | import json 7 | import concurrent.futures 8 | try: 9 | import pillow_jxl 10 | except ModuleNotFoundError: 11 | pass 12 | from PIL import Image 13 | 14 | 15 | def get_aspect_ratio(image_path): 16 | try: 17 | with Image.open(image_path) as img: 18 | width, height = img.size 19 | return image_path, width / height 20 | except Exception as e: 21 | os.remove(image_path) 22 | return None 23 | 24 | 25 | def analyze_images(directory_path): 26 | aspect_ratios = {} 27 | good_count = 0 28 | bad_count = 0 29 | image_files = [ 30 | os.path.join(directory_path, filename) 31 | for filename in os.listdir(directory_path) 32 | if filename.endswith(".jpg") or filename.endswith(".png") 33 | ] 34 | 35 | with concurrent.futures.ThreadPoolExecutor(max_workers=64) as executor: 36 | futures = { 37 | executor.submit(get_aspect_ratio, image_path): image_path 38 | for image_path in image_files 39 | } 40 | for future in concurrent.futures.as_completed(futures): 41 | image_path = futures[future] 42 | try: 43 | aspect_ratio = future.result() 44 | if aspect_ratio is not None: 45 | image_path, aspect_ratio = aspect_ratio 46 | aspect_ratio = round(aspect_ratio, 2) # round to 2 decimal places 47 | if aspect_ratio not in aspect_ratios: 48 | aspect_ratios[aspect_ratio] = [] 49 | aspect_ratios[aspect_ratio].append(image_path) 50 | good_count += 1 51 | else: 52 | bad_count += 1 53 | except Exception as e: 54 | pass 55 | print(f"Good images: {good_count}, Bad images: {bad_count}") 56 | return aspect_ratios 57 | 58 | 59 | def write_to_json(data, filename): 60 | with open(filename, "w") as outfile: 61 | json.dump(data, outfile) 62 | 63 | 64 | if __name__ == "__main__": 65 | image_directory = "/notebooks/datasets/laion-high-resolution/downloaded_images" 66 | output_file = "aspect_ratios.json" 67 | aspect_ratios = analyze_images(image_directory) 68 | write_to_json(aspect_ratios, output_file) 69 | -------------------------------------------------------------------------------- /toolkit/datasets/check_latent_corruption.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2024-04-05 17:19:44,198 [DEBUG] (LocalDataBackend) Checking if /Volumes/models/training/vae_cache/sdxl/photo-concept-bucket/image_data/1027365.pt exists = True 3 | 2024-04-05 17:19:44,198 [DEBUG] (LocalDataBackend) Checking if /Volumes/models/training/vae_cache/sdxl/photo-concept-bucket/image_data/10064767.pt exists = True 4 | 2024-04-05 17:19:44,223 [DEBUG] (LocalDataBackend) Checking if /Volumes/models/training/vae_cache/sdxl/photo-concept-bucket/image_data/13997787.pt exists = True 5 | 2024-04-05 17:19:44,223 [DEBUG] (LocalDataBackend) Checking if /Volumes/models/training/vae_cache/sdxl/photo-concept-bucket/image_data/13565183.pt exists = True 6 | """ 7 | 8 | latent_file_paths = ["1027365", "10064767", "13997787", "13565183"] 9 | 10 | prefix = "/Volumes/models/training/vae_cache/sdxl/photo-concept-bucket/image_data/" 11 | 12 | # load the latent_file_paths 13 | import torch 14 | 15 | for latent_file_path in latent_file_paths: 16 | print(f"{prefix}{latent_file_path}.pt") 17 | latent = torch.load( 18 | f"{prefix}{latent_file_path}.pt", map_location=torch.device("cpu") 19 | ) 20 | print(f"Shape: {latent.shape}") 21 | print(f"Mean: {latent.mean()}") 22 | print(f"Std: {latent.std()}") 23 | print(f"Min: {latent.min()}") 24 | print(f"Is corrupt: {torch.isnan(latent).any() or torch.isinf(latent).any()}") 25 | -------------------------------------------------------------------------------- /toolkit/datasets/clear_s3_bucket.py: -------------------------------------------------------------------------------- 1 | import boto3, os, logging, argparse, datetime 2 | from botocore.config import Config 3 | 4 | # Set up logging 5 | logging.basicConfig(level=os.getenv("LOGLEVEL", "INFO")) 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def initialize_s3_client(args): 10 | """Initialize the boto3 S3 client using the provided AWS credentials and settings.""" 11 | s3_config = Config(max_pool_connections=100) 12 | 13 | s3_client = boto3.client( 14 | "s3", 15 | endpoint_url=args.aws_endpoint_url, 16 | region_name=args.aws_region_name, 17 | aws_access_key_id=args.aws_access_key_id, 18 | aws_secret_access_key=args.aws_secret_access_key, 19 | config=s3_config, 20 | ) 21 | return s3_client 22 | 23 | 24 | from concurrent.futures import ThreadPoolExecutor 25 | 26 | 27 | def delete_object(s3_client, bucket_name, object_key): 28 | try: 29 | s3_client.delete_object(Bucket=bucket_name, Key=object_key) 30 | logger.info(f"Deleted: {object_key}") 31 | except Exception as e: 32 | logger.error(f"Error deleting {object_key} in bucket {bucket_name}: {e}") 33 | 34 | 35 | def clear_s3_bucket( 36 | s3_client, 37 | bucket_name, 38 | num_workers=10, 39 | search_pattern: str = None, 40 | older_than_date: str = None, 41 | ): 42 | try: 43 | logger.info(f"Clearing out bucket {bucket_name}") 44 | 45 | # Convert the date string to a datetime object 46 | if older_than_date: 47 | target_date = datetime.datetime.strptime(older_than_date, "%Y-%m-%d") 48 | else: 49 | target_date = None 50 | 51 | # Initialize paginator 52 | paginator = s3_client.get_paginator("list_objects_v2") 53 | 54 | # Create a PageIterator from the Paginator 55 | page_iterator = paginator.paginate(Bucket=bucket_name) 56 | 57 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 58 | for page in page_iterator: 59 | if "Contents" not in page: 60 | logger.info(f"No more items in bucket {bucket_name}") 61 | break 62 | 63 | # Filter by the older_than_date if provided 64 | if target_date: 65 | filtered_objects = [ 66 | s3_object 67 | for s3_object in page["Contents"] 68 | if s3_object["LastModified"].replace(tzinfo=None) < target_date 69 | ] 70 | else: 71 | filtered_objects = page["Contents"] 72 | 73 | if search_pattern is not None: 74 | keys_to_delete = [ 75 | s3_object["Key"] 76 | for s3_object in filtered_objects 77 | if search_pattern in s3_object["Key"] 78 | ] 79 | else: 80 | keys_to_delete = [ 81 | s3_object["Key"] for s3_object in filtered_objects 82 | ] 83 | 84 | executor.map( 85 | delete_object, 86 | [s3_client] * len(keys_to_delete), 87 | [bucket_name] * len(keys_to_delete), 88 | keys_to_delete, 89 | ) 90 | 91 | logger.info(f"Cleared out bucket {bucket_name}") 92 | 93 | except Exception as e: 94 | logger.error(f"Error clearing out bucket {bucket_name}: {e}") 95 | 96 | 97 | def parse_args(): 98 | parser = argparse.ArgumentParser(description="Clear out an S3 bucket.") 99 | parser.add_argument( 100 | "--aws_bucket_name", 101 | type=str, 102 | required=True, 103 | help="The AWS bucket name to clear.", 104 | ) 105 | parser.add_argument("--aws_endpoint_url", type=str, help="The AWS server to use.") 106 | parser.add_argument( 107 | "--num_workers", 108 | type=int, 109 | help="Number of workers to use for clearing.", 110 | default=10, 111 | ) 112 | parser.add_argument( 113 | "--search_pattern", 114 | type=str, 115 | help="If provided, files with this in their Content key will be removed only.", 116 | default=None, 117 | ) 118 | parser.add_argument("--aws_region_name", type=str, help="The AWS region to use.") 119 | parser.add_argument("--aws_access_key_id", type=str, help="AWS access key ID.") 120 | parser.add_argument( 121 | "--aws_secret_access_key", type=str, help="AWS secret access key." 122 | ) 123 | parser.add_argument( 124 | "--older_than_date", 125 | type=str, 126 | help="If provided, only files older than this date (format: YYYY-MM-DD) will be cleared.", 127 | default=None, 128 | ) 129 | return parser.parse_args() 130 | 131 | 132 | def main(): 133 | args = parse_args() 134 | s3_client = initialize_s3_client(args) 135 | clear_s3_bucket( 136 | s3_client, 137 | args.aws_bucket_name, 138 | num_workers=args.num_workers, 139 | search_pattern=args.search_pattern, 140 | older_than_date=args.older_than_date, 141 | ) 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /toolkit/datasets/controlnet/create_canny_edge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | try: 4 | import pillow_jxl 5 | except ModuleNotFoundError: 6 | pass 7 | from PIL import Image 8 | import numpy as np 9 | 10 | 11 | def generate_canny_edge_dataset(input_dir, output_dir_original, output_dir_edges): 12 | # Create output directories if they do not exist 13 | if not os.path.exists(output_dir_original): 14 | os.makedirs(output_dir_original) 15 | if not os.path.exists(output_dir_edges): 16 | os.makedirs(output_dir_edges) 17 | 18 | # Process each image in the input directory 19 | for filename in os.listdir(input_dir): 20 | if filename.lower().endswith((".png", ".jpg", ".jpeg", ".jxl")): 21 | image_path = os.path.join(input_dir, filename) 22 | original_image = Image.open(image_path) 23 | original_image.save(os.path.join(output_dir_original, filename)) 24 | 25 | # Read image in grayscale 26 | image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 27 | if image is None: 28 | # If OpenCV fails, try loading the image with Pillow 29 | try: 30 | pil_image = Image.open(image_path) 31 | # Convert Pillow image to a format OpenCV can use 32 | image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) 33 | except Exception as e: 34 | print(f"Failed to load image with Pillow: {e}") 35 | continue 36 | # Apply Canny edge detection 37 | edges = cv2.Canny(image, 100, 200) 38 | 39 | # Save edge image 40 | edge_image_path = os.path.join(output_dir_edges, filename) 41 | cv2.imwrite(edge_image_path, edges) 42 | print(f"Processed {filename}") 43 | 44 | 45 | if __name__ == "__main__": 46 | input_dir = ( 47 | "/Volumes/ml/datasets/animals/antelope" # Update this to your folder path 48 | ) 49 | output_dir_original = "/Volumes/ml/datasets/canny-edge/animals/antelope-data" # Update this to your desired output path for originals 50 | output_dir_edges = "/Volumes/ml/datasets/canny-edge/animals/antelope-conditioning" # Update this to your desired output path for edges 51 | generate_canny_edge_dataset(input_dir, output_dir_original, output_dir_edges) 52 | -------------------------------------------------------------------------------- /toolkit/datasets/dataset_from_kellyc.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import re 3 | import os 4 | import argparse 5 | from urllib.parse import urlparse, parse_qs 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | from tqdm import tqdm 8 | try: 9 | import pillow_jxl 10 | except ModuleNotFoundError: 11 | pass 12 | from PIL import Image 13 | 14 | 15 | def get_image_width(url): 16 | """Extract width from the image URL.""" 17 | parsed_url = urlparse(url) 18 | query_params = parse_qs(parsed_url.query) 19 | return int(query_params.get("w", [0])[0]) 20 | 21 | 22 | def get_photo_id(url): 23 | """Extract photo ID from the image URL.""" 24 | match = re.search(r"/photos/(\d+)", url) 25 | return match.group(1) if match else None 26 | 27 | 28 | conn_timeout = 6 29 | read_timeout = 60 30 | timeouts = (conn_timeout, read_timeout) 31 | 32 | 33 | def download_image(url, output_path, minimum_image_size: int, minimum_pixel_area: int): 34 | """Download an image.""" 35 | response = requests.get(url, timeout=timeouts, stream=True) 36 | 37 | if response.status_code == 200: 38 | filename = os.path.basename(url.split("?")[0]) 39 | file_path = os.path.join(output_path, filename) 40 | # Convert path to PNG: 41 | file_path = file_path.replace(".jpg", ".png") 42 | 43 | with open(file_path, "wb") as f: 44 | for chunk in response.iter_content(1024): 45 | f.write(chunk) 46 | # Check if the file meets the minimum size requirements 47 | image = Image.open(file_path) 48 | width, height = image.size 49 | if minimum_image_size > 0 and ( 50 | width < minimum_image_size or height < minimum_image_size 51 | ): 52 | os.remove(file_path) 53 | return f"Nuked tiny image: {url}" 54 | if minimum_pixel_area > 0 and (width * height < minimum_pixel_area): 55 | os.remove(file_path) 56 | return f"Nuked tiny image: {url}" 57 | 58 | return f"Downloaded: {url}" 59 | return f"Failed to download: {url}" 60 | 61 | 62 | def process_urls(urls, output_path, minimum_image_size: int, minimum_pixel_area: int): 63 | """Process a list of URLs.""" 64 | # Simple URL list 65 | results = [] 66 | for url in urls: 67 | result = download_image( 68 | url, output_path, minimum_image_size, minimum_pixel_area 69 | ) 70 | results.append(result) 71 | return "\n".join(results) 72 | 73 | 74 | def main(args): 75 | os.makedirs(args.output_path, exist_ok=True) 76 | 77 | url_groups = {} 78 | 79 | with open(args.file_path, "r") as file: 80 | for line in file: 81 | urls = line.strip().split() 82 | # Treat as a simple URL list 83 | url_groups[line] = urls 84 | 85 | with ThreadPoolExecutor(max_workers=args.workers) as executor: 86 | futures = [ 87 | executor.submit( 88 | process_urls, 89 | urls, 90 | args.output_path, 91 | args.minimum_image_size, 92 | args.minimum_pixel_area, 93 | ) 94 | for urls in url_groups.values() 95 | ] 96 | for future in tqdm( 97 | as_completed(futures), total=len(futures), desc="Downloading images" 98 | ): 99 | if args.debug: 100 | print(future.result()) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser( 105 | description="Download smallest images from Pexels." 106 | ) 107 | parser.add_argument( 108 | "--file_path", type=str, help="Path to the text file containing image URLs." 109 | ) 110 | parser.add_argument( 111 | "--output_path", 112 | type=str, 113 | help="Path to the directory where images will be saved.", 114 | ) 115 | parser.add_argument( 116 | "--minimum_image_size", 117 | type=int, 118 | default=0, 119 | help="Both sides of the image must be larger than this. ZERO disables this.", 120 | ) 121 | parser.add_argument( 122 | "--minimum_pixel_area", 123 | type=int, 124 | default=0, 125 | help="The total number of pixels in the image must be larger than this. ZERO disables this. Recommended value: 1024*1024", 126 | ) 127 | parser.add_argument( 128 | "--workers", 129 | type=int, 130 | default=64, 131 | help="Number of worker threads. Default is 64.", 132 | ) 133 | parser.add_argument( 134 | "--debug", 135 | action="store_true", 136 | help="Print debug messages.", 137 | ) 138 | args = parser.parse_args() 139 | main(args) 140 | -------------------------------------------------------------------------------- /toolkit/datasets/enhance_with_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | try: 3 | import pillow_jxl 4 | except ModuleNotFoundError: 5 | pass 6 | from PIL import Image 7 | from diffusers import ControlNetModel, DiffusionPipeline 8 | from diffusers.utils import load_image 9 | 10 | 11 | def resize_for_condition_image(input_image: Image.Image, resolution: int): 12 | input_image = input_image.convert("RGB") 13 | W, H = input_image.size 14 | k = float(resolution) / min(H, W) 15 | H *= k 16 | W *= k 17 | H = int(round(H / 64.0)) * 64 18 | W = int(round(W / 64.0)) * 64 19 | img = input_image.resize((W, H), resample=Image.LANCZOS) 20 | return img 21 | 22 | 23 | controlnet = ControlNetModel.from_pretrained( 24 | "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.bfloat16 25 | ) 26 | pipe = DiffusionPipeline.from_pretrained( 27 | "SG161222/Realistic_Vision_V5.0_noVAE", 28 | custom_pipeline="stable_diffusion_controlnet_img2img", 29 | controlnet=controlnet, 30 | torch_dtype=torch.bfloat16, 31 | ).to("cuda" if torch.cuda.is_available() else "cpu") 32 | from diffusers import DDIMScheduler 33 | 34 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 35 | # pipe.unet.set_attention_slice(1) 36 | source_image = load_image( 37 | "/Volumes/models/training/datasets/animals/antelope/0e17715606.jpg" 38 | ) 39 | 40 | condition_image = resize_for_condition_image(source_image, 1024) 41 | image = pipe( 42 | prompt="best quality", 43 | negative_prompt="deformed eyes, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated", 44 | image=condition_image, 45 | controlnet_conditioning_image=condition_image, 46 | width=condition_image.size[0], 47 | height=condition_image.size[1], 48 | strength=1.0, 49 | generator=torch.manual_seed(20), 50 | num_inference_steps=32, 51 | ).images[0] 52 | 53 | image.save("output.png") 54 | -------------------------------------------------------------------------------- /toolkit/datasets/folder_to_parquet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script exists to scan a folder and import all of the image data into equally sized parquet files. 3 | 4 | Fields collected: 5 | 6 | filename 7 | image hash 8 | width 9 | height 10 | luminance 11 | image_data 12 | """ 13 | 14 | import os, argparse 15 | try: 16 | import pillow_jxl 17 | except ModuleNotFoundError: 18 | pass 19 | from PIL import Image 20 | import numpy as np 21 | import pandas as pd 22 | from tqdm import tqdm 23 | 24 | 25 | def get_image_hash(image): 26 | """Calculate the hash of an image.""" 27 | image = image.convert("L").resize((8, 8)) 28 | pixels = list(image.getdata()) 29 | avg = sum(pixels) / len(pixels) 30 | bits = "".join("1" if pixel > avg else "0" for pixel in pixels) 31 | return int(bits, 2) 32 | 33 | 34 | def get_image_luminance(image): 35 | """Calculate the luminance of an image.""" 36 | image = image.convert("L") 37 | pixels = list(image.getdata()) 38 | return sum(pixels) / len(pixels) 39 | 40 | 41 | def get_size(image): 42 | """Get the size of an image.""" 43 | return image.size 44 | 45 | 46 | argparser = argparse.ArgumentParser() 47 | argparser.add_argument("input_folder", help="Folder to scan for images") 48 | argparser.add_argument("output_folder", help="Folder to save parquet files") 49 | 50 | args = argparser.parse_args() 51 | 52 | os.makedirs(args.output_folder, exist_ok=True) 53 | 54 | data = [] 55 | for root, _, files in os.walk(args.input_folder): 56 | for file in tqdm(files, desc="Processing images"): 57 | try: 58 | image = Image.open(os.path.join(root, file)) 59 | except: 60 | continue 61 | 62 | width, height = get_size(image) 63 | luminance = get_image_luminance(image) 64 | image_hash = get_image_hash(image) 65 | # Get the smallest original compressed representation of the image 66 | file_data = open(os.path.join(root, file), "rb").read() 67 | image_data = np.frombuffer(file_data, dtype=np.uint8) 68 | 69 | data.append((file, image_hash, width, height, luminance, image_data)) 70 | 71 | df = pd.DataFrame( 72 | data, columns=["filename", "image_hash", "width", "height", "luminance", "image"] 73 | ) 74 | df.to_parquet(os.path.join(args.output_folder, "images.parquet"), index=False) 75 | 76 | print("Done!") 77 | -------------------------------------------------------------------------------- /toolkit/datasets/masked_loss/generate_dataset_masks_via_huggingface.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from gradio_client import Client, handle_file 5 | 6 | 7 | def main(): 8 | # Set up argument parser 9 | parser = argparse.ArgumentParser( 10 | description="Mask images in a directory using Florence SAM Masking." 11 | ) 12 | parser.add_argument( 13 | "--input_dir", 14 | type=str, 15 | required=True, 16 | help="Path to the input directory containing images.", 17 | ) 18 | parser.add_argument( 19 | "--output_dir", 20 | type=str, 21 | required=True, 22 | help="Path to the output directory to save masked images.", 23 | ) 24 | parser.add_argument( 25 | "--text_input", 26 | type=str, 27 | default="person", 28 | help='Text prompt for masking (default: "person").', 29 | ) 30 | parser.add_argument( 31 | "--model", 32 | type=str, 33 | default="SkalskiP/florence-sam-masking", 34 | help='Model name to use (default: "SkalskiP/florence-sam-masking").', 35 | ) 36 | args = parser.parse_args() 37 | 38 | input_path = args.input_dir 39 | output_path = args.output_dir 40 | text_input = args.text_input 41 | model_name = args.model 42 | 43 | # Create the output directory if it doesn't exist 44 | os.makedirs(output_path, exist_ok=True) 45 | 46 | # Initialize the Gradio client 47 | client = Client(model_name) 48 | 49 | # Get all files in the input directory 50 | files = os.listdir(input_path) 51 | 52 | # Iterate over all files 53 | for file in files: 54 | # Construct the full file path 55 | full_path = os.path.join(input_path, file) 56 | # Check if the file is an image 57 | if os.path.isfile(full_path) and full_path.lower().endswith( 58 | (".jpg", ".jpeg", ".png", ".webp") 59 | ): 60 | # Define the path for the output mask 61 | mask_path = os.path.join(output_path, file) 62 | # Skip if the mask already exists 63 | if os.path.exists(mask_path): 64 | print(f"Mask already exists for {file}, skipping.") 65 | continue 66 | # Predict the mask 67 | try: 68 | mask_filename = client.predict( 69 | image_input=handle_file(full_path), 70 | text_input=text_input, 71 | api_name="/process_image", 72 | ) 73 | # Move the generated mask to the output directory 74 | shutil.move(mask_filename, mask_path) 75 | print(f"Saved mask to {mask_path}") 76 | except Exception as e: 77 | print(f"Failed to process {file}: {e}") 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /toolkit/datasets/masked_loss/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | einops 3 | timm 4 | transformers 5 | samv2 6 | supervision 7 | opencv-python 8 | pytest 9 | -------------------------------------------------------------------------------- /toolkit/datasets/random_recrop_for_json_image_metadata.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script is used to recrop your JSON image metadata dataset so that you can 3 | safely delete the VAE cache and then recreate it with new crops. Just point it 4 | at your multidatabackend.json file and it will take care of the rest, then 5 | delete your VAE cache folder and let ST recache. 6 | ''' 7 | 8 | import sys 9 | import json 10 | import random 11 | import shutil 12 | import os 13 | 14 | 15 | def update_crop_coordinates_from_multidatabackend(multidatabackend_file): 16 | # Read the multidatabackend.json file 17 | with open(multidatabackend_file, 'r') as f: 18 | datasets = json.load(f) 19 | 20 | # Ensure datasets is a list 21 | if not isinstance(datasets, list): 22 | datasets = [datasets] 23 | 24 | for dataset in datasets: 25 | # Skip datasets that are disabled 26 | if dataset.get('disabled', False): 27 | continue 28 | 29 | # Get required fields 30 | instance_data_dir = dataset.get('instance_data_dir') 31 | cache_file_suffix = dataset.get('cache_file_suffix') 32 | 33 | if not instance_data_dir or not cache_file_suffix: 34 | print(f"Skipping dataset {dataset.get('id', 'unknown')} due to missing 'instance_data_dir' or 'cache_file_suffix'") 35 | continue 36 | 37 | # Build the metadata file path 38 | metadata_file = os.path.join(instance_data_dir, f'aspect_ratio_bucket_metadata_{cache_file_suffix}.json') 39 | 40 | # Check if metadata file exists 41 | if not os.path.exists(metadata_file): 42 | print(f"Metadata file {metadata_file} does not exist, skipping") 43 | continue 44 | 45 | # Now process the metadata file 46 | with open(metadata_file, 'r') as f: 47 | data = json.load(f) 48 | 49 | for key in data: 50 | metadata = data[key] 51 | inter_size = metadata.get('intermediary_size') 52 | target_size = metadata.get('target_size') 53 | if inter_size is None or target_size is None: 54 | continue 55 | 56 | # Assuming sizes are in (height, width) format 57 | inter_height, inter_width = inter_size 58 | target_height, target_width = target_size 59 | 60 | max_crop_top = max(inter_height - target_height, 0) 61 | max_crop_left = max(inter_width - target_width, 0) 62 | 63 | crop_top = random.randint(0, max_crop_top) 64 | crop_left = random.randint(0, max_crop_left) 65 | 66 | # Update the crop_coordinates 67 | metadata['crop_coordinates'] = [crop_top, crop_left] 68 | 69 | # Backup the original metadata file 70 | backup_file = metadata_file + '.bak' 71 | shutil.copyfile(metadata_file, backup_file) 72 | 73 | # Write the updated data back to the metadata file 74 | with open(metadata_file, 'w') as f: 75 | json.dump(data, f, indent=2) 76 | 77 | print(f"Updated crop_coordinates in {metadata_file}, backup saved as {backup_file}") 78 | 79 | 80 | if __name__ == "__main__": 81 | if len(sys.argv) != 2: 82 | print("Usage: python update_crop_coordinates.py multidatabackend.json") 83 | sys.exit(1) 84 | multidatabackend_file = sys.argv[1] 85 | update_crop_coordinates_from_multidatabackend(multidatabackend_file) -------------------------------------------------------------------------------- /toolkit/datasets/update_parquet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from PIL import Image 4 | from concurrent.futures import ProcessPoolExecutor 5 | from tqdm import tqdm 6 | 7 | # set 'fork' spawn mode 8 | import multiprocessing as mp 9 | 10 | mp.set_start_method("fork") 11 | 12 | PARQUET_FILE = "photo-concept-bucket.parquet" 13 | IMAGE_DATA = "output_dir" 14 | 15 | # Load the parquet file 16 | df = pd.read_parquet(PARQUET_FILE, engine="pyarrow") 17 | 18 | # Function to process a chunk of IDs 19 | import json 20 | 21 | 22 | def process_images(ids_chunk): 23 | summary = { 24 | "id": [], 25 | "old_width": [], 26 | "new_width": [], 27 | "old_height": [], 28 | "new_height": [], 29 | "old_aspect_ratio": [], 30 | "new_aspect_ratio": [], 31 | } 32 | 33 | for id in ids_chunk: 34 | metadata_path = os.path.join(IMAGE_DATA, f"{id}.json") 35 | if not os.path.exists(metadata_path): 36 | continue 37 | # Use the simpletuner data if the image is not found 38 | try: 39 | with open(metadata_path) as f: 40 | row = json.load(f) 41 | width, height = row["image_size"] 42 | aspect_ratio = row["aspect_ratio"] 43 | except KeyError: 44 | print(f"Image {metadata_path} not found in simpletuner data") 45 | continue 46 | 47 | # Locate the row in the DataFrame 48 | row = df.loc[df["id"] == id] 49 | 50 | # Check for differences 51 | if not row.empty and ( 52 | row.iloc[0]["width"] != width or row.iloc[0]["height"] != height 53 | ): 54 | print( 55 | f"Updated image {id}: {row.iloc[0]['width']}x{row.iloc[0]['height']} -> {width}x{height}" 56 | ) 57 | summary["id"].append(id) 58 | summary["old_width"].append(row.iloc[0]["width"]) 59 | summary["new_width"].append(width) 60 | summary["old_height"].append(row.iloc[0]["height"]) 61 | summary["new_height"].append(height) 62 | summary["old_aspect_ratio"].append(row.iloc[0]["aspect_ratio"]) 63 | summary["new_aspect_ratio"].append(aspect_ratio) 64 | 65 | return summary 66 | 67 | 68 | # Split IDs into chunks for parallel processing 69 | ids = df["id"].values 70 | num_processes = os.cpu_count() 71 | chunk_size = len(ids) // num_processes + (len(ids) % num_processes > 0) 72 | id_chunks = [ids[i : i + chunk_size] for i in range(0, len(ids), chunk_size)] 73 | 74 | # Process the images in parallel 75 | with ProcessPoolExecutor(max_workers=num_processes) as executor: 76 | results = list(tqdm(executor.map(process_images, id_chunks), total=len(id_chunks))) 77 | 78 | # Combine results from all processes 79 | combined_summary = pd.DataFrame() 80 | for result in results: 81 | combined_summary = pd.concat([combined_summary, pd.DataFrame(result)]) 82 | 83 | # Update the DataFrame based on the combined summary 84 | for index, row in combined_summary.iterrows(): 85 | idx = df.index[df["id"] == row["id"]].tolist()[0] 86 | df.at[idx, "width"] = row["new_width"] 87 | df.at[idx, "height"] = row["new_height"] 88 | df.at[idx, "aspect_ratio"] = row["new_aspect_ratio"] 89 | 90 | # Save the updated DataFrame to the parquet file 91 | df.to_parquet(PARQUET_FILE, engine="pyarrow") 92 | -------------------------------------------------------------------------------- /toolkit/inference/inference_ddpm.py: -------------------------------------------------------------------------------- 1 | # Use Pytorch 2! 2 | import torch 3 | from diffusers import ( 4 | StableDiffusionPipeline, 5 | DiffusionPipeline, 6 | AutoencoderKL, 7 | UNet2DConditionModel, 8 | DDPMScheduler, 9 | ) 10 | from transformers import CLIPTextModel 11 | 12 | # Any model currently on Huggingface Hub. 13 | # model_id = 'junglerally/digital-diffusion' 14 | # model_id = 'ptx0/realism-engine' 15 | # model_id = 'ptx0/artius_v21' 16 | # model_id = 'ptx0/pseudo-journey' 17 | model_id = "ptx0/pseudo-journey-v2" 18 | pipeline = DiffusionPipeline.from_pretrained(model_id) 19 | 20 | # Optimize! 21 | pipeline.unet = torch.compile(pipeline.unet) 22 | scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") 23 | 24 | # Remove this if you get an error. 25 | torch.set_float32_matmul_precision("high") 26 | 27 | pipeline.to("cuda") 28 | prompts = { 29 | "woman": "a woman, hanging out on the beach", 30 | "man": "a man playing guitar in a park", 31 | "lion": "Explore the ++majestic beauty++ of untamed ++lion prides++ as they roam the African plains --captivating expressions-- in the wildest national geographic adventure", 32 | "child": "a child flying a kite on a sunny day", 33 | "bear": "best quality ((bear)) in the swiss alps cinematic 8k highly detailed sharp focus intricate fur", 34 | "alien": "an alien exploring the Mars surface", 35 | "robot": "a robot serving coffee in a cafe", 36 | "knight": "a knight protecting a castle", 37 | "menn": "a group of smiling and happy men", 38 | "bicycle": "a bicycle, on a mountainside, on a sunny day", 39 | "cosmic": "cosmic entity, sitting in an impossible position, quantum reality, colours", 40 | "wizard": "a mage wizard, bearded and gray hair, blue star hat with wand and mystical haze", 41 | "wizarddd": "digital art, fantasy, portrait of an old wizard, detailed", 42 | "macro": "a dramatic city-scape at sunset or sunrise", 43 | "micro": "RNA and other molecular machinery of life", 44 | "gecko": "a leopard gecko stalking a cricket", 45 | } 46 | for shortname, prompt in prompts.items(): 47 | # old prompt: '' 48 | image = pipeline( 49 | prompt=prompt, 50 | negative_prompt="malformed, disgusting, overexposed, washed-out", 51 | num_inference_steps=32, 52 | generator=torch.Generator(device="cuda").manual_seed(1641421826), 53 | width=1152, 54 | height=768, 55 | guidance_scale=7.5, 56 | ).images[0] 57 | image.save(f"test/{shortname}_nobetas.png", format="PNG") 58 | -------------------------------------------------------------------------------- /toolkit/inference/inference_karras.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionKDiffusionPipeline, AutoencoderKL 3 | 4 | pipe = StableDiffusionKDiffusionPipeline.from_pretrained( 5 | "/models/pseudo-test", 6 | torch_dtype=torch.float16, 7 | safety_checker=None, 8 | ) 9 | pipe.set_scheduler("sample_dpmpp_2m") 10 | vae = AutoencoderKL.from_pretrained( 11 | "stabilityai/sd-vae-ft-mse", use_safetensors=True, torch_dtype=torch.float16 12 | ) 13 | pipe.vae = vae 14 | pipe.to("cuda") 15 | image = pipe( 16 | prompt="best quality ((bear)) in the swiss alps cinematic 8k highly detailed sharp focus intricate fur", 17 | negative_prompt="malformed, disgusting", 18 | num_inference_steps=50, 19 | generator=torch.Generator(device="cuda").manual_seed(42), 20 | width=1280, 21 | height=720, 22 | guidance_scale=3.0, 23 | use_karras_sigmas=True, 24 | ).images[0] 25 | image.save("bear.png", format="PNG") 26 | -------------------------------------------------------------------------------- /toolkit/inference/inference_sigma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import Transformer2DModel 3 | from sigma import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline 4 | 5 | setattr(Transformer2DModel, "_init_patched_inputs", pixart_sigma_init_patched_inputs) 6 | device = torch.device( 7 | "cuda:0" 8 | if torch.cuda.is_available() 9 | else "mps" if torch.backends.mps.is_available() else "cpu" 10 | ) 11 | 12 | transformer = Transformer2DModel.from_pretrained( 13 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 14 | subfolder="transformer", 15 | use_safetensors=True, 16 | ) 17 | pipe = PixArtSigmaPipeline.from_pretrained( 18 | "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", 19 | transformer=transformer, 20 | use_safetensors=True, 21 | ) 22 | pipe.to(device=device, dtype=torch.bfloat16) 23 | 24 | # Enable memory optimizations. 25 | # pipe.enable_model_cpu_offload() 26 | 27 | prompt = "A small cactus with a happy face in the Sahara desert." 28 | image = pipe(prompt).images[0] 29 | image.save("./catcus.png") 30 | -------------------------------------------------------------------------------- /toolkit/inference/inference_snr_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | try: 3 | import pillow_jxl 4 | except ModuleNotFoundError: 5 | pass 6 | from PIL import Image 7 | from diffusers import ( 8 | StableDiffusionPipeline, 9 | DiffusionPipeline, 10 | AutoencoderKL, 11 | UNet2DConditionModel, 12 | DDPMScheduler, 13 | DDIMScheduler, 14 | ) 15 | from transformers import CLIPTextModel 16 | from helpers.prompts import prompts 17 | 18 | model_id = "/notebooks/datasets/models/pseudo-realism" 19 | # model_id = 'stabilityai/stable-diffusion-2-1' 20 | pipe = StableDiffusionPipeline.from_pretrained(model_id) 21 | pipe.unet = torch.compile(pipe.unet) 22 | 23 | scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") 24 | torch.set_float32_matmul_precision("high") 25 | pipe.to("cuda") 26 | negative_prompt = "cropped, out-of-frame, low quality, low res, oorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, synthetic, rendering" 27 | 28 | 29 | def create_image_grid(images, ncols): 30 | # Assuming all images are the same size, get dimensions of the first image 31 | width, height = images[0].size 32 | 33 | # Create a new image of size that can fit all the small images in a grid 34 | grid_image = Image.new( 35 | "RGB", (width * ncols, height * ((len(images) + ncols - 1) // ncols)) 36 | ) 37 | 38 | # Loop through all images and paste them into the grid image 39 | for index, image in enumerate(images): 40 | row = index // ncols 41 | col = index % ncols 42 | grid_image.paste(image, (col * width, row * height)) 43 | 44 | return grid_image 45 | 46 | 47 | all_images = [] 48 | 49 | for shortname, prompt in prompts.items(): 50 | for TIMESTEP_TYPE in ["trailing", "leading"]: 51 | for RESCALE_BETAS_ZEROS_SNR in [True, False]: 52 | for GUIDANCE_RESCALE in [0, 0.3, 0.5, 0.7]: 53 | for GUIDANCE in [5, 6, 7, 8, 9]: 54 | pipe.scheduler = DDIMScheduler.from_config( 55 | pipe.scheduler.config, 56 | timestep_spacing=TIMESTEP_TYPE, 57 | rescale_betas_zero_snr=RESCALE_BETAS_ZEROS_SNR, 58 | ) 59 | generator = torch.Generator(device="cpu").manual_seed(0) 60 | image = pipe( 61 | prompt=prompt, 62 | width=1152, 63 | height=768, 64 | negative_prompt=negative_prompt, 65 | generator=generator, 66 | num_images_per_prompt=1, 67 | num_inference_steps=50, 68 | guidance_scale=GUIDANCE, 69 | guidance_rescale=GUIDANCE_RESCALE, 70 | ).images[0] 71 | image_path = f"test/{shortname}_{TIMESTEP_TYPE}_{RESCALE_BETAS_ZEROS_SNR}_{GUIDANCE}g_{GUIDANCE_RESCALE}r.png" 72 | image.save(image_path, format="PNG") 73 | all_images.append(Image.open(image_path)) 74 | 75 | # create image grid after all images are generated 76 | image_grid = create_image_grid(all_images, 4) # 4 is the number of columns 77 | image_grid.save("image_comparison_grid.png", format="PNG") 78 | -------------------------------------------------------------------------------- /toolkit/inference/tile_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | try: 3 | import pillow_jxl 4 | except ModuleNotFoundError: 5 | pass 6 | from PIL import Image 7 | 8 | # Define the image size 9 | img_size = (768, 768) 10 | 11 | # Define the directory 12 | directory = os.getcwd() 13 | 14 | # Get a list of all the image files 15 | files = [f for f in os.listdir(directory) if f.endswith(".png")] 16 | 17 | # Extract subjects from files 18 | subjects = list(set([f.split("-")[0] for f in files])) 19 | 20 | # For each subject, sort the files and combine them 21 | for subject in subjects: 22 | # Get all images of the current subject 23 | subject_files = [f for f in files if f.startswith(subject)] 24 | subject_files.sort(key=lambda x: int(x.split("-")[1].split(".")[0])) 25 | 26 | # Create a new blank image to paste the others onto 27 | new_image = Image.new("RGB", (len(subject_files) * img_size[0], img_size[1])) 28 | 29 | # For each image file 30 | for i, file in enumerate(subject_files): 31 | # Open the image file 32 | img = Image.open(file) 33 | 34 | # Paste the image into the new image 35 | new_image.paste(img, (i * img_size[0], 0)) 36 | 37 | # Save the new image 38 | new_image.save(f"{subject}-combined.png") 39 | -------------------------------------------------------------------------------- /toolkit/inference/tile_samplers.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pillow_jxl 3 | except ModuleNotFoundError: 4 | pass 5 | from PIL import Image, ImageDraw, ImageFont 6 | import requests 7 | from io import BytesIO 8 | 9 | # Placeholder URL 10 | url = "https://sa1s3optim.patientpop.com/assets/images/provider/photos/2353184.jpg" 11 | 12 | # Download the image from the URL 13 | response = requests.get(url) 14 | original_image = Image.open(BytesIO(response.content)) 15 | 16 | # Define target size (1 megapixel) 17 | target_width = 1000 18 | target_height = 1000 19 | 20 | # Resize the image using different samplers 21 | samplers = { 22 | "NEAREST": Image.NEAREST, 23 | "BOX": Image.BOX, 24 | "HAMMING": Image.HAMMING, 25 | "BILINEAR": Image.BILINEAR, 26 | "BICUBIC": Image.BICUBIC, 27 | "LANCZOS": Image.LANCZOS, 28 | } 29 | 30 | # Create a new image to combine the results 31 | combined_width = target_width * len(samplers) 32 | combined_height = target_height + 50 # Extra space for labels 33 | combined_image = Image.new("RGB", (combined_width, combined_height), "white") 34 | draw = ImageDraw.Draw(combined_image) 35 | 36 | # Load a default font 37 | try: 38 | font = ImageFont.load_default() 39 | except IOError: 40 | font = None 41 | 42 | # Resize and add each sampler result to the combined image 43 | for i, (label, sampler) in enumerate(samplers.items()): 44 | resized_image = original_image.resize((target_width, target_height), sampler) 45 | combined_image.paste(resized_image, (i * target_width, 50)) 46 | 47 | # Draw the label 48 | text_position = (i * target_width + 20, 15) 49 | draw.text(text_position, label, fill="black", font=font) 50 | 51 | # Save or display the combined image 52 | combined_image_path = "downsampled_image_comparison.png" 53 | combined_image.save(combined_image_path) 54 | combined_image_path 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # Quiet down, you. 4 | ds_logger1 = logging.getLogger("DeepSpeed") 5 | ds_logger2 = logging.getLogger("torch.distributed.elastic.multiprocessing.redirects") 6 | ds_logger1.setLevel("ERROR") 7 | ds_logger2.setLevel("ERROR") 8 | import logging.config 9 | 10 | logging.config.dictConfig( 11 | { 12 | "version": 1, 13 | "disable_existing_loggers": True, 14 | } 15 | ) 16 | from os import environ 17 | 18 | environ["ACCELERATE_LOG_LEVEL"] = "WARNING" 19 | 20 | from helpers.training.trainer import Trainer 21 | from helpers.training.state_tracker import StateTracker 22 | from helpers import log_format 23 | 24 | logger = logging.getLogger("SimpleTuner") 25 | logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) 26 | 27 | if __name__ == "__main__": 28 | trainer = None 29 | try: 30 | import multiprocessing 31 | 32 | multiprocessing.set_start_method("fork") 33 | except Exception as e: 34 | logger.error( 35 | "Failed to set the multiprocessing start method to 'fork'. Unexpected behaviour such as high memory overhead or poor performance may result." 36 | f"\nError: {e}" 37 | ) 38 | try: 39 | trainer = Trainer( 40 | exit_on_error=True, 41 | ) 42 | trainer.configure_webhook() 43 | trainer.init_noise_schedule() 44 | trainer.init_seed() 45 | 46 | trainer.init_huggingface_hub() 47 | 48 | trainer.init_preprocessing_models() 49 | trainer.init_precision(preprocessing_models_only=True) 50 | trainer.init_data_backend() 51 | # trainer.init_validation_prompts() 52 | trainer.init_unload_text_encoder() 53 | trainer.init_unload_vae() 54 | 55 | trainer.init_load_base_model() 56 | trainer.init_controlnet_model() 57 | trainer.init_precision() 58 | trainer.init_freeze_models() 59 | trainer.init_trainable_peft_adapter() 60 | trainer.init_ema_model() 61 | # EMA must be quantised if the base model is as well. 62 | trainer.init_precision(ema_only=True) 63 | 64 | trainer.move_models(destination="accelerator") 65 | trainer.init_validations() 66 | trainer.init_benchmark_base_model() 67 | 68 | trainer.resume_and_prepare() 69 | 70 | trainer.init_trackers() 71 | trainer.train() 72 | except KeyboardInterrupt: 73 | if StateTracker.get_webhook_handler() is not None: 74 | StateTracker.get_webhook_handler().send( 75 | message="Training has been interrupted by user action (lost terminal, or ctrl+C)." 76 | ) 77 | except Exception as e: 78 | import traceback 79 | 80 | if StateTracker.get_webhook_handler() is not None: 81 | StateTracker.get_webhook_handler().send( 82 | message=f"Training has failed. Please check the logs for more information: {e}" 83 | ) 84 | print(e) 85 | print(traceback.format_exc()) 86 | if trainer is not None and trainer.bf is not None: 87 | trainer.bf.stop_fetching() 88 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Pull config from config.env 4 | [ -f "config/config.env" ] && source config/config.env 5 | 6 | # If the user has not provided VENV_PATH, we will assume $(pwd)/.venv 7 | if [ -z "${VENV_PATH}" ]; then 8 | # what if we have VIRTUAL_ENV? use that instead 9 | if [ -n "${VIRTUAL_ENV}" ]; then 10 | export VENV_PATH="${VIRTUAL_ENV}" 11 | elif [ -d "$PWD/.venv" ]; then 12 | export VENV_PATH="$PWD/.venv" 13 | elif [ -d "$PWD/venv" ]; then 14 | export VENV_PATH="$PWD/venv" 15 | fi 16 | fi 17 | 18 | # If a venv hasn't already been activated, activate it now 19 | if [[ -z "${VIRTUAL_ENV}" ]]; then 20 | source "${VENV_PATH}/bin/activate" 21 | fi 22 | 23 | if [ -z "${DISABLE_LD_OVERRIDE}" ]; then 24 | export NVJITLINK_PATH="$(find "${VENV_PATH}" -name nvjitlink -type d)/lib" 25 | # if it's not empty, we will add it to LD_LIBRARY_PATH at the front: 26 | if [ -n "${NVJITLINK_PATH}" ]; then 27 | export LD_LIBRARY_PATH="${NVJITLINK_PATH}:${LD_LIBRARY_PATH}" 28 | fi 29 | fi 30 | 31 | if [ -z "${TQDM_NCOLS}" ]; then 32 | export TQDM_NCOLS=125 33 | fi 34 | if [ -z "${TQDM_LEAVE}" ]; then 35 | export TQDM_LEAVE=false 36 | fi 37 | 38 | export TOKENIZERS_PARALLELISM=false 39 | export PLATFORM 40 | PLATFORM=$(uname -s) 41 | if [[ "$PLATFORM" == "Darwin" ]]; then 42 | export MIXED_PRECISION="no" 43 | fi 44 | 45 | if [ -z "${ACCELERATE_EXTRA_ARGS}" ]; then 46 | ACCELERATE_EXTRA_ARGS="" 47 | fi 48 | 49 | if [ -z "${TRAINING_NUM_PROCESSES}" ]; then 50 | echo "Set custom env vars permanently in config/config.env:" 51 | printf "TRAINING_NUM_PROCESSES not set, defaulting to 1.\n" 52 | TRAINING_NUM_PROCESSES=1 53 | fi 54 | 55 | if [ -z "${TRAINING_NUM_MACHINES}" ]; then 56 | printf "TRAINING_NUM_MACHINES not set, defaulting to 1.\n" 57 | TRAINING_NUM_MACHINES=1 58 | fi 59 | 60 | if [ -z "${MIXED_PRECISION}" ]; then 61 | printf "MIXED_PRECISION not set, defaulting to bf16.\n" 62 | MIXED_PRECISION=bf16 63 | fi 64 | 65 | if [ -z "${TRAINING_DYNAMO_BACKEND}" ]; then 66 | printf "TRAINING_DYNAMO_BACKEND not set, defaulting to no.\n" 67 | TRAINING_DYNAMO_BACKEND="no" 68 | fi 69 | 70 | if [ -z "${ENV}" ]; then 71 | printf "ENV not set, defaulting to default.\n" 72 | export ENV="default" 73 | fi 74 | export ENV_PATH="" 75 | if [[ "$ENV" != "default" ]]; then 76 | export ENV_PATH="${ENV}/" 77 | fi 78 | 79 | if [ -z "${CONFIG_BACKEND}" ]; then 80 | if [ -n "${CONFIG_TYPE}" ]; then 81 | export CONFIG_BACKEND="${CONFIG_TYPE}" 82 | fi 83 | fi 84 | 85 | if [ -z "${CONFIG_BACKEND}" ]; then 86 | export CONFIG_BACKEND="env" 87 | export CONFIG_PATH="config/${ENV_PATH}config" 88 | if [ -f "${CONFIG_PATH}.json" ]; then 89 | export CONFIG_BACKEND="json" 90 | elif [ -f "${CONFIG_PATH}.toml" ]; then 91 | export CONFIG_BACKEND="toml" 92 | elif [ -f "${CONFIG_PATH}.env" ]; then 93 | export CONFIG_BACKEND="env" 94 | fi 95 | echo "Using ${CONFIG_BACKEND} backend: ${CONFIG_PATH}.${CONFIG_BACKEND}" 96 | fi 97 | 98 | # Update dependencies 99 | if [ -z "${DISABLE_UPDATES}" ]; then 100 | echo 'Updating dependencies. Set DISABLE_UPDATES to prevent this.' 101 | if [ -f "pyproject.toml" ] && [ -f "poetry.lock" ]; then 102 | nvidia-smi > /dev/null 2>&1 && poetry install 103 | uname -s | grep -q Darwin && poetry install -C install/apple 104 | rocm-smi > /dev/null 2>&1 && poetry install -C install/rocm 105 | fi 106 | fi 107 | if [[ -z "${ACCELERATE_CONFIG_PATH}" ]]; then 108 | # Look for accelerate config in HF_HOME first, otherwise fallback to $HOME 109 | if [[ -f "${HF_HOME}/accelerate/default_config.yaml" ]]; then 110 | ACCELERATE_CONFIG_PATH="${HF_HOME}/accelerate/default_config.yaml" 111 | else 112 | ACCELERATE_CONFIG_PATH="${HOME}/.cache/huggingface/accelerate/default_config.yaml" 113 | fi 114 | fi 115 | # Run the training script. 116 | if [ -f "${ACCELERATE_CONFIG_PATH}" ]; then 117 | echo "Using Accelerate config file: ${ACCELERATE_CONFIG_PATH}" 118 | accelerate launch --config_file="${ACCELERATE_CONFIG_PATH}" train.py 119 | else 120 | echo "Accelerate config file not found: ${ACCELERATE_CONFIG_PATH}. Using values from config.env." 121 | accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train.py 122 | 123 | fi 124 | 125 | exit 0 126 | --------------------------------------------------------------------------------