├── .dockerignore ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets └── VAE_test1.jpg ├── build_and_push_docker.yaml ├── cog.yaml ├── config ├── examples │ ├── extract.example.yml │ ├── generate.example.yaml │ ├── mod_lora_scale.yaml │ ├── train_lora_flux_24gb.yaml │ └── train_slider.example.yml └── replicate.yml ├── docker └── Dockerfile ├── extensions └── example │ ├── ExampleMergeModels.py │ ├── __init__.py │ └── config │ └── config.example.yaml ├── extensions_built_in ├── advanced_generator │ ├── Img2ImgGenerator.py │ ├── PureLoraGenerator.py │ ├── ReferenceGenerator.py │ ├── __init__.py │ └── config │ │ └── train.example.yaml ├── concept_replacer │ ├── ConceptReplacer.py │ ├── __init__.py │ └── config │ │ └── train.example.yaml ├── dataset_tools │ ├── DatasetTools.py │ ├── SuperTagger.py │ ├── SyncFromCollection.py │ ├── __init__.py │ └── tools │ │ ├── caption.py │ │ ├── dataset_tools_config_modules.py │ │ ├── fuyu_utils.py │ │ ├── image_tools.py │ │ ├── llava_utils.py │ │ └── sync_tools.py ├── image_reference_slider_trainer │ ├── ImageReferenceSliderTrainerProcess.py │ ├── __init__.py │ └── config │ │ └── train.example.yaml ├── sd_trainer │ ├── SDTrainer.py │ ├── __init__.py │ └── config │ │ └── train.example.yaml └── ultimate_slider_trainer │ ├── UltimateSliderTrainerProcess.py │ ├── __init__.py │ └── config │ └── train.example.yaml ├── info.py ├── jobs ├── BaseJob.py ├── ExtensionJob.py ├── ExtractJob.py ├── GenerateJob.py ├── MergeJob.py ├── ModJob.py ├── TrainJob.py ├── __init__.py └── process │ ├── BaseExtensionProcess.py │ ├── BaseExtractProcess.py │ ├── BaseMergeProcess.py │ ├── BaseProcess.py │ ├── BaseSDTrainProcess.py │ ├── BaseTrainProcess.py │ ├── ExtractLoconProcess.py │ ├── ExtractLoraProcess.py │ ├── GenerateProcess.py │ ├── MergeLoconProcess.py │ ├── ModRescaleLoraProcess.py │ ├── TrainESRGANProcess.py │ ├── TrainFineTuneProcess.py │ ├── TrainSDRescaleProcess.py │ ├── TrainSliderProcess.py │ ├── TrainSliderProcessOld.py │ ├── TrainVAEProcess.py │ ├── __init__.py │ └── models │ └── vgg19_critic.py ├── lora-license.md ├── notebooks ├── FLUX_1_dev_LoRA_Training.ipynb └── SliderTraining.ipynb ├── output └── .gitkeep ├── predict.py ├── requirements.txt ├── run.py ├── scripts ├── convert_cog.py ├── convert_lora_to_peft_format.py ├── generate_sampler_step_scales.py ├── make_diffusers_model.py ├── make_lcm_sdxl_model.py ├── patch_te_adapter.py └── repair_dataset_folder.py ├── testing ├── compare_keys.py ├── generate_lora_mapping.py ├── generate_weight_mappings.py ├── merge_in_text_encoder_adapter.py ├── shrink_pixart.py ├── shrink_pixart2.py ├── shrink_pixart_sm.py ├── shrink_pixart_sm2.py ├── shrink_pixart_sm3.py ├── test_bucket_dataloader.py ├── test_model_load_save.py ├── test_vae.py └── test_vae_cycle.py ├── toolkit ├── __init__.py ├── basic.py ├── buckets.py ├── civitai.py ├── clip_vision_adapter.py ├── config.py ├── config_modules.py ├── cuda_malloc.py ├── custom_adapter.py ├── data_loader.py ├── data_transfer_object │ └── data_loader.py ├── dataloader_mixins.py ├── ema.py ├── embedding.py ├── esrgan_utils.py ├── extension.py ├── guidance.py ├── image_utils.py ├── inversion_utils.py ├── ip_adapter.py ├── job.py ├── keymaps │ ├── stable_diffusion_refiner.json │ ├── stable_diffusion_refiner_ldm_base.safetensors │ ├── stable_diffusion_refiner_unmatched.json │ ├── stable_diffusion_sd1.json │ ├── stable_diffusion_sd1_ldm_base.safetensors │ ├── stable_diffusion_sd2.json │ ├── stable_diffusion_sd2_ldm_base.safetensors │ ├── stable_diffusion_sd2_unmatched.json │ ├── stable_diffusion_sdxl.json │ ├── stable_diffusion_sdxl_ldm_base.safetensors │ ├── stable_diffusion_sdxl_unmatched.json │ ├── stable_diffusion_ssd.json │ ├── stable_diffusion_ssd_ldm_base.safetensors │ ├── stable_diffusion_ssd_unmatched.json │ ├── stable_diffusion_vega.json │ └── stable_diffusion_vega_ldm_base.safetensors ├── kohya_model_util.py ├── layers.py ├── llvae.py ├── lora_special.py ├── lorm.py ├── losses.py ├── lycoris_special.py ├── lycoris_utils.py ├── metadata.py ├── models │ ├── DoRA.py │ ├── LoRAFormer.py │ ├── RRDB.py │ ├── auraflow.py │ ├── block.py │ ├── clip_fusion.py │ ├── clip_pre_processor.py │ ├── ilora.py │ ├── ilora2.py │ ├── single_value_adapter.py │ ├── size_agnostic_feature_encoder.py │ ├── te_adapter.py │ ├── te_aug_adapter.py │ ├── vd_adapter.py │ └── zipper_resampler.py ├── network_mixins.py ├── optimizer.py ├── orig_configs │ └── sd_xl_refiner.yaml ├── paths.py ├── photomaker.py ├── photomaker_pipeline.py ├── pipelines.py ├── progress_bar.py ├── prompt_utils.py ├── reference_adapter.py ├── resampler.py ├── sampler.py ├── samplers │ ├── custom_flowmatch_sampler.py │ └── custom_lcm_scheduler.py ├── saving.py ├── scheduler.py ├── sd_device_states_presets.py ├── stable_diffusion_model.py ├── style.py ├── timer.py ├── train_pipelines.py ├── train_tools.py └── util │ ├── adafactor_stochastic_rounding.py │ └── inverse_cfg.py └── train.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | /env.sh 163 | /models 164 | /custom/* 165 | !/custom/.gitkeep 166 | /.tmp 167 | /venv.bkp 168 | /venv.* 169 | /config/* 170 | !/config/examples 171 | !/config/_PUT_YOUR_CONFIGS_HERE).txt 172 | /output/* 173 | !/output/.gitkeep 174 | /extensions/* 175 | !/extensions/example 176 | /temp -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "repositories/sd-scripts"] 2 | path = repositories/sd-scripts 3 | url = https://github.com/kohya-ss/sd-scripts.git 4 | [submodule "repositories/leco"] 5 | path = repositories/leco 6 | url = https://github.com/p1atdev/LECO 7 | [submodule "repositories/batch_annotator"] 8 | path = repositories/batch_annotator 9 | url = https://github.com/ostris/batch-annotator 10 | [submodule "repositories/ipadapter"] 11 | path = repositories/ipadapter 12 | url = https://github.com/tencent-ailab/IP-Adapter.git 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ostris, LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/VAE_test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/assets/VAE_test1.jpg -------------------------------------------------------------------------------- /build_and_push_docker.yaml: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." 4 | # wait 2 seconds 5 | sleep 2 6 | docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:latest -f docker/Dockerfile . 7 | docker tag aitoolkit:latest ostris/aitoolkit:latest 8 | docker push ostris/aitoolkit:latest -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | gpu: true 6 | cuda: "12.1" 7 | python_version: "3.10.14" 8 | python_packages: 9 | - "torch>=2.3" 10 | - "torchvision" 11 | - "safetensors==0.4.4" 12 | - "diffusers==0.30.0" 13 | - "transformers==4.44.0" 14 | - "lycoris-lora==1.8.3" 15 | - "flatten_json==0.1.14" 16 | - "pyyaml==6.0.1" 17 | - "oyaml==1.0" 18 | - "tensorboard==2.17.0" 19 | - "kornia==0.7.3" 20 | - "einops==0.8.0" 21 | - "accelerate==0.33.0" 22 | - "toml==0.10.2" 23 | - "albumentations==1.4.3" 24 | - "opencv-python==4.10.0.84" 25 | - "pillow==10.4.0" 26 | - "pydantic==1.10.17" 27 | - "omegaconf==2.3.0" 28 | - "k-diffusion" 29 | - "open_clip_torch==2.26.1" 30 | - "timm==1.0.8" 31 | - "prodigyopt==1.0" 32 | - "controlnet_aux==0.0.7" 33 | - "bitsandbytes==0.43.3" 34 | - "scikit-image==0.24.0" 35 | - "huggingface-hub==0.24.5" 36 | - "hf_transfer==0.1.8" 37 | - "lpips==0.1.4" 38 | - "optimum-quanto==0.2.4" 39 | - "sentencepiece==0.2.0" 40 | 41 | # predict.py defines how predictions are run on your model 42 | predict: "predict.py:Predictor" 43 | train: "train.py:train" 44 | -------------------------------------------------------------------------------- /config/examples/extract.example.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # this is in yaml format. You can use json if you prefer 3 | # I like both but yaml is easier to read and write 4 | # plus it has comments which is nice for documentation 5 | job: extract # tells the runner what to do 6 | config: 7 | # the name will be used to create a folder in the output folder 8 | # it will also replace any [name] token in the rest of this config 9 | name: name_of_your_model 10 | # can be hugging face model, a .ckpt, or a .safetensors 11 | base_model: "/path/to/base/model.safetensors" 12 | # can be hugging face model, a .ckpt, or a .safetensors 13 | extract_model: "/path/to/model/to/extract/trained.safetensors" 14 | # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model 15 | output_folder: "/path/to/output/folder" 16 | is_v2: false 17 | dtype: fp16 # saved dtype 18 | device: cpu # cpu, cuda:0, etc 19 | 20 | # processes can be chained like this to run multiple in a row 21 | # they must all use same models above, but great for testing different 22 | # sizes and typed of extractions. It is much faster as we already have the models loaded 23 | process: 24 | # process 1 25 | - type: locon # locon or lora (locon is lycoris) 26 | filename: "[name]_64_32.safetensors" # will be put in output folder 27 | dtype: fp16 28 | mode: fixed 29 | linear: 64 30 | conv: 32 31 | 32 | # process 2 33 | - type: locon 34 | output_path: "/absolute/path/for/this/output.safetensors" # can be absolute 35 | mode: ratio 36 | linear: 0.2 37 | conv: 0.2 38 | 39 | # process 3 40 | - type: locon 41 | filename: "[name]_ratio_02.safetensors" 42 | mode: quantile 43 | linear: 0.5 44 | conv: 0.5 45 | 46 | # process 4 47 | - type: lora # traditional lora extraction (lierla) with linear layers only 48 | filename: "[name]_4.safetensors" 49 | mode: fixed # fixed, ratio, quantile supported for lora as well 50 | linear: 4 # lora dim or rank 51 | # no conv for lora 52 | 53 | # process 5 54 | - type: lora 55 | filename: "[name]_q05.safetensors" 56 | mode: quantile 57 | linear: 0.5 58 | 59 | # you can put any information you want here, and it will be saved in the model 60 | # the below is an example. I recommend doing trigger words at a minimum 61 | # in the metadata. The software will include this plus some other information 62 | meta: 63 | name: "[name]" # [name] gets replaced with the name above 64 | description: A short description of your model 65 | trigger_words: 66 | - put 67 | - trigger 68 | - words 69 | - here 70 | version: '0.1' 71 | creator: 72 | name: Your Name 73 | email: your@email.com 74 | website: https://yourwebsite.com 75 | any: All meta data above is arbitrary, it can be whatever you want. 76 | -------------------------------------------------------------------------------- /config/examples/generate.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | job: generate # tells the runner what to do 4 | config: 5 | name: "generate" # this is not really used anywhere currently but required by runner 6 | process: 7 | # process 1 8 | - type: to_folder # process images to a folder 9 | output_folder: "output/gen" 10 | device: cuda:0 # cpu, cuda:0, etc 11 | generate: 12 | # these are your defaults you can override most of them with flags 13 | sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now 14 | width: 1024 15 | height: 1024 16 | neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" 17 | seed: -1 # -1 is random 18 | guidance_scale: 7 19 | sample_steps: 20 20 | ext: ".png" # .png, .jpg, .jpeg, .webp 21 | 22 | # here ate the flags you can use for prompts. Always start with 23 | # your prompt first then add these flags after. You can use as many 24 | # like 25 | # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20 26 | # we will try to support all sd-scripts flags where we can 27 | 28 | # FROM SD-SCRIPTS 29 | # --n Treat everything until the next option as a negative prompt. 30 | # --w Specify the width of the generated image. 31 | # --h Specify the height of the generated image. 32 | # --d Specify the seed for the generated image. 33 | # --l Specify the CFG scale for the generated image. 34 | # --s Specify the number of steps during generation. 35 | 36 | # OURS and some QOL additions 37 | # --p2 Prompt for the second text encoder (SDXL only) 38 | # --n2 Negative prompt for the second text encoder (SDXL only) 39 | # --gr Specify the guidance rescale for the generated image (SDXL only) 40 | # --seed Specify the seed for the generated image same as --d 41 | # --cfg Specify the CFG scale for the generated image same as --l 42 | # --steps Specify the number of steps during generation same as --s 43 | 44 | prompt_file: false # if true a txt file will be created next to images with prompt strings used 45 | # prompts can also be a path to a text file with one prompt per line 46 | # prompts: "/path/to/prompts.txt" 47 | prompts: 48 | - "photo of batman" 49 | - "photo of superman" 50 | - "photo of spiderman" 51 | - "photo of a superhero --n batman superman spiderman" 52 | 53 | model: 54 | # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt 55 | # name_or_path: "runwayml/stable-diffusion-v1-5" 56 | name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors" 57 | is_v2: false # for v2 models 58 | is_v_pred: false # for v-prediction models (most v2 models) 59 | is_xl: false # for SDXL models 60 | dtype: bf16 61 | -------------------------------------------------------------------------------- /config/examples/mod_lora_scale.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: mod 3 | config: 4 | name: name_of_your_model_v1 5 | process: 6 | - type: rescale_lora 7 | # path to your current lora model 8 | input_path: "/path/to/lora/lora.safetensors" 9 | # output path for your new lora model, can be the same as input_path to replace 10 | output_path: "/path/to/lora/output_lora_v1.safetensors" 11 | # replaces meta with the meta below (plus minimum meta fields) 12 | # if false, we will leave the meta alone except for updating hashes (sd-script hashes) 13 | replace_meta: true 14 | # how to adjust, we can scale the up_down weights or the alpha 15 | # up_down is the default and probably the best, they will both net the same outputs 16 | # would only affect rare NaN cases and maybe merging with old merge tools 17 | scale_target: 'up_down' 18 | # precision to save, fp16 is the default and standard 19 | save_dtype: fp16 20 | # current_weight is the ideal weight you use as a multiplier when using the lora 21 | # IE in automatic1111 the 6.0 is the current_weight 22 | # you can do negatives here too if you want to flip the lora 23 | current_weight: 6.0 24 | # target_weight is the ideal weight you use as a multiplier when using the lora 25 | # instead of the one above. IE in automatic1111 instead of using 26 | # we want to use so 1.0 is the target_weight 27 | target_weight: 1.0 28 | 29 | # base model for the lora 30 | # this is just used to add meta so automatic111 knows which model it is for 31 | # assume v1.5 if these are not set 32 | is_xl: false 33 | is_v2: false 34 | meta: 35 | # this is only used if you set replace_meta to true above 36 | name: "[name]" # [name] gets replaced with the name above 37 | description: A short description of your lora 38 | trigger_words: 39 | - put 40 | - trigger 41 | - words 42 | - here 43 | version: '0.1' 44 | creator: 45 | name: Your Name 46 | email: your@email.com 47 | website: https://yourwebsite.com 48 | any: All meta data above is arbitrary, it can be whatever you want. 49 | -------------------------------------------------------------------------------- /config/examples/train_lora_flux_24gb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | # this name will be the folder and filename name 5 | name: "my_first_flux_lora_v1" 6 | process: 7 | - type: 'sd_trainer' 8 | # root folder to save training sessions/samples/weights 9 | training_folder: "output" 10 | # uncomment to see performance stats in the terminal every N steps 11 | # performance_log_every: 1000 12 | device: cuda:0 13 | # if a trigger word is specified, it will be added to captions of training data if it does not already exist 14 | # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word 15 | # trigger_word: "p3r5on" 16 | network: 17 | type: "lora" 18 | linear: 16 19 | linear_alpha: 16 20 | save: 21 | dtype: float16 # precision to save 22 | save_every: 250 # save every this many steps 23 | max_step_saves_to_keep: 4 # how many intermittent saves to keep 24 | datasets: 25 | # datasets are a folder of images. captions need to be txt files with the same name as the image 26 | # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently 27 | # images will automatically be resized and bucketed into the resolution specified 28 | - folder_path: "/path/to/images/folder" 29 | caption_ext: "txt" 30 | caption_dropout_rate: 0.05 # will drop out the caption 5% of time 31 | shuffle_tokens: false # shuffle caption order, split by commas 32 | cache_latents_to_disk: true # leave this true unless you know what you're doing 33 | resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions 34 | train: 35 | batch_size: 1 36 | steps: 4000 # total number of steps to train 500 - 4000 is a good range 37 | gradient_accumulation_steps: 1 38 | train_unet: true 39 | train_text_encoder: false # probably won't work with flux 40 | content_or_style: balanced # content, style, balanced 41 | gradient_checkpointing: true # need the on unless you have a ton of vram 42 | noise_scheduler: "flowmatch" # for training only 43 | optimizer: "adamw8bit" 44 | lr: 4e-4 45 | # uncomment this to skip the pre training sample 46 | # skip_first_sample: true 47 | 48 | # ema will smooth out learning, but could slow it down. Recommended to leave on. 49 | ema_config: 50 | use_ema: true 51 | ema_decay: 0.99 52 | 53 | # will probably need this if gpu supports it for flux, other dtypes may not work correctly 54 | dtype: bf16 55 | model: 56 | # huggingface model name or path 57 | name_or_path: "black-forest-labs/FLUX.1-dev" 58 | is_flux: true 59 | quantize: true # run 8bit mixed precision 60 | # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. 61 | sample: 62 | sampler: "flowmatch" # must match train.noise_scheduler 63 | sample_every: 250 # sample every this many steps 64 | width: 1024 65 | height: 1024 66 | prompts: 67 | # you can add [trigger] to the prompts here and it will be replaced with the trigger word 68 | # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ 69 | - "woman with red hair, playing chess at the park, bomb going off in the background" 70 | - "a woman holding a coffee cup, in a beanie, sitting at a cafe" 71 | - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" 72 | - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" 73 | - "a bear building a log cabin in the snow covered mountains" 74 | - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" 75 | - "hipster man with a beard, building a chair, in a wood shop" 76 | - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" 77 | - "a man holding a sign that says, 'this is a sign'" 78 | - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" 79 | neg: "" # not used on flux 80 | seed: 42 81 | walk_seed: true 82 | guidance_scale: 4 83 | sample_steps: 20 84 | # you can add any additional meta info here. [name] is replaced with config name at top 85 | meta: 86 | name: "[name]" 87 | version: '1.0' 88 | -------------------------------------------------------------------------------- /config/replicate.yml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | name: "flux_train_replicate" 5 | process: 6 | - type: 'sd_trainer' 7 | training_folder: "output" 8 | device: cuda:0 9 | trigger_word: "TOK" 10 | network: 11 | type: "lora" 12 | linear: 16 13 | linear_alpha: 16 14 | save: 15 | dtype: float16 # precision to save 16 | save_every: 1001 # save every this many steps 17 | max_step_saves_to_keep: 1 # how many intermittent saves to keep 18 | datasets: 19 | - folder_path: "input_images" 20 | caption_ext: "filename" 21 | caption_dropout_rate: 0.05 # will drop out the caption 5% of time 22 | shuffle_tokens: false # shuffle caption order, split by commas 23 | cache_latents_to_disk: true # leave this true unless you know what you're doing 24 | resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions 25 | train: 26 | batch_size: 1 27 | steps: 1000 28 | gradient_accumulation_steps: 1 29 | train_unet: true 30 | train_text_encoder: false # probably won't work with flux 31 | content_or_style: balanced # content, style, balanced 32 | gradient_checkpointing: true # need the on unless you have a ton of vram 33 | noise_scheduler: "flowmatch" # for training only 34 | optimizer: "adamw8bit" 35 | lr: 0.0004 36 | ema_config: 37 | use_ema: true 38 | ema_decay: 0.99 39 | dtype: bf16 40 | model: 41 | name_or_path: 'black-forest-labs/FLUX.1-dev' 42 | is_flux: true 43 | quantize: true # run 8bit mixed precision 44 | sample: 45 | sampler: "flowmatch" # must match train.noise_scheduler 46 | sample_every: 250 # sample every this many steps 47 | width: 1024 48 | height: 1024 49 | prompts: 50 | # you can add [trigger] to the prompts here and it will be replaced with the trigger word 51 | - "a sign that says 'I LOVE PROMPTS!' in the style of [trigger]" 52 | neg: "" # not used on flux 53 | seed: 42 54 | walk_seed: true 55 | guidance_scale: 4 56 | sample_steps: 20 57 | meta: 58 | name: "[name]" 59 | version: '1.0' 60 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM runpod/base:0.6.2-cuda12.1.0 2 | LABEL authors="jaret" 3 | 4 | # Install dependencies 5 | RUN apt-get update 6 | 7 | WORKDIR /app 8 | ARG CACHEBUST=1 9 | RUN git clone https://github.com/ostris/ai-toolkit.git && \ 10 | cd ai-toolkit && \ 11 | git submodule update --init --recursive 12 | 13 | WORKDIR /app/ai-toolkit 14 | 15 | RUN ln -s /usr/bin/python3 /usr/bin/python 16 | RUN python -m pip install -r requirements.txt 17 | 18 | RUN apt-get install -y tmux nvtop htop 19 | 20 | WORKDIR / 21 | CMD ["/start.sh"] -------------------------------------------------------------------------------- /extensions/example/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an example extension for custom training. It is great for experimenting with new ideas. 2 | from toolkit.extension import Extension 3 | 4 | 5 | # We make a subclass of Extension 6 | class ExampleMergeExtension(Extension): 7 | # uid must be unique, it is how the extension is identified 8 | uid = "example_merge_extension" 9 | 10 | # name is the name of the extension for printing 11 | name = "Example Merge Extension" 12 | 13 | # This is where your process class is loaded 14 | # keep your imports in here so they don't slow down the rest of the program 15 | @classmethod 16 | def get_process(cls): 17 | # import your process class here so it is only loaded when needed and return it 18 | from .ExampleMergeModels import ExampleMergeModels 19 | return ExampleMergeModels 20 | 21 | 22 | AI_TOOLKIT_EXTENSIONS = [ 23 | # you can put a list of extensions here 24 | ExampleMergeExtension 25 | ] 26 | -------------------------------------------------------------------------------- /extensions/example/config/config.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Always include at least one example config file to show how to use your extension. 3 | # use plenty of comments so users know how to use it and what everything does 4 | 5 | # all extensions will use this job name 6 | job: extension 7 | config: 8 | name: 'my_awesome_merge' 9 | process: 10 | # Put your example processes here. This will be passed 11 | # to your extension process in the config argument. 12 | # the type MUST match your extension uid 13 | - type: "example_merge_extension" 14 | # save path for the merged model 15 | save_path: "output/merge/[name].safetensors" 16 | # save type 17 | dtype: fp16 18 | # device to run it on 19 | device: cuda:0 20 | # input models can only be SD1.x and SD2.x models for this example (currently) 21 | models_to_merge: 22 | # weights are relative, total weights will be normalized 23 | # for example. If you have 2 models with weight 1.0, they will 24 | # both be weighted 0.5. If you have 1 model with weight 1.0 and 25 | # another with weight 2.0, the first will be weighted 1/3 and the 26 | # second will be weighted 2/3 27 | - name_or_path: "input/model1.safetensors" 28 | weight: 1.0 29 | - name_or_path: "input/model2.safetensors" 30 | weight: 1.0 31 | - name_or_path: "input/model3.safetensors" 32 | weight: 0.3 33 | - name_or_path: "input/model4.safetensors" 34 | weight: 1.0 35 | 36 | 37 | # you can put any information you want here, and it will be saved in the model 38 | # the below is an example. I recommend doing trigger words at a minimum 39 | # in the metadata. The software will include this plus some other information 40 | meta: 41 | name: "[name]" # [name] gets replaced with the name above 42 | description: A short description of your model 43 | version: '0.1' 44 | creator: 45 | name: Your Name 46 | email: your@email.com 47 | website: https://yourwebsite.com 48 | any: All meta data above is arbitrary, it can be whatever you want. -------------------------------------------------------------------------------- /extensions_built_in/advanced_generator/PureLoraGenerator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig 5 | from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm 6 | from toolkit.sd_device_states_presets import get_train_sd_device_state_preset 7 | from toolkit.stable_diffusion_model import StableDiffusion 8 | import gc 9 | import torch 10 | from jobs.process import BaseExtensionProcess 11 | from toolkit.train_tools import get_torch_dtype 12 | 13 | 14 | def flush(): 15 | torch.cuda.empty_cache() 16 | gc.collect() 17 | 18 | 19 | class PureLoraGenerator(BaseExtensionProcess): 20 | 21 | def __init__(self, process_id: int, job, config: OrderedDict): 22 | super().__init__(process_id, job, config) 23 | self.output_folder = self.get_conf('output_folder', required=True) 24 | self.device = self.get_conf('device', 'cuda') 25 | self.device_torch = torch.device(self.device) 26 | self.model_config = ModelConfig(**self.get_conf('model', required=True)) 27 | self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) 28 | self.dtype = self.get_conf('dtype', 'float16') 29 | self.torch_dtype = get_torch_dtype(self.dtype) 30 | lorm_config = self.get_conf('lorm', None) 31 | self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None 32 | 33 | self.device_state_preset = get_train_sd_device_state_preset( 34 | device=torch.device(self.device), 35 | ) 36 | 37 | self.progress_bar = None 38 | self.sd = StableDiffusion( 39 | device=self.device, 40 | model_config=self.model_config, 41 | dtype=self.dtype, 42 | ) 43 | 44 | def run(self): 45 | super().run() 46 | print("Loading model...") 47 | with torch.no_grad(): 48 | self.sd.load_model() 49 | self.sd.unet.eval() 50 | self.sd.unet.to(self.device_torch) 51 | if isinstance(self.sd.text_encoder, list): 52 | for te in self.sd.text_encoder: 53 | te.eval() 54 | te.to(self.device_torch) 55 | else: 56 | self.sd.text_encoder.eval() 57 | self.sd.to(self.device_torch) 58 | 59 | print(f"Converting to LoRM UNet") 60 | # replace the unet with LoRMUnet 61 | convert_diffusers_unet_to_lorm( 62 | self.sd.unet, 63 | config=self.lorm_config, 64 | ) 65 | 66 | sample_folder = os.path.join(self.output_folder) 67 | gen_img_config_list = [] 68 | 69 | sample_config = self.generate_config 70 | start_seed = sample_config.seed 71 | current_seed = start_seed 72 | for i in range(len(sample_config.prompts)): 73 | if sample_config.walk_seed: 74 | current_seed = start_seed + i 75 | 76 | filename = f"[time]_[count].{self.generate_config.ext}" 77 | output_path = os.path.join(sample_folder, filename) 78 | prompt = sample_config.prompts[i] 79 | extra_args = {} 80 | gen_img_config_list.append(GenerateImageConfig( 81 | prompt=prompt, # it will autoparse the prompt 82 | width=sample_config.width, 83 | height=sample_config.height, 84 | negative_prompt=sample_config.neg, 85 | seed=current_seed, 86 | guidance_scale=sample_config.guidance_scale, 87 | guidance_rescale=sample_config.guidance_rescale, 88 | num_inference_steps=sample_config.sample_steps, 89 | network_multiplier=sample_config.network_multiplier, 90 | output_path=output_path, 91 | output_ext=sample_config.ext, 92 | adapter_conditioning_scale=sample_config.adapter_conditioning_scale, 93 | **extra_args 94 | )) 95 | 96 | # send to be generated 97 | self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) 98 | print("Done generating images") 99 | # cleanup 100 | del self.sd 101 | gc.collect() 102 | torch.cuda.empty_cache() 103 | -------------------------------------------------------------------------------- /extensions_built_in/advanced_generator/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an example extension for custom training. It is great for experimenting with new ideas. 2 | from toolkit.extension import Extension 3 | 4 | 5 | # This is for generic training (LoRA, Dreambooth, FineTuning) 6 | class AdvancedReferenceGeneratorExtension(Extension): 7 | # uid must be unique, it is how the extension is identified 8 | uid = "reference_generator" 9 | 10 | # name is the name of the extension for printing 11 | name = "Reference Generator" 12 | 13 | # This is where your process class is loaded 14 | # keep your imports in here so they don't slow down the rest of the program 15 | @classmethod 16 | def get_process(cls): 17 | # import your process class here so it is only loaded when needed and return it 18 | from .ReferenceGenerator import ReferenceGenerator 19 | return ReferenceGenerator 20 | 21 | 22 | # This is for generic training (LoRA, Dreambooth, FineTuning) 23 | class PureLoraGenerator(Extension): 24 | # uid must be unique, it is how the extension is identified 25 | uid = "pure_lora_generator" 26 | 27 | # name is the name of the extension for printing 28 | name = "Pure LoRA Generator" 29 | 30 | # This is where your process class is loaded 31 | # keep your imports in here so they don't slow down the rest of the program 32 | @classmethod 33 | def get_process(cls): 34 | # import your process class here so it is only loaded when needed and return it 35 | from .PureLoraGenerator import PureLoraGenerator 36 | return PureLoraGenerator 37 | 38 | 39 | # This is for generic training (LoRA, Dreambooth, FineTuning) 40 | class Img2ImgGeneratorExtension(Extension): 41 | # uid must be unique, it is how the extension is identified 42 | uid = "batch_img2img" 43 | 44 | # name is the name of the extension for printing 45 | name = "Img2ImgGeneratorExtension" 46 | 47 | # This is where your process class is loaded 48 | # keep your imports in here so they don't slow down the rest of the program 49 | @classmethod 50 | def get_process(cls): 51 | # import your process class here so it is only loaded when needed and return it 52 | from .Img2ImgGenerator import Img2ImgGenerator 53 | return Img2ImgGenerator 54 | 55 | 56 | AI_TOOLKIT_EXTENSIONS = [ 57 | # you can put a list of extensions here 58 | AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension 59 | ] 60 | -------------------------------------------------------------------------------- /extensions_built_in/advanced_generator/config/train.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | name: test_v1 5 | process: 6 | - type: 'textual_inversion_trainer' 7 | training_folder: "out/TI" 8 | device: cuda:0 9 | # for tensorboard logging 10 | log_dir: "out/.tensorboard" 11 | embedding: 12 | trigger: "your_trigger_here" 13 | tokens: 12 14 | init_words: "man with short brown hair" 15 | save_format: "safetensors" # 'safetensors' or 'pt' 16 | save: 17 | dtype: float16 # precision to save 18 | save_every: 100 # save every this many steps 19 | max_step_saves_to_keep: 5 # only affects step counts 20 | datasets: 21 | - folder_path: "/path/to/dataset" 22 | caption_ext: "txt" 23 | default_caption: "[trigger]" 24 | buckets: true 25 | resolution: 512 26 | train: 27 | noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" 28 | steps: 3000 29 | weight_jitter: 0.0 30 | lr: 5e-5 31 | train_unet: false 32 | gradient_checkpointing: true 33 | train_text_encoder: false 34 | optimizer: "adamw" 35 | # optimizer: "prodigy" 36 | optimizer_params: 37 | weight_decay: 1e-2 38 | lr_scheduler: "constant" 39 | max_denoising_steps: 1000 40 | batch_size: 4 41 | dtype: bf16 42 | xformers: true 43 | min_snr_gamma: 5.0 44 | # skip_first_sample: true 45 | noise_offset: 0.0 # not needed for this 46 | model: 47 | # objective reality v2 48 | name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" 49 | is_v2: false # for v2 models 50 | is_xl: false # for SDXL models 51 | is_v_pred: false # for v-prediction models (most v2 models) 52 | sample: 53 | sampler: "ddpm" # must match train.noise_scheduler 54 | sample_every: 100 # sample every this many steps 55 | width: 512 56 | height: 512 57 | prompts: 58 | - "photo of [trigger] laughing" 59 | - "photo of [trigger] smiling" 60 | - "[trigger] close up" 61 | - "dark scene [trigger] frozen" 62 | - "[trigger] nighttime" 63 | - "a painting of [trigger]" 64 | - "a drawing of [trigger]" 65 | - "a cartoon of [trigger]" 66 | - "[trigger] pixar style" 67 | - "[trigger] costume" 68 | neg: "" 69 | seed: 42 70 | walk_seed: false 71 | guidance_scale: 7 72 | sample_steps: 20 73 | network_multiplier: 1.0 74 | 75 | logging: 76 | log_every: 10 # log every this many steps 77 | use_wandb: false # not supported yet 78 | verbose: false 79 | 80 | # You can put any information you want here, and it will be saved in the model. 81 | # The below is an example, but you can put your grocery list in it if you want. 82 | # It is saved in the model so be aware of that. The software will include this 83 | # plus some other information for you automatically 84 | meta: 85 | # [name] gets replaced with the name above 86 | name: "[name]" 87 | # version: '1.0' 88 | # creator: 89 | # name: Your Name 90 | # email: your@gmail.com 91 | # website: https://your.website 92 | -------------------------------------------------------------------------------- /extensions_built_in/concept_replacer/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an example extension for custom training. It is great for experimenting with new ideas. 2 | from toolkit.extension import Extension 3 | 4 | 5 | # This is for generic training (LoRA, Dreambooth, FineTuning) 6 | class ConceptReplacerExtension(Extension): 7 | # uid must be unique, it is how the extension is identified 8 | uid = "concept_replacer" 9 | 10 | # name is the name of the extension for printing 11 | name = "Concept Replacer" 12 | 13 | # This is where your process class is loaded 14 | # keep your imports in here so they don't slow down the rest of the program 15 | @classmethod 16 | def get_process(cls): 17 | # import your process class here so it is only loaded when needed and return it 18 | from .ConceptReplacer import ConceptReplacer 19 | return ConceptReplacer 20 | 21 | 22 | 23 | AI_TOOLKIT_EXTENSIONS = [ 24 | # you can put a list of extensions here 25 | ConceptReplacerExtension, 26 | ] 27 | -------------------------------------------------------------------------------- /extensions_built_in/concept_replacer/config/train.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | name: test_v1 5 | process: 6 | - type: 'textual_inversion_trainer' 7 | training_folder: "out/TI" 8 | device: cuda:0 9 | # for tensorboard logging 10 | log_dir: "out/.tensorboard" 11 | embedding: 12 | trigger: "your_trigger_here" 13 | tokens: 12 14 | init_words: "man with short brown hair" 15 | save_format: "safetensors" # 'safetensors' or 'pt' 16 | save: 17 | dtype: float16 # precision to save 18 | save_every: 100 # save every this many steps 19 | max_step_saves_to_keep: 5 # only affects step counts 20 | datasets: 21 | - folder_path: "/path/to/dataset" 22 | caption_ext: "txt" 23 | default_caption: "[trigger]" 24 | buckets: true 25 | resolution: 512 26 | train: 27 | noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" 28 | steps: 3000 29 | weight_jitter: 0.0 30 | lr: 5e-5 31 | train_unet: false 32 | gradient_checkpointing: true 33 | train_text_encoder: false 34 | optimizer: "adamw" 35 | # optimizer: "prodigy" 36 | optimizer_params: 37 | weight_decay: 1e-2 38 | lr_scheduler: "constant" 39 | max_denoising_steps: 1000 40 | batch_size: 4 41 | dtype: bf16 42 | xformers: true 43 | min_snr_gamma: 5.0 44 | # skip_first_sample: true 45 | noise_offset: 0.0 # not needed for this 46 | model: 47 | # objective reality v2 48 | name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" 49 | is_v2: false # for v2 models 50 | is_xl: false # for SDXL models 51 | is_v_pred: false # for v-prediction models (most v2 models) 52 | sample: 53 | sampler: "ddpm" # must match train.noise_scheduler 54 | sample_every: 100 # sample every this many steps 55 | width: 512 56 | height: 512 57 | prompts: 58 | - "photo of [trigger] laughing" 59 | - "photo of [trigger] smiling" 60 | - "[trigger] close up" 61 | - "dark scene [trigger] frozen" 62 | - "[trigger] nighttime" 63 | - "a painting of [trigger]" 64 | - "a drawing of [trigger]" 65 | - "a cartoon of [trigger]" 66 | - "[trigger] pixar style" 67 | - "[trigger] costume" 68 | neg: "" 69 | seed: 42 70 | walk_seed: false 71 | guidance_scale: 7 72 | sample_steps: 20 73 | network_multiplier: 1.0 74 | 75 | logging: 76 | log_every: 10 # log every this many steps 77 | use_wandb: false # not supported yet 78 | verbose: false 79 | 80 | # You can put any information you want here, and it will be saved in the model. 81 | # The below is an example, but you can put your grocery list in it if you want. 82 | # It is saved in the model so be aware of that. The software will include this 83 | # plus some other information for you automatically 84 | meta: 85 | # [name] gets replaced with the name above 86 | name: "[name]" 87 | # version: '1.0' 88 | # creator: 89 | # name: Your Name 90 | # email: your@gmail.com 91 | # website: https://your.website 92 | -------------------------------------------------------------------------------- /extensions_built_in/dataset_tools/DatasetTools.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import gc 3 | import torch 4 | from jobs.process import BaseExtensionProcess 5 | 6 | 7 | def flush(): 8 | torch.cuda.empty_cache() 9 | gc.collect() 10 | 11 | 12 | class DatasetTools(BaseExtensionProcess): 13 | 14 | def __init__(self, process_id: int, job, config: OrderedDict): 15 | super().__init__(process_id, job, config) 16 | 17 | def run(self): 18 | super().run() 19 | 20 | raise NotImplementedError("This extension is not yet implemented") 21 | -------------------------------------------------------------------------------- /extensions_built_in/dataset_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from toolkit.extension import Extension 2 | 3 | 4 | class DatasetToolsExtension(Extension): 5 | uid = "dataset_tools" 6 | 7 | # name is the name of the extension for printing 8 | name = "Dataset Tools" 9 | 10 | # This is where your process class is loaded 11 | # keep your imports in here so they don't slow down the rest of the program 12 | @classmethod 13 | def get_process(cls): 14 | # import your process class here so it is only loaded when needed and return it 15 | from .DatasetTools import DatasetTools 16 | return DatasetTools 17 | 18 | 19 | class SyncFromCollectionExtension(Extension): 20 | uid = "sync_from_collection" 21 | name = "Sync from Collection" 22 | 23 | @classmethod 24 | def get_process(cls): 25 | # import your process class here so it is only loaded when needed and return it 26 | from .SyncFromCollection import SyncFromCollection 27 | return SyncFromCollection 28 | 29 | 30 | class SuperTaggerExtension(Extension): 31 | uid = "super_tagger" 32 | name = "Super Tagger" 33 | 34 | @classmethod 35 | def get_process(cls): 36 | # import your process class here so it is only loaded when needed and return it 37 | from .SuperTagger import SuperTagger 38 | return SuperTagger 39 | 40 | 41 | AI_TOOLKIT_EXTENSIONS = [ 42 | SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension 43 | ] 44 | -------------------------------------------------------------------------------- /extensions_built_in/dataset_tools/tools/caption.py: -------------------------------------------------------------------------------- 1 | 2 | caption_manipulation_steps = ['caption', 'caption_short'] 3 | 4 | default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.' 5 | default_short_prompt = 'caption this image in less than ten words' 6 | 7 | default_replacements = [ 8 | ("the image features", ""), 9 | ("the image shows", ""), 10 | ("the image depicts", ""), 11 | ("the image is", ""), 12 | ("in this image", ""), 13 | ("in the image", ""), 14 | ] 15 | 16 | 17 | def clean_caption(cap, replacements=None): 18 | if replacements is None: 19 | replacements = default_replacements 20 | 21 | # remove any newlines 22 | cap = cap.replace("\n", ", ") 23 | cap = cap.replace("\r", ", ") 24 | cap = cap.replace(".", ",") 25 | cap = cap.replace("\"", "") 26 | 27 | # remove unicode characters 28 | cap = cap.encode('ascii', 'ignore').decode('ascii') 29 | 30 | # make lowercase 31 | cap = cap.lower() 32 | # remove any extra spaces 33 | cap = " ".join(cap.split()) 34 | 35 | for replacement in replacements: 36 | if replacement[0].startswith('*'): 37 | # we are removing all text if it starts with this and the rest matches 38 | search_text = replacement[0][1:] 39 | if cap.startswith(search_text): 40 | cap = "" 41 | else: 42 | cap = cap.replace(replacement[0].lower(), replacement[1].lower()) 43 | 44 | cap_list = cap.split(",") 45 | # trim whitespace 46 | cap_list = [c.strip() for c in cap_list] 47 | # remove empty strings 48 | cap_list = [c for c in cap_list if c != ""] 49 | # remove duplicates 50 | cap_list = list(dict.fromkeys(cap_list)) 51 | # join back together 52 | cap = ", ".join(cap_list) 53 | return cap -------------------------------------------------------------------------------- /extensions_built_in/dataset_tools/tools/fuyu_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer 2 | 3 | from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption 4 | import torch 5 | from PIL import Image 6 | 7 | 8 | class FuyuImageProcessor: 9 | def __init__(self, device='cuda'): 10 | from transformers import FuyuProcessor, FuyuForCausalLM 11 | self.device = device 12 | self.model: FuyuForCausalLM = None 13 | self.processor: FuyuProcessor = None 14 | self.dtype = torch.bfloat16 15 | self.tokenizer: AutoTokenizer 16 | self.is_loaded = False 17 | 18 | def load_model(self): 19 | from transformers import FuyuProcessor, FuyuForCausalLM 20 | model_path = "adept/fuyu-8b" 21 | kwargs = {"device_map": self.device} 22 | kwargs['load_in_4bit'] = True 23 | kwargs['quantization_config'] = BitsAndBytesConfig( 24 | load_in_4bit=True, 25 | bnb_4bit_compute_dtype=self.dtype, 26 | bnb_4bit_use_double_quant=True, 27 | bnb_4bit_quant_type='nf4' 28 | ) 29 | self.processor = FuyuProcessor.from_pretrained(model_path) 30 | self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 31 | self.is_loaded = True 32 | 33 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 34 | self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs) 35 | self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer) 36 | 37 | def generate_caption( 38 | self, image: Image, 39 | prompt: str = default_long_prompt, 40 | replacements=default_replacements, 41 | max_new_tokens=512 42 | ): 43 | # prepare inputs for the model 44 | # text_prompt = f"{prompt}\n" 45 | 46 | # image = image.convert('RGB') 47 | model_inputs = self.processor(text=prompt, images=[image]) 48 | model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in 49 | model_inputs.items()} 50 | 51 | generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens) 52 | prompt_len = model_inputs["input_ids"].shape[-1] 53 | output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True) 54 | output = clean_caption(output, replacements=replacements) 55 | return output 56 | 57 | # inputs = self.processor(text=text_prompt, images=image, return_tensors="pt") 58 | # for k, v in inputs.items(): 59 | # inputs[k] = v.to(self.device) 60 | 61 | # # autoregressively generate text 62 | # generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens) 63 | # generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True) 64 | # output = generation_text[0] 65 | # 66 | # return clean_caption(output, replacements=replacements) 67 | -------------------------------------------------------------------------------- /extensions_built_in/dataset_tools/tools/image_tools.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Type, TYPE_CHECKING, Union 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image, ImageOps 6 | 7 | Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch'] 8 | 9 | img_manipulation_steps = ['contrast_stretch'] 10 | 11 | img_ext = ['.jpg', '.jpeg', '.png', '.webp'] 12 | 13 | if TYPE_CHECKING: 14 | from .llava_utils import LLaVAImageProcessor 15 | from .fuyu_utils import FuyuImageProcessor 16 | 17 | ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor'] 18 | 19 | 20 | def pil_to_cv2(image): 21 | """Convert a PIL image to a cv2 image.""" 22 | return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 23 | 24 | 25 | def cv2_to_pil(image): 26 | """Convert a cv2 image to a PIL image.""" 27 | return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 28 | 29 | 30 | def load_image(img_path: str): 31 | image = Image.open(img_path).convert('RGB') 32 | try: 33 | # transpose with exif data 34 | image = ImageOps.exif_transpose(image) 35 | except Exception as e: 36 | pass 37 | return image 38 | 39 | 40 | def resize_to_max(image, max_width=1024, max_height=1024): 41 | width, height = image.size 42 | if width <= max_width and height <= max_height: 43 | return image 44 | 45 | scale = min(max_width / width, max_height / height) 46 | width = int(width * scale) 47 | height = int(height * scale) 48 | 49 | return image.resize((width, height), Image.LANCZOS) 50 | -------------------------------------------------------------------------------- /extensions_built_in/dataset_tools/tools/llava_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption 3 | 4 | import torch 5 | from PIL import Image, ImageOps 6 | 7 | from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor 8 | 9 | img_ext = ['.jpg', '.jpeg', '.png', '.webp'] 10 | 11 | 12 | class LLaVAImageProcessor: 13 | def __init__(self, device='cuda'): 14 | try: 15 | from llava.model import LlavaLlamaForCausalLM 16 | except ImportError: 17 | # print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") 18 | print( 19 | "You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") 20 | raise 21 | self.device = device 22 | self.model: LlavaLlamaForCausalLM = None 23 | self.tokenizer: AutoTokenizer = None 24 | self.image_processor: CLIPImageProcessor = None 25 | self.is_loaded = False 26 | 27 | def load_model(self): 28 | from llava.model import LlavaLlamaForCausalLM 29 | 30 | model_path = "4bit/llava-v1.5-13b-3GB" 31 | # kwargs = {"device_map": "auto"} 32 | kwargs = {"device_map": self.device} 33 | kwargs['load_in_4bit'] = True 34 | kwargs['quantization_config'] = BitsAndBytesConfig( 35 | load_in_4bit=True, 36 | bnb_4bit_compute_dtype=torch.float16, 37 | bnb_4bit_use_double_quant=True, 38 | bnb_4bit_quant_type='nf4' 39 | ) 40 | self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 41 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 42 | vision_tower = self.model.get_vision_tower() 43 | if not vision_tower.is_loaded: 44 | vision_tower.load_model() 45 | vision_tower.to(device=self.device) 46 | self.image_processor = vision_tower.image_processor 47 | self.is_loaded = True 48 | 49 | def generate_caption( 50 | self, image: 51 | Image, prompt: str = default_long_prompt, 52 | replacements=default_replacements, 53 | max_new_tokens=512 54 | ): 55 | from llava.conversation import conv_templates, SeparatorStyle 56 | from llava.utils import disable_torch_init 57 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 58 | from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria 59 | # question = "how many dogs are in the picture?" 60 | disable_torch_init() 61 | conv_mode = "llava_v0" 62 | conv = conv_templates[conv_mode].copy() 63 | roles = conv.roles 64 | image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda() 65 | 66 | inp = f"{roles[0]}: {prompt}" 67 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 68 | conv.append_message(conv.roles[0], inp) 69 | conv.append_message(conv.roles[1], None) 70 | raw_prompt = conv.get_prompt() 71 | input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, 72 | return_tensors='pt').unsqueeze(0).cuda() 73 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 74 | keywords = [stop_str] 75 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) 76 | with torch.inference_mode(): 77 | output_ids = self.model.generate( 78 | input_ids, images=image_tensor, do_sample=True, temperature=0.1, 79 | max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], 80 | top_p=0.8 81 | ) 82 | outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 83 | conv.messages[-1][-1] = outputs 84 | output = outputs.rsplit('', 1)[0] 85 | return clean_caption(output, replacements=replacements) 86 | -------------------------------------------------------------------------------- /extensions_built_in/image_reference_slider_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an example extension for custom training. It is great for experimenting with new ideas. 2 | from toolkit.extension import Extension 3 | 4 | 5 | # We make a subclass of Extension 6 | class ImageReferenceSliderTrainer(Extension): 7 | # uid must be unique, it is how the extension is identified 8 | uid = "image_reference_slider_trainer" 9 | 10 | # name is the name of the extension for printing 11 | name = "Image Reference Slider Trainer" 12 | 13 | # This is where your process class is loaded 14 | # keep your imports in here so they don't slow down the rest of the program 15 | @classmethod 16 | def get_process(cls): 17 | # import your process class here so it is only loaded when needed and return it 18 | from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess 19 | return ImageReferenceSliderTrainerProcess 20 | 21 | 22 | AI_TOOLKIT_EXTENSIONS = [ 23 | # you can put a list of extensions here 24 | ImageReferenceSliderTrainer 25 | ] 26 | -------------------------------------------------------------------------------- /extensions_built_in/image_reference_slider_trainer/config/train.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | name: example_name 5 | process: 6 | - type: 'image_reference_slider_trainer' 7 | training_folder: "/mnt/Train/out/LoRA" 8 | device: cuda:0 9 | # for tensorboard logging 10 | log_dir: "/home/jaret/Dev/.tensorboard" 11 | network: 12 | type: "lora" 13 | linear: 8 14 | linear_alpha: 8 15 | train: 16 | noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" 17 | steps: 5000 18 | lr: 1e-4 19 | train_unet: true 20 | gradient_checkpointing: true 21 | train_text_encoder: true 22 | optimizer: "adamw" 23 | optimizer_params: 24 | weight_decay: 1e-2 25 | lr_scheduler: "constant" 26 | max_denoising_steps: 1000 27 | batch_size: 1 28 | dtype: bf16 29 | xformers: true 30 | skip_first_sample: true 31 | noise_offset: 0.0 32 | model: 33 | name_or_path: "/path/to/model.safetensors" 34 | is_v2: false # for v2 models 35 | is_xl: false # for SDXL models 36 | is_v_pred: false # for v-prediction models (most v2 models) 37 | save: 38 | dtype: float16 # precision to save 39 | save_every: 1000 # save every this many steps 40 | max_step_saves_to_keep: 2 # only affects step counts 41 | sample: 42 | sampler: "ddpm" # must match train.noise_scheduler 43 | sample_every: 100 # sample every this many steps 44 | width: 512 45 | height: 512 46 | prompts: 47 | - "photo of a woman with red hair taking a selfie --m -3" 48 | - "photo of a woman with red hair taking a selfie --m -1" 49 | - "photo of a woman with red hair taking a selfie --m 1" 50 | - "photo of a woman with red hair taking a selfie --m 3" 51 | - "close up photo of a man smiling at the camera, in a tank top --m -3" 52 | - "close up photo of a man smiling at the camera, in a tank top--m -1" 53 | - "close up photo of a man smiling at the camera, in a tank top --m 1" 54 | - "close up photo of a man smiling at the camera, in a tank top --m 3" 55 | - "photo of a blonde woman smiling, barista --m -3" 56 | - "photo of a blonde woman smiling, barista --m -1" 57 | - "photo of a blonde woman smiling, barista --m 1" 58 | - "photo of a blonde woman smiling, barista --m 3" 59 | - "photo of a Christina Hendricks --m -1" 60 | - "photo of a Christina Hendricks --m -1" 61 | - "photo of a Christina Hendricks --m 1" 62 | - "photo of a Christina Hendricks --m 3" 63 | - "photo of a Christina Ricci --m -3" 64 | - "photo of a Christina Ricci --m -1" 65 | - "photo of a Christina Ricci --m 1" 66 | - "photo of a Christina Ricci --m 3" 67 | neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" 68 | seed: 42 69 | walk_seed: false 70 | guidance_scale: 7 71 | sample_steps: 20 72 | network_multiplier: 1.0 73 | 74 | logging: 75 | log_every: 10 # log every this many steps 76 | use_wandb: false # not supported yet 77 | verbose: false 78 | 79 | slider: 80 | datasets: 81 | - pair_folder: "/path/to/folder/side/by/side/images" 82 | network_weight: 2.0 83 | target_class: "" # only used as default if caption txt are not present 84 | size: 512 85 | - pair_folder: "/path/to/folder/side/by/side/images" 86 | network_weight: 4.0 87 | target_class: "" # only used as default if caption txt are not present 88 | size: 512 89 | 90 | 91 | # you can put any information you want here, and it will be saved in the model 92 | # the below is an example. I recommend doing trigger words at a minimum 93 | # in the metadata. The software will include this plus some other information 94 | meta: 95 | name: "[name]" # [name] gets replaced with the name above 96 | description: A short description of your model 97 | trigger_words: 98 | - put 99 | - trigger 100 | - words 101 | - here 102 | version: '0.1' 103 | creator: 104 | name: Your Name 105 | email: your@email.com 106 | website: https://yourwebsite.com 107 | any: All meta data above is arbitrary, it can be whatever you want. -------------------------------------------------------------------------------- /extensions_built_in/sd_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an example extension for custom training. It is great for experimenting with new ideas. 2 | from toolkit.extension import Extension 3 | 4 | 5 | # This is for generic training (LoRA, Dreambooth, FineTuning) 6 | class SDTrainerExtension(Extension): 7 | # uid must be unique, it is how the extension is identified 8 | uid = "sd_trainer" 9 | 10 | # name is the name of the extension for printing 11 | name = "SD Trainer" 12 | 13 | # This is where your process class is loaded 14 | # keep your imports in here so they don't slow down the rest of the program 15 | @classmethod 16 | def get_process(cls): 17 | # import your process class here so it is only loaded when needed and return it 18 | from .SDTrainer import SDTrainer 19 | return SDTrainer 20 | 21 | 22 | # for backwards compatability 23 | class TextualInversionTrainer(SDTrainerExtension): 24 | uid = "textual_inversion_trainer" 25 | 26 | 27 | AI_TOOLKIT_EXTENSIONS = [ 28 | # you can put a list of extensions here 29 | SDTrainerExtension, TextualInversionTrainer 30 | ] 31 | -------------------------------------------------------------------------------- /extensions_built_in/sd_trainer/config/train.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | name: test_v1 5 | process: 6 | - type: 'textual_inversion_trainer' 7 | training_folder: "out/TI" 8 | device: cuda:0 9 | # for tensorboard logging 10 | log_dir: "out/.tensorboard" 11 | embedding: 12 | trigger: "your_trigger_here" 13 | tokens: 12 14 | init_words: "man with short brown hair" 15 | save_format: "safetensors" # 'safetensors' or 'pt' 16 | save: 17 | dtype: float16 # precision to save 18 | save_every: 100 # save every this many steps 19 | max_step_saves_to_keep: 5 # only affects step counts 20 | datasets: 21 | - folder_path: "/path/to/dataset" 22 | caption_ext: "txt" 23 | default_caption: "[trigger]" 24 | buckets: true 25 | resolution: 512 26 | train: 27 | noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" 28 | steps: 3000 29 | weight_jitter: 0.0 30 | lr: 5e-5 31 | train_unet: false 32 | gradient_checkpointing: true 33 | train_text_encoder: false 34 | optimizer: "adamw" 35 | # optimizer: "prodigy" 36 | optimizer_params: 37 | weight_decay: 1e-2 38 | lr_scheduler: "constant" 39 | max_denoising_steps: 1000 40 | batch_size: 4 41 | dtype: bf16 42 | xformers: true 43 | min_snr_gamma: 5.0 44 | # skip_first_sample: true 45 | noise_offset: 0.0 # not needed for this 46 | model: 47 | # objective reality v2 48 | name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" 49 | is_v2: false # for v2 models 50 | is_xl: false # for SDXL models 51 | is_v_pred: false # for v-prediction models (most v2 models) 52 | sample: 53 | sampler: "ddpm" # must match train.noise_scheduler 54 | sample_every: 100 # sample every this many steps 55 | width: 512 56 | height: 512 57 | prompts: 58 | - "photo of [trigger] laughing" 59 | - "photo of [trigger] smiling" 60 | - "[trigger] close up" 61 | - "dark scene [trigger] frozen" 62 | - "[trigger] nighttime" 63 | - "a painting of [trigger]" 64 | - "a drawing of [trigger]" 65 | - "a cartoon of [trigger]" 66 | - "[trigger] pixar style" 67 | - "[trigger] costume" 68 | neg: "" 69 | seed: 42 70 | walk_seed: false 71 | guidance_scale: 7 72 | sample_steps: 20 73 | network_multiplier: 1.0 74 | 75 | logging: 76 | log_every: 10 # log every this many steps 77 | use_wandb: false # not supported yet 78 | verbose: false 79 | 80 | # You can put any information you want here, and it will be saved in the model. 81 | # The below is an example, but you can put your grocery list in it if you want. 82 | # It is saved in the model so be aware of that. The software will include this 83 | # plus some other information for you automatically 84 | meta: 85 | # [name] gets replaced with the name above 86 | name: "[name]" 87 | # version: '1.0' 88 | # creator: 89 | # name: Your Name 90 | # email: your@gmail.com 91 | # website: https://your.website 92 | -------------------------------------------------------------------------------- /extensions_built_in/ultimate_slider_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an example extension for custom training. It is great for experimenting with new ideas. 2 | from toolkit.extension import Extension 3 | 4 | 5 | # We make a subclass of Extension 6 | class UltimateSliderTrainer(Extension): 7 | # uid must be unique, it is how the extension is identified 8 | uid = "ultimate_slider_trainer" 9 | 10 | # name is the name of the extension for printing 11 | name = "Ultimate Slider Trainer" 12 | 13 | # This is where your process class is loaded 14 | # keep your imports in here so they don't slow down the rest of the program 15 | @classmethod 16 | def get_process(cls): 17 | # import your process class here so it is only loaded when needed and return it 18 | from .UltimateSliderTrainerProcess import UltimateSliderTrainerProcess 19 | return UltimateSliderTrainerProcess 20 | 21 | 22 | AI_TOOLKIT_EXTENSIONS = [ 23 | # you can put a list of extensions here 24 | UltimateSliderTrainer 25 | ] 26 | -------------------------------------------------------------------------------- /extensions_built_in/ultimate_slider_trainer/config/train.example.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | job: extension 3 | config: 4 | name: example_name 5 | process: 6 | - type: 'image_reference_slider_trainer' 7 | training_folder: "/mnt/Train/out/LoRA" 8 | device: cuda:0 9 | # for tensorboard logging 10 | log_dir: "/home/jaret/Dev/.tensorboard" 11 | network: 12 | type: "lora" 13 | linear: 8 14 | linear_alpha: 8 15 | train: 16 | noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" 17 | steps: 5000 18 | lr: 1e-4 19 | train_unet: true 20 | gradient_checkpointing: true 21 | train_text_encoder: true 22 | optimizer: "adamw" 23 | optimizer_params: 24 | weight_decay: 1e-2 25 | lr_scheduler: "constant" 26 | max_denoising_steps: 1000 27 | batch_size: 1 28 | dtype: bf16 29 | xformers: true 30 | skip_first_sample: true 31 | noise_offset: 0.0 32 | model: 33 | name_or_path: "/path/to/model.safetensors" 34 | is_v2: false # for v2 models 35 | is_xl: false # for SDXL models 36 | is_v_pred: false # for v-prediction models (most v2 models) 37 | save: 38 | dtype: float16 # precision to save 39 | save_every: 1000 # save every this many steps 40 | max_step_saves_to_keep: 2 # only affects step counts 41 | sample: 42 | sampler: "ddpm" # must match train.noise_scheduler 43 | sample_every: 100 # sample every this many steps 44 | width: 512 45 | height: 512 46 | prompts: 47 | - "photo of a woman with red hair taking a selfie --m -3" 48 | - "photo of a woman with red hair taking a selfie --m -1" 49 | - "photo of a woman with red hair taking a selfie --m 1" 50 | - "photo of a woman with red hair taking a selfie --m 3" 51 | - "close up photo of a man smiling at the camera, in a tank top --m -3" 52 | - "close up photo of a man smiling at the camera, in a tank top--m -1" 53 | - "close up photo of a man smiling at the camera, in a tank top --m 1" 54 | - "close up photo of a man smiling at the camera, in a tank top --m 3" 55 | - "photo of a blonde woman smiling, barista --m -3" 56 | - "photo of a blonde woman smiling, barista --m -1" 57 | - "photo of a blonde woman smiling, barista --m 1" 58 | - "photo of a blonde woman smiling, barista --m 3" 59 | - "photo of a Christina Hendricks --m -1" 60 | - "photo of a Christina Hendricks --m -1" 61 | - "photo of a Christina Hendricks --m 1" 62 | - "photo of a Christina Hendricks --m 3" 63 | - "photo of a Christina Ricci --m -3" 64 | - "photo of a Christina Ricci --m -1" 65 | - "photo of a Christina Ricci --m 1" 66 | - "photo of a Christina Ricci --m 3" 67 | neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" 68 | seed: 42 69 | walk_seed: false 70 | guidance_scale: 7 71 | sample_steps: 20 72 | network_multiplier: 1.0 73 | 74 | logging: 75 | log_every: 10 # log every this many steps 76 | use_wandb: false # not supported yet 77 | verbose: false 78 | 79 | slider: 80 | datasets: 81 | - pair_folder: "/path/to/folder/side/by/side/images" 82 | network_weight: 2.0 83 | target_class: "" # only used as default if caption txt are not present 84 | size: 512 85 | - pair_folder: "/path/to/folder/side/by/side/images" 86 | network_weight: 4.0 87 | target_class: "" # only used as default if caption txt are not present 88 | size: 512 89 | 90 | 91 | # you can put any information you want here, and it will be saved in the model 92 | # the below is an example. I recommend doing trigger words at a minimum 93 | # in the metadata. The software will include this plus some other information 94 | meta: 95 | name: "[name]" # [name] gets replaced with the name above 96 | description: A short description of your model 97 | trigger_words: 98 | - put 99 | - trigger 100 | - words 101 | - here 102 | version: '0.1' 103 | creator: 104 | name: Your Name 105 | email: your@email.com 106 | website: https://yourwebsite.com 107 | any: All meta data above is arbitrary, it can be whatever you want. -------------------------------------------------------------------------------- /info.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | v = OrderedDict() 4 | v["name"] = "ai-toolkit" 5 | v["repo"] = "https://github.com/ostris/ai-toolkit" 6 | v["version"] = "0.1.0" 7 | 8 | software_meta = v 9 | -------------------------------------------------------------------------------- /jobs/BaseJob.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from collections import OrderedDict 3 | from typing import List 4 | 5 | from jobs.process import BaseProcess 6 | 7 | 8 | class BaseJob: 9 | 10 | def __init__(self, config: OrderedDict): 11 | if not config: 12 | raise ValueError('config is required') 13 | self.process: List[BaseProcess] 14 | 15 | self.config = config['config'] 16 | self.raw_config = config 17 | self.job = config['job'] 18 | self.torch_profiler = self.get_conf('torch_profiler', False) 19 | self.name = self.get_conf('name', required=True) 20 | if 'meta' in config: 21 | self.meta = config['meta'] 22 | else: 23 | self.meta = OrderedDict() 24 | 25 | def get_conf(self, key, default=None, required=False): 26 | if key in self.config: 27 | return self.config[key] 28 | elif required: 29 | raise ValueError(f'config file error. Missing "config.{key}" key') 30 | else: 31 | return default 32 | 33 | def run(self): 34 | print("") 35 | print(f"#############################################") 36 | print(f"# Running job: {self.name}") 37 | print(f"#############################################") 38 | print("") 39 | # implement in child class 40 | # be sure to call super().run() first 41 | pass 42 | 43 | def load_processes(self, process_dict: dict): 44 | # only call if you have processes in this job type 45 | if 'process' not in self.config: 46 | raise ValueError('config file is invalid. Missing "config.process" key') 47 | if len(self.config['process']) == 0: 48 | raise ValueError('config file is invalid. "config.process" must be a list of processes') 49 | 50 | module = importlib.import_module('jobs.process') 51 | 52 | # add the processes 53 | self.process = [] 54 | for i, process in enumerate(self.config['process']): 55 | if 'type' not in process: 56 | raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') 57 | 58 | # check if dict key is process type 59 | if process['type'] in process_dict: 60 | if isinstance(process_dict[process['type']], str): 61 | ProcessClass = getattr(module, process_dict[process['type']]) 62 | else: 63 | # it is the class 64 | ProcessClass = process_dict[process['type']] 65 | self.process.append(ProcessClass(i, self, process)) 66 | else: 67 | raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') 68 | 69 | def cleanup(self): 70 | # if you implement this in child clas, 71 | # be sure to call super().cleanup() LAST 72 | del self 73 | -------------------------------------------------------------------------------- /jobs/ExtensionJob.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from jobs import BaseJob 4 | from toolkit.extension import get_all_extensions_process_dict 5 | from toolkit.paths import CONFIG_ROOT 6 | 7 | class ExtensionJob(BaseJob): 8 | 9 | def __init__(self, config: OrderedDict): 10 | super().__init__(config) 11 | self.device = self.get_conf('device', 'cpu') 12 | self.process_dict = get_all_extensions_process_dict() 13 | self.load_processes(self.process_dict) 14 | 15 | def run(self): 16 | super().run() 17 | 18 | print("") 19 | print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") 20 | 21 | for process in self.process: 22 | process.run() 23 | -------------------------------------------------------------------------------- /jobs/ExtractJob.py: -------------------------------------------------------------------------------- 1 | from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint 2 | from collections import OrderedDict 3 | from jobs import BaseJob 4 | from toolkit.train_tools import get_torch_dtype 5 | 6 | process_dict = { 7 | 'locon': 'ExtractLoconProcess', 8 | 'lora': 'ExtractLoraProcess', 9 | } 10 | 11 | 12 | class ExtractJob(BaseJob): 13 | 14 | def __init__(self, config: OrderedDict): 15 | super().__init__(config) 16 | self.base_model_path = self.get_conf('base_model', required=True) 17 | self.model_base = None 18 | self.model_base_text_encoder = None 19 | self.model_base_vae = None 20 | self.model_base_unet = None 21 | self.extract_model_path = self.get_conf('extract_model', required=True) 22 | self.model_extract = None 23 | self.model_extract_text_encoder = None 24 | self.model_extract_vae = None 25 | self.model_extract_unet = None 26 | self.extract_unet = self.get_conf('extract_unet', True) 27 | self.extract_text_encoder = self.get_conf('extract_text_encoder', True) 28 | self.dtype = self.get_conf('dtype', 'fp16') 29 | self.torch_dtype = get_torch_dtype(self.dtype) 30 | self.output_folder = self.get_conf('output_folder', required=True) 31 | self.is_v2 = self.get_conf('is_v2', False) 32 | self.device = self.get_conf('device', 'cpu') 33 | 34 | # loads the processes from the config 35 | self.load_processes(process_dict) 36 | 37 | def run(self): 38 | super().run() 39 | # load models 40 | print(f"Loading models for extraction") 41 | print(f" - Loading base model: {self.base_model_path}") 42 | # (text_model, vae, unet) 43 | self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) 44 | self.model_base_text_encoder = self.model_base[0] 45 | self.model_base_vae = self.model_base[1] 46 | self.model_base_unet = self.model_base[2] 47 | 48 | print(f" - Loading extract model: {self.extract_model_path}") 49 | self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) 50 | self.model_extract_text_encoder = self.model_extract[0] 51 | self.model_extract_vae = self.model_extract[1] 52 | self.model_extract_unet = self.model_extract[2] 53 | 54 | print("") 55 | print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") 56 | 57 | for process in self.process: 58 | process.run() 59 | -------------------------------------------------------------------------------- /jobs/GenerateJob.py: -------------------------------------------------------------------------------- 1 | from jobs import BaseJob 2 | from collections import OrderedDict 3 | from typing import List 4 | from jobs.process import GenerateProcess 5 | from toolkit.paths import REPOS_ROOT 6 | 7 | import sys 8 | 9 | sys.path.append(REPOS_ROOT) 10 | 11 | process_dict = { 12 | 'to_folder': 'GenerateProcess', 13 | } 14 | 15 | 16 | class GenerateJob(BaseJob): 17 | 18 | def __init__(self, config: OrderedDict): 19 | super().__init__(config) 20 | self.device = self.get_conf('device', 'cpu') 21 | 22 | # loads the processes from the config 23 | self.load_processes(process_dict) 24 | 25 | def run(self): 26 | super().run() 27 | print("") 28 | print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") 29 | 30 | for process in self.process: 31 | process.run() 32 | -------------------------------------------------------------------------------- /jobs/MergeJob.py: -------------------------------------------------------------------------------- 1 | from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint 2 | from collections import OrderedDict 3 | from jobs import BaseJob 4 | from toolkit.train_tools import get_torch_dtype 5 | 6 | process_dict = { 7 | } 8 | 9 | 10 | class MergeJob(BaseJob): 11 | 12 | def __init__(self, config: OrderedDict): 13 | super().__init__(config) 14 | self.dtype = self.get_conf('dtype', 'fp16') 15 | self.torch_dtype = get_torch_dtype(self.dtype) 16 | self.is_v2 = self.get_conf('is_v2', False) 17 | self.device = self.get_conf('device', 'cpu') 18 | 19 | # loads the processes from the config 20 | self.load_processes(process_dict) 21 | 22 | def run(self): 23 | super().run() 24 | 25 | print("") 26 | print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") 27 | 28 | for process in self.process: 29 | process.run() 30 | -------------------------------------------------------------------------------- /jobs/ModJob.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from jobs import BaseJob 4 | from toolkit.metadata import get_meta_for_safetensors 5 | from toolkit.train_tools import get_torch_dtype 6 | 7 | process_dict = { 8 | 'rescale_lora': 'ModRescaleLoraProcess', 9 | } 10 | 11 | 12 | class ModJob(BaseJob): 13 | 14 | def __init__(self, config: OrderedDict): 15 | super().__init__(config) 16 | self.device = self.get_conf('device', 'cpu') 17 | 18 | # loads the processes from the config 19 | self.load_processes(process_dict) 20 | 21 | def run(self): 22 | super().run() 23 | 24 | print("") 25 | print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") 26 | 27 | for process in self.process: 28 | process.run() 29 | -------------------------------------------------------------------------------- /jobs/TrainJob.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from jobs import BaseJob 5 | from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint 6 | from collections import OrderedDict 7 | from typing import List 8 | from jobs.process import BaseExtractProcess, TrainFineTuneProcess 9 | from datetime import datetime 10 | import yaml 11 | from toolkit.paths import REPOS_ROOT 12 | 13 | import sys 14 | 15 | sys.path.append(REPOS_ROOT) 16 | 17 | process_dict = { 18 | 'vae': 'TrainVAEProcess', 19 | 'slider': 'TrainSliderProcess', 20 | 'slider_old': 'TrainSliderProcessOld', 21 | 'lora_hack': 'TrainLoRAHack', 22 | 'rescale_sd': 'TrainSDRescaleProcess', 23 | 'esrgan': 'TrainESRGANProcess', 24 | 'reference': 'TrainReferenceProcess', 25 | } 26 | 27 | 28 | class TrainJob(BaseJob): 29 | 30 | def __init__(self, config: OrderedDict): 31 | super().__init__(config) 32 | self.training_folder = self.get_conf('training_folder', required=True) 33 | self.is_v2 = self.get_conf('is_v2', False) 34 | self.device = self.get_conf('device', 'cpu') 35 | # self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) 36 | # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 37 | self.log_dir = self.get_conf('log_dir', None) 38 | 39 | # loads the processes from the config 40 | self.load_processes(process_dict) 41 | 42 | 43 | def run(self): 44 | super().run() 45 | print("") 46 | print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") 47 | 48 | for process in self.process: 49 | process.run() 50 | -------------------------------------------------------------------------------- /jobs/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseJob import BaseJob 2 | from .ExtractJob import ExtractJob 3 | from .TrainJob import TrainJob 4 | from .MergeJob import MergeJob 5 | from .ModJob import ModJob 6 | from .GenerateJob import GenerateJob 7 | from .ExtensionJob import ExtensionJob 8 | -------------------------------------------------------------------------------- /jobs/process/BaseExtensionProcess.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import ForwardRef 3 | from jobs.process.BaseProcess import BaseProcess 4 | 5 | 6 | class BaseExtensionProcess(BaseProcess): 7 | def __init__( 8 | self, 9 | process_id: int, 10 | job, 11 | config: OrderedDict 12 | ): 13 | super().__init__(process_id, job, config) 14 | self.process_id: int 15 | self.config: OrderedDict 16 | self.progress_bar: ForwardRef('tqdm') = None 17 | 18 | def run(self): 19 | super().run() 20 | -------------------------------------------------------------------------------- /jobs/process/BaseExtractProcess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | from safetensors.torch import save_file 5 | 6 | from jobs.process.BaseProcess import BaseProcess 7 | from toolkit.metadata import get_meta_for_safetensors 8 | 9 | from typing import ForwardRef 10 | 11 | from toolkit.train_tools import get_torch_dtype 12 | 13 | 14 | class BaseExtractProcess(BaseProcess): 15 | 16 | def __init__( 17 | self, 18 | process_id: int, 19 | job, 20 | config: OrderedDict 21 | ): 22 | super().__init__(process_id, job, config) 23 | self.config: OrderedDict 24 | self.output_folder: str 25 | self.output_filename: str 26 | self.output_path: str 27 | self.process_id = process_id 28 | self.job = job 29 | self.config = config 30 | self.dtype = self.get_conf('dtype', self.job.dtype) 31 | self.torch_dtype = get_torch_dtype(self.dtype) 32 | self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet) 33 | self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder) 34 | 35 | def run(self): 36 | # here instead of init because child init needs to go first 37 | self.output_path = self.get_output_path() 38 | # implement in child class 39 | # be sure to call super().run() first 40 | pass 41 | 42 | # you can override this in the child class if you want 43 | # call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this 44 | def get_output_path(self, prefix=None, suffix=None): 45 | config_output_path = self.get_conf('output_path', None) 46 | config_filename = self.get_conf('filename', None) 47 | # replace [name] with name 48 | 49 | if config_output_path is not None: 50 | config_output_path = config_output_path.replace('[name]', self.job.name) 51 | return config_output_path 52 | 53 | if config_output_path is None and config_filename is not None: 54 | # build the output path from the output folder and filename 55 | return os.path.join(self.job.output_folder, config_filename) 56 | 57 | # build our own 58 | 59 | if suffix is None: 60 | # we will just add process it to the end of the filename if there is more than one process 61 | # and no other suffix was given 62 | suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else '' 63 | 64 | if prefix is None: 65 | prefix = '' 66 | 67 | output_filename = f"{prefix}{self.output_filename}{suffix}" 68 | 69 | return os.path.join(self.job.output_folder, output_filename) 70 | 71 | def save(self, state_dict): 72 | # prepare meta 73 | save_meta = get_meta_for_safetensors(self.meta, self.job.name) 74 | 75 | # save 76 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True) 77 | 78 | for key in list(state_dict.keys()): 79 | v = state_dict[key] 80 | v = v.detach().clone().to("cpu").to(self.torch_dtype) 81 | state_dict[key] = v 82 | 83 | # having issues with meta 84 | save_file(state_dict, self.output_path, save_meta) 85 | 86 | print(f"Saved to {self.output_path}") 87 | -------------------------------------------------------------------------------- /jobs/process/BaseMergeProcess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | from safetensors.torch import save_file 5 | 6 | from jobs.process.BaseProcess import BaseProcess 7 | from toolkit.metadata import get_meta_for_safetensors 8 | from toolkit.train_tools import get_torch_dtype 9 | 10 | 11 | class BaseMergeProcess(BaseProcess): 12 | 13 | def __init__( 14 | self, 15 | process_id: int, 16 | job, 17 | config: OrderedDict 18 | ): 19 | super().__init__(process_id, job, config) 20 | self.process_id: int 21 | self.config: OrderedDict 22 | self.output_path = self.get_conf('output_path', required=True) 23 | self.dtype = self.get_conf('dtype', self.job.dtype) 24 | self.torch_dtype = get_torch_dtype(self.dtype) 25 | 26 | def run(self): 27 | # implement in child class 28 | # be sure to call super().run() first 29 | pass 30 | 31 | def save(self, state_dict): 32 | # prepare meta 33 | save_meta = get_meta_for_safetensors(self.meta, self.job.name) 34 | 35 | # save 36 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True) 37 | 38 | for key in list(state_dict.keys()): 39 | v = state_dict[key] 40 | v = v.detach().clone().to("cpu").to(self.torch_dtype) 41 | state_dict[key] = v 42 | 43 | # having issues with meta 44 | save_file(state_dict, self.output_path, save_meta) 45 | 46 | print(f"Saved to {self.output_path}") 47 | -------------------------------------------------------------------------------- /jobs/process/BaseProcess.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from collections import OrderedDict 4 | 5 | from toolkit.timer import Timer 6 | 7 | 8 | class BaseProcess(object): 9 | 10 | def __init__( 11 | self, 12 | process_id: int, 13 | job: 'BaseJob', 14 | config: OrderedDict 15 | ): 16 | self.process_id = process_id 17 | self.meta: OrderedDict 18 | self.job = job 19 | self.config = config 20 | self.raw_process_config = config 21 | self.name = self.get_conf('name', self.job.name) 22 | self.meta = copy.deepcopy(self.job.meta) 23 | self.timer: Timer = Timer(f'{self.name} Timer') 24 | self.performance_log_every = self.get_conf('performance_log_every', 0) 25 | 26 | print(json.dumps(self.config, indent=4)) 27 | 28 | def get_conf(self, key, default=None, required=False, as_type=None): 29 | # split key by '.' and recursively get the value 30 | keys = key.split('.') 31 | 32 | # see if it exists in the config 33 | value = self.config 34 | for subkey in keys: 35 | if subkey in value: 36 | value = value[subkey] 37 | else: 38 | value = None 39 | break 40 | 41 | if value is not None: 42 | if as_type is not None: 43 | value = as_type(value) 44 | return value 45 | elif required: 46 | raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key') 47 | else: 48 | if as_type is not None and default is not None: 49 | return as_type(default) 50 | return default 51 | 52 | def run(self): 53 | # implement in child class 54 | # be sure to call super().run() first incase something is added here 55 | pass 56 | 57 | def add_meta(self, additional_meta: OrderedDict): 58 | self.meta.update(additional_meta) 59 | 60 | 61 | from jobs import BaseJob 62 | -------------------------------------------------------------------------------- /jobs/process/BaseTrainProcess.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datetime import datetime 3 | import os 4 | from collections import OrderedDict 5 | from typing import TYPE_CHECKING, Union 6 | 7 | import torch 8 | import yaml 9 | 10 | from jobs.process.BaseProcess import BaseProcess 11 | 12 | if TYPE_CHECKING: 13 | from jobs import TrainJob, BaseJob, ExtensionJob 14 | from torch.utils.tensorboard import SummaryWriter 15 | from tqdm import tqdm 16 | 17 | 18 | class BaseTrainProcess(BaseProcess): 19 | 20 | def __init__( 21 | self, 22 | process_id: int, 23 | job, 24 | config: OrderedDict 25 | ): 26 | super().__init__(process_id, job, config) 27 | self.process_id: int 28 | self.config: OrderedDict 29 | self.writer: 'SummaryWriter' 30 | self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] 31 | self.progress_bar: 'tqdm' = None 32 | 33 | self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None) 34 | # if training seed is set, use it 35 | if self.training_seed is not None: 36 | torch.manual_seed(self.training_seed) 37 | if torch.cuda.is_available(): 38 | torch.cuda.manual_seed(self.training_seed) 39 | random.seed(self.training_seed) 40 | 41 | self.progress_bar = None 42 | self.writer = None 43 | self.training_folder = self.get_conf('training_folder', 44 | self.job.training_folder if hasattr(self.job, 'training_folder') else None) 45 | self.save_root = os.path.join(self.training_folder, self.name) 46 | self.step = 0 47 | self.first_step = 0 48 | self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None) 49 | self.setup_tensorboard() 50 | self.save_training_config() 51 | 52 | def run(self): 53 | super().run() 54 | # implement in child class 55 | # be sure to call super().run() first 56 | pass 57 | 58 | # def print(self, message, **kwargs): 59 | def print(self, *args): 60 | if self.progress_bar is not None: 61 | self.progress_bar.write(' '.join(map(str, args))) 62 | self.progress_bar.update() 63 | else: 64 | print(*args) 65 | 66 | def setup_tensorboard(self): 67 | if self.log_dir: 68 | from torch.utils.tensorboard import SummaryWriter 69 | now = datetime.now() 70 | time_str = now.strftime('%Y%m%d-%H%M%S') 71 | summary_name = f"{self.name}_{time_str}" 72 | summary_dir = os.path.join(self.log_dir, summary_name) 73 | self.writer = SummaryWriter(summary_dir) 74 | 75 | def save_training_config(self): 76 | os.makedirs(self.save_root, exist_ok=True) 77 | save_dif = os.path.join(self.save_root, f'config.yaml') 78 | with open(save_dif, 'w') as f: 79 | yaml.dump(self.job.raw_config, f) 80 | -------------------------------------------------------------------------------- /jobs/process/ExtractLoconProcess.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from toolkit.lycoris_utils import extract_diff 3 | from .BaseExtractProcess import BaseExtractProcess 4 | 5 | mode_dict = { 6 | 'fixed': { 7 | 'linear': 64, 8 | 'conv': 32, 9 | 'type': int 10 | }, 11 | 'threshold': { 12 | 'linear': 0, 13 | 'conv': 0, 14 | 'type': float 15 | }, 16 | 'ratio': { 17 | 'linear': 0.5, 18 | 'conv': 0.5, 19 | 'type': float 20 | }, 21 | 'quantile': { 22 | 'linear': 0.5, 23 | 'conv': 0.5, 24 | 'type': float 25 | } 26 | } 27 | 28 | 29 | class ExtractLoconProcess(BaseExtractProcess): 30 | def __init__(self, process_id: int, job, config: OrderedDict): 31 | super().__init__(process_id, job, config) 32 | self.mode = self.get_conf('mode', 'fixed') 33 | self.use_sparse_bias = self.get_conf('use_sparse_bias', False) 34 | self.sparsity = self.get_conf('sparsity', 0.98) 35 | self.disable_cp = self.get_conf('disable_cp', False) 36 | 37 | # set modes 38 | if self.mode not in list(mode_dict.keys()): 39 | raise ValueError(f"Unknown mode: {self.mode}") 40 | self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) 41 | self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) 42 | 43 | def run(self): 44 | super().run() 45 | print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}") 46 | 47 | state_dict, extract_diff_meta = extract_diff( 48 | self.job.model_base, 49 | self.job.model_extract, 50 | self.mode, 51 | self.linear_param, 52 | self.conv_param, 53 | self.job.device, 54 | self.use_sparse_bias, 55 | self.sparsity, 56 | not self.disable_cp, 57 | extract_unet=self.extract_unet, 58 | extract_text_encoder=self.extract_text_encoder 59 | ) 60 | 61 | self.add_meta(extract_diff_meta) 62 | self.save(state_dict) 63 | 64 | def get_output_path(self, prefix=None, suffix=None): 65 | if suffix is None: 66 | suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" 67 | return super().get_output_path(prefix, suffix) 68 | 69 | -------------------------------------------------------------------------------- /jobs/process/ExtractLoraProcess.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from toolkit.lycoris_utils import extract_diff 3 | from .BaseExtractProcess import BaseExtractProcess 4 | 5 | 6 | mode_dict = { 7 | 'fixed': { 8 | 'linear': 4, 9 | 'conv': 0, 10 | 'type': int 11 | }, 12 | 'threshold': { 13 | 'linear': 0, 14 | 'conv': 0, 15 | 'type': float 16 | }, 17 | 'ratio': { 18 | 'linear': 0.5, 19 | 'conv': 0, 20 | 'type': float 21 | }, 22 | 'quantile': { 23 | 'linear': 0.5, 24 | 'conv': 0, 25 | 'type': float 26 | } 27 | } 28 | 29 | CLAMP_QUANTILE = 0.99 30 | MIN_DIFF = 1e-6 31 | 32 | 33 | class ExtractLoraProcess(BaseExtractProcess): 34 | 35 | def __init__(self, process_id: int, job, config: OrderedDict): 36 | super().__init__(process_id, job, config) 37 | self.mode = self.get_conf('mode', 'fixed') 38 | 39 | # set modes 40 | if self.mode not in list(mode_dict.keys()): 41 | raise ValueError(f"Unknown mode: {self.mode}") 42 | self.linear = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) 43 | self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) 44 | self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) 45 | self.use_sparse_bias = self.get_conf('use_sparse_bias', False) 46 | self.sparsity = self.get_conf('sparsity', 0.98) 47 | 48 | def run(self): 49 | super().run() 50 | print(f"Running process: {self.mode}, dim: {self.dim}") 51 | 52 | state_dict, extract_diff_meta = extract_diff( 53 | self.job.model_base, 54 | self.job.model_extract, 55 | self.mode, 56 | self.linear_param, 57 | self.conv_param, 58 | self.job.device, 59 | self.use_sparse_bias, 60 | self.sparsity, 61 | small_conv=False, 62 | linear_only=self.conv_param > 0.0000000001, 63 | extract_unet=self.extract_unet, 64 | extract_text_encoder=self.extract_text_encoder 65 | ) 66 | 67 | self.add_meta(extract_diff_meta) 68 | self.save(state_dict) 69 | 70 | def get_output_path(self, prefix=None, suffix=None): 71 | if suffix is None: 72 | suffix = f"_{self.dim}" 73 | return super().get_output_path(prefix, suffix) 74 | -------------------------------------------------------------------------------- /jobs/process/MergeLoconProcess.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from toolkit.lycoris_utils import extract_diff 3 | from .BaseExtractProcess import BaseExtractProcess 4 | 5 | 6 | class MergeLoconProcess(BaseExtractProcess): 7 | def __init__(self, process_id: int, job, config: OrderedDict): 8 | super().__init__(process_id, job, config) 9 | 10 | def run(self): 11 | super().run() 12 | new_state_dict = {} 13 | raise NotImplementedError("This is not implemented yet") 14 | 15 | 16 | def get_output_path(self, prefix=None, suffix=None): 17 | if suffix is None: 18 | suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" 19 | return super().get_output_path(prefix, suffix) 20 | 21 | -------------------------------------------------------------------------------- /jobs/process/TrainFineTuneProcess.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from jobs import TrainJob 3 | from jobs.process import BaseTrainProcess 4 | 5 | 6 | class TrainFineTuneProcess(BaseTrainProcess): 7 | def __init__(self,process_id: int, job: TrainJob, config: OrderedDict): 8 | super().__init__(process_id, job, config) 9 | 10 | def run(self): 11 | # implement in child class 12 | # be sure to call super().run() first 13 | pass 14 | -------------------------------------------------------------------------------- /jobs/process/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseExtractProcess import BaseExtractProcess 2 | from .ExtractLoconProcess import ExtractLoconProcess 3 | from .ExtractLoraProcess import ExtractLoraProcess 4 | from .BaseProcess import BaseProcess 5 | from .BaseTrainProcess import BaseTrainProcess 6 | from .TrainVAEProcess import TrainVAEProcess 7 | from .BaseMergeProcess import BaseMergeProcess 8 | from .TrainSliderProcess import TrainSliderProcess 9 | from .TrainSliderProcessOld import TrainSliderProcessOld 10 | from .TrainSDRescaleProcess import TrainSDRescaleProcess 11 | from .ModRescaleLoraProcess import ModRescaleLoraProcess 12 | from .GenerateProcess import GenerateProcess 13 | from .BaseExtensionProcess import BaseExtensionProcess 14 | from .TrainESRGANProcess import TrainESRGANProcess 15 | from .BaseSDTrainProcess import BaseSDTrainProcess 16 | -------------------------------------------------------------------------------- /lora-license.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: other 3 | license_name: flux-1-dev-non-commercial-license 4 | license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md 5 | --- -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/output/.gitkeep -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://cog.run/python 3 | 4 | from cog import BasePredictor, Input, Path 5 | import os 6 | from typing import Optional 7 | 8 | class Predictor(BasePredictor): 9 | def setup(self): 10 | """Load the model into memory to make running multiple predictions efficient""" 11 | 12 | def predict( 13 | self, 14 | prompt: str = Input( 15 | description="Please check the Train Tab at the top of the model to train a LoRA." 16 | ), 17 | ) -> Path: 18 | """Run a single prediction on the model""" 19 | os.system("touch empty.zip") 20 | return Path("empty.zip") 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | safetensors 4 | git+https://github.com/huggingface/diffusers.git 5 | transformers 6 | lycoris-lora==1.8.3 7 | flatten_json 8 | pyyaml 9 | oyaml 10 | tensorboard 11 | kornia 12 | invisible-watermark 13 | einops 14 | accelerate 15 | toml 16 | albumentations 17 | pydantic 18 | omegaconf 19 | k-diffusion 20 | open_clip_torch 21 | timm 22 | prodigyopt 23 | controlnet_aux==0.0.7 24 | python-dotenv 25 | bitsandbytes 26 | hf_transfer 27 | lpips 28 | pytorch_fid 29 | optimum-quanto 30 | sentencepiece -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 3 | import sys 4 | from typing import Union, OrderedDict 5 | from dotenv import load_dotenv 6 | # Load the .env file if it exists 7 | load_dotenv() 8 | 9 | sys.path.insert(0, os.getcwd()) 10 | # must come before ANY torch or fastai imports 11 | # import toolkit.cuda_malloc 12 | 13 | # turn off diffusers telemetry until I can figure out how to make it opt-in 14 | os.environ['DISABLE_TELEMETRY'] = 'YES' 15 | 16 | # check if we have DEBUG_TOOLKIT in env 17 | if os.environ.get("DEBUG_TOOLKIT", "0") == "1": 18 | # set torch to trace mode 19 | import torch 20 | torch.autograd.set_detect_anomaly(True) 21 | import argparse 22 | from toolkit.job import get_job 23 | 24 | 25 | def print_end_message(jobs_completed, jobs_failed): 26 | failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" 27 | completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" 28 | 29 | print("") 30 | print("========================================") 31 | print("Result:") 32 | if len(completed_string) > 0: 33 | print(f" - {completed_string}") 34 | if len(failure_string) > 0: 35 | print(f" - {failure_string}") 36 | print("========================================") 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | 42 | # require at lease one config file 43 | parser.add_argument( 44 | 'config_file_list', 45 | nargs='+', 46 | type=str, 47 | help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' 48 | ) 49 | 50 | # flag to continue if failed job 51 | parser.add_argument( 52 | '-r', '--recover', 53 | action='store_true', 54 | help='Continue running additional jobs even if a job fails' 55 | ) 56 | 57 | # flag to continue if failed job 58 | parser.add_argument( 59 | '-n', '--name', 60 | type=str, 61 | default=None, 62 | help='Name to replace [name] tag in config file, useful for shared config file' 63 | ) 64 | args = parser.parse_args() 65 | 66 | config_file_list = args.config_file_list 67 | if len(config_file_list) == 0: 68 | raise Exception("You must provide at least one config file") 69 | 70 | jobs_completed = 0 71 | jobs_failed = 0 72 | 73 | print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") 74 | 75 | for config_file in config_file_list: 76 | try: 77 | job = get_job(config_file, args.name) 78 | job.run() 79 | job.cleanup() 80 | jobs_completed += 1 81 | except Exception as e: 82 | print(f"Error running job: {e}") 83 | jobs_failed += 1 84 | if not args.recover: 85 | print_end_message(jobs_completed, jobs_failed) 86 | raise e 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /scripts/convert_cog.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | import os 4 | import torch 5 | from safetensors import safe_open 6 | from safetensors.torch import save_file 7 | 8 | device = torch.device('cpu') 9 | 10 | # [diffusers] -> kohya 11 | embedding_mapping = { 12 | 'text_encoders_0': 'clip_l', 13 | 'text_encoders_1': 'clip_g' 14 | } 15 | 16 | PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps') 18 | sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json') 19 | 20 | # load keymap 21 | with open(sdxl_keymap_path, 'r') as f: 22 | ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap'] 23 | 24 | # invert the item / key pairs 25 | diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()} 26 | 27 | 28 | def get_ldm_key(diffuser_key): 29 | diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}" 30 | diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight') 31 | diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight') 32 | diffuser_key = diffuser_key.replace('_alpha', '.alpha') 33 | diffuser_key = diffuser_key.replace('_processor_to_', '_to_') 34 | diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.') 35 | if diffuser_key in diffusers_ldm_keymap: 36 | return diffusers_ldm_keymap[diffuser_key] 37 | else: 38 | raise KeyError(f"Key {diffuser_key} not found in keymap") 39 | 40 | 41 | def convert_cog(lora_path, embedding_path): 42 | embedding_state_dict = OrderedDict() 43 | lora_state_dict = OrderedDict() 44 | 45 | # # normal dict 46 | # normal_dict = OrderedDict() 47 | # example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors" 48 | # with safe_open(example_path, framework="pt", device='cpu') as f: 49 | # keys = list(f.keys()) 50 | # for key in keys: 51 | # normal_dict[key] = f.get_tensor(key) 52 | 53 | with safe_open(embedding_path, framework="pt", device='cpu') as f: 54 | keys = list(f.keys()) 55 | for key in keys: 56 | new_key = embedding_mapping[key] 57 | embedding_state_dict[new_key] = f.get_tensor(key) 58 | 59 | with safe_open(lora_path, framework="pt", device='cpu') as f: 60 | keys = list(f.keys()) 61 | lora_rank = None 62 | 63 | # get the lora dim first. Check first 3 linear layers just to be safe 64 | for key in keys: 65 | new_key = get_ldm_key(key) 66 | tensor = f.get_tensor(key) 67 | num_checked = 0 68 | if len(tensor.shape) == 2: 69 | this_dim = min(tensor.shape) 70 | if lora_rank is None: 71 | lora_rank = this_dim 72 | elif lora_rank != this_dim: 73 | raise ValueError(f"lora rank is not consistent, got {tensor.shape}") 74 | else: 75 | num_checked += 1 76 | if num_checked >= 3: 77 | break 78 | 79 | for key in keys: 80 | new_key = get_ldm_key(key) 81 | tensor = f.get_tensor(key) 82 | if new_key.endswith('.lora_down.weight'): 83 | alpha_key = new_key.replace('.lora_down.weight', '.alpha') 84 | # diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims 85 | # assume first smallest dim is the lora rank if shape is 2 86 | lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank 87 | 88 | lora_state_dict[new_key] = tensor 89 | 90 | return lora_state_dict, embedding_state_dict 91 | 92 | 93 | if __name__ == "__main__": 94 | import argparse 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument( 98 | 'lora_path', 99 | type=str, 100 | help='Path to lora file' 101 | ) 102 | parser.add_argument( 103 | 'embedding_path', 104 | type=str, 105 | help='Path to embedding file' 106 | ) 107 | 108 | parser.add_argument( 109 | '--lora_output', 110 | type=str, 111 | default="lora_output", 112 | ) 113 | 114 | parser.add_argument( 115 | '--embedding_output', 116 | type=str, 117 | default="embedding_output", 118 | ) 119 | 120 | args = parser.parse_args() 121 | 122 | lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path) 123 | 124 | # save them 125 | save_file(lora_state_dict, args.lora_output) 126 | save_file(embedding_state_dict, args.embedding_output) 127 | print(f"Saved lora to {args.lora_output}") 128 | print(f"Saved embedding to {args.embedding_output}") 129 | -------------------------------------------------------------------------------- /scripts/convert_lora_to_peft_format.py: -------------------------------------------------------------------------------- 1 | # currently only works with flux as support is not quite there yet 2 | 3 | import argparse 4 | import os.path 5 | from collections import OrderedDict 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | 'input_path', 10 | type=str, 11 | help='Path to original sdxl model' 12 | ) 13 | parser.add_argument( 14 | 'output_path', 15 | type=str, 16 | help='output path' 17 | ) 18 | args = parser.parse_args() 19 | args.input_path = os.path.abspath(args.input_path) 20 | args.output_path = os.path.abspath(args.output_path) 21 | 22 | from safetensors.torch import load_file, save_file 23 | 24 | meta = OrderedDict() 25 | meta['format'] = 'pt' 26 | 27 | state_dict = load_file(args.input_path) 28 | 29 | # peft doesnt have an alpha so we need to scale the weights 30 | alpha_keys = [ 31 | 'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux 32 | ] 33 | 34 | # keys where the rank is in the first dimension 35 | rank_idx0_keys = [ 36 | 'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight' 37 | # 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight' 38 | ] 39 | 40 | alpha = None 41 | rank = None 42 | 43 | for key in rank_idx0_keys: 44 | if key in state_dict: 45 | rank = int(state_dict[key].shape[0]) 46 | break 47 | 48 | if rank is None: 49 | raise ValueError(f'Could not find rank in state dict') 50 | 51 | for key in alpha_keys: 52 | if key in state_dict: 53 | alpha = int(state_dict[key]) 54 | break 55 | 56 | if alpha is None: 57 | # set to rank if not found 58 | alpha = rank 59 | 60 | 61 | up_multiplier = alpha / rank 62 | 63 | new_state_dict = {} 64 | 65 | for key, value in state_dict.items(): 66 | if key.endswith('.alpha'): 67 | continue 68 | 69 | orig_dtype = value.dtype 70 | 71 | new_val = value.float() * up_multiplier 72 | 73 | new_key = key 74 | new_key = new_key.replace('lora_transformer_', 'transformer.') 75 | for i in range(100): 76 | new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.') 77 | new_key = new_key.replace('lora_down', 'lora_A') 78 | new_key = new_key.replace('lora_up', 'lora_B') 79 | new_key = new_key.replace('_lora', '.lora') 80 | new_key = new_key.replace('attn_', 'attn.') 81 | new_key = new_key.replace('ff_', 'ff.') 82 | new_key = new_key.replace('context_net_', 'context.net.') 83 | new_key = new_key.replace('0_proj', '0.proj') 84 | new_key = new_key.replace('norm_linear', 'norm.linear') 85 | new_key = new_key.replace('norm_out_linear', 'norm_out.linear') 86 | new_key = new_key.replace('to_out_', 'to_out.') 87 | 88 | new_state_dict[new_key] = new_val.to(orig_dtype) 89 | 90 | save_file(new_state_dict, args.output_path, meta) 91 | print(f'Saved to {args.output_path}') 92 | -------------------------------------------------------------------------------- /scripts/generate_sampler_step_scales.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | from diffusers import StableDiffusionPipeline 5 | import sys 6 | 7 | PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | # add project root to path 9 | sys.path.append(PROJECT_ROOT) 10 | 11 | SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales') 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Process some images.') 15 | add_arg = parser.add_argument 16 | add_arg('--model', type=str, required=True, help='Path to model') 17 | add_arg('--sampler', type=str, required=True, help='Name of sampler') 18 | 19 | args = parser.parse_args() 20 | 21 | -------------------------------------------------------------------------------- /scripts/make_diffusers_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | import sys 4 | import os 5 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(ROOT_DIR) 7 | 8 | import torch 9 | 10 | from toolkit.config_modules import ModelConfig 11 | from toolkit.stable_diffusion_model import StableDiffusion 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | 'input_path', 17 | type=str, 18 | help='Path to original sdxl model' 19 | ) 20 | parser.add_argument( 21 | 'output_path', 22 | type=str, 23 | help='output path' 24 | ) 25 | parser.add_argument('--sdxl', action='store_true', help='is sdxl model') 26 | parser.add_argument('--refiner', action='store_true', help='is refiner model') 27 | parser.add_argument('--ssd', action='store_true', help='is ssd model') 28 | parser.add_argument('--sd2', action='store_true', help='is sd 2 model') 29 | 30 | args = parser.parse_args() 31 | device = torch.device('cpu') 32 | dtype = torch.float32 33 | 34 | print(f"Loading model from {args.input_path}") 35 | 36 | 37 | diffusers_model_config = ModelConfig( 38 | name_or_path=args.input_path, 39 | is_xl=args.sdxl, 40 | is_v2=args.sd2, 41 | is_ssd=args.ssd, 42 | dtype=dtype, 43 | ) 44 | diffusers_sd = StableDiffusion( 45 | model_config=diffusers_model_config, 46 | device=device, 47 | dtype=dtype, 48 | ) 49 | diffusers_sd.load_model() 50 | 51 | 52 | print(f"Loaded model from {args.input_path}") 53 | 54 | diffusers_sd.pipeline.fuse_lora() 55 | 56 | meta = OrderedDict() 57 | 58 | diffusers_sd.save(args.output_path, meta=meta) 59 | 60 | 61 | print(f"Saved to {args.output_path}") 62 | -------------------------------------------------------------------------------- /scripts/make_lcm_sdxl_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | 4 | import torch 5 | 6 | from toolkit.config_modules import ModelConfig 7 | from toolkit.stable_diffusion_model import StableDiffusion 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | 'input_path', 13 | type=str, 14 | help='Path to original sdxl model' 15 | ) 16 | parser.add_argument( 17 | 'output_path', 18 | type=str, 19 | help='output path' 20 | ) 21 | parser.add_argument('--sdxl', action='store_true', help='is sdxl model') 22 | parser.add_argument('--refiner', action='store_true', help='is refiner model') 23 | parser.add_argument('--ssd', action='store_true', help='is ssd model') 24 | parser.add_argument('--sd2', action='store_true', help='is sd 2 model') 25 | 26 | args = parser.parse_args() 27 | device = torch.device('cpu') 28 | dtype = torch.float32 29 | 30 | print(f"Loading model from {args.input_path}") 31 | 32 | if args.sdxl: 33 | adapter_id = "latent-consistency/lcm-lora-sdxl" 34 | if args.refiner: 35 | adapter_id = "latent-consistency/lcm-lora-sdxl" 36 | elif args.ssd: 37 | adapter_id = "latent-consistency/lcm-lora-ssd-1b" 38 | else: 39 | adapter_id = "latent-consistency/lcm-lora-sdv1-5" 40 | 41 | 42 | diffusers_model_config = ModelConfig( 43 | name_or_path=args.input_path, 44 | is_xl=args.sdxl, 45 | is_v2=args.sd2, 46 | is_ssd=args.ssd, 47 | dtype=dtype, 48 | ) 49 | diffusers_sd = StableDiffusion( 50 | model_config=diffusers_model_config, 51 | device=device, 52 | dtype=dtype, 53 | ) 54 | diffusers_sd.load_model() 55 | 56 | 57 | print(f"Loaded model from {args.input_path}") 58 | 59 | diffusers_sd.pipeline.load_lora_weights(adapter_id) 60 | diffusers_sd.pipeline.fuse_lora() 61 | 62 | meta = OrderedDict() 63 | 64 | diffusers_sd.save(args.output_path, meta=meta) 65 | 66 | 67 | print(f"Saved to {args.output_path}") 68 | -------------------------------------------------------------------------------- /scripts/patch_te_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import save_file, load_file 3 | from collections import OrderedDict 4 | meta = OrderedDict() 5 | meta["format"] ="pt" 6 | 7 | attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors") 8 | state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors") 9 | 10 | attn_list = [] 11 | for key, value in state_dict.items(): 12 | if "attn1" in key: 13 | attn_list.append(key) 14 | 15 | attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'] 16 | 17 | adapter_names = [] 18 | for i in range(100): 19 | if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict: 20 | adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter") 21 | 22 | 23 | for i in range(len(adapter_names)): 24 | adapter_name = adapter_names[i] 25 | attn_name = attn_names[i] 26 | adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight' 27 | adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight' 28 | state_k_name = attn_name.replace(".processor", ".to_k.weight") 29 | state_v_name = attn_name.replace(".processor", ".to_v.weight") 30 | if adapter_k_name in attn_dict: 31 | state_dict[state_k_name] = attn_dict[adapter_k_name] 32 | state_dict[state_v_name] = attn_dict[adapter_v_name] 33 | else: 34 | print("adapter_k_name", adapter_k_name) 35 | print("state_k_name", state_k_name) 36 | 37 | for key, value in state_dict.items(): 38 | state_dict[key] = value.cpu().to(torch.float16) 39 | 40 | save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta) 41 | 42 | print("Done") 43 | -------------------------------------------------------------------------------- /scripts/repair_dataset_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | from PIL.ImageOps import exif_transpose 4 | from tqdm import tqdm 5 | import os 6 | 7 | parser = argparse.ArgumentParser(description='Process some images.') 8 | parser.add_argument("input_folder", type=str, help="Path to folder containing images") 9 | 10 | args = parser.parse_args() 11 | 12 | img_types = ['.jpg', '.jpeg', '.png', '.webp'] 13 | 14 | # find all images in the input folder 15 | images = [] 16 | for root, _, files in os.walk(args.input_folder): 17 | for file in files: 18 | if file.lower().endswith(tuple(img_types)): 19 | images.append(os.path.join(root, file)) 20 | print(f"Found {len(images)} images") 21 | 22 | num_skipped = 0 23 | num_repaired = 0 24 | num_deleted = 0 25 | 26 | pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image") 27 | for img_path in images: 28 | filename = os.path.basename(img_path) 29 | filename_no_ext, file_extension = os.path.splitext(filename) 30 | # if it is jpg, ignore 31 | if file_extension.lower() == '.jpg': 32 | num_skipped += 1 33 | pbar.update(1) 34 | 35 | continue 36 | 37 | try: 38 | img = Image.open(img_path) 39 | except Exception as e: 40 | print(f"Error opening {img_path}: {e}") 41 | # delete it 42 | os.remove(img_path) 43 | num_deleted += 1 44 | pbar.update(1) 45 | pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") 46 | continue 47 | 48 | 49 | try: 50 | img = exif_transpose(img) 51 | except Exception as e: 52 | print(f"Error rotating {img_path}: {e}") 53 | 54 | new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg') 55 | 56 | img = img.convert("RGB") 57 | img.save(new_path, quality=95) 58 | # remove the old file 59 | os.remove(img_path) 60 | num_repaired += 1 61 | pbar.update(1) 62 | # update pbar 63 | pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") 64 | 65 | print("Done") -------------------------------------------------------------------------------- /testing/compare_keys.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from diffusers.loaders import LoraLoaderMixin 6 | from safetensors.torch import load_file 7 | from collections import OrderedDict 8 | import json 9 | # this was just used to match the vae keys to the diffusers keys 10 | # you probably wont need this. Unless they change them.... again... again 11 | # on second thought, you probably will 12 | 13 | device = torch.device('cpu') 14 | dtype = torch.float32 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | # require at lease one config file 19 | parser.add_argument( 20 | 'file_1', 21 | nargs='+', 22 | type=str, 23 | help='Path to first safe tensor file' 24 | ) 25 | 26 | parser.add_argument( 27 | 'file_2', 28 | nargs='+', 29 | type=str, 30 | help='Path to second safe tensor file' 31 | ) 32 | 33 | args = parser.parse_args() 34 | 35 | find_matches = False 36 | 37 | state_dict_file_1 = load_file(args.file_1[0]) 38 | state_dict_1_keys = list(state_dict_file_1.keys()) 39 | 40 | state_dict_file_2 = load_file(args.file_2[0]) 41 | state_dict_2_keys = list(state_dict_file_2.keys()) 42 | keys_in_both = [] 43 | 44 | keys_not_in_state_dict_2 = [] 45 | for key in state_dict_1_keys: 46 | if key not in state_dict_2_keys: 47 | keys_not_in_state_dict_2.append(key) 48 | 49 | keys_not_in_state_dict_1 = [] 50 | for key in state_dict_2_keys: 51 | if key not in state_dict_1_keys: 52 | keys_not_in_state_dict_1.append(key) 53 | 54 | keys_in_both = [] 55 | for key in state_dict_1_keys: 56 | if key in state_dict_2_keys: 57 | keys_in_both.append(key) 58 | 59 | # sort them 60 | keys_not_in_state_dict_2.sort() 61 | keys_not_in_state_dict_1.sort() 62 | keys_in_both.sort() 63 | 64 | 65 | json_data = { 66 | "both": keys_in_both, 67 | "not_in_state_dict_2": keys_not_in_state_dict_2, 68 | "not_in_state_dict_1": keys_not_in_state_dict_1 69 | } 70 | json_data = json.dumps(json_data, indent=4) 71 | 72 | remaining_diffusers_values = OrderedDict() 73 | for key in keys_not_in_state_dict_1: 74 | remaining_diffusers_values[key] = state_dict_file_2[key] 75 | 76 | # print(remaining_diffusers_values.keys()) 77 | 78 | remaining_ldm_values = OrderedDict() 79 | for key in keys_not_in_state_dict_2: 80 | remaining_ldm_values[key] = state_dict_file_1[key] 81 | 82 | # print(json_data) 83 | 84 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 85 | json_save_path = os.path.join(project_root, 'config', 'keys.json') 86 | json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') 87 | json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') 88 | state_dict_1_filename = os.path.basename(args.file_1[0]) 89 | state_dict_2_filename = os.path.basename(args.file_2[0]) 90 | # save key names for each in own file 91 | with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: 92 | f.write(json.dumps(state_dict_1_keys, indent=4)) 93 | 94 | with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f: 95 | f.write(json.dumps(state_dict_2_keys, indent=4)) 96 | 97 | 98 | with open(json_save_path, 'w') as f: 99 | f.write(json_data) -------------------------------------------------------------------------------- /testing/generate_lora_mapping.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from safetensors.torch import load_file 5 | import argparse 6 | import os 7 | import json 8 | 9 | PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json') 12 | 13 | # load keymap 14 | with open(keymap_path, 'r') as f: 15 | keymap = json.load(f) 16 | 17 | lora_keymap = OrderedDict() 18 | 19 | # convert keymap to lora key naming 20 | for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items(): 21 | if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'): 22 | # skip it 23 | continue 24 | # sdxl has same te for locon with kohya and ours 25 | if ldm_key.startswith('conditioner'): 26 | #skip it 27 | continue 28 | # ignore vae 29 | if ldm_key.startswith('first_stage_model'): 30 | continue 31 | ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_') 32 | ldm_key = ldm_key.replace('.weight', '') 33 | ldm_key = ldm_key.replace('.', '_') 34 | 35 | diffusers_key = diffusers_key.replace('unet_', 'lora_unet_') 36 | diffusers_key = diffusers_key.replace('.weight', '') 37 | diffusers_key = diffusers_key.replace('.', '_') 38 | 39 | lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha" 40 | lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight" 41 | lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight" 42 | 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("input", help="input file") 46 | parser.add_argument("input2", help="input2 file") 47 | 48 | args = parser.parse_args() 49 | 50 | # name = args.name 51 | # if args.sdxl: 52 | # name += '_sdxl' 53 | # elif args.sd2: 54 | # name += '_sd2' 55 | # else: 56 | # name += '_sd1' 57 | name = 'stable_diffusion_locon_sdxl' 58 | 59 | locon_save = load_file(args.input) 60 | our_save = load_file(args.input2) 61 | 62 | our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys())) 63 | locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys())) 64 | 65 | print(f"we have {len(our_extra_keys)} extra keys") 66 | print(f"locon has {len(locon_extra_keys)} extra keys") 67 | 68 | save_dtype = torch.float16 69 | print(f"our extra keys: {our_extra_keys}") 70 | print(f"locon extra keys: {locon_extra_keys}") 71 | 72 | 73 | def export_state_dict(our_save): 74 | converted_state_dict = OrderedDict() 75 | for key, value in our_save.items(): 76 | # test encoders share keys for some reason 77 | if key.startswith('lora_te'): 78 | converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) 79 | else: 80 | converted_key = key 81 | for ldm_key, diffusers_key in lora_keymap.items(): 82 | if converted_key == diffusers_key: 83 | converted_key = ldm_key 84 | 85 | converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype) 86 | return converted_state_dict 87 | 88 | def import_state_dict(loaded_state_dict): 89 | converted_state_dict = OrderedDict() 90 | for key, value in loaded_state_dict.items(): 91 | if key.startswith('lora_te'): 92 | converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) 93 | else: 94 | converted_key = key 95 | for ldm_key, diffusers_key in lora_keymap.items(): 96 | if converted_key == ldm_key: 97 | converted_key = diffusers_key 98 | 99 | converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype) 100 | return converted_state_dict 101 | 102 | 103 | # check it again 104 | converted_state_dict = export_state_dict(our_save) 105 | converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys())) 106 | locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys())) 107 | 108 | 109 | print(f"we have {len(converted_extra_keys)} extra keys") 110 | print(f"locon has {len(locon_extra_keys)} extra keys") 111 | 112 | print(f"our extra keys: {converted_extra_keys}") 113 | 114 | # convert back 115 | cycle_state_dict = import_state_dict(converted_state_dict) 116 | cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys())) 117 | our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys())) 118 | 119 | print(f"we have {len(our_extra_keys)} extra keys") 120 | print(f"cycle has {len(cycle_extra_keys)} extra keys") 121 | 122 | # save keymap 123 | to_save = OrderedDict() 124 | to_save['ldm_diffusers_keymap'] = lora_keymap 125 | 126 | with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f: 127 | json.dump(to_save, f, indent=4) 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /testing/shrink_pixart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | from collections import OrderedDict 4 | 5 | model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" 6 | output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" 7 | 8 | state_dict = load_file(model_path) 9 | 10 | meta = OrderedDict() 11 | meta["format"] = "pt" 12 | 13 | new_state_dict = {} 14 | 15 | # Move non-blocks over 16 | for key, value in state_dict.items(): 17 | if not key.startswith("transformer_blocks."): 18 | new_state_dict[key] = value 19 | 20 | block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', 21 | 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', 22 | 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', 23 | 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', 24 | 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', 25 | 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', 26 | 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', 27 | 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', 28 | 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', 29 | 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', 30 | 'transformer_blocks.{idx}.scale_shift_table'] 31 | 32 | # New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27 33 | 34 | current_idx = 0 35 | for i in range(28): 36 | if i not in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]: 37 | # todo merge in with previous block 38 | for name in block_names: 39 | try: 40 | new_state_dict_key = name.format(idx=current_idx - 1) 41 | old_state_dict_key = name.format(idx=i) 42 | new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5) 43 | except KeyError: 44 | raise KeyError(f"KeyError: {name.format(idx=current_idx)}") 45 | else: 46 | for name in block_names: 47 | new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)] 48 | current_idx += 1 49 | 50 | 51 | # make sure they are all fp16 and on cpu 52 | for key, value in new_state_dict.items(): 53 | new_state_dict[key] = value.to(torch.float16).cpu() 54 | 55 | # save the new state dict 56 | save_file(new_state_dict, output_path, metadata=meta) 57 | 58 | new_param_count = sum([v.numel() for v in new_state_dict.values()]) 59 | old_param_count = sum([v.numel() for v in state_dict.values()]) 60 | 61 | print(f"Old param count: {old_param_count:,}") 62 | print(f"New param count: {new_param_count:,}") -------------------------------------------------------------------------------- /testing/shrink_pixart2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | from collections import OrderedDict 4 | 5 | model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" 6 | output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" 7 | 8 | state_dict = load_file(model_path) 9 | 10 | meta = OrderedDict() 11 | meta["format"] = "pt" 12 | 13 | new_state_dict = {} 14 | 15 | # Move non-blocks over 16 | for key, value in state_dict.items(): 17 | if not key.startswith("transformer_blocks."): 18 | new_state_dict[key] = value 19 | 20 | block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', 21 | 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', 22 | 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', 23 | 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', 24 | 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', 25 | 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', 26 | 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', 27 | 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', 28 | 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', 29 | 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', 30 | 'transformer_blocks.{idx}.scale_shift_table'] 31 | 32 | # Blocks to keep 33 | # keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27] 34 | keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27] 35 | 36 | 37 | def weighted_merge(kept_block, removed_block, weight): 38 | return kept_block * (1 - weight) + removed_block * weight 39 | 40 | 41 | # First, copy all kept blocks to new_state_dict 42 | for i, old_idx in enumerate(keep_blocks): 43 | for name in block_names: 44 | old_key = name.format(idx=old_idx) 45 | new_key = name.format(idx=i) 46 | new_state_dict[new_key] = state_dict[old_key].clone() 47 | 48 | # Then, merge information from removed blocks 49 | for i in range(28): 50 | if i not in keep_blocks: 51 | # Find the nearest kept blocks 52 | prev_kept = max([b for b in keep_blocks if b < i]) 53 | next_kept = min([b for b in keep_blocks if b > i]) 54 | 55 | # Calculate the weight based on position 56 | weight = (i - prev_kept) / (next_kept - prev_kept) 57 | 58 | for name in block_names: 59 | removed_key = name.format(idx=i) 60 | prev_new_key = name.format(idx=keep_blocks.index(prev_kept)) 61 | next_new_key = name.format(idx=keep_blocks.index(next_kept)) 62 | 63 | # Weighted merge for previous kept block 64 | new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight) 65 | 66 | # Weighted merge for next kept block 67 | new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key], 68 | 1 - weight) 69 | 70 | # Convert to fp16 and move to CPU 71 | for key, value in new_state_dict.items(): 72 | new_state_dict[key] = value.to(torch.float16).cpu() 73 | 74 | # Save the new state dict 75 | save_file(new_state_dict, output_path, metadata=meta) 76 | 77 | new_param_count = sum([v.numel() for v in new_state_dict.values()]) 78 | old_param_count = sum([v.numel() for v in state_dict.values()]) 79 | 80 | print(f"Old param count: {old_param_count:,}") 81 | print(f"New param count: {new_param_count:,}") -------------------------------------------------------------------------------- /testing/shrink_pixart_sm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | from collections import OrderedDict 4 | 5 | meta = OrderedDict() 6 | meta['format'] = "pt" 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def reduce_weight(weight, target_size): 12 | weight = weight.to(device, torch.float32) 13 | original_shape = weight.shape 14 | flattened = weight.view(-1, original_shape[-1]) 15 | 16 | if flattened.shape[1] <= target_size: 17 | return weight 18 | 19 | U, S, V = torch.svd(flattened) 20 | reduced = torch.mm(U[:, :target_size], torch.diag(S[:target_size])) 21 | 22 | if reduced.shape[1] < target_size: 23 | padding = torch.zeros(reduced.shape[0], target_size - reduced.shape[1], device=device) 24 | reduced = torch.cat((reduced, padding), dim=1) 25 | 26 | return reduced.view(original_shape[:-1] + (target_size,)) 27 | 28 | 29 | def reduce_bias(bias, target_size): 30 | bias = bias.to(device, torch.float32) 31 | original_size = bias.shape[0] 32 | 33 | if original_size <= target_size: 34 | return torch.nn.functional.pad(bias, (0, target_size - original_size)) 35 | else: 36 | return bias.view(-1, original_size // target_size).mean(dim=1)[:target_size] 37 | 38 | 39 | # Load your original state dict 40 | state_dict = load_file( 41 | "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") 42 | 43 | # Create a new state dict for the reduced model 44 | new_state_dict = {} 45 | 46 | source_hidden_size = 1152 47 | target_hidden_size = 1024 48 | 49 | for key, value in state_dict.items(): 50 | value = value.to(device, torch.float32) 51 | if 'weight' in key or 'scale_shift_table' in key: 52 | if value.shape[0] == source_hidden_size: 53 | value = value[:target_hidden_size] 54 | elif value.shape[0] == source_hidden_size * 4: 55 | value = value[:target_hidden_size * 4] 56 | elif value.shape[0] == source_hidden_size * 6: 57 | value = value[:target_hidden_size * 6] 58 | 59 | if len(value.shape) > 1 and value.shape[ 60 | 1] == source_hidden_size and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: 61 | value = value[:, :target_hidden_size] 62 | elif len(value.shape) > 1 and value.shape[1] == source_hidden_size * 4: 63 | value = value[:, :target_hidden_size * 4] 64 | 65 | elif 'bias' in key: 66 | if value.shape[0] == source_hidden_size: 67 | value = value[:target_hidden_size] 68 | elif value.shape[0] == source_hidden_size * 4: 69 | value = value[:target_hidden_size * 4] 70 | elif value.shape[0] == source_hidden_size * 6: 71 | value = value[:target_hidden_size * 6] 72 | 73 | new_state_dict[key] = value 74 | 75 | # Move all to CPU and convert to float16 76 | for key, value in new_state_dict.items(): 77 | new_state_dict[key] = value.cpu().to(torch.float16) 78 | 79 | # Save the new state dict 80 | save_file(new_state_dict, 81 | "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", 82 | metadata=meta) 83 | 84 | print("Done!") 85 | -------------------------------------------------------------------------------- /testing/shrink_pixart_sm2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | from collections import OrderedDict 4 | 5 | meta = OrderedDict() 6 | meta['format'] = "pt" 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def reduce_weight(weight, target_size): 12 | weight = weight.to(device, torch.float32) 13 | original_shape = weight.shape 14 | 15 | if len(original_shape) == 1: 16 | # For 1D tensors, simply truncate 17 | return weight[:target_size] 18 | 19 | if original_shape[0] <= target_size: 20 | return weight 21 | 22 | # Reshape the tensor to 2D 23 | flattened = weight.reshape(original_shape[0], -1) 24 | 25 | # Perform SVD 26 | U, S, V = torch.svd(flattened) 27 | 28 | # Reduce the dimensions 29 | reduced = torch.mm(U[:target_size, :], torch.diag(S)).mm(V.t()) 30 | 31 | # Reshape back to the original shape with reduced first dimension 32 | new_shape = (target_size,) + original_shape[1:] 33 | return reduced.reshape(new_shape) 34 | 35 | 36 | def reduce_bias(bias, target_size): 37 | bias = bias.to(device, torch.float32) 38 | return bias[:target_size] 39 | 40 | 41 | # Load your original state dict 42 | state_dict = load_file( 43 | "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") 44 | 45 | # Create a new state dict for the reduced model 46 | new_state_dict = {} 47 | 48 | for key, value in state_dict.items(): 49 | value = value.to(device, torch.float32) 50 | 51 | if 'weight' in key or 'scale_shift_table' in key: 52 | if value.shape[0] == 1152: 53 | if len(value.shape) == 4: 54 | orig_shape = value.shape 55 | output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) 56 | # reshape to (1152, -1) 57 | value = value.view(value.shape[0], -1) 58 | value = reduce_weight(value, 512) 59 | value = value.view(output_shape) 60 | else: 61 | # value = reduce_weight(value.t(), 576).t().contiguous() 62 | value = reduce_weight(value, 512) 63 | pass 64 | elif value.shape[0] == 4608: 65 | if len(value.shape) == 4: 66 | orig_shape = value.shape 67 | output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) 68 | value = value.view(value.shape[0], -1) 69 | value = reduce_weight(value, 2048) 70 | value = value.view(output_shape) 71 | else: 72 | value = reduce_weight(value, 2048) 73 | elif value.shape[0] == 6912: 74 | if len(value.shape) == 4: 75 | orig_shape = value.shape 76 | output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) 77 | value = value.view(value.shape[0], -1) 78 | value = reduce_weight(value, 3072) 79 | value = value.view(output_shape) 80 | else: 81 | value = reduce_weight(value, 3072) 82 | 83 | if len(value.shape) > 1 and value.shape[ 84 | 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: 85 | value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction 86 | pass 87 | elif len(value.shape) > 1 and value.shape[1] == 4608: 88 | value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction 89 | pass 90 | 91 | elif 'bias' in key: 92 | if value.shape[0] == 1152: 93 | value = reduce_bias(value, 512) 94 | elif value.shape[0] == 4608: 95 | value = reduce_bias(value, 2048) 96 | elif value.shape[0] == 6912: 97 | value = reduce_bias(value, 3072) 98 | 99 | new_state_dict[key] = value 100 | 101 | # Move all to CPU and convert to float16 102 | for key, value in new_state_dict.items(): 103 | new_state_dict[key] = value.cpu().to(torch.float16) 104 | 105 | # Save the new state dict 106 | save_file(new_state_dict, 107 | "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", 108 | metadata=meta) 109 | 110 | print("Done!") -------------------------------------------------------------------------------- /testing/shrink_pixart_sm3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | from collections import OrderedDict 4 | 5 | meta = OrderedDict() 6 | meta['format'] = "pt" 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def reduce_weight(weight, target_size): 12 | weight = weight.to(device, torch.float32) 13 | # resize so target_size is the first dimension 14 | tmp_weight = weight.view(1, 1, weight.shape[0], weight.shape[1]) 15 | 16 | # use interpolate to resize the tensor 17 | new_weight = torch.nn.functional.interpolate(tmp_weight, size=(target_size, weight.shape[1]), mode='bicubic', align_corners=True) 18 | 19 | # reshape back to original shape 20 | return new_weight.view(target_size, weight.shape[1]) 21 | 22 | 23 | def reduce_bias(bias, target_size): 24 | bias = bias.view(1, 1, bias.shape[0], 1) 25 | 26 | new_bias = torch.nn.functional.interpolate(bias, size=(target_size, 1), mode='bicubic', align_corners=True) 27 | 28 | return new_bias.view(target_size) 29 | 30 | 31 | # Load your original state dict 32 | state_dict = load_file( 33 | "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") 34 | 35 | # Create a new state dict for the reduced model 36 | new_state_dict = {} 37 | 38 | for key, value in state_dict.items(): 39 | value = value.to(device, torch.float32) 40 | 41 | if 'weight' in key or 'scale_shift_table' in key: 42 | if value.shape[0] == 1152: 43 | if len(value.shape) == 4: 44 | orig_shape = value.shape 45 | output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) 46 | # reshape to (1152, -1) 47 | value = value.view(value.shape[0], -1) 48 | value = reduce_weight(value, 512) 49 | value = value.view(output_shape) 50 | else: 51 | # value = reduce_weight(value.t(), 576).t().contiguous() 52 | value = reduce_weight(value, 512) 53 | pass 54 | elif value.shape[0] == 4608: 55 | if len(value.shape) == 4: 56 | orig_shape = value.shape 57 | output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) 58 | value = value.view(value.shape[0], -1) 59 | value = reduce_weight(value, 2048) 60 | value = value.view(output_shape) 61 | else: 62 | value = reduce_weight(value, 2048) 63 | elif value.shape[0] == 6912: 64 | if len(value.shape) == 4: 65 | orig_shape = value.shape 66 | output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) 67 | value = value.view(value.shape[0], -1) 68 | value = reduce_weight(value, 3072) 69 | value = value.view(output_shape) 70 | else: 71 | value = reduce_weight(value, 3072) 72 | 73 | if len(value.shape) > 1 and value.shape[ 74 | 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: 75 | value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction 76 | pass 77 | elif len(value.shape) > 1 and value.shape[1] == 4608: 78 | value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction 79 | pass 80 | 81 | elif 'bias' in key: 82 | if value.shape[0] == 1152: 83 | value = reduce_bias(value, 512) 84 | elif value.shape[0] == 4608: 85 | value = reduce_bias(value, 2048) 86 | elif value.shape[0] == 6912: 87 | value = reduce_bias(value, 3072) 88 | 89 | new_state_dict[key] = value 90 | 91 | # Move all to CPU and convert to float16 92 | for key, value in new_state_dict.items(): 93 | new_state_dict[key] = value.cpu().to(torch.float16) 94 | 95 | # Save the new state dict 96 | save_file(new_state_dict, 97 | "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", 98 | metadata=meta) 99 | 100 | print("Done!") -------------------------------------------------------------------------------- /testing/test_bucket_dataloader.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | import sys 8 | import os 9 | import cv2 10 | import random 11 | from transformers import CLIPImageProcessor 12 | 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | from toolkit.paths import SD_SCRIPTS_ROOT 15 | import torchvision.transforms.functional 16 | from toolkit.image_utils import show_img 17 | 18 | sys.path.append(SD_SCRIPTS_ROOT) 19 | 20 | from library.model_util import load_vae 21 | from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO 22 | from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \ 23 | trigger_dataloader_setup_epoch 24 | from toolkit.config_modules import DatasetConfig 25 | import argparse 26 | from tqdm import tqdm 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('dataset_folder', type=str, default='input') 30 | parser.add_argument('--epochs', type=int, default=1) 31 | 32 | 33 | 34 | args = parser.parse_args() 35 | 36 | dataset_folder = args.dataset_folder 37 | resolution = 512 38 | bucket_tolerance = 64 39 | batch_size = 1 40 | 41 | clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") 42 | 43 | class FakeAdapter: 44 | def __init__(self): 45 | self.clip_image_processor = clip_processor 46 | 47 | 48 | ## make fake sd 49 | class FakeSD: 50 | def __init__(self): 51 | self.adapter = FakeAdapter() 52 | 53 | 54 | 55 | 56 | dataset_config = DatasetConfig( 57 | dataset_path=dataset_folder, 58 | clip_image_path=dataset_folder, 59 | square_crop=True, 60 | resolution=resolution, 61 | # caption_ext='json', 62 | default_caption='default', 63 | # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/', 64 | buckets=True, 65 | bucket_tolerance=bucket_tolerance, 66 | # poi='person', 67 | # shuffle_augmentations=True, 68 | # augmentations=[ 69 | # { 70 | # 'method': 'Posterize', 71 | # 'num_bits': [(0, 4), (0, 4), (0, 4)], 72 | # 'p': 1.0 73 | # }, 74 | # 75 | # ] 76 | ) 77 | 78 | dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) 79 | 80 | 81 | # run through an epoch ang check sizes 82 | dataloader_iterator = iter(dataloader) 83 | for epoch in range(args.epochs): 84 | for batch in tqdm(dataloader): 85 | batch: 'DataLoaderBatchDTO' 86 | img_batch = batch.tensor 87 | batch_size, channels, height, width = img_batch.shape 88 | 89 | # img_batch = color_block_imgs(img_batch, neg1_1=True) 90 | 91 | chunks = torch.chunk(img_batch, batch_size, dim=0) 92 | # put them so they are size by side 93 | big_img = torch.cat(chunks, dim=3) 94 | big_img = big_img.squeeze(0) 95 | 96 | control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) 97 | big_control_img = torch.cat(control_chunks, dim=3) 98 | big_control_img = big_control_img.squeeze(0) * 2 - 1 99 | 100 | 101 | # resize control image 102 | big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) 103 | 104 | big_img = torch.cat([big_img, big_control_img], dim=2) 105 | 106 | min_val = big_img.min() 107 | max_val = big_img.max() 108 | 109 | big_img = (big_img / 2 + 0.5).clamp(0, 1) 110 | 111 | # convert to image 112 | img = transforms.ToPILImage()(big_img) 113 | 114 | show_img(img) 115 | 116 | time.sleep(1.0) 117 | # if not last epoch 118 | if epoch < args.epochs - 1: 119 | trigger_dataloader_setup_epoch(dataloader) 120 | 121 | cv2.destroyAllWindows() 122 | 123 | print('done') 124 | -------------------------------------------------------------------------------- /testing/test_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | import torch 5 | from torchvision.transforms import Resize, ToTensor 6 | from diffusers import AutoencoderKL 7 | from pytorch_fid import fid_score 8 | from skimage.metrics import peak_signal_noise_ratio as psnr 9 | import lpips 10 | from tqdm import tqdm 11 | from torchvision import transforms 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | def load_images(folder_path): 16 | images = [] 17 | for filename in os.listdir(folder_path): 18 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')): 19 | img_path = os.path.join(folder_path, filename) 20 | images.append(img_path) 21 | return images 22 | 23 | 24 | def paramiter_count(model): 25 | state_dict = model.state_dict() 26 | paramiter_count = 0 27 | for key in state_dict: 28 | paramiter_count += torch.numel(state_dict[key]) 29 | return int(paramiter_count) 30 | 31 | 32 | def calculate_metrics(vae, images, max_imgs=-1): 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | vae = vae.to(device) 35 | lpips_model = lpips.LPIPS(net='alex').to(device) 36 | 37 | rfid_scores = [] 38 | psnr_scores = [] 39 | lpips_scores = [] 40 | 41 | # transform = transforms.Compose([ 42 | # transforms.Resize(256, antialias=True), 43 | # transforms.CenterCrop(256) 44 | # ]) 45 | # needs values between -1 and 1 46 | to_tensor = ToTensor() 47 | 48 | if max_imgs > 0 and len(images) > max_imgs: 49 | images = images[:max_imgs] 50 | 51 | for img_path in tqdm(images): 52 | try: 53 | img = Image.open(img_path).convert('RGB') 54 | # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device) 55 | img_tensor = to_tensor(img).unsqueeze(0).to(device) 56 | img_tensor = 2 * img_tensor - 1 57 | # if width or height is not divisible by 8, crop it 58 | if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0: 59 | img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8] 60 | 61 | except Exception as e: 62 | print(f"Error processing {img_path}: {e}") 63 | continue 64 | 65 | 66 | with torch.no_grad(): 67 | reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample 68 | 69 | # Calculate rFID 70 | # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed) 71 | # rfid_scores.append(rfid) 72 | 73 | # Calculate PSNR 74 | psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy()) 75 | psnr_scores.append(psnr_val) 76 | 77 | # Calculate LPIPS 78 | lpips_val = lpips_model(img_tensor, reconstructed).item() 79 | lpips_scores.append(lpips_val) 80 | 81 | # avg_rfid = sum(rfid_scores) / len(rfid_scores) 82 | avg_rfid = 0 83 | avg_psnr = sum(psnr_scores) / len(psnr_scores) 84 | avg_lpips = sum(lpips_scores) / len(lpips_scores) 85 | 86 | return avg_rfid, avg_psnr, avg_lpips 87 | 88 | 89 | def main(): 90 | parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions") 91 | parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") 92 | parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") 93 | parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") 94 | args = parser.parse_args() 95 | 96 | if os.path.isfile(args.vae_path): 97 | vae = AutoencoderKL.from_single_file(args.vae_path) 98 | else: 99 | vae = AutoencoderKL.from_pretrained(args.vae_path) 100 | vae.eval() 101 | vae = vae.to(device) 102 | print(f"Model has {paramiter_count(vae)} parameters") 103 | images = load_images(args.image_folder) 104 | 105 | avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) 106 | 107 | # print(f"Average rFID: {avg_rfid}") 108 | print(f"Average PSNR: {avg_psnr}") 109 | print(f"Average LPIPS: {avg_lpips}") 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /testing/test_vae_cycle.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from safetensors.torch import load_file 5 | from collections import OrderedDict 6 | from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers 7 | import json 8 | # this was just used to match the vae keys to the diffusers keys 9 | # you probably wont need this. Unless they change them.... again... again 10 | # on second thought, you probably will 11 | 12 | device = torch.device('cpu') 13 | dtype = torch.float32 14 | vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors' 15 | 16 | find_matches = False 17 | 18 | state_dict_ldm = load_file(vae_path) 19 | diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device) 20 | 21 | ldm_keys = state_dict_ldm.keys() 22 | 23 | matched_keys = {} 24 | duplicated_keys = { 25 | 26 | } 27 | 28 | if find_matches: 29 | # find values that match with a very low mse 30 | for ldm_key in ldm_keys: 31 | ldm_value = state_dict_ldm[ldm_key] 32 | for diffusers_key in list(diffusers_vae.state_dict().keys()): 33 | diffusers_value = diffusers_vae.state_dict()[diffusers_key] 34 | if diffusers_key in vae_keys_squished_on_diffusers: 35 | diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1) 36 | # if they are not same shape, skip 37 | if ldm_value.shape != diffusers_value.shape: 38 | continue 39 | mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value) 40 | if mse < 1e-6: 41 | if ldm_key in list(matched_keys.keys()): 42 | print(f'{ldm_key} already matched to {matched_keys[ldm_key]}') 43 | if ldm_key in duplicated_keys: 44 | duplicated_keys[ldm_key].append(diffusers_key) 45 | else: 46 | duplicated_keys[ldm_key] = [diffusers_key] 47 | continue 48 | matched_keys[ldm_key] = diffusers_key 49 | is_matched = True 50 | break 51 | 52 | print(f'Found {len(matched_keys)} matches') 53 | 54 | dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae) 55 | dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys()) 56 | keys_in_both = [] 57 | 58 | keys_not_in_diffusers = [] 59 | for key in ldm_keys: 60 | if key not in dif_to_ldm_state_dict_keys: 61 | keys_not_in_diffusers.append(key) 62 | 63 | keys_not_in_ldm = [] 64 | for key in dif_to_ldm_state_dict_keys: 65 | if key not in ldm_keys: 66 | keys_not_in_ldm.append(key) 67 | 68 | keys_in_both = [] 69 | for key in ldm_keys: 70 | if key in dif_to_ldm_state_dict_keys: 71 | keys_in_both.append(key) 72 | 73 | # sort them 74 | keys_not_in_diffusers.sort() 75 | keys_not_in_ldm.sort() 76 | keys_in_both.sort() 77 | 78 | # print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}') 79 | # print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}') 80 | # print(f'Keys in both: {len(keys_in_both)}{keys_in_both}') 81 | 82 | json_data = { 83 | "both": keys_in_both, 84 | "ldm": keys_not_in_diffusers, 85 | "diffusers": keys_not_in_ldm 86 | } 87 | json_data = json.dumps(json_data, indent=4) 88 | 89 | remaining_diffusers_values = OrderedDict() 90 | for key in keys_not_in_ldm: 91 | remaining_diffusers_values[key] = dif_to_ldm_state_dict[key] 92 | 93 | # print(remaining_diffusers_values.keys()) 94 | 95 | remaining_ldm_values = OrderedDict() 96 | for key in keys_not_in_diffusers: 97 | remaining_ldm_values[key] = state_dict_ldm[key] 98 | 99 | # print(json_data) 100 | 101 | project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 102 | json_save_path = os.path.join(project_root, 'config', 'keys.json') 103 | json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') 104 | json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') 105 | 106 | with open(json_save_path, 'w') as f: 107 | f.write(json_data) 108 | if find_matches: 109 | with open(json_matched_save_path, 'w') as f: 110 | f.write(json.dumps(matched_keys, indent=4)) 111 | with open(json_duped_save_path, 'w') as f: 112 | f.write(json.dumps(duplicated_keys, indent=4)) 113 | -------------------------------------------------------------------------------- /toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/toolkit/__init__.py -------------------------------------------------------------------------------- /toolkit/basic.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch 4 | 5 | 6 | def value_map(inputs, min_in, max_in, min_out, max_out): 7 | return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out 8 | 9 | 10 | def flush(garbage_collect=True): 11 | torch.cuda.empty_cache() 12 | if garbage_collect: 13 | gc.collect() 14 | 15 | 16 | def get_mean_std(tensor): 17 | if len(tensor.shape) == 3: 18 | tensor = tensor.unsqueeze(0) 19 | elif len(tensor.shape) != 4: 20 | raise Exception("Expected tensor of shape (batch_size, channels, width, height)") 21 | mean, variance = torch.mean( 22 | tensor, dim=[2, 3], keepdim=True 23 | ), torch.var( 24 | tensor, dim=[2, 3], 25 | keepdim=True 26 | ) 27 | std = torch.sqrt(variance + 1e-5) 28 | return mean, std 29 | 30 | 31 | def adain(content_features, style_features): 32 | # Assumes that the content and style features are of shape (batch_size, channels, width, height) 33 | 34 | dims = [2, 3] 35 | if len(content_features.shape) == 3: 36 | # content_features = content_features.unsqueeze(0) 37 | # style_features = style_features.unsqueeze(0) 38 | dims = [1] 39 | 40 | # Step 1: Calculate mean and variance of content features 41 | content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features, 42 | dim=dims, 43 | keepdim=True) 44 | # Step 2: Calculate mean and variance of style features 45 | style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims, 46 | keepdim=True) 47 | 48 | # Step 3: Normalize content features 49 | content_std = torch.sqrt(content_var + 1e-5) 50 | normalized_content = (content_features - content_mean) / content_std 51 | 52 | # Step 4: Scale and shift normalized content with style's statistics 53 | style_std = torch.sqrt(style_var + 1e-5) 54 | stylized_content = normalized_content * style_std + style_mean 55 | 56 | return stylized_content 57 | -------------------------------------------------------------------------------- /toolkit/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Union 4 | 5 | import oyaml as yaml 6 | import re 7 | from collections import OrderedDict 8 | 9 | from toolkit.paths import TOOLKIT_ROOT 10 | 11 | possible_extensions = ['.json', '.jsonc', '.yaml', '.yml'] 12 | 13 | 14 | def get_cwd_abs_path(path): 15 | if not os.path.isabs(path): 16 | path = os.path.join(os.getcwd(), path) 17 | return path 18 | 19 | 20 | def replace_env_vars_in_string(s: str) -> str: 21 | """ 22 | Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable. 23 | If the environment variable is not set, raise an error. 24 | """ 25 | 26 | def replacer(match): 27 | var_name = match.group(1) 28 | value = os.environ.get(var_name) 29 | 30 | if value is None: 31 | raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.") 32 | 33 | return value 34 | 35 | return re.sub(r'\$\{([^}]+)\}', replacer, s) 36 | 37 | 38 | def preprocess_config(config: OrderedDict, name: str = None): 39 | if "job" not in config: 40 | raise ValueError("config file must have a job key") 41 | if "config" not in config: 42 | raise ValueError("config file must have a config section") 43 | if "name" not in config["config"] and name is None: 44 | raise ValueError("config file must have a config.name key") 45 | # we need to replace tags. For now just [name] 46 | if name is None: 47 | name = config["config"]["name"] 48 | config_string = json.dumps(config) 49 | config_string = config_string.replace("[name]", name) 50 | config = json.loads(config_string, object_pairs_hook=OrderedDict) 51 | return config 52 | 53 | 54 | # Fixes issue where yaml doesnt load exponents correctly 55 | fixed_loader = yaml.SafeLoader 56 | fixed_loader.add_implicit_resolver( 57 | u'tag:yaml.org,2002:float', 58 | re.compile(u'''^(?: 59 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 60 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 61 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 62 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 63 | |[-+]?\\.(?:inf|Inf|INF) 64 | |\\.(?:nan|NaN|NAN))$''', re.X), 65 | list(u'-+0123456789.')) 66 | 67 | 68 | def get_config( 69 | config_file_path_or_dict: Union[str, dict, OrderedDict], 70 | name=None 71 | ): 72 | # if we got a dict, process it and return it 73 | if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict): 74 | config = config_file_path_or_dict 75 | return preprocess_config(config, name) 76 | 77 | config_file_path = config_file_path_or_dict 78 | 79 | # first check if it is in the config folder 80 | config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) 81 | # see if it is in the config folder with any of the possible extensions if it doesnt have one 82 | real_config_path = None 83 | if not os.path.exists(config_path): 84 | for ext in possible_extensions: 85 | if os.path.exists(config_path + ext): 86 | real_config_path = config_path + ext 87 | break 88 | 89 | # if we didn't find it there, check if it is a full path 90 | if not real_config_path: 91 | if os.path.exists(config_file_path): 92 | real_config_path = config_file_path 93 | elif os.path.exists(get_cwd_abs_path(config_file_path)): 94 | real_config_path = get_cwd_abs_path(config_file_path) 95 | 96 | if not real_config_path: 97 | raise ValueError(f"Could not find config file {config_file_path}") 98 | 99 | # if we found it, check if it is a json or yaml file 100 | with open(real_config_path, 'r', encoding='utf-8') as f: 101 | content = f.read() 102 | content_with_env_replaced = replace_env_vars_in_string(content) 103 | if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'): 104 | config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict) 105 | elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'): 106 | config = yaml.load(content_with_env_replaced, Loader=fixed_loader) 107 | else: 108 | raise ValueError(f"Config file {config_file_path} must be a json or yaml file") 109 | 110 | return preprocess_config(config, name) 111 | -------------------------------------------------------------------------------- /toolkit/cuda_malloc.py: -------------------------------------------------------------------------------- 1 | # ref comfy ui 2 | import os 3 | import importlib.util 4 | 5 | 6 | # Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. 7 | def get_gpu_names(): 8 | if os.name == 'nt': 9 | import ctypes 10 | 11 | # Define necessary C structures and types 12 | class DISPLAY_DEVICEA(ctypes.Structure): 13 | _fields_ = [ 14 | ('cb', ctypes.c_ulong), 15 | ('DeviceName', ctypes.c_char * 32), 16 | ('DeviceString', ctypes.c_char * 128), 17 | ('StateFlags', ctypes.c_ulong), 18 | ('DeviceID', ctypes.c_char * 128), 19 | ('DeviceKey', ctypes.c_char * 128) 20 | ] 21 | 22 | # Load user32.dll 23 | user32 = ctypes.windll.user32 24 | 25 | # Call EnumDisplayDevicesA 26 | def enum_display_devices(): 27 | device_info = DISPLAY_DEVICEA() 28 | device_info.cb = ctypes.sizeof(device_info) 29 | device_index = 0 30 | gpu_names = set() 31 | 32 | while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): 33 | device_index += 1 34 | gpu_names.add(device_info.DeviceString.decode('utf-8')) 35 | return gpu_names 36 | 37 | return enum_display_devices() 38 | else: 39 | return set() 40 | 41 | 42 | blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", 43 | "GeForce 945M", 44 | "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", 45 | "Quadro K620", 46 | "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", 47 | "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", 48 | "Quadro M5500", "Quadro M6000", 49 | "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", 50 | "GeForce GTX 1650", "GeForce GTX 1630" 51 | } 52 | 53 | 54 | def cuda_malloc_supported(): 55 | try: 56 | names = get_gpu_names() 57 | except: 58 | names = set() 59 | for x in names: 60 | if "NVIDIA" in x: 61 | for b in blacklist: 62 | if b in x: 63 | return False 64 | return True 65 | 66 | 67 | cuda_malloc = False 68 | 69 | if not cuda_malloc: 70 | try: 71 | version = "" 72 | torch_spec = importlib.util.find_spec("torch") 73 | for folder in torch_spec.submodule_search_locations: 74 | ver_file = os.path.join(folder, "version.py") 75 | if os.path.isfile(ver_file): 76 | spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) 77 | module = importlib.util.module_from_spec(spec) 78 | spec.loader.exec_module(module) 79 | version = module.__version__ 80 | if int(version[0]) >= 2: # enable by default for torch version 2.0 and up 81 | cuda_malloc = cuda_malloc_supported() 82 | except: 83 | pass 84 | 85 | if cuda_malloc: 86 | env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) 87 | if env_var is None: 88 | env_var = "backend:cudaMallocAsync" 89 | else: 90 | env_var += ",backend:cudaMallocAsync" 91 | 92 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var 93 | print("CUDA Malloc Async Enabled") 94 | -------------------------------------------------------------------------------- /toolkit/esrgan_utils.py: -------------------------------------------------------------------------------- 1 | 2 | to_basicsr_dict = { 3 | 'model.0.weight': 'conv_first.weight', 4 | 'model.0.bias': 'conv_first.bias', 5 | 'model.1.sub.23.weight': 'conv_body.weight', 6 | 'model.1.sub.23.bias': 'conv_body.bias', 7 | 'model.3.weight': 'conv_up1.weight', 8 | 'model.3.bias': 'conv_up1.bias', 9 | 'model.6.weight': 'conv_up2.weight', 10 | 'model.6.bias': 'conv_up2.bias', 11 | 'model.8.weight': 'conv_hr.weight', 12 | 'model.8.bias': 'conv_hr.bias', 13 | 'model.10.bias': 'conv_last.bias', 14 | 'model.10.weight': 'conv_last.weight', 15 | # 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight' 16 | } 17 | 18 | def convert_state_dict_to_basicsr(state_dict): 19 | new_state_dict = {} 20 | for k, v in state_dict.items(): 21 | if k in to_basicsr_dict: 22 | new_state_dict[to_basicsr_dict[k]] = v 23 | elif k.startswith('model.1.sub.'): 24 | bsr_name = k.replace('model.1.sub.', 'body.').lower() 25 | bsr_name = bsr_name.replace('.0.weight', '.weight') 26 | bsr_name = bsr_name.replace('.0.bias', '.bias') 27 | new_state_dict[bsr_name] = v 28 | else: 29 | new_state_dict[k] = v 30 | return new_state_dict 31 | 32 | 33 | # just matching a commonly used format 34 | def convert_basicsr_state_dict_to_save_format(state_dict): 35 | new_state_dict = {} 36 | to_basicsr_dict_values = list(to_basicsr_dict.values()) 37 | for k, v in state_dict.items(): 38 | if k in to_basicsr_dict_values: 39 | for key, value in to_basicsr_dict.items(): 40 | if value == k: 41 | new_state_dict[key] = v 42 | 43 | elif k.startswith('body.'): 44 | bsr_name = k.replace('body.', 'model.1.sub.').lower() 45 | bsr_name = bsr_name.replace('rdb', 'RDB') 46 | bsr_name = bsr_name.replace('.weight', '.0.weight') 47 | bsr_name = bsr_name.replace('.bias', '.0.bias') 48 | new_state_dict[bsr_name] = v 49 | else: 50 | new_state_dict[k] = v 51 | return new_state_dict 52 | -------------------------------------------------------------------------------- /toolkit/extension.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import pkgutil 4 | from typing import List 5 | 6 | from toolkit.paths import TOOLKIT_ROOT 7 | 8 | 9 | class Extension(object): 10 | """Base class for extensions. 11 | 12 | Extensions are registered with the ExtensionManager, which is 13 | responsible for calling the extension's load() and unload() 14 | methods at the appropriate times. 15 | 16 | """ 17 | 18 | name: str = None 19 | uid: str = None 20 | 21 | @classmethod 22 | def get_process(cls): 23 | # extend in subclass 24 | pass 25 | 26 | 27 | def get_all_extensions() -> List[Extension]: 28 | extension_folders = ['extensions', 'extensions_built_in'] 29 | 30 | # This will hold the classes from all extension modules 31 | all_extension_classes: List[Extension] = [] 32 | 33 | # Iterate over all directories (i.e., packages) in the "extensions" directory 34 | for sub_dir in extension_folders: 35 | extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) 36 | for (_, name, _) in pkgutil.iter_modules([extensions_dir]): 37 | try: 38 | # Import the module 39 | module = importlib.import_module(f"{sub_dir}.{name}") 40 | # Get the value of the AI_TOOLKIT_EXTENSIONS variable 41 | extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) 42 | # Check if the value is a list 43 | if isinstance(extensions, list): 44 | # Iterate over the list and add the classes to the main list 45 | all_extension_classes.extend(extensions) 46 | except ImportError as e: 47 | print(f"Failed to import the {name} module. Error: {str(e)}") 48 | 49 | return all_extension_classes 50 | 51 | 52 | def get_all_extensions_process_dict(): 53 | all_extensions = get_all_extensions() 54 | process_dict = {} 55 | for extension in all_extensions: 56 | process_dict[extension.uid] = extension.get_process() 57 | return process_dict 58 | -------------------------------------------------------------------------------- /toolkit/job.py: -------------------------------------------------------------------------------- 1 | from typing import Union, OrderedDict 2 | 3 | from toolkit.config import get_config 4 | 5 | 6 | def get_job( 7 | config_path: Union[str, dict, OrderedDict], 8 | name=None 9 | ): 10 | config = get_config(config_path, name) 11 | if not config['job']: 12 | raise ValueError('config file is invalid. Missing "job" key') 13 | 14 | job = config['job'] 15 | if job == 'extract': 16 | from jobs import ExtractJob 17 | return ExtractJob(config) 18 | if job == 'train': 19 | from jobs import TrainJob 20 | return TrainJob(config) 21 | if job == 'mod': 22 | from jobs import ModJob 23 | return ModJob(config) 24 | if job == 'generate': 25 | from jobs import GenerateJob 26 | return GenerateJob(config) 27 | if job == 'extension': 28 | from jobs import ExtensionJob 29 | return ExtensionJob(config) 30 | 31 | # elif job == 'train': 32 | # from jobs import TrainJob 33 | # return TrainJob(config) 34 | else: 35 | raise ValueError(f'Unknown job type {job}') 36 | 37 | 38 | def run_job( 39 | config: Union[str, dict, OrderedDict], 40 | name=None 41 | ): 42 | job = get_job(config, name) 43 | job.run() 44 | job.cleanup() 45 | -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_refiner_unmatched.json: -------------------------------------------------------------------------------- 1 | { 2 | "ldm": { 3 | "conditioner.embedders.0.model.logit_scale": { 4 | "shape": [], 5 | "min": 4.60546875, 6 | "max": 4.60546875 7 | }, 8 | "conditioner.embedders.0.model.text_projection": { 9 | "shape": [ 10 | 1280, 11 | 1280 12 | ], 13 | "min": -0.15966796875, 14 | "max": 0.230712890625 15 | } 16 | }, 17 | "diffusers": { 18 | "te1_text_projection.weight": { 19 | "shape": [ 20 | 1280, 21 | 1280 22 | ], 23 | "min": -0.15966796875, 24 | "max": 0.230712890625 25 | } 26 | } 27 | } -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_sdxl_unmatched.json: -------------------------------------------------------------------------------- 1 | { 2 | "ldm": { 3 | "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": { 4 | "shape": [ 5 | 1, 6 | 77 7 | ], 8 | "min": 0.0, 9 | "max": 76.0 10 | }, 11 | "conditioner.embedders.1.model.logit_scale": { 12 | "shape": [], 13 | "min": 4.60546875, 14 | "max": 4.60546875 15 | }, 16 | "conditioner.embedders.1.model.text_projection": { 17 | "shape": [ 18 | 1280, 19 | 1280 20 | ], 21 | "min": -0.15966796875, 22 | "max": 0.230712890625 23 | } 24 | }, 25 | "diffusers": { 26 | "te1_text_projection.weight": { 27 | "shape": [ 28 | 1280, 29 | 1280 30 | ], 31 | "min": -0.15966796875, 32 | "max": 0.230712890625 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-ai-toolkit/12c8336fbc7d772c83789fa2e19ca04c15452999/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_ssd_unmatched.json: -------------------------------------------------------------------------------- 1 | { 2 | "ldm": { 3 | "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": { 4 | "shape": [ 5 | 1, 6 | 77 7 | ], 8 | "min": 0.0, 9 | "max": 76.0 10 | }, 11 | "conditioner.embedders.1.model.text_model.embeddings.position_ids": { 12 | "shape": [ 13 | 1, 14 | 77 15 | ], 16 | "min": 0.0, 17 | "max": 76.0 18 | } 19 | }, 20 | "diffusers": {} 21 | } -------------------------------------------------------------------------------- /toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /toolkit/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | 7 | class ReductionKernel(nn.Module): 8 | # Tensorflow 9 | def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None): 10 | if device is None: 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | super(ReductionKernel, self).__init__() 13 | self.kernel_size = kernel_size 14 | self.in_channels = in_channels 15 | numpy_kernel = self.build_kernel() 16 | self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) 17 | 18 | def build_kernel(self): 19 | # tensorflow kernel is (height, width, in_channels, out_channels) 20 | # pytorch kernel is (out_channels, in_channels, height, width) 21 | kernel_size = self.kernel_size 22 | channels = self.in_channels 23 | kernel_shape = [channels, channels, kernel_size, kernel_size] 24 | kernel = np.zeros(kernel_shape, np.float32) 25 | 26 | kernel_value = 1.0 / (kernel_size * kernel_size) 27 | for i in range(0, channels): 28 | kernel[i, i, :, :] = kernel_value 29 | return kernel 30 | 31 | def forward(self, x): 32 | return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1) 33 | 34 | 35 | class CheckpointGradients(nn.Module): 36 | def __init__(self, is_gradient_checkpointing=True): 37 | super(CheckpointGradients, self).__init__() 38 | self.is_gradient_checkpointing = is_gradient_checkpointing 39 | 40 | def forward(self, module, *args, num_chunks=1): 41 | if self.is_gradient_checkpointing: 42 | return checkpoint(module, *args, num_chunks=self.num_chunks) 43 | else: 44 | return module(*args) 45 | -------------------------------------------------------------------------------- /toolkit/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .llvae import LosslessLatentEncoder 3 | 4 | 5 | def total_variation(image): 6 | """ 7 | Compute normalized total variation. 8 | Inputs: 9 | - image: PyTorch Variable of shape (N, C, H, W) 10 | Returns: 11 | - TV: total variation normalized by the number of elements 12 | """ 13 | n_elements = image.shape[1] * image.shape[2] * image.shape[3] 14 | return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + 15 | torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) 16 | 17 | 18 | class ComparativeTotalVariation(torch.nn.Module): 19 | """ 20 | Compute the comparative loss in tv between two images. to match their tv 21 | """ 22 | 23 | def forward(self, pred, target): 24 | return torch.abs(total_variation(pred) - total_variation(target)) 25 | 26 | 27 | # Gradient penalty 28 | def get_gradient_penalty(critic, real, fake, device): 29 | with torch.autocast(device_type='cuda'): 30 | real = real.float() 31 | fake = fake.float() 32 | alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float() 33 | interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True) 34 | if torch.isnan(interpolates).any(): 35 | print('d_interpolates is nan') 36 | d_interpolates = critic(interpolates) 37 | fake = torch.ones(real.size(0), 1, device=device) 38 | 39 | if torch.isnan(d_interpolates).any(): 40 | print('fake is nan') 41 | gradients = torch.autograd.grad( 42 | outputs=d_interpolates, 43 | inputs=interpolates, 44 | grad_outputs=fake, 45 | create_graph=True, 46 | retain_graph=True, 47 | only_inputs=True, 48 | )[0] 49 | 50 | # see if any are nan 51 | if torch.isnan(gradients).any(): 52 | print('gradients is nan') 53 | 54 | gradients = gradients.view(gradients.size(0), -1) 55 | gradient_norm = gradients.norm(2, dim=1) 56 | gradient_penalty = ((gradient_norm - 1) ** 2).mean() 57 | return gradient_penalty.float() 58 | 59 | 60 | class PatternLoss(torch.nn.Module): 61 | def __init__(self, pattern_size=4, dtype=torch.float32): 62 | super().__init__() 63 | self.pattern_size = pattern_size 64 | self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype) 65 | 66 | def forward(self, pred, target): 67 | pred_latents = self.llvae_encoder(pred) 68 | target_latents = self.llvae_encoder(target) 69 | 70 | matrix_pixels = self.pattern_size * self.pattern_size 71 | 72 | color_chans = pred_latents.shape[1] // 3 73 | # pytorch 74 | r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1) 75 | r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1) 76 | 77 | def separated_chan_loss(latent_chan): 78 | nonlocal matrix_pixels 79 | chan_mean = torch.mean(latent_chan, dim=[1, 2, 3]) 80 | chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1) 81 | chan_loss = None 82 | for chan in chan_splits: 83 | this_mean = torch.mean(chan, dim=[1, 2, 3]) 84 | this_chan_loss = torch.abs(this_mean - chan_mean) 85 | if chan_loss is None: 86 | chan_loss = this_chan_loss 87 | else: 88 | chan_loss = chan_loss + this_chan_loss 89 | chan_loss = chan_loss * (1 / matrix_pixels) 90 | return chan_loss 91 | 92 | r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target)) 93 | g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target)) 94 | b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target)) 95 | return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333 96 | 97 | 98 | -------------------------------------------------------------------------------- /toolkit/metadata.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | from io import BytesIO 4 | 5 | import safetensors 6 | from safetensors import safe_open 7 | 8 | from info import software_meta 9 | from toolkit.train_tools import addnet_hash_legacy 10 | from toolkit.train_tools import addnet_hash_safetensors 11 | 12 | 13 | def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict: 14 | # stringify the meta and reparse OrderedDict to replace [name] with name 15 | meta_string = json.dumps(meta) 16 | if name is not None: 17 | meta_string = meta_string.replace("[name]", name) 18 | save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict) 19 | if add_software_info: 20 | save_meta["software"] = software_meta 21 | # safetensors can only be one level deep 22 | for key, value in save_meta.items(): 23 | # if not float, int, bool, or str, convert to json string 24 | if not isinstance(value, str): 25 | save_meta[key] = json.dumps(value) 26 | # add the pt format 27 | save_meta["format"] = "pt" 28 | return save_meta 29 | 30 | 31 | def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict: 32 | """Precalculate the model hashes needed by sd-webui-additional-networks to 33 | save time on indexing the model later.""" 34 | 35 | # Because writing user metadata to the file can change the result of 36 | # sd_models.model_hash(), only retain the training metadata for purposes of 37 | # calculating the hash, as they are meant to be immutable 38 | metadata = {k: v for k, v in meta.items() if k.startswith("ss_")} 39 | 40 | bytes = safetensors.torch.save(state_dict, metadata) 41 | b = BytesIO(bytes) 42 | 43 | model_hash = addnet_hash_safetensors(b) 44 | legacy_hash = addnet_hash_legacy(b) 45 | meta["sshs_model_hash"] = model_hash 46 | meta["sshs_legacy_hash"] = legacy_hash 47 | return meta 48 | 49 | 50 | def add_base_model_info_to_meta( 51 | meta: OrderedDict, 52 | base_model: str = None, 53 | is_v1: bool = False, 54 | is_v2: bool = False, 55 | is_xl: bool = False, 56 | ) -> OrderedDict: 57 | if base_model is not None: 58 | meta['ss_base_model'] = base_model 59 | elif is_v2: 60 | meta['ss_v2'] = True 61 | meta['ss_base_model_version'] = 'sd_2.1' 62 | 63 | elif is_xl: 64 | meta['ss_base_model_version'] = 'sdxl_1.0' 65 | else: 66 | # default to v1.5 67 | meta['ss_base_model_version'] = 'sd_1.5' 68 | return meta 69 | 70 | 71 | def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: 72 | parsed_meta = OrderedDict() 73 | for key, value in meta.items(): 74 | try: 75 | parsed_meta[key] = json.loads(value) 76 | except json.decoder.JSONDecodeError: 77 | parsed_meta[key] = value 78 | return parsed_meta 79 | 80 | 81 | def load_metadata_from_safetensors(file_path: str) -> OrderedDict: 82 | try: 83 | with safe_open(file_path, framework="pt") as f: 84 | metadata = f.metadata() 85 | return parse_metadata_from_safetensors(metadata) 86 | except Exception as e: 87 | print(f"Error loading metadata from {file_path}: {e}") 88 | return OrderedDict() 89 | -------------------------------------------------------------------------------- /toolkit/models/clip_pre_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class UpsampleBlock(nn.Module): 6 | def __init__( 7 | self, 8 | in_channels: int, 9 | out_channels: int, 10 | ): 11 | super().__init__() 12 | self.in_channels = in_channels 13 | self.out_channels = out_channels 14 | self.conv_in = nn.Sequential( 15 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), 16 | nn.GELU() 17 | ) 18 | self.conv_up = nn.Sequential( 19 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), 20 | nn.GELU() 21 | ) 22 | 23 | self.conv_out = nn.Sequential( 24 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 25 | ) 26 | 27 | def forward(self, x): 28 | x = self.conv_in(x) 29 | x = self.conv_up(x) 30 | x = self.conv_out(x) 31 | return x 32 | 33 | 34 | class CLIPImagePreProcessor(nn.Module): 35 | def __init__( 36 | self, 37 | input_size=896, 38 | clip_input_size=224, 39 | downscale_factor: int = 16, 40 | ): 41 | super().__init__() 42 | # make sure they are evenly divisible 43 | assert input_size % clip_input_size == 0 44 | in_channels = 3 45 | 46 | self.input_size = input_size 47 | self.clip_input_size = clip_input_size 48 | self.downscale_factor = downscale_factor 49 | 50 | subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768 51 | channels = subpixel_channels 52 | 53 | upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4 54 | 55 | num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2 56 | 57 | # make the residual down up blocks 58 | self.upsample_blocks = nn.ModuleList() 59 | self.subpixel_blocks = nn.ModuleList() 60 | current_channels = channels 61 | current_downscale = downscale_factor 62 | for _ in range(num_upsample_blocks): 63 | # determine the reshuffled channel count for this dimension 64 | output_downscale = current_downscale // 2 65 | out_channels = in_channels * output_downscale ** 2 66 | # out_channels = current_channels // 2 67 | self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) 68 | current_channels = out_channels 69 | current_downscale = output_downscale 70 | self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale)) 71 | 72 | # (bs, 768, 56, 56) -> (bs, 192, 112, 112) 73 | # (bs, 192, 112, 112) -> (bs, 48, 224, 224) 74 | 75 | self.conv_out = nn.Conv2d( 76 | current_channels, 77 | out_channels=3, 78 | kernel_size=3, 79 | padding=1 80 | ) # (bs, 48, 224, 224) -> (bs, 3, 224, 224) 81 | 82 | # do a pooling layer to downscale the input to 1/3 of the size 83 | # (bs, 3, 896, 896) -> (bs, 3, 224, 224) 84 | kernel_size = input_size // clip_input_size 85 | self.res_down = nn.AvgPool2d( 86 | kernel_size=kernel_size, 87 | stride=kernel_size 88 | ) # (bs, 3, 896, 896) -> (bs, 3, 224, 224) 89 | 90 | # make a blending for output residual with near 0 weight 91 | self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224) 92 | 93 | self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56) 94 | 95 | self.conv_in = nn.Sequential( 96 | nn.Conv2d( 97 | subpixel_channels, 98 | channels, 99 | kernel_size=3, 100 | padding=1 101 | ), 102 | nn.GELU() 103 | ) # (bs, 768, 56, 56) -> (bs, 768, 56, 56) 104 | 105 | # make 2 deep blocks 106 | 107 | def forward(self, x): 108 | inputs = x 109 | # resize to input_size x input_size 110 | x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic') 111 | 112 | res = self.res_down(inputs) 113 | 114 | x = self.unshuffle(x) 115 | x = self.conv_in(x) 116 | for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks): 117 | x = up(x) 118 | block_res = subpixel(inputs) 119 | x = x + block_res 120 | x = self.conv_out(x) 121 | # blend residual 122 | x = x * self.res_blend + res 123 | return x 124 | -------------------------------------------------------------------------------- /toolkit/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Adafactor, AdamW 3 | 4 | 5 | def get_optimizer( 6 | params, 7 | optimizer_type='adam', 8 | learning_rate=1e-6, 9 | optimizer_params=None 10 | ): 11 | if optimizer_params is None: 12 | optimizer_params = {} 13 | lower_type = optimizer_type.lower() 14 | if lower_type.startswith("dadaptation"): 15 | # dadaptation optimizer does not use standard learning rate. 1 is the default value 16 | import dadaptation 17 | print("Using DAdaptAdam optimizer") 18 | use_lr = learning_rate 19 | if use_lr < 0.1: 20 | # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 21 | use_lr = 1.0 22 | if lower_type.endswith('lion'): 23 | optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) 24 | elif lower_type.endswith('adam'): 25 | optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) 26 | elif lower_type == 'dadaptation': 27 | # backwards compatibility 28 | optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params) 29 | # warn user that dadaptation is deprecated 30 | print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") 31 | elif lower_type.startswith("prodigy"): 32 | from prodigyopt import Prodigy 33 | 34 | print("Using Prodigy optimizer") 35 | use_lr = learning_rate 36 | if use_lr < 0.1: 37 | # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 38 | use_lr = 1.0 39 | 40 | print(f"Using lr {use_lr}") 41 | # let net be the neural network you want to train 42 | # you can choose weight decay value based on your problem, 0 by default 43 | optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params) 44 | elif lower_type.endswith("8bit"): 45 | import bitsandbytes 46 | 47 | if lower_type == "adam8bit": 48 | return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) 49 | elif lower_type == "adamw8bit": 50 | return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) 51 | elif lower_type == "lion8bit": 52 | return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) 53 | else: 54 | raise ValueError(f'Unknown optimizer type {optimizer_type}') 55 | elif lower_type == 'adam': 56 | optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) 57 | elif lower_type == 'adamw': 58 | optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) 59 | elif lower_type == 'lion': 60 | try: 61 | from lion_pytorch import Lion 62 | return Lion(params, lr=learning_rate, **optimizer_params) 63 | except ImportError: 64 | raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") 65 | elif lower_type == 'adagrad': 66 | optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) 67 | elif lower_type == 'adafactor': 68 | # hack in stochastic rounding 69 | if 'relative_step' not in optimizer_params: 70 | optimizer_params['relative_step'] = False 71 | if 'scale_parameter' not in optimizer_params: 72 | optimizer_params['scale_parameter'] = False 73 | if 'warmup_init' not in optimizer_params: 74 | optimizer_params['warmup_init'] = False 75 | optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) 76 | from toolkit.util.adafactor_stochastic_rounding import step_adafactor 77 | optimizer.step = step_adafactor.__get__(optimizer, Adafactor) 78 | else: 79 | raise ValueError(f'Unknown optimizer type {optimizer_type}') 80 | return optimizer 81 | -------------------------------------------------------------------------------- /toolkit/orig_configs/sd_xl_refiner.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | weighting_config: 13 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 14 | scaling_config: 15 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 16 | discretization_config: 17 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 18 | 19 | network_config: 20 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | adm_in_channels: 2560 23 | num_classes: sequential 24 | use_checkpoint: True 25 | in_channels: 4 26 | out_channels: 4 27 | model_channels: 384 28 | attention_resolutions: [4, 2] 29 | num_res_blocks: 2 30 | channel_mult: [1, 2, 4, 4] 31 | num_head_channels: 64 32 | use_spatial_transformer: True 33 | use_linear_in_transformer: True 34 | transformer_depth: 4 35 | context_dim: [1280, 1280, 1280, 1280] # 1280 36 | spatial_transformer_attn_type: softmax-xformers 37 | legacy: False 38 | 39 | conditioner_config: 40 | target: sgm.modules.GeneralConditioner 41 | params: 42 | emb_models: 43 | # crossattn and vector cond 44 | - is_trainable: False 45 | input_key: txt 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 47 | params: 48 | arch: ViT-bigG-14 49 | version: laion2b_s39b_b160k 50 | legacy: False 51 | freeze: True 52 | layer: penultimate 53 | always_return_pooled: True 54 | # vector cond 55 | - is_trainable: False 56 | input_key: original_size_as_tuple 57 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 58 | params: 59 | outdim: 256 # multiplied by two 60 | # vector cond 61 | - is_trainable: False 62 | input_key: crop_coords_top_left 63 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 64 | params: 65 | outdim: 256 # multiplied by two 66 | # vector cond 67 | - is_trainable: False 68 | input_key: aesthetic_score 69 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 70 | params: 71 | outdim: 256 # multiplied by one 72 | 73 | first_stage_config: 74 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 75 | params: 76 | embed_dim: 4 77 | monitor: val/rec_loss 78 | ddconfig: 79 | attn_type: vanilla-xformers 80 | double_z: true 81 | z_channels: 4 82 | resolution: 256 83 | in_channels: 3 84 | out_ch: 3 85 | ch: 128 86 | ch_mult: [1, 2, 4, 4] 87 | num_res_blocks: 2 88 | attn_resolutions: [] 89 | dropout: 0.0 90 | lossconfig: 91 | target: torch.nn.Identity 92 | -------------------------------------------------------------------------------- /toolkit/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') 5 | SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts") 6 | REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories") 7 | KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") 8 | ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs") 9 | DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs") 10 | 11 | # check if ENV variable is set 12 | if 'MODELS_PATH' in os.environ: 13 | MODELS_PATH = os.environ['MODELS_PATH'] 14 | else: 15 | MODELS_PATH = os.path.join(TOOLKIT_ROOT, "models") 16 | 17 | 18 | def get_path(path): 19 | # we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root 20 | if not os.path.isabs(path): 21 | path = os.path.join(TOOLKIT_ROOT, path) 22 | return path 23 | -------------------------------------------------------------------------------- /toolkit/progress_bar.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import time 3 | 4 | 5 | class ToolkitProgressBar(tqdm): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.paused = False 9 | self.last_time = self._time() 10 | 11 | def pause(self): 12 | if not self.paused: 13 | self.paused = True 14 | self.last_time = self._time() 15 | 16 | def unpause(self): 17 | if self.paused: 18 | self.paused = False 19 | cur_t = self._time() 20 | self.start_t += cur_t - self.last_time 21 | self.last_print_t = cur_t 22 | 23 | def update(self, *args, **kwargs): 24 | if not self.paused: 25 | super().update(*args, **kwargs) 26 | -------------------------------------------------------------------------------- /toolkit/samplers/custom_flowmatch_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from diffusers import FlowMatchEulerDiscreteScheduler 4 | import torch 5 | 6 | 7 | class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): 8 | 9 | def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: 10 | sigmas = self.sigmas.to(device=device, dtype=dtype) 11 | schedule_timesteps = self.timesteps.to(device) 12 | timesteps = timesteps.to(device) 13 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 14 | 15 | sigma = sigmas[step_indices].flatten() 16 | while len(sigma.shape) < n_dim: 17 | sigma = sigma.unsqueeze(-1) 18 | 19 | return sigma 20 | 21 | def add_noise( 22 | self, 23 | original_samples: torch.Tensor, 24 | noise: torch.Tensor, 25 | timesteps: torch.Tensor, 26 | ) -> torch.Tensor: 27 | ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 28 | ## Add noise according to flow matching. 29 | ## zt = (1 - texp) * x + texp * z1 30 | 31 | # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) 32 | # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise 33 | 34 | # timestep needs to be in [0, 1], we store them in [0, 1000] 35 | # noisy_sample = (1 - timestep) * latent + timestep * noise 36 | t_01 = (timesteps / 1000).to(original_samples.device) 37 | noisy_model_input = (1 - t_01) * original_samples + t_01 * noise 38 | 39 | # n_dim = original_samples.ndim 40 | # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) 41 | # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise 42 | return noisy_model_input 43 | 44 | def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: 45 | return sample 46 | 47 | def set_train_timesteps(self, num_timesteps, device, linear=False): 48 | if linear: 49 | timesteps = torch.linspace(1000, 0, num_timesteps, device=device) 50 | self.timesteps = timesteps 51 | return timesteps 52 | else: 53 | # distribute them closer to center. Inference distributes them as a bias toward first 54 | # Generate values from 0 to 1 55 | t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) 56 | 57 | # Scale and reverse the values to go from 1000 to 0 58 | timesteps = ((1 - t) * 1000) 59 | 60 | # Sort the timesteps in descending order 61 | timesteps, _ = torch.sort(timesteps, descending=True) 62 | 63 | self.timesteps = timesteps.to(device=device) 64 | 65 | return timesteps 66 | -------------------------------------------------------------------------------- /toolkit/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup 4 | 5 | 6 | def get_lr_scheduler( 7 | name: Optional[str], 8 | optimizer: torch.optim.Optimizer, 9 | **kwargs, 10 | ): 11 | if name == "cosine": 12 | if 'total_iters' in kwargs: 13 | kwargs['T_max'] = kwargs.pop('total_iters') 14 | return torch.optim.lr_scheduler.CosineAnnealingLR( 15 | optimizer, **kwargs 16 | ) 17 | elif name == "cosine_with_restarts": 18 | if 'total_iters' in kwargs: 19 | kwargs['T_0'] = kwargs.pop('total_iters') 20 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 21 | optimizer, **kwargs 22 | ) 23 | elif name == "step": 24 | 25 | return torch.optim.lr_scheduler.StepLR( 26 | optimizer, **kwargs 27 | ) 28 | elif name == "constant": 29 | if 'facor' not in kwargs: 30 | kwargs['factor'] = 1.0 31 | 32 | return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs) 33 | elif name == "linear": 34 | 35 | return torch.optim.lr_scheduler.LinearLR( 36 | optimizer, **kwargs 37 | ) 38 | elif name == 'constant_with_warmup': 39 | # see if num_warmup_steps is in kwargs 40 | if 'num_warmup_steps' not in kwargs: 41 | print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000") 42 | kwargs['num_warmup_steps'] = 1000 43 | del kwargs['total_iters'] 44 | return get_constant_schedule_with_warmup(optimizer, **kwargs) 45 | else: 46 | # try to use a diffusers scheduler 47 | print(f"Trying to use diffusers scheduler {name}") 48 | try: 49 | name = SchedulerType(name) 50 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 51 | return schedule_func(optimizer, **kwargs) 52 | except Exception as e: 53 | print(e) 54 | pass 55 | raise ValueError( 56 | "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" 57 | ) 58 | -------------------------------------------------------------------------------- /toolkit/sd_device_states_presets.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import copy 5 | 6 | empty_preset = { 7 | 'vae': { 8 | 'training': False, 9 | 'device': 'cpu', 10 | }, 11 | 'unet': { 12 | 'training': False, 13 | 'requires_grad': False, 14 | 'device': 'cpu', 15 | }, 16 | 'text_encoder': { 17 | 'training': False, 18 | 'requires_grad': False, 19 | 'device': 'cpu', 20 | }, 21 | 'adapter': { 22 | 'training': False, 23 | 'requires_grad': False, 24 | 'device': 'cpu', 25 | }, 26 | 'refiner_unet': { 27 | 'training': False, 28 | 'requires_grad': False, 29 | 'device': 'cpu', 30 | }, 31 | } 32 | 33 | 34 | def get_train_sd_device_state_preset( 35 | device: Union[str, torch.device], 36 | train_unet: bool = False, 37 | train_text_encoder: bool = False, 38 | cached_latents: bool = False, 39 | train_lora: bool = False, 40 | train_adapter: bool = False, 41 | train_embedding: bool = False, 42 | train_refiner: bool = False, 43 | ): 44 | preset = copy.deepcopy(empty_preset) 45 | if not cached_latents: 46 | preset['vae']['device'] = device 47 | 48 | if train_unet: 49 | preset['unet']['training'] = True 50 | preset['unet']['requires_grad'] = True 51 | preset['unet']['device'] = device 52 | else: 53 | preset['unet']['device'] = device 54 | 55 | if train_text_encoder: 56 | preset['text_encoder']['training'] = True 57 | preset['text_encoder']['requires_grad'] = True 58 | preset['text_encoder']['device'] = device 59 | else: 60 | preset['text_encoder']['device'] = device 61 | 62 | if train_embedding: 63 | preset['text_encoder']['training'] = True 64 | preset['text_encoder']['requires_grad'] = True 65 | preset['text_encoder']['training'] = True 66 | preset['unet']['training'] = True 67 | 68 | if train_refiner: 69 | preset['refiner_unet']['training'] = True 70 | preset['refiner_unet']['requires_grad'] = True 71 | preset['refiner_unet']['device'] = device 72 | # if not training unet, move that to cpu 73 | if not train_unet: 74 | preset['unet']['device'] = 'cpu' 75 | 76 | if train_lora: 77 | # preset['text_encoder']['requires_grad'] = False 78 | preset['unet']['requires_grad'] = False 79 | if train_refiner: 80 | preset['refiner_unet']['requires_grad'] = False 81 | 82 | if train_adapter: 83 | preset['adapter']['requires_grad'] = True 84 | preset['adapter']['training'] = True 85 | preset['adapter']['device'] = device 86 | preset['unet']['training'] = True 87 | preset['unet']['requires_grad'] = False 88 | preset['unet']['device'] = device 89 | preset['text_encoder']['device'] = device 90 | 91 | return preset 92 | -------------------------------------------------------------------------------- /toolkit/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import OrderedDict, deque 3 | 4 | 5 | class Timer: 6 | def __init__(self, name='Timer', max_buffer=10): 7 | self.name = name 8 | self.max_buffer = max_buffer 9 | self.timers = OrderedDict() 10 | self.active_timers = {} 11 | self.current_timer = None # Used for the context manager functionality 12 | 13 | def start(self, timer_name): 14 | if timer_name not in self.timers: 15 | self.timers[timer_name] = deque(maxlen=self.max_buffer) 16 | self.active_timers[timer_name] = time.time() 17 | 18 | def cancel(self, timer_name): 19 | """Cancel an active timer.""" 20 | if timer_name in self.active_timers: 21 | del self.active_timers[timer_name] 22 | 23 | def stop(self, timer_name): 24 | if timer_name not in self.active_timers: 25 | raise ValueError(f"Timer '{timer_name}' was not started!") 26 | 27 | elapsed_time = time.time() - self.active_timers[timer_name] 28 | self.timers[timer_name].append(elapsed_time) 29 | 30 | # Clean up active timers 31 | del self.active_timers[timer_name] 32 | 33 | # Check if this timer's buffer exceeds max_buffer and remove the oldest if it does 34 | if len(self.timers[timer_name]) > self.max_buffer: 35 | self.timers[timer_name].popleft() 36 | 37 | def print(self): 38 | print(f"\nTimer '{self.name}':") 39 | # sort by longest at top 40 | for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True): 41 | avg_time = sum(timings) / len(timings) 42 | print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}") 43 | 44 | print('') 45 | 46 | def reset(self): 47 | self.timers.clear() 48 | self.active_timers.clear() 49 | 50 | def __call__(self, timer_name): 51 | """Enable the use of the Timer class as a context manager.""" 52 | self.current_timer = timer_name 53 | self.start(timer_name) 54 | return self 55 | 56 | def __enter__(self): 57 | pass 58 | 59 | def __exit__(self, exc_type, exc_value, traceback): 60 | if exc_type is None: 61 | # No exceptions, stop the timer normally 62 | self.stop(self.current_timer) 63 | else: 64 | # There was an exception, cancel the timer 65 | self.cancel(self.current_timer) 66 | -------------------------------------------------------------------------------- /toolkit/util/inverse_cfg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def inverse_classifier_guidance( 5 | noise_pred_cond: torch.Tensor, 6 | noise_pred_uncond: torch.Tensor, 7 | guidance_scale: torch.Tensor 8 | ): 9 | """ 10 | Adjust the noise_pred_cond for the classifier free guidance algorithm 11 | to ensure that the final noise prediction equals the original noise_pred_cond. 12 | """ 13 | # To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond 14 | # based on the formula used in the algorithm. 15 | # We derive the formula to find the correct adjustment for noise_pred_cond: 16 | # noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1) 17 | # It's important to check if guidance_scale is not 1 to avoid division by zero. 18 | if guidance_scale == 1: 19 | # If guidance_scale is 1, adjusting is not needed or possible in the same way, 20 | # since it would lead to division by zero. This also means the algorithm inherently 21 | # doesn't alter the noise_pred_cond in relation to noise_pred_uncond. 22 | # Thus, we return the original values, though this situation might need special handling. 23 | return noise_pred_cond 24 | adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale 25 | return adjusted_noise_pred_cond 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://cog.run/python 3 | 4 | from cog import BaseModel, Input, Path, Secret 5 | import os 6 | import yaml 7 | import subprocess 8 | from zipfile import ZipFile 9 | from huggingface_hub import HfApi 10 | 11 | class TrainingOutput(BaseModel): 12 | weights: Path 13 | 14 | # Run in lucataco/sandbox2 15 | def train( 16 | images: Path = Input( 17 | description="A zip/tar file containing the images that will be used for training. File names must be their captions: a_photo_of_TOK.png, etc. Min 12 images required." 18 | ), 19 | model_name: str = Input(description="Model name", default="black-forest-labs/FLUX.1-dev"), 20 | hf_token: Secret = Input(description="HuggingFace token to use for accessing model"), 21 | steps: int = Input(description="Number of training steps. Recommended range 500-4000", ge=10, le=4000, default=1000), 22 | learning_rate: float = Input(description="Learning rate", default=4e-4), 23 | batch_size: int = Input(description="Batch size", default=1), 24 | resolution: str = Input(description="Image resolutions for training", default="512,768,1024"), 25 | lora_linear: int = Input(description="LoRA linear value", default=16), 26 | lora_linear_alpha: int = Input(description="LoRA linear alpha value", default=16), 27 | repo_id: str = Input(description="Enter HuggingFace repo id to upload LoRA to HF. Will return zip file if left empty.Ex: lucataco/flux-dev-lora", default=None), 28 | ) -> TrainingOutput: 29 | """Run a single prediction on the model""" 30 | print("Starting Training") 31 | # Cleanup previous runs 32 | os.system("rm -rf output") 33 | # Cleanup training images (from canceled training runs) 34 | input_dir = "input_images" 35 | os.system(f"rm -rf {input_dir}") 36 | 37 | # Set huggingface token via huggingface-cli login 38 | os.system(f"huggingface-cli login --token {hf_token.get_secret_value()}") 39 | 40 | # Update the config file using YAML 41 | config_path = "config/replicate.yml" 42 | with open(config_path, 'r') as file: 43 | config = yaml.safe_load(file) 44 | 45 | # Update the configuration 46 | config['config']['process'][0]['model']['name_or_path'] = model_name 47 | config['config']['process'][0]['train']['steps'] = steps 48 | config['config']['process'][0]['save']['save_every'] = steps + 1 49 | config['config']['process'][0]['train']['lr'] = learning_rate 50 | config['config']['process'][0]['train']['batch_size'] = batch_size 51 | config['config']['process'][0]['datasets'][0]['resolution'] = [int(res) for res in resolution.split(',')] 52 | config['config']['process'][0]['network']['linear'] = lora_linear 53 | config['config']['process'][0]['network']['linear_alpha'] = lora_linear_alpha 54 | 55 | # Save config changes 56 | with open(config_path, 'w') as file: 57 | yaml.dump(config, file) 58 | 59 | # Unzip images from input images file to the input_images folder 60 | input_images = str(images) 61 | if input_images.endswith(".zip"): 62 | print("Detected zip file") 63 | os.makedirs(input_dir, exist_ok=True) 64 | with ZipFile(input_images, "r") as zip_ref: 65 | zip_ref.extractall(input_dir+"/") 66 | print("Extracted zip file") 67 | elif input_images.endswith(".tar"): 68 | print("Detected tar file") 69 | os.makedirs(input_dir, exist_ok=True) 70 | os.system(f"tar -xvf {input_images} -C {input_dir}") 71 | print("Extracted tar file") 72 | 73 | # Run - bash train.sh 74 | subprocess.check_call(["python", "run.py", "config/replicate.yml"], close_fds=False) 75 | 76 | # Zip up the output folder 77 | output_lora = "output/flux_train_replicate" 78 | # copy license file to output folder 79 | os.system(f"cp lora-license.md {output_lora}/README.md") 80 | output_zip_path = "/tmp/output.zip" 81 | os.system(f"zip -r {output_zip_path} {output_lora}") 82 | 83 | # cleanup input_images folder 84 | os.system(f"rm -rf {input_dir}") 85 | 86 | if hf_token is not None and repo_id is not None: 87 | api = HfApi() 88 | api.upload_folder( 89 | repo_id=repo_id, 90 | folder_path=output_lora, 91 | repo_type="model", 92 | use_auth_token=hf_token 93 | ) 94 | return TrainingOutput(weights=Path(output_zip_path)) 95 | --------------------------------------------------------------------------------